# Study the correlation between the SSP Parameters

Generate simulation data to study the correlation between the parameters using `fors2tostellarpopsynthesis`package

- Author Sylvie Dagoret-Campagne
- Afflilation : IJCLab/IN2P3/CNRS
- Organisation : LSST-DESC
- creation date : 2023-12-05
- last update : 2023-12-06


| computer | processor | kernel              |    date     |
| --- | --- | --- | --- |
| CC       | CPU       | conda_jax0325_py310 | 2023-11-10  |
| macbookpro | CPU | conda_jaxcpu_dsps_py310 | 2023-11-10  | 


libraries 
=========

jax
---

- jaxlib-0.3.25+cuda11.cudnn82
- jaxopt
- optax
- corner
- arviz
- numpyro
- graphviz

sps
---

- fsps
- prospect
- dsps
- diffstar
- diffmah



(conda_jax0325_py310) 
`/pbs/throng/lsst/users/dagoret/desc/StellarPopulationSynthesis>pip list | grep` 

| lib | version |
|--- | --- | 
|jax  |                         0.4.20 |
|jaxlib |                       0.4.20 |
|jaxopt  |                      0.8.2 |



## examples

- jaxcosmo : https://github.com/DifferentiableUniverseInitiative/jax-cosmo-paper/blob/master/notebooks/VectorizedNumPyro.ipynb
- on atmosphere : https://github.com/sylvielsstfr/FitDiffAtmo/blob/main/docs/notebooks/fitdiffatmo/test_numpyro_orderedict_diffatmemul_5params_P_pwv_oz_tau_beta.ipynb

## Import

### import external packages

In [None]:
import h5py
import pandas as pd
import numpy as np
import os
import re
import pickle 
import matplotlib as mpl
import matplotlib.pyplot as plt
%matplotlib inline
import matplotlib.colors as colors
import matplotlib.cm as cmx
import collections
from collections import OrderedDict
import re
import matplotlib.gridspec as gridspec
from sklearn.gaussian_process import GaussianProcessRegressor, kernels
import copy

In [None]:
import jax
import jax.numpy as jnp
from jax import vmap
import jaxopt
import optax
jax.config.update("jax_enable_x64", True)
from interpax import interp1d

from jax.lax import fori_loop
from jax.lax import select,cond
from jax.lax import concatenate

In [None]:
import numpyro
from numpyro import optim
from numpyro.diagnostics import print_summary
import numpyro.distributions as dist
from numpyro.distributions import constraints
from numpyro.infer import MCMC, HMC, NUTS, SVI, Trace_ELBO
from numpyro.infer.autoguide import AutoBNAFNormal, AutoMultivariateNormal
from numpyro.infer.reparam import NeuTraReparam
from numpyro.handlers import seed, trace, condition


import corner
import arviz as az

### import internal packages

In [None]:
from fors2tostellarpopsynthesis.parameters  import SSPParametersFit,paramslist_to_dict

In [None]:
from fors2tostellarpopsynthesis.fitters.fitter_jaxopt import (SSP_DATA,mean_spectrum,mean_mags,mean_sfr,ssp_spectrum_fromparam)

In [None]:
from fors2tostellarpopsynthesis.fitters.fitter_numpyro import(PARAM_SIMLAW_NODUST,PARAM_SIMLAW_WITHDUST,
                            PARAM_NAMES,PARAM_VAL,PARAM_MIN,PARAM_MAX,PARAM_SIGMA)

from fors2tostellarpopsynthesis.fitters.fitter_numpyro import(galaxymodel_nodust_av,galaxymodel_nodust,galaxymodel_withdust_av,galaxymodel_withdust)

In [None]:
from fors2tostellarpopsynthesis.fitters.fitter_util import plot_params_kde

## Configuration

### matplotlib configuration

In [None]:
plt.rcParams["figure.figsize"] = (12,6)
plt.rcParams["axes.labelsize"] = 'xx-large'
plt.rcParams['axes.titlesize'] = 'xx-large'
plt.rcParams['xtick.labelsize']= 'xx-large'
plt.rcParams['ytick.labelsize']= 'xx-large'
plt.rcParams['legend.fontsize']=  12
plt.rcParams['font.size'] = 12

### Steering MCMC

In [None]:
## Steering MCMC
#NUM_SAMPLES = 5_000
#N_CHAINS    = 4
#NUM_WARMUP  = 1_000
NUM_SAMPLES = 1_000
N_CHAINS    = 4
NUM_WARMUP  = 500
df_mcmc = pd.Series({"num_samples":NUM_SAMPLES, "n_chains":N_CHAINS, "num_warmup":NUM_WARMUP})


In [None]:
print("=========== Start MCMC  ============= :")
df_mcmc

### Selection on what to simulate and output

In [None]:
FLAG_NODUST = True
FLAG_DUST = True

## Defining MCMC output files

#fileout_nodust_pickle = f"DSPS_nodust_mcmc_params_{N_CHAINS}_{NUM_WARMUP}_{NUM_SAMPLES}.pickle"
#fileout_nodust_csv = f"DSPS_nodust_mcmc_params_{N_CHAINS}_{NUM_WARMUP}_{NUM_SAMPLES}.csv"
fileout_nodust_hdf = f"DSPS_nodust_mcmc_params_{N_CHAINS}_{NUM_WARMUP}_{NUM_SAMPLES}.hdf"

#fileout_dust_pickle = f"DSPS_dust_mcmc_params_{N_CHAINS}_{NUM_WARMUP}_{NUM_SAMPLES}.pickle"
#fileout_dust_csv = f"DSPS_dust_mcmc_params_{N_CHAINS}_{NUM_WARMUP}_{NUM_SAMPLES}.csv"
fileout_dust_hdf = f"DSPS_dust_mcmc_params_{N_CHAINS}_{NUM_WARMUP}_{NUM_SAMPLES}.hdf"

### Observation

In [None]:
# Select observation
# - choose the redshift of the observation
# - choose the relative error on flux at all wavelength. Here choose 10%
# - choose the absolute error. This value will be overwritten 
#   after recalculating the absolute error for each wavelength (provided an an array)
Z_OBS = 0.5
SIGMAREL_OBS = 0.1
SIGMA_OBS = 1e-11

df_info = pd.Series({"z_obs":Z_OBS,"sigmarel_obs":SIGMAREL_OBS})

In [None]:
print("=========== Start Observations  ============= :")
df_info

## Fit parameters

In [None]:
p = SSPParametersFit()

In [None]:
# select parameter true values and change it wrt default value
dict_sel_params_true = copy.deepcopy(p.DICT_PARAMS_true)
dict_sel_params_true['MAH_lgmO'] = 10.0
dict_sel_params_true['MAH_logtc'] = 0.8
dict_sel_params_true['MAH_early_index'] = 3.0
dict_sel_params_true['MAH_late_index'] = 0.5
dict_sel_params_true['AV'] = 0.5
dict_sel_params_true['UV_BUMP'] = 2.5
dict_sel_params_true['PLAW_SLOPE'] = -0.1

list_sel_params_true = list(dict_sel_params_true.values())

df_params = pd.DataFrame({"name":PARAM_NAMES,
                          "min": PARAM_MIN,
                          "val": PARAM_VAL,
                          "max": PARAM_MAX,
                          "sig":PARAM_SIGMA,
                          "true":list_sel_params_true})

df_params  = df_params.round(decimals=3)



In [None]:
print("=========== DSPS Parameters to fit ============= :")
df_params

## True value spectra

In [None]:
# generate spectrum from true selected values
# - it provide the wlsamm wavelength array
# - it provide the suposedly measured spectrum from true parameter
wlsall,spec_rest_noatt_true,spec_rest_att_true = ssp_spectrum_fromparam(dict_sel_params_true,Z_OBS)

In [None]:
# The problem is the above is that the parameters are drawn randomly
# Thus redefine the errors properly 
sigmanodust_obs_true = SIGMAREL_OBS*spec_rest_noatt_true
sigmadust_obs_true = SIGMAREL_OBS*spec_rest_att_true

In [None]:
fig = plt.figure(figsize=(10,6))
gs = fig.add_gridspec(2, 1,height_ratios=[3.,1.5])
ax1 = fig.add_subplot(gs[0])
ax2 = fig.add_subplot(gs[1])

ax1.set_xscale('log')
ax1.set_yscale('log')
ax1.plot(wlsall,spec_rest_noatt_true,linewidth=2,color="k" ,label="True spec no dust")
ax1.plot(wlsall,spec_rest_att_true, linewidth=2 ,color='r', label="True spec with dust")

ax1.errorbar(wlsall,spec_rest_noatt_true,yerr=sigmanodust_obs_true,fmt='o',ms=0.5 ,linewidth=2, capsize=0, c='k', label="no dust")
ax1.errorbar(wlsall,spec_rest_att_true,yerr=sigmadust_obs_true, fmt='o', ms=0.5,linewidth=2, capsize=0, c='r', label="with dust")
ax1.set_xlabel("$\lambda$ (nm)")
ax1.set_ylabel("DSPS SED ")
ax1.legend()

ymax = jnp.max(spec_rest_noatt_true)*2.
ymin = ymax/1e6
ax1.set_ylim(ymin,ymax)
ax1.set_xlim(1e2,1e6)
ax1.grid();
ax1.set_title("True spectra")

#error
ax2.set_yscale('log')
ax2.set_xscale('log')

ax2.plot(wlsall,sigmanodust_obs_true,'k-')
ax2.plot(wlsall,sigmadust_obs_true,'r-')
ax2.set_xlabel("$\lambda$ (nm)")
ax2.set_ylabel("DSPS SED error")


ymax = jnp.max(sigmanodust_obs_true)*5.
ymin = ymax/1e6
ax2.set_ylim(ymin,ymax)
ax2.set_xlim(1e2,1e6)
ax2.grid();

plt.tight_layout()

## calculate the array of errors on the spectrum by using the average models in numpyro

In [None]:
# calculate the array of errors on the spectrum by using the average models in numpyro
# the goal is to obtain 
# - sigmanodust_obs to replace sigma_obs
# - sigmadust_obs to replace sigma_obs
with seed(rng_seed=42):
    spec_nodust,sigmanodust_obs = galaxymodel_nodust_av(wlsall,Fobs=None,
                       initparamval = PARAM_VAL, minparamval = PARAM_MIN,maxparamval = PARAM_MAX,
                       sigmaparamval = PARAM_SIGMA,paramnames = PARAM_NAMES,z_obs= Z_OBS, sigmarel = SIGMAREL_OBS)


In [None]:
with seed(rng_seed=42):
    spec_withdust,sigmadust_obs = galaxymodel_withdust_av(wlsall,Fobs=None,
                       initparamval = PARAM_VAL, minparamval = PARAM_MIN,maxparamval = PARAM_MAX,
                       sigmaparamval = PARAM_SIGMA,paramnames = PARAM_NAMES,z_obs= Z_OBS, sigmarel = SIGMAREL_OBS)


In [None]:
fig = plt.figure(figsize=(10,5))
gs = fig.add_gridspec(2, 1,height_ratios=[3.,1.])
ax1 = fig.add_subplot(gs[0])
ax2 = fig.add_subplot(gs[1])

ax1.set_xscale('log')
ax1.set_yscale('log')
ax1.plot(wlsall,spec_nodust,linewidth=2,color="k" ,label="sim spec no dust")
ax1.plot(wlsall,spec_withdust, linewidth=2 ,color='r', label="sim spec with dust")

ax1.errorbar(wlsall,spec_nodust,yerr=sigmanodust_obs,fmt='o',ms=0.5 ,linewidth=2, capsize=0, c='k', label="no dust")
ax1.errorbar(wlsall,spec_withdust,yerr=sigmadust_obs, fmt='o', ms=0.5,linewidth=2, capsize=0, c='r', label="with dust")
ax1.set_xlabel("$\lambda$ (nm)")
ax1.set_ylabel("DSPS SED ")
ax1.set_title("Sim spectra")
ax1.legend()

ymax = jnp.max(spec_nodust)*2.
ymin = ymax/1e6
ax1.set_ylim(ymin,ymax)
ax1.set_xlim(1e2,1e6)
ax1.grid();

#error
ax2.set_yscale('log')
ax2.set_xscale('log')
ymax = jnp.max(sigmanodust_obs)*5.
ymin = ymax/1e6
ax2.plot(wlsall,sigmanodust_obs,'k-')
ax2.plot(wlsall,sigmadust_obs,'r-')
ax2.set_xlabel("$\lambda$ (nm)")
ax2.set_ylabel("DSPS SED error")

ax2.set_ylim(ymin,ymax)
ax2.set_xlim(1e2,1e6)
ax2.grid();

plt.tight_layout()

## Bayesian modelling

In [None]:
condlist_fix = jnp.where(PARAM_SIMLAW_NODUST == "fixed",True,False)
condlist_fix

In [None]:
condlist_uniform = jnp.where(PARAM_SIMLAW_NODUST == "uniform",True,False)
condlist_uniform

In [None]:
numpyro.render_model(galaxymodel_nodust, model_args=(jnp.array([0.]),jnp.array([1.]),
                                                     PARAM_VAL, 
                                                     PARAM_MIN,
                                                     PARAM_MAX,
                                                     PARAM_SIGMA, 
                                                     PARAM_NAMES,
                                                     Z_OBS,sigmanodust_obs),render_distributions=True)

In [None]:
numpyro.render_model(galaxymodel_withdust, model_args=(jnp.array([0.]),jnp.array([1.]),
                                                     PARAM_VAL, 
                                                     PARAM_MIN,
                                                     PARAM_MAX,
                                                     PARAM_SIGMA, 
                                                     PARAM_NAMES,
                                                     Z_OBS,sigmadust_obs),render_distributions=True)

In [None]:
# So, let's generate the data at the fiducial parameters
fiducial_model = condition(galaxymodel_nodust, dict_sel_params_true)
# Run the conditionned model (if not conditionned the priors are executed)
trace_data_nodust = trace(seed(fiducial_model, jax.random.PRNGKey(42))).get_trace(wlsall,
                       initparamval = PARAM_VAL,                                                           
                       minparamval = PARAM_MIN,
                       maxparamval = PARAM_MAX,
                       sigmaparamval = PARAM_SIGMA,
                       paramnames = PARAM_NAMES,
                       z_obs = Z_OBS,
                       sigma = sigmanodust_obs_true)

In [None]:
spec_nodust = trace_data_nodust['F']["value"]

In [None]:
# So, let's generate the data at the fiducial parameters
fiducial_model = condition(galaxymodel_withdust, dict_sel_params_true)
# Run the conditionned model (if not conditionned the priors are executed)
trace_data_withdust = trace(seed(fiducial_model, jax.random.PRNGKey(42))).get_trace(wlsall,
                       initparamval = PARAM_VAL,                                                             
                       minparamval = PARAM_MIN,
                       maxparamval = PARAM_MAX,
                       sigmaparamval = PARAM_SIGMA,
                       paramnames = PARAM_NAMES,
                       z_obs = Z_OBS,
                       sigma = sigmadust_obs_true)

In [None]:
#trace_data_withdust

In [None]:
spec_withdust = trace_data_withdust['F']["value"]

In [None]:
fig,ax = plt.subplots(1,1,figsize=(10,4))
ax.errorbar(wlsall,spec_nodust,yerr=sigmanodust_obs_true,fmt='o',ms=0.5 ,linewidth=2, capsize=0, c='k', label="no dust")
ax.errorbar(wlsall,spec_withdust,yerr=sigmadust_obs_true, fmt='o', ms=0.5,linewidth=2, capsize=0, c='r', label="with dust")
ax.set_xlabel("$\lambda$ (nm)")
ax.set_ylabel("DSPS spectrum")
ax.legend()
ax.set_yscale('log')
ax.set_xscale('log')
ymax = jnp.max(spec_nodust)
ymin = ymax/1e6
ax.set_ylim(ymin,ymax)
ax.set_xlim(1e2,1e6)

ax.grid();
ax.set_title("simulated spectrum")

## MCMC without dust

### Run MCMC

In [None]:
if FLAG_NODUST:
    print(f"===========  MCMC simulation : No Dust , num_samples = {NUM_SAMPLES}, n_chains = {N_CHAINS}, num_warmup = {NUM_WARMUP} ========")
    print(f" >>> output file {fileout_nodust_hdf}")

    # Run NUTS.
    rng_key = jax.random.PRNGKey(42)
    rng_key, rng_key0, rng_key1, rng_key2 = jax.random.split(rng_key, 4)


    kernel = NUTS(galaxymodel_nodust, dense_mass=True, target_accept_prob=0.9,init_strategy=numpyro.infer.init_to_median())

    
    mcmc = MCMC(kernel, num_warmup= NUM_WARMUP, num_samples=NUM_SAMPLES,  
            num_chains=N_CHAINS,
            chain_method='vectorized',
            progress_bar=True)
    # see https://forum.pyro.ai/t/cannot-find-valid-initial-parameters-when-using-nuts-for-simple-gaussian-mixture-model-in-numpyro/2181/2
    with numpyro.validation_enabled():
        mcmc.run(rng_key, wlsin=wlsall, Fobs = spec_rest_noatt_true, 
                       initparamval = PARAM_VAL,
                       minparamval = PARAM_MIN,
                       maxparamval = PARAM_MAX,
                       sigmaparamval = PARAM_SIGMA,
                       paramnames = PARAM_NAMES,
                       z_obs = Z_OBS,
                       sigma = sigmanodust_obs_true)
                       #extra_fields=('potential_energy',))
        mcmc.print_summary()
        samples_nuts = mcmc.get_samples()

### Results of MCMC no dust

In [None]:
if FLAG_NODUST:
    az.ess(samples_nuts, relative=True)  # efficacité relative

In [None]:
if FLAG_NODUST:
    data = az.from_numpyro(mcmc)
    az.plot_trace(data, compact=True);

In [None]:
if FLAG_NODUST:
    data

### Pandas dataframe

In [None]:
if FLAG_NODUST:
    df_nodust = pd.DataFrame(samples_nuts) 

In [None]:
#df_nodust

In [None]:
if FLAG_NODUST:
    #with open(fileout_pickle, 'wb') as handle:
    #    pickle.dump(samples_nuts, handle, protocol=pickle.HIGHEST_PROTOCOL)
    #df_nodust.to_csv(fileout_csv)

    df_nodust.to_hdf(fileout_nodust_hdf, key="dsps_mcmc_nodust",mode='a', complevel=9) 
    df_info.to_hdf(fileout_nodust_hdf,key="obs",mode='a')
    df_params.to_hdf(fileout_nodust_hdf,key="params",mode='a')
    df_mcmc.to_hdf(fileout_nodust_hdf,key="mcmc",mode='a')

## MCMC With DUST

### run MCMC with Dust

In [None]:
if FLAG_DUST:
    print(f"===========  MCMC simulation : With Dust , num_samples = {NUM_SAMPLES}, n_chains = {N_CHAINS}, num_warmup = {NUM_WARMUP} ========")
    print(f" >>> output file {fileout_dust_hdf}")
    # Run NUTS.
    rng_key = jax.random.PRNGKey(42)
    rng_key, rng_key0, rng_key1, rng_key2 = jax.random.split(rng_key, 4)


    kernel = NUTS(galaxymodel_withdust, dense_mass=True, target_accept_prob=0.9,
              init_strategy=numpyro.infer.init_to_median())
    
    mcmc = MCMC(kernel, num_warmup=NUM_WARMUP, num_samples=NUM_SAMPLES,  
            num_chains=N_CHAINS,
            chain_method='vectorized',
            progress_bar=True)

    # see https://forum.pyro.ai/t/cannot-find-valid-initial-parameters-when-using-nuts-for-simple-gaussian-mixture-model-in-numpyro/2181/2
    with numpyro.validation_enabled():
        mcmc.run(rng_key, wlsin=wlsall, Fobs=spec_rest_att_true, 
                       initparamval = PARAM_VAL,
                       minparamval = PARAM_MIN,
                       maxparamval = PARAM_MAX,
                       sigmaparamval = PARAM_SIGMA,
                       paramnames = PARAM_NAMES,
                       z_obs = Z_OBS,
                       sigma = sigmadust_obs_true)
                       #extra_fields=('potential_energy',))
        mcmc.print_summary()
        samples_nuts = mcmc.get_samples()

### results of MCMC with DUST

In [None]:
if FLAG_DUST:
    az.ess(samples_nuts, relative=True)  # efficacité relative

In [None]:
if FLAG_DUST:
    data = az.from_numpyro(mcmc)
    az.plot_trace(data, compact=True);
    plt.tight_layout()

### pandas dataframe with dust

In [None]:
if FLAG_DUST:
    df_dust = pd.DataFrame(samples_nuts) 

In [None]:
#df_dust

In [None]:
if FLAG_DUST:
    
    #with open(fileout_pickle, 'wb') as handle:
    #    pickle.dump(samples_nuts, handle, protocol=pickle.HIGHEST_PROTOCOL)
    #df_dust.to_csv(fileout_csv)
    df_dust = pd.DataFrame(samples_nuts)
    df_dust.to_hdf(fileout_dust_hdf, key="dsps_mcmc_dust",mode='a', complevel=9)
    df_info.to_hdf(fileout_dust_hdf,key="obs",mode='a')
    df_params.to_hdf(fileout_dust_hdf,key="params",mode='a')
    df_mcmc.to_hdf(fileout_dust_hdf,key="mcmc",mode='a')
    


In [None]:
data

In [None]:
mcmc.print_summary()

In [None]:
df_dust.describe()