# Study the Spectrum and vary many parameters

Study the correlation between the parameters using `fors2tostellarpopsynthesis`package

- Author Sylvie Dagoret-Campagne
- Afflilation : IJCLab/IN2P3/CNRS
- Organisation : LSST-DESC
- creation date : 2023-12-04
- last update : 2023-12-04


| computer | processor | kernel              |    date     |
| --- | --- | --- | --- |
| CC       | CPU       | conda_jax0325_py310 | 2023-11-10  |
| macbookpro | CPU | conda_jaxcpu_dsps_py310 | 2023-11-10  | 


libraries 
=========

jax
---

- jaxlib-0.3.25+cuda11.cudnn82
- jaxopt
- optax
- corner
- arviz
- numpyro
- graphviz

sps
---

- fsps
- prospect
- dsps
- diffstar
- diffmah



(conda_jax0325_py310) 
`/pbs/throng/lsst/users/dagoret/desc/StellarPopulationSynthesis>pip list | grep` 

| lib | version |
|--- | --- | 
|jax  |                         0.4.20 |
|jaxlib |                       0.4.20 |
|jaxopt  |                      0.8.2 |



## Import

### import external packages

In [None]:
import h5py
import pandas as pd
import numpy as np
import os
import re
import pickle 
import matplotlib as mpl
import matplotlib.pyplot as plt
%matplotlib inline
import matplotlib.colors as colors
import matplotlib.cm as cmx
import collections
from collections import OrderedDict
import re
import matplotlib.gridspec as gridspec
from sklearn.gaussian_process import GaussianProcessRegressor, kernels

In [None]:
import jax
import jax.numpy as jnp
from jax import vmap
import jaxopt
import optax

from interpax import interp1d

In [None]:
from jax import random

In [None]:
import numpyro
from numpyro import optim
from numpyro.diagnostics import print_summary
import numpyro.distributions as dist
from numpyro.distributions import constraints
from numpyro.infer import MCMC, HMC, NUTS, SVI, Trace_ELBO
from numpyro.infer.autoguide import AutoBNAFNormal, AutoMultivariateNormal
from numpyro.infer.reparam import NeuTraReparam
from numpyro.handlers import seed, trace, condition
import corner
import arviz as az

### import internal packages

In [None]:
from fors2tostellarpopsynthesis.parameters  import SSPParametersFit,paramslist_to_dict

In [None]:
from fors2tostellarpopsynthesis.fitters.fitter_jaxopt import (lik_spec,lik_mag,lik_comb,
get_infos_spec,get_infos_mag,get_infos_comb)

from fors2tostellarpopsynthesis.fitters.fitter_jaxopt import (SSP_DATA,mean_spectrum,mean_mags,mean_sfr,ssp_spectrum_fromparam)

## Configuration

### Jax

In [None]:
jax.config.update("jax_enable_x64", True)
key = random.PRNGKey(0)
key

### matplotlib configuration

In [None]:
plt.rcParams["figure.figsize"] = (12,6)
plt.rcParams["axes.labelsize"] = 'xx-large'
plt.rcParams['axes.titlesize'] = 'xx-large'
plt.rcParams['xtick.labelsize']= 'xx-large'
plt.rcParams['ytick.labelsize']= 'xx-large'
plt.rcParams['legend.fontsize']=  16
plt.rcParams['font.size'] = 15

## Fit parameters

In [None]:
p = SSPParametersFit()

In [None]:
dir(p)

In [None]:
p.DICT_PARAMS_true

In [None]:
p.INIT_PARAMS

In [None]:
p.PARAM_NAMES_FLAT

In [None]:
p.PARAMS_MIN

In [None]:
p.PARAMS_MAX

In [None]:
wls,spec_rest,spec_rest_att = ssp_spectrum_fromparam(p.DICT_PARAMS_true,0)

## Simulation of parameters

### Selected parameters

In [None]:
param_to_simulate = ['MAH_logtc','MAH_early_index','MAH_late_index',
                     'MS_lgmcrit','MS_lgy_at_mcrit','MS_indx_lo','MS_indx_hi','MS_tau_dep',
                     'Q_lg_qt','Q_qlglgdt','Q_lg_drop','Q_lg_rejuv']

In [None]:
np_sim = len(param_to_simulate)
np_sim

### No dust

In [None]:
p.INIT_PARAMS = p.INIT_PARAMS.at[-1].set(1.)
p.INIT_PARAMS = p.INIT_PARAMS.at[-2].set(0.)
p.INIT_PARAMS = p.INIT_PARAMS.at[-3].set(0.)
p.INIT_PARAMS = p.INIT_PARAMS.at[-4].set(0.)

### Change the range of parameters

In [None]:
FLAG_INCREASE_RANGE_MAH = False

if FLAG_INCREASE_RANGE_MAH:
    # MAH_logtc
    p.PARAMS_MIN = p.PARAMS_MIN.at[1].set(0.01)
    p.PARAMS_MAX = p.PARAMS_MAX.at[1].set(0.15)

    # MAH_early_index
    p.PARAMS_MIN = p.PARAMS_MIN.at[2].set(0.1)
    p.PARAMS_MAX = p.PARAMS_MAX.at[2].set(10.)

    # MAH_late_index
    p.PARAMS_MIN = p.PARAMS_MIN.at[3].set(0.1)
    p.PARAMS_MAX = p.PARAMS_MAX.at[3].set(10.)

In [None]:
FLAG_INCREASE_RANGE_MS = True

if FLAG_INCREASE_RANGE_MS:
    # MS_lgmcrit
    p.PARAMS_MIN = p.PARAMS_MIN.at[4].set(9.)
    p.PARAMS_MAX = p.PARAMS_MAX.at[4].set(16.)

    # MS_lgy_at_mcrit
    p.PARAMS_MIN = p.PARAMS_MIN.at[5].set(-5.)
    p.PARAMS_MAX = p.PARAMS_MAX.at[5].set(-0.1)

    #MS_indx_lo
    p.PARAMS_MIN = p.PARAMS_MIN.at[6].set(0.1)
    p.PARAMS_MAX = p.PARAMS_MAX.at[6].set(5.)

    #MS_indx_hi
    p.PARAMS_MIN = p.PARAMS_MIN.at[7].set(-5.)
    p.PARAMS_MAX = p.PARAMS_MAX.at[7].set(-0.1)

    #MS_tau_dep
    p.PARAMS_MIN = p.PARAMS_MIN.at[8].set(0.1)
    p.PARAMS_MAX = p.PARAMS_MAX.at[8].set(10.)

In [None]:
FLAG_INCREASE_RANGE_Q = True

if FLAG_INCREASE_RANGE_Q:
    #'Q_lg_qt', 1.0),
    p.PARAMS_MIN = p.PARAMS_MIN.at[9].set(0.5)
    p.PARAMS_MAX = p.PARAMS_MAX.at[9].set(3.)
    #('Q_qlglgdt', -0.50725),
    p.PARAMS_MIN = p.PARAMS_MIN.at[10].set(-3.)
    p.PARAMS_MAX = p.PARAMS_MAX.at[10].set(-0.5)            

    # ('Q_lg_drop', -1.01773),
    p.PARAMS_MIN = p.PARAMS_MIN.at[11].set(-3.)
    p.PARAMS_MAX = p.PARAMS_MAX.at[11].set(-0.5)                  

    #('Q_lg_rejuv', -0.212307),
    p.PARAMS_MIN = p.PARAMS_MIN.at[12].set(-5.)
    p.PARAMS_MAX = p.PARAMS_MAX.at[12].set(-0.05)    

### number of simulations

In [None]:
n_gals = 100

In [None]:
all_sim_params = jnp.tile(p.INIT_PARAMS, n_gals)
all_sim_params = all_sim_params.reshape((n_gals, -1))

In [None]:
all_sim_params.shape

### generate random values for each parameters

In [None]:
key, *params_subkeys = random.split(key, num=np_sim+1)

In [None]:
countsim =0
for index,param_name in enumerate(p.PARAM_NAMES_FLAT):
    if param_name in param_to_simulate:
        subkey = params_subkeys[countsim]
        param_simvalues = jax.random.uniform(subkey, shape=(n_gals,), minval=p.PARAMS_MIN[index], maxval=p.PARAMS_MAX[index])
        all_sim_params = all_sim_params.at[:,index].set(param_simvalues)
        countsim+=1
        

## Simulation of spectra

In [None]:
z_obs = 0.5

In [None]:
fig, ax = plt.subplots(1, 1,figsize=(16,6))

for igal in range(n_gals):
    dict_params = paramslist_to_dict(all_sim_params[igal,:],p.PARAM_NAMES_FLAT)
    tarr,sfh_gal = mean_sfr(dict_params,z_obs)
    
    ax.plot(tarr,sfh_gal )

ax.set_yscale('log')
ax.grid()
ax.set_title("DSPS SFR")

In [None]:
fig, ax = plt.subplots(1, 1)

for igal in range(n_gals):
    dict_params = paramslist_to_dict(all_sim_params[igal,:],p.PARAM_NAMES_FLAT)
    wls,spec_rest,spec_rest_att = ssp_spectrum_fromparam(dict_params,z_obs)

    indexes_spec = jnp.where(jnp.logical_and(wls>=1e2,wls<=1e5))[0]
    
    ax.plot(wls[indexes_spec],spec_rest[indexes_spec])
ax.set_xscale('log')
ax.set_yscale('log')
ax.grid()
ax.set_title("DSPS Spectrum $F_\\nu(\\lambda)$")
ax.set_xlabel("$\\lambda (\AA)$") 
ax.set_ylabel("$F_\\nu(\\lambda) - (AB\, per \,Hz)$")

In [None]:
fig, ax = plt.subplots(1, 1)

for igal in range(n_gals):
    dict_params = paramslist_to_dict(all_sim_params[igal,:],p.PARAM_NAMES_FLAT)
    wls,spec_rest,spec_rest_att = ssp_spectrum_fromparam(dict_params,z_obs)

    indexes_spec = jnp.where(jnp.logical_and(wls>=1e2,wls<=1e5))[0]
    
    ax.plot(wls[indexes_spec],3e-2*spec_rest[indexes_spec]/wls[indexes_spec]**2)
ax.set_xscale('log')
ax.set_yscale('log')
ax.grid()
ax.set_title("DSPS Spectrum $F_\\lambda(\\lambda)$")
ax.set_xlabel("$\\lambda (\AA)$") 
ax.set_ylabel("$F_\\lambda(\\lambda) - (AB\, per \,\AA)$") 