# plot the correlation between the SSP Parameters

Generate simulation data to study the correlation between the parameters using `fors2tostellarpopsynthesis`package

- Author Sylvie Dagoret-Campagne
- Afflilation : IJCLab/IN2P3/CNRS
- Organisation : LSST-DESC
- creation date : 2023-12-05
- last update : 2023-12-06


| computer | processor | kernel              |    date     |
| --- | --- | --- | --- |
| CC       | CPU       | conda_jax0325_py310 | 2023-11-10  |
| macbookpro | CPU | conda_jaxcpu_dsps_py310 | 2023-11-10  | 


libraries 
=========

jax
---

- jaxlib-0.3.25+cuda11.cudnn82
- jaxopt
- optax
- corner
- arviz
- numpyro
- graphviz

sps
---

- fsps
- prospect
- dsps
- diffstar
- diffmah



(conda_jax0325_py310) 
`/pbs/throng/lsst/users/dagoret/desc/StellarPopulationSynthesis>pip list | grep` 

| lib | version |
|--- | --- | 
|jax  |                         0.4.20 |
|jaxlib |                       0.4.20 |
|jaxopt  |                      0.8.2 |



## examples

- jaxcosmo : https://github.com/DifferentiableUniverseInitiative/jax-cosmo-paper/blob/master/notebooks/VectorizedNumPyro.ipynb
- on atmosphere : https://github.com/sylvielsstfr/FitDiffAtmo/blob/main/docs/notebooks/fitdiffatmo/test_numpyro_orderedict_diffatmemul_5params_P_pwv_oz_tau_beta.ipynb

## Import

### import external packages

In [None]:
import h5py
import pandas as pd
import numpy as np
import os
import re
import pickle 
import matplotlib as mpl
import matplotlib.pyplot as plt
%matplotlib inline
import matplotlib.colors as colors
import matplotlib.cm as cmx
import collections
from collections import OrderedDict
import re
import matplotlib.gridspec as gridspec
from sklearn.gaussian_process import GaussianProcessRegressor, kernels

In [None]:
import jax
import jax.numpy as jnp
from jax import vmap
import jaxopt
import optax
jax.config.update("jax_enable_x64", True)
from interpax import interp1d

from jax.lax import fori_loop
from jax.lax import select,cond
from jax.lax import concatenate

In [None]:
import numpyro
from numpyro import optim
from numpyro.diagnostics import print_summary
import numpyro.distributions as dist
from numpyro.distributions import constraints
from numpyro.infer import MCMC, HMC, NUTS, SVI, Trace_ELBO
from numpyro.infer.autoguide import AutoBNAFNormal, AutoMultivariateNormal
from numpyro.infer.reparam import NeuTraReparam
from numpyro.handlers import seed, trace, condition


import corner
import arviz as az

In [None]:
import pickle

### import internal packages

In [None]:
from fors2tostellarpopsynthesis.parameters  import SSPParametersFit,paramslist_to_dict

In [None]:
from fors2tostellarpopsynthesis.fitters.fitter_jaxopt import (SSP_DATA,mean_spectrum,mean_mags,mean_sfr,ssp_spectrum_fromparam)

In [None]:
from fors2tostellarpopsynthesis.fitters.fitter_numpyro import(PARAM_SIMLAW_NODUST,PARAM_SIMLAW_WITHDUST,
                            PARAM_NAMES,PARAM_VAL,PARAM_MIN,PARAM_MAX,PARAM_SIGMA)

from fors2tostellarpopsynthesis.fitters.fitter_numpyro import(galaxymodel_nodust_av,galaxymodel_nodust,galaxymodel_withdust_av,galaxymodel_withdust)

In [None]:
from fors2tostellarpopsynthesis.fitters.fitter_util import plot_params_kde

## Configuration

In [None]:
Lyman_lines = [1220., 1030. ,973.,950., 938., 930.]
Balmer_lines = [6562.791,4861.351,4340.4721,4101.740,3970.072,3889.0641,3835.3971]
Paschen_lines = [8750., 12820., 10938.0,10050., 9546.2, 9229.7,9015.3, 8862.89,8750.46,8665.02]
Brackett_lines = [40522.79, 26258.71, 21661.178, 19440., 18179.21]
Pfund_lines = [ 74599.0, 46537.8, 37405.76 , 32969.8, 30400.]
all_Hydrogen_lines = [ Lyman_lines, Balmer_lines, Paschen_lines, Brackett_lines, Pfund_lines]  

### matplotlib configuration

In [None]:
plt.rcParams["figure.figsize"] = (12,12)
plt.rcParams["axes.labelsize"] = 'xx-large'
plt.rcParams['axes.titlesize'] = 'xx-large'
plt.rcParams['xtick.labelsize']= 'xx-large'
plt.rcParams['ytick.labelsize']= 'xx-large'
plt.rcParams['legend.fontsize']=  12
plt.rcParams['font.size'] = 12

## Fit parameters

In [None]:
p = SSPParametersFit()

In [None]:
#p.INIT_PARAMS = p.INIT_PARAMS.at[-4].set(0.)
#p.INIT_PARAMS = p.INIT_PARAMS.at[-3].set(0.)
#p.INIT_PARAMS = p.INIT_PARAMS.at[-2].set(0.)
#p.INIT_PARAMS = p.INIT_PARAMS.at[-1].set(1.)

In [None]:
wlsall,spec_rest,spec_rest_att = ssp_spectrum_fromparam(p.DICT_PARAMS_true,0)

In [None]:
print(PARAM_SIMLAW_NODUST)  
print(PARAM_SIMLAW_WITHDUST) 
print(PARAM_NAMES)
print(PARAM_VAL)
print(PARAM_MIN) 
print(PARAM_MAX)
print(PARAM_SIGMA)

In [None]:
z_obs = 0.5
sigmarel_obs = 0.1
sigma_obs = 1e-11

## Read MCMC without dust

In [None]:
PARAM_NAMES

In [None]:
PARAM_NODUST = np.setdiff1d(PARAM_NAMES,['AV', 'UV_BUMP', 'PLAW_SLOPE', 'SCALEF'])

In [None]:
PARAM_NODUST = ['MAH_lgmO', 'MAH_logtc', 'MAH_early_index', 'MAH_late_index',
       'MS_lgmcrit', 'MS_lgy_at_mcrit', 'MS_indx_lo', 'MS_indx_hi',
       'MS_tau_dep', 'Q_lg_qt', 'Q_qlglgdt', 'Q_lg_drop', 'Q_lg_rejuv']

In [None]:
PARAM_NODUST_DICT = OrderedDict()
for name in PARAM_NODUST:
    PARAM_NODUST_DICT[name] = f"{name}"   

### Output file for MCMC without dust

In [None]:
!ls datamcmcparams/

In [None]:
#filein_hdf = f"datamcmcparams/DSPS_nodust_mcmc_params_wide.hdf"
filein_hdf = f"datamcmcparams/DSPS_nodust_mcmc_params_4_500_1000.hdf"

In [None]:
with h5py.File(filein_hdf, 'r') as f:
    keys = list(f.keys())
    print(keys)

In [None]:
df_info = pd.read_hdf(filein_hdf,key="info")
df = pd.read_hdf(filein_hdf,key="dsps_mcmc_nodust")

In [None]:
df_info

In [None]:
df

In [None]:
dict_params_nodust = df.to_dict('list')

In [None]:
reordered_samples = OrderedDict()
for key in PARAM_NODUST:
    reordered_samples[key] = dict_params_nodust[key]

In [None]:
#plot_params_kde(reordered_samples, pcut=[0.001,99.999], var_names=PARAM_NODUST)

In [None]:
reordered_samples.keys()

In [None]:
import arviz.labels as azl
#'MAH_early_index', 'MAH_late_index', 'MAH_lgmO', 'MAH_logtc',
#       'MS_indx_hi', 'MS_indx_lo', 'MS_lgmcrit', 'MS_lgy_at_mcrit',
#       'MS_tau_dep', 'Q_lg_drop', 'Q_lg_qt', 'Q_lg_rejuv', 'Q_qlglgdt'
labeller = azl.MapLabeller(var_name_map=PARAM_NODUST_DICT)


In [None]:
az.rcParams["plot.max_subplots"] = 200 
nparams = len(PARAM_NODUST)
par_names = PARAM_NODUST
par_true = p.DICT_PARAMS_true
ax=az.plot_pair(
        reordered_samples,
        kind="kde",
        labeller=labeller,
        marginal_kwargs={"plot_kwargs": {"lw":3, "c":"blue", "ls":"-"}},
        kde_kwargs={
            "hdi_probs": [0.3, 0.68, 0.9],  # Plot 30%, 68% and 90% HDI contours
            "contour_kwargs":{"colors":None, "cmap":"Blues", "linewidths":3,
                              "linestyles":"-"},
            "contourf_kwargs":{"alpha":0.5},
        },
        point_estimate_kwargs={"lw": 3, "c": "b"},
        marginals=True, textsize=50, point_estimate='median',
    );

# plot true parameter point
for idy in range(nparams):
    for idx in range(idy):
        label_x = par_names[idx]
        label_y = par_names[idy]
        ax[idy,idx].scatter(par_true[label_x],par_true[label_y],c="r",s=150,zorder=10)
        

for idx,name in enumerate(par_names):
    ax[idx,idx].axvline(par_true[name],c='r',lw=3)
    

In [None]:
#plot_params_kde(reordered_samples, pcut=[0.001,99.999],var_names=PARAM_NODUST)

## Read MCMC with dust

In [None]:
PARAM_NAMES

In [None]:
PARAM_WITHDUST = ['MAH_lgmO', 'MAH_logtc', 'MAH_early_index', 'MAH_late_index',
       'MS_lgmcrit', 'MS_lgy_at_mcrit', 'MS_indx_lo', 'MS_indx_hi',
       'MS_tau_dep', 'Q_lg_qt', 'Q_qlglgdt', 'Q_lg_drop', 'Q_lg_rejuv',
       'AV', 'UV_BUMP', 'PLAW_SLOPE']

In [None]:
PARAM_WITHDUST_DICT = OrderedDict()
for name in PARAM_WITHDUST:
    PARAM_WITHDUST_DICT[name] = f"{name}"  

### Output file for MCMC with dust

In [None]:
#filein_hdf = f"datamcmcparams/DSPS_dust_mcmc_params.hdf"
filein_hdf = f"datamcmcparams/DSPS_dust_mcmc_params_4_500_1000.hdf"

In [None]:
!ls datamcmcparams/

In [None]:
with h5py.File(filein_hdf, 'r') as f:
    keys = list(f.keys())
    print(keys)

In [None]:
df_info = pd.read_hdf(filein_hdf,key="info")
df = pd.read_hdf(filein_hdf,key="dsps_mcmc_dust")

In [None]:
dict_params_withdust = df.to_dict('list')

In [None]:
reordered_samples = OrderedDict()
for key in PARAM_WITHDUST:
    reordered_samples[key] = dict_params_withdust[key]

In [None]:
labeller = azl.MapLabeller(var_name_map=PARAM_WITHDUST_DICT)

In [None]:
az.rcParams["plot.max_subplots"] = 200 
nparams = len(PARAM_WITHDUST)
par_names = PARAM_WITHDUST
par_true = p.DICT_PARAMS_true
ax=az.plot_pair(
        reordered_samples,
        kind="kde",
        labeller=labeller,
        marginal_kwargs={"plot_kwargs": {"lw":3, "c":"blue", "ls":"-"}},
        kde_kwargs={
            "hdi_probs": [0.3, 0.68, 0.9],  # Plot 30%, 68% and 90% HDI contours
            "contour_kwargs":{"colors":None, "cmap":"Blues", "linewidths":3,
                              "linestyles":"-"},
            "contourf_kwargs":{"alpha":0.5},
        },
        point_estimate_kwargs={"lw": 3, "c": "b"},
        marginals=True, textsize=50, point_estimate='median',
    );

# plot true parameter point
for idy in range(nparams):
    for idx in range(idy):
        label_x = par_names[idx]
        label_y = par_names[idy]
        ax[idy,idx].scatter(par_true[label_x],par_true[label_y],c="r",s=150,zorder=10)
        

for idx,name in enumerate(par_names):
    ax[idx,idx].axvline(par_true[name],c='r',lw=3)
    