# Study the correlation between the SSP Parameters

Study the correlation between the parameters using `fors2tostellarpopsynthesis`package

- Author Sylvie Dagoret-Campagne
- Afflilation : IJCLab/IN2P3/CNRS
- Organisation : LSST-DESC
- creation date : 2023-12-04
- last update : 2023-12-05


| computer | processor | kernel              |    date     |
| --- | --- | --- | --- |
| CC       | CPU       | conda_jax0325_py310 | 2023-11-10  |
| macbookpro | CPU | conda_jaxcpu_dsps_py310 | 2023-11-10  | 


libraries 
=========

jax
---

- jaxlib-0.3.25+cuda11.cudnn82
- jaxopt
- optax
- corner
- arviz
- numpyro
- graphviz

sps
---

- fsps
- prospect
- dsps
- diffstar
- diffmah



(conda_jax0325_py310) 
`/pbs/throng/lsst/users/dagoret/desc/StellarPopulationSynthesis>pip list | grep` 

| lib | version |
|--- | --- | 
|jax  |                         0.4.20 |
|jaxlib |                       0.4.20 |
|jaxopt  |                      0.8.2 |



## examples

- jaxcosmo : https://github.com/DifferentiableUniverseInitiative/jax-cosmo-paper/blob/master/notebooks/VectorizedNumPyro.ipynb
- on atmosphere : https://github.com/sylvielsstfr/FitDiffAtmo/blob/main/docs/notebooks/fitdiffatmo/test_numpyro_orderedict_diffatmemul_5params_P_pwv_oz_tau_beta.ipynb

## Import

### import external packages

In [None]:
import h5py
import pandas as pd
import numpy as np
import os
import re
import pickle 
import matplotlib as mpl
import matplotlib.pyplot as plt
%matplotlib inline
import matplotlib.colors as colors
import matplotlib.cm as cmx
import collections
from collections import OrderedDict
import re
import matplotlib.gridspec as gridspec
from sklearn.gaussian_process import GaussianProcessRegressor, kernels

In [None]:
import jax
import jax.numpy as jnp
from jax import vmap
import jaxopt
import optax
jax.config.update("jax_enable_x64", True)
from interpax import interp1d

from jax.lax import fori_loop
from jax.lax import select,cond
from jax.lax import concatenate

In [None]:
import numpyro
from numpyro import optim
from numpyro.diagnostics import print_summary
import numpyro.distributions as dist
from numpyro.distributions import constraints
from numpyro.infer import MCMC, HMC, NUTS, SVI, Trace_ELBO
from numpyro.infer.autoguide import AutoBNAFNormal, AutoMultivariateNormal
from numpyro.infer.reparam import NeuTraReparam
from numpyro.handlers import seed, trace, condition


import corner
import arviz as az

### import internal packages

In [None]:
from fors2tostellarpopsynthesis.parameters  import SSPParametersFit,paramslist_to_dict

In [None]:
from fors2tostellarpopsynthesis.fitters.fitter_jaxopt import (lik_spec,lik_mag,lik_comb,
get_infos_spec,get_infos_mag,get_infos_comb)

from fors2tostellarpopsynthesis.fitters.fitter_jaxopt import (SSP_DATA,mean_spectrum,mean_mags,mean_sfr,ssp_spectrum_fromparam)

## Configuration

### matplotlib configuration

In [None]:
plt.rcParams["figure.figsize"] = (12,6)
plt.rcParams["axes.labelsize"] = 'xx-large'
plt.rcParams['axes.titlesize'] = 'xx-large'
plt.rcParams['xtick.labelsize']= 'xx-large'
plt.rcParams['ytick.labelsize']= 'xx-large'
plt.rcParams['legend.fontsize']=  16
plt.rcParams['font.size'] = 15

### Steering MCMC

In [None]:
## Steering MCMC
#NUM_SAMPLES = 5_000
#N_CHAINS    = 4
#NUM_WARMUP  = 1_000
NUM_SAMPLES = 500
N_CHAINS    = 2
NUM_WARMUP  = 100
df_mcmc = pd.Series({"num_samples":NUM_SAMPLES, "n_chains":N_CHAINS, "num_warmup":NUM_WARMUP})

In [None]:
print("=========== Start MCMC  ============= :")
df_mcmc

### Selection on what to simulate and output

In [None]:
# flags below set to false to generate docs
FLAG_NODUST = False
FLAG_DUST = False

## Defining MCMC output files

#fileout_nodust_pickle = f"DSPS_nodust_mcmc_params0_{N_CHAINS}_{NUM_WARMUP}_{NUM_SAMPLES}.pickle"
#fileout_nodust_csv = f"DSPS_nodust_mcmc_params0_{N_CHAINS}_{NUM_WARMUP}_{NUM_SAMPLES}.csv"
fileout_nodust_hdf = f"DSPS_nodust_mcmc_params0_{N_CHAINS}_{NUM_WARMUP}_{NUM_SAMPLES}.hdf"

#fileout_dust_pickle = f"DSPS_dust_mcmc_params0_{N_CHAINS}_{NUM_WARMUP}_{NUM_SAMPLES}.pickle"
#fileout_dust_csv = f"DSPS_dust_mcmc_params0_{N_CHAINS}_{NUM_WARMUP}_{NUM_SAMPLES}.csv"
fileout_dust_hdf = f"DSPS_dust_mcmc_params0_{N_CHAINS}_{NUM_WARMUP}_{NUM_SAMPLES}.hdf"

### Observation

In [None]:
# Select observation
# - choose the redshift of the observation
# - choose the relative error on flux at all wavelength. Here choose 10%
# - choose the absolute error. This value will be overwritten 
#   after recalculating the absolute error for each wavelength (provided an an array)
Z_OBS = 0.5
SIGMAREL_OBS = 0.5
SIGMA_OBS = 1e-11

df_info = pd.Series({"z_obs":Z_OBS,"sigmarel_obs":SIGMAREL_OBS})

In [None]:
print("=========== Start Observations  ============= :")
df_info

## Fit parameters

In [None]:
p = SSPParametersFit()

In [None]:
p.DICT_PARAMS_true

In [None]:
p.INIT_PARAMS

In [None]:
p.INIT_PARAMS = p.INIT_PARAMS.at[-4].set(0.)
p.INIT_PARAMS = p.INIT_PARAMS.at[-3].set(0.)
p.INIT_PARAMS = p.INIT_PARAMS.at[-2].set(0.)
p.INIT_PARAMS = p.INIT_PARAMS.at[-1].set(1.)

In [None]:
wlsall,spec_rest_true,spec_rest_att_true = ssp_spectrum_fromparam(p.DICT_PARAMS_true,Z_OBS)

In [None]:
# The problem is the above is that the parameters are drawn randomly
# Thus redefine the errors properly 
sigmanodust_obs_true = SIGMAREL_OBS*spec_rest_true
sigmadust_obs_true = SIGMAREL_OBS*spec_rest_att_true

In [None]:
PARAM_SIMLAW_NODUST = np.array(["uniform","normal","normal","normal", 
                "normal","normal","normal","normal","normal",
                "normal","normal","normal","normal",
                "fixed","fixed","fixed","fixed"])         

In [None]:
PARAM_SIMLAW_WITHDUST = np.array(["uniform","normal","normal","normal", 
                "normal","normal","normal","normal","normal",
                "normal","normal","normal","normal",
                "uniform","uniform","uniform","fixed"])         

In [None]:
PARAM_NAMES = np.array(p.PARAM_NAMES_FLAT)

In [None]:
PARAM_VAL = p.INIT_PARAMS
PARAM_MIN = p.PARAMS_MIN
PARAM_MAX = p.PARAMS_MAX
PARAM_SIGMA = jnp.sqrt(0.5*((PARAM_VAL-PARAM_MIN)**2 + (PARAM_VAL-PARAM_MAX)**2))

In [None]:
z_obs = 0.5
sigma_rel = 0.5
sigma_obs = 1e-8

## Bayesian modelling

In [None]:
condlist_fix = jnp.where(PARAM_SIMLAW_NODUST == "fixed",True,False)
condlist_fix

In [None]:
condlist_uniform = jnp.where(PARAM_SIMLAW_NODUST == "uniform",True,False)
condlist_uniform

In [None]:
def galaxymodel_nodust(wlsin,Fobs=None,
                       initparamval = PARAM_VAL, minparamval = PARAM_MIN,maxparamval = PARAM_MAX,
                       sigmaparamval = PARAM_SIGMA,paramnames = PARAM_NAMES,z_obs= z_obs, sigma = sigma_obs):
                      
    """
    Models of Galaxy spectrum at rest
    
    :param wlsin: array of input spectrum wavelength 
    :type wlsin: float in Angstrom

    :param initparamval: initialisation parameters
    :type initparamval:
    
    :param z_obs :  redshift of observation
    
    """

    
    dict_params = OrderedDict()

    # MAH_lgmO
    idx = 0
    name = paramnames[idx]
    valmean = initparamval[idx]
    valmin = minparamval[idx]
    valmax = maxparamval[idx]
    scale = sigmaparamval[idx]    
    val = numpyro.sample(name,dist.Uniform(valmin,valmax))
    dict_params[name] = val

    # MAH_logtc
    idx = 1
    name = paramnames[idx]
    valmean = initparamval[idx]
    valmin = minparamval[idx]
    valmax = maxparamval[idx]
    scale = sigmaparamval[idx]    
    val = numpyro.sample(name,dist.Normal(valmean,scale))
    dict_params[name] = val
            

    #MAH_early_index
    idx = 2
    name = paramnames[idx]
    valmean = initparamval[idx]
    valmin = minparamval[idx]
    valmax = maxparamval[idx]
    scale = sigmaparamval[idx]    
    val = numpyro.sample(name,dist.Normal(valmean,scale))
    dict_params[name] = val

    #MAH_late_index
    idx = 3
    name = paramnames[idx]
    valmean = initparamval[idx]
    valmin = minparamval[idx]
    valmax = maxparamval[idx]
    scale = sigmaparamval[idx]    
    val = numpyro.sample(name,dist.Normal(valmean,scale))
    dict_params[name] = val


    # MS_lgmcrit
    idx = 4
    name = paramnames[idx]
    valmean = initparamval[idx]
    valmin = minparamval[idx]
    valmax = maxparamval[idx]
    scale = sigmaparamval[idx]    
    val = numpyro.sample(name,dist.Normal(valmean,scale))
    dict_params[name] = val

    # MS_lgy_at_mcrit
    idx = 5
    name = paramnames[idx]
    valmean = initparamval[idx]
    valmin = minparamval[idx]
    valmax = maxparamval[idx]
    scale = sigmaparamval[idx]    
    val = numpyro.sample(name,dist.Normal(valmean,scale))
    dict_params[name] = val


    # MS_indx_lo
    idx = 6
    name = paramnames[idx]
    valmean = initparamval[idx]
    valmin = minparamval[idx]
    valmax = maxparamval[idx]
    scale = sigmaparamval[idx]    
    val = numpyro.sample(name,dist.Normal(valmean,scale))
    dict_params[name] = val

    # MS_indx_hi
    idx = 7
    name = paramnames[idx]
    valmean = initparamval[idx]
    valmin = minparamval[idx]
    valmax = maxparamval[idx]
    scale = sigmaparamval[idx]    
    val = numpyro.sample(name,dist.Normal(valmean,scale))
    dict_params[name] = val

    # MS_tau_dep
    idx = 8
    name = paramnames[idx]
    valmean = initparamval[idx]
    valmin = minparamval[idx]
    valmax = maxparamval[idx]
    scale = sigmaparamval[idx]    
    val = numpyro.sample(name,dist.Normal(valmean,scale))
    dict_params[name] = val

    # Q_lg_qt
    idx = 9
    name = paramnames[idx]
    valmean = initparamval[idx]
    valmin = minparamval[idx]
    valmax = maxparamval[idx]
    scale = sigmaparamval[idx]    
    val = numpyro.sample(name,dist.Normal(valmean,scale))
    dict_params[name] = val

    # Q_qlglgdt
    idx = 10
    name = paramnames[idx]
    valmean = initparamval[idx]
    valmin = minparamval[idx]
    valmax = maxparamval[idx]
    scale = sigmaparamval[idx]    
    val = numpyro.sample(name,dist.Normal(valmean,scale))
    dict_params[name] = val

    # Q_lg_drop
    idx = 11
    name = paramnames[idx]
    valmean = initparamval[idx]
    valmin = minparamval[idx]
    valmax = maxparamval[idx]
    scale = sigmaparamval[idx]    
    val = numpyro.sample(name,dist.Normal(valmean,scale))
    dict_params[name] = val
    
    # Q_lg_rejuv
    idx = 12
    name = paramnames[idx]
    valmean = initparamval[idx]
    valmin = minparamval[idx]
    valmax = maxparamval[idx]
    scale = sigmaparamval[idx]    
    val = numpyro.sample(name,dist.Normal(valmean,scale))
    dict_params[name] = val

    # AV
    idx = 13
    name = paramnames[idx]
    valmean = initparamval[idx]
    valmin = minparamval[idx]
    valmax = maxparamval[idx]
    scale = sigmaparamval[idx]    
    #val = numpyro.sample(name,dist.Normal(0.0,0.1))
    dict_params[name] = valmean
    
    # UV_BUMP
    idx = 14
    name = paramnames[idx]
    valmean = initparamval[idx]
    valmin = minparamval[idx]
    valmax = maxparamval[idx]
    scale = sigmaparamval[idx]    
    #val = numpyro.sample(name,dist.Normal(0.0,0.1))
    dict_params[name] = valmean
    
    # PLAW_SLOPE
    idx = 15
    name = paramnames[idx]
    valmean = initparamval[idx]
    valmin = minparamval[idx]
    valmax = maxparamval[idx]
    scale = sigmaparamval[idx]    
    #val = numpyro.sample(name,dist.Normal(0.,0.1))
    dict_params[name] = valmean
    
    # SCALEF
    idx = 16
    name = paramnames[idx]
    valmean = initparamval[idx]
    valmin = minparamval[idx]
    valmax = maxparamval[idx]
    scale = sigmaparamval[idx]    
    #val = numpyro.sample(name,dist.Normal(valmean,scale))
    dict_params[name] = 1.

    
    wls,sed_notattenuated,sed_attenuated = ssp_spectrum_fromparam(dict_params,z_obs)   
            
    # interpolate measured
    Fsim = interp1d(wlsin,wls,sed_notattenuated)
    
    with numpyro.plate("obs", wlsin.shape[0]):  # les observables sont indépendantes
        numpyro.sample('F', dist.Normal(Fsim, sigma), obs=Fobs)


In [None]:
def galaxymodel_withdust(wlsin,Fobs=None,
                         initparamval = PARAM_VAL, minparamval = PARAM_MIN,maxparamval = PARAM_MAX,
                         sigmaparamval = PARAM_SIGMA,paramnames = PARAM_NAMES,z_obs= z_obs, sigma = sigma_obs):
    """
    Models of Galaxy spectrum at rest
    
    :param wlsin: array of input spectrum wavelength 
    :type wlsin: float in Angstrom
    
    :param z_obs :  redshift of observation
    
    """

    
    dict_params = OrderedDict()

    # MAH_lgmO
    idx = 0
    name = paramnames[idx]
    valmean = initparamval[idx]
    valmin = minparamval[idx]
    valmax = maxparamval[idx]
    scale = sigmaparamval[idx]    
    val = numpyro.sample(name,dist.Uniform(valmin,valmax))
    dict_params[name] = val

    # MAH_logtc
    idx = 1
    name = paramnames[idx]
    valmean = initparamval[idx]
    valmin = minparamval[idx]
    valmax = maxparamval[idx]
    scale = sigmaparamval[idx]    
    val = numpyro.sample(name,dist.Normal(valmean,scale))
    dict_params[name] = val
            

    #MAH_early_index
    idx = 2
    name = paramnames[idx]
    valmean = initparamval[idx]
    valmin = minparamval[idx]
    valmax = maxparamval[idx]
    scale = sigmaparamval[idx]    
    val = numpyro.sample(name,dist.Normal(valmean,scale))
    dict_params[name] = val

    #MAH_late_index
    idx = 3
    name = paramnames[idx]
    valmean = initparamval[idx]
    valmin = minparamval[idx]
    valmax = maxparamval[idx]
    scale = sigmaparamval[idx]    
    val = numpyro.sample(name,dist.Normal(valmean,scale))
    dict_params[name] = val


    # MS_lgmcrit
    idx = 4
    name = paramnames[idx]
    valmean = initparamval[idx]
    valmin = minparamval[idx]
    valmax = maxparamval[idx]
    scale = sigmaparamval[idx]    
    val = numpyro.sample(name,dist.Normal(valmean,scale))
    dict_params[name] = val

    # MS_lgy_at_mcrit
    idx = 5
    name = paramnames[idx]
    valmean = initparamval[idx]
    valmin = minparamval[idx]
    valmax = maxparamval[idx]
    scale = sigmaparamval[idx]    
    val = numpyro.sample(name,dist.Normal(valmean,scale))
    dict_params[name] = val


    # MS_indx_lo
    idx = 6
    name = paramnames[idx]
    valmean = initparamval[idx]
    valmin = minparamval[idx]
    valmax = maxparamval[idx]
    scale = sigmaparamval[idx]    
    val = numpyro.sample(name,dist.Normal(valmean,scale))
    dict_params[name] = val

    # MS_indx_hi
    idx = 7
    name = paramnames[idx]
    valmean = initparamval[idx]
    valmin = minparamval[idx]
    valmax = maxparamval[idx]
    scale = sigmaparamval[idx]    
    val = numpyro.sample(name,dist.Normal(valmean,scale))
    dict_params[name] = val

    # MS_tau_dep
    idx = 8
    name = paramnames[idx]
    valmean = initparamval[idx]
    valmin = minparamval[idx]
    valmax = maxparamval[idx]
    scale = sigmaparamval[idx]    
    val = numpyro.sample(name,dist.Normal(valmean,scale))
    dict_params[name] = val

    # Q_lg_qt
    idx = 9
    name = paramnames[idx]
    valmean = initparamval[idx]
    valmin = minparamval[idx]
    valmax = maxparamval[idx]
    scale = sigmaparamval[idx]    
    val = numpyro.sample(name,dist.Normal(valmean,scale))
    dict_params[name] = val

    # Q_qlglgdt
    idx = 10
    name = paramnames[idx]
    valmean = initparamval[idx]
    valmin = minparamval[idx]
    valmax = maxparamval[idx]
    scale = sigmaparamval[idx]    
    val = numpyro.sample(name,dist.Normal(valmean,scale))
    dict_params[name] = val

    # Q_lg_drop
    idx = 11
    name = paramnames[idx]
    valmean = initparamval[idx]
    valmin = minparamval[idx]
    valmax = maxparamval[idx]
    scale = sigmaparamval[idx]    
    val = numpyro.sample(name,dist.Normal(valmean,scale))
    dict_params[name] = val
    
    # Q_lg_rejuv
    idx = 12
    name = paramnames[idx]
    valmean = initparamval[idx]
    valmin = minparamval[idx]
    valmax = maxparamval[idx]
    scale = sigmaparamval[idx]    
    val = numpyro.sample(name,dist.Normal(valmean,scale))
    dict_params[name] = val

    # AV
    idx = 13
    name = paramnames[idx]
    valmean = initparamval[idx]
    valmin = minparamval[idx]
    valmax = maxparamval[idx]
    scale = sigmaparamval[idx]    
    #val = numpyro.sample(name,dist.Normal(valmean,scale))
    val = numpyro.sample(name,dist.Uniform(valmin,valmax))
    dict_params[name] = val
    
    # UV_BUMP
    idx = 14
    name = paramnames[idx]
    valmean = initparamval[idx]
    valmin = minparamval[idx]
    valmax = maxparamval[idx]
    scale = sigmaparamval[idx]    
    #val = numpyro.sample(name,dist.Normal(valmean,scale))
    val = numpyro.sample(name,dist.Uniform(valmin,valmax))
    dict_params[name] = val
    
    # PLAW_SLOPE
    idx = 15
    name = paramnames[idx]
    valmean = initparamval[idx]
    valmin = minparamval[idx]
    valmax = maxparamval[idx]
    scale = sigmaparamval[idx]    
    #val = numpyro.sample(name,dist.Normal(valmean,scale))
    val = numpyro.sample(name,dist.Uniform(valmin,valmax))
    dict_params[name] = val
    
    # SCALEF
    idx = 16
    name = paramnames[idx]
    valmean = initparamval[idx]
    valmin = minparamval[idx]
    valmax = maxparamval[idx]
    scale = sigmaparamval[idx]    
    #val = numpyro.sample(name,dist.Normal(valmean,scale))
    dict_params[name] = 1.0
    
        
            
    wls,sed_notattenuated,sed_attenuated = ssp_spectrum_fromparam(dict_params,z_obs)   
            
    # interpolate measured
    Fsim = interp1d(wlsin,wls,sed_attenuated)
    
    with numpyro.plate("obs", wlsin.shape[0]):  # les observables sont indépendantes
        numpyro.sample('F', dist.Normal(Fsim, sigma), obs=Fobs)


In [None]:
numpyro.render_model(galaxymodel_nodust, model_args=(jnp.array([0.]),jnp.array([1.]),
                                                     PARAM_VAL, 
                                                     PARAM_MIN,
                                                     PARAM_MAX,
                                                     PARAM_SIGMA, 
                                                     PARAM_NAMES,
                                                     Z_OBS,sigma_obs),render_distributions=True)

In [None]:
numpyro.render_model(galaxymodel_withdust, model_args=(jnp.array([0.]),jnp.array([1.]),
                                                     PARAM_VAL, 
                                                     PARAM_MIN,
                                                     PARAM_MAX,
                                                     PARAM_SIGMA, 
                                                     PARAM_NAMES,
                                                     Z_OBS,sigma_obs),render_distributions=True)

In [None]:

# So, let's generate the data at the fiducial parameters
fiducial_model = condition(galaxymodel_nodust, p.DICT_PARAMS_true)
# Run the conditionned model (if not conditionned the priors are executed)
trace_data_nodust = trace(seed(fiducial_model, jax.random.PRNGKey(42))).get_trace(wlsall,
                       minparamval = PARAM_MIN,
                       maxparamval = PARAM_MAX,
                       sigmaparamval = PARAM_SIGMA,
                       paramnames = PARAM_NAMES,
                       z_obs = Z_OBS,
                       sigma = sigma_obs)

In [None]:
spec_nodust = trace_data_nodust['F']["value"]

In [None]:
# So, let's generate the data at the fiducial parameters
fiducial_model = condition(galaxymodel_withdust, p.DICT_PARAMS_true)
# Run the conditionned model (if not conditionned the priors are executed)
trace_data_withdust = trace(seed(fiducial_model, jax.random.PRNGKey(42))).get_trace(wlsall,
                       minparamval = PARAM_MIN,
                       maxparamval = PARAM_MAX,
                       sigmaparamval = PARAM_SIGMA,
                       paramnames = PARAM_NAMES,
                       z_obs = z_obs,
                       sigma = sigma_obs)

In [None]:
spec_withdust = trace_data_withdust['F']["value"]

In [None]:
fig,ax = plt.subplots(1,1,figsize=(10,3))
ax.errorbar(wlsall,spec_nodust,yerr=sigma_obs,fmt='o',ms=0.5 ,linewidth=2, capsize=0, c='k', label="no dust")
ax.errorbar(wlsall,spec_withdust,yerr=sigma_obs, fmt='o', ms=0.5,linewidth=2, capsize=0, c='r', label="with dust")
ax.set_xlabel("$\lambda$ (nm)")
ax.set_ylabel("DSPS spectrim")
ax.legend()
ax.set_yscale('log')
ax.set_ylim(1e-11,1e-5)
ax.set_xlim(1e2,1e6)
ax.set_xscale('log')
ax.grid();


In [None]:
spec_rest_att_true

In [None]:
if FLAG_NODUST:
    print(f"===========  MCMC simulation : No Dust , num_samples = {NUM_SAMPLES}, n_chains = {N_CHAINS}, num_warmup = {NUM_WARMUP} ========")
    print(f" >>> output file {fileout_nodust_hdf}")


    # Run NUTS.
    rng_key = jax.random.PRNGKey(42)
    rng_key, rng_key0, rng_key1, rng_key2 = jax.random.split(rng_key, 4)


    kernel = NUTS(galaxymodel_nodust, dense_mass=True, target_accept_prob=0.9,
              init_strategy=numpyro.infer.init_to_median())

    mcmc = MCMC(kernel, num_warmup=NUM_WARMUP, num_samples=NUM_SAMPLES,  
            num_chains=N_CHAINS,
            chain_method='vectorized',
            progress_bar=True)
    # see https://forum.pyro.ai/t/cannot-find-valid-initial-parameters-when-using-nuts-for-simple-gaussian-mixture-model-in-numpyro/2181
    with numpyro.validation_enabled():
        mcmc.run(rng_key, wlsin=wlsall, Fobs= spec_rest_true,
                       initparamval = PARAM_VAL,
                       minparamval = PARAM_MIN,
                       maxparamval = PARAM_MAX,
                       sigmaparamval = PARAM_SIGMA,
                       paramnames = PARAM_NAMES,
                       z_obs = Z_OBS,
                       sigma = sigmanodust_obs_true)
                       #extra_fields=('potential_energy',))
        mcmc.print_summary()
        samples_nuts = mcmc.get_samples()

In [None]:
if FLAG_NODUST:
    az.ess(samples_nuts, relative=True)  # efficacité relative

In [None]:
if FLAG_NODUST:
    data = az.from_numpyro(mcmc)
    az.plot_trace(data, compact=True);

## MCMC With DUST

### run MCMC with Dust

In [None]:
if FLAG_DUST:
    print(f"===========  MCMC simulation : with Dust , num_samples = {NUM_SAMPLES}, n_chains = {N_CHAINS}, num_warmup = {NUM_WARMUP} ========")
    print(f" >>> output file {fileout_dust_hdf}")

    # Run NUTS.
    rng_key = jax.random.PRNGKey(42)
    rng_key, rng_key0, rng_key1, rng_key2 = jax.random.split(rng_key, 4)


    kernel = NUTS(galaxymodel_nodust, dense_mass=True, target_accept_prob=0.9,
              init_strategy=numpyro.infer.init_to_median())

    mcmc = MCMC(kernel, num_warmup=NUM_WARMUP, num_samples=NUM_SAMPLES,  
            num_chains=N_CHAINS,
            chain_method='vectorized',
            progress_bar=True)
    # see https://forum.pyro.ai/t/cannot-find-valid-initial-parameters-when-using-nuts-for-simple-gaussian-mixture-model-in-numpyro/2181
    with numpyro.validation_enabled():
        mcmc.run(rng_key, wlsin=wlsall, Fobs = spec_rest_att_true,
                       initparamval = PARAM_VAL,
                       minparamval = PARAM_MIN,
                       maxparamval = PARAM_MAX,
                       sigmaparamval = PARAM_SIGMA,
                       paramnames = PARAM_NAMES,
                       z_obs = Z_OBS,
                       sigma = sigmadust_obs_true)
                       #extra_fields=('potential_energy',))
        mcmc.print_summary()
        samples_nuts = mcmc.get_samples()

In [None]:
if FLAG_DUST:
    az.ess(samples_nuts, relative=True)  # efficacité relative

In [None]:
if FLAG_DUST:
    data = az.from_numpyro(mcmc)
    az.plot_trace(data, compact=True);