# plot Spectra from the SSP Parameters

Generate simulation data to study the correlation between the parameters using `fors2tostellarpopsynthesis`package

- Author Sylvie Dagoret-Campagne
- Afflilation : IJCLab/IN2P3/CNRS
- Organisation : LSST-DESC
- creation date : 2023-12-05
- last update : 2023-12-08


| computer | processor | kernel              |    date     |
| --- | --- | --- | --- |
| CC       | CPU       | conda_jax0325_py310 | 2023-11-10  |
| macbookpro | CPU | conda_jaxcpu_dsps_py310 | 2023-11-10  | 


libraries 
=========

jax
---

- jaxlib-0.3.25+cuda11.cudnn82
- jaxopt
- optax
- corner
- arviz
- numpyro
- graphviz

sps
---

- fsps
- prospect
- dsps
- diffstar
- diffmah



(conda_jax0325_py310) 
`/pbs/throng/lsst/users/dagoret/desc/StellarPopulationSynthesis>pip list | grep` 

| lib | version |
|--- | --- | 
|jax  |                         0.4.20 |
|jaxlib |                       0.4.20 |
|jaxopt  |                      0.8.2 |



## examples

- jaxcosmo : https://github.com/DifferentiableUniverseInitiative/jax-cosmo-paper/blob/master/notebooks/VectorizedNumPyro.ipynb
- on atmosphere : https://github.com/sylvielsstfr/FitDiffAtmo/blob/main/docs/notebooks/fitdiffatmo/test_numpyro_orderedict_diffatmemul_5params_P_pwv_oz_tau_beta.ipynb

## Import

### import external packages

In [None]:
import h5py
import pandas as pd
import numpy as np
import os
import re
import pickle 
import matplotlib as mpl
import matplotlib.pyplot as plt
%matplotlib inline
import matplotlib.colors as colors
import matplotlib.cm as cmx
#import seaborn as sns
#sns.set_theme(style='white')
import collections
from collections import OrderedDict
import re
import matplotlib.gridspec as gridspec
from sklearn.gaussian_process import GaussianProcessRegressor, kernels
import itertools

In [None]:
import jax
import jax.numpy as jnp
from jax import vmap
import jaxopt
import optax
jax.config.update("jax_enable_x64", True)
from interpax import interp1d

from jax.lax import fori_loop
from jax.lax import select,cond
from jax.lax import concatenate

In [None]:
import numpyro
from numpyro import optim
from numpyro.diagnostics import print_summary
import numpyro.distributions as dist
from numpyro.distributions import constraints
from numpyro.infer import MCMC, HMC, NUTS, SVI, Trace_ELBO
from numpyro.infer.autoguide import AutoBNAFNormal, AutoMultivariateNormal
from numpyro.infer.reparam import NeuTraReparam
from numpyro.handlers import seed, trace, condition


import corner
import arviz as az

In [None]:
import pickle

### import internal packages

In [None]:
from fors2tostellarpopsynthesis.parameters  import SSPParametersFit,paramslist_to_dict

In [None]:
from fors2tostellarpopsynthesis.fitters.fitter_jaxopt import (SSP_DATA,mean_spectrum,mean_mags,mean_sfr,ssp_spectrum_fromparam)

In [None]:
from fors2tostellarpopsynthesis.fitters.fitter_numpyro import(PARAM_SIMLAW_NODUST,PARAM_SIMLAW_WITHDUST,
                            PARAM_NAMES,PARAM_VAL,PARAM_MIN,PARAM_MAX,PARAM_SIGMA)

from fors2tostellarpopsynthesis.fitters.fitter_numpyro import(galaxymodel_nodust_av,galaxymodel_nodust,galaxymodel_withdust_av,galaxymodel_withdust)

In [None]:
from fors2tostellarpopsynthesis.fitters.fitter_util import plot_params_kde,calc_ratio

## Configuration

In [None]:
Lyman_lines = [1220., 1030. ,973.,950., 938., 930.]
Balmer_lines = [6562.791,4861.351,4340.4721,4101.740,3970.072,3889.0641,3835.3971]
Paschen_lines = [8750., 12820., 10938.0,10050., 9546.2, 9229.7,9015.3, 8862.89,8750.46,8665.02]
Brackett_lines = [40522.79, 26258.71, 21661.178, 19440., 18179.21]
Pfund_lines = [ 74599.0, 46537.8, 37405.76 , 32969.8, 30400.]
all_Hydrogen_lines = [ Lyman_lines, Balmer_lines, Paschen_lines, Brackett_lines, Pfund_lines]
Color_lines = ["purple", "blue", "green", "red","grey"]
Balmer_thres = 3645.6
Lyman_thres = 911.267
Paschen_thres = 8200.
Brackett_thres = 14580.
Pfund_lines = 22800.
all_Hydrogen_thres = [Lyman_thres , Balmer_thres, Paschen_thres, Brackett_thres, Pfund_lines]

- D4000
Hereafter the 4000 angstrom break is defined as the ratio between the average flux density in ergs−1cm−2Hz−1 between 4050 and 4250 angstrom and that between 3750 and 3950 angstrom (Bruzual 1983).

In [None]:
D4000_red = [4050.,4250] 
D4000_blue = [3750.,3950.]

W_BALMER = [Balmer_thres, Balmer_lines[0]]
W_LYMAN = [Lyman_thres, Lyman_lines[0]]

In [None]:
def plot_hydrogen_lines(ax):
    nth = len(all_Hydrogen_thres)
    for idx,group_lines in enumerate(all_Hydrogen_lines):
        color = Color_lines[idx]
        for wl_line in group_lines:
            ax.axvline(wl_line,color=color)
        if idx< nth:
            ax.axvline(all_Hydrogen_thres[idx],color=color,linestyle=":")
    ax.axvspan(W_LYMAN[0],W_LYMAN[1],facecolor='blue', alpha=0.5)
    ax.axvspan(W_BALMER[0],W_BALMER[1],facecolor='red', alpha=0.5)

In [None]:
#sns.color_palette("hls", 100)

### matplotlib configuration

In [None]:
plt.rcParams["figure.figsize"] = (12,12)
plt.rcParams["axes.labelsize"] = 'xx-large'
plt.rcParams['axes.titlesize'] = 'xx-large'
plt.rcParams['xtick.labelsize']= 'xx-large'
plt.rcParams['ytick.labelsize']= 'xx-large'
plt.rcParams['legend.fontsize']=  10
plt.rcParams['font.size'] = 10

## Fit parameters

In [None]:
p = SSPParametersFit()

In [None]:
PARAM_NAMES = p.PARAM_NAMES_FLAT
PARAM_NAMES

In [None]:
p.INIT_PARAMS = p.INIT_PARAMS.at[-4].set(0.)
p.INIT_PARAMS = p.INIT_PARAMS.at[-3].set(0.)
p.INIT_PARAMS = p.INIT_PARAMS.at[-2].set(0.)
p.INIT_PARAMS = p.INIT_PARAMS.at[-1].set(1.)

In [None]:
wlsall,spec_rest,spec_rest_att = ssp_spectrum_fromparam(p.DICT_PARAMS_true,0)

## Read MCMC without dust

In [None]:
PARAM_NAMES

In [None]:
PARAM_NODUST = ['MAH_lgmO', 'MAH_logtc', 'MAH_early_index', 'MAH_late_index',
       'MS_lgmcrit', 'MS_lgy_at_mcrit', 'MS_indx_lo', 'MS_indx_hi',
       'MS_tau_dep', 'Q_lg_qt', 'Q_qlglgdt', 'Q_lg_drop', 'Q_lg_rejuv']

In [None]:
#PARAM_NODUST_DICT = OrderedDict()
#for name in PARAM_NODUST:
#    PARAM_NODUST_DICT[name] = f"{name}"   

### Output file for MCMC without dust

In [None]:
filein_pickle = f"datamcmcparams/DSPS_nodust_mcmc_params_wide.pickle"
filein_csv = f"datamcmcparams/DSPS_nodust_mcmc_params_wide.csv"
filein_hdf = f"datamcmcparams/DSPS_nodust_mcmc_params_wide.hdf"

In [None]:
with h5py.File(filein_hdf, 'r') as f:
    keys = list(f.keys())
    print(keys)

In [None]:
df_info = pd.read_hdf(filein_hdf,key="info")
df = pd.read_hdf(filein_hdf,key="dsps_mcmc_nodust")

In [None]:
z_obs = df_info['z_obs']

In [None]:
df

In [None]:
all_paramdicts = []
for idx,row in df.iterrows():
    values = [ row[name] for name in PARAM_NODUST ]
    values.append(0.) # AV
    values.append(0.) # UVBUMP
    values.append(0.) # PLAW
    values.append(1.) # SCALEF
    d = paramslist_to_dict(values,PARAM_NAMES)
    all_paramdicts.append(d)

In [None]:
len(all_paramdicts)

In [None]:
#sns.color_palette("hls", 100)
#sns.set(color_codes=True)
#sns.set_style("whitegrid")
fig = plt.figure(figsize=(10,5))

ax = fig.add_subplot()

ax.set_xscale('log')
ax.set_yscale('log')
all_D4000 = []
all_DBalmerLyman = []
for idx,d in enumerate(all_paramdicts):
    if idx%20 == 0:
        wlsall,spec_rest,spec_rest_att = ssp_spectrum_fromparam(d,z_obs)
        d4000 = calc_ratio(wlsall,spec_rest)
        dBL = calc_ratio(wlsall,spec_rest,W_LYMAN ,W_BALMER )
        all_D4000.append(d4000) 
        all_DBalmerLyman.append(dBL)
        ax.plot(wlsall,spec_rest,alpha=0.5)
ax.set_ylabel("DSPS SED $F_\\nu$")
ax.set_xlim(1e2,1e6)
plot_hydrogen_lines(ax)

In [None]:
fig = plt.figure(figsize=(16,5))
axs = fig.subplots(1,3)

for idx, ax in enumerate(np.ravel(axs)):
    if idx==0:
        ax.hist(all_D4000,bins=50,facecolor="b")
        ax.set_xlabel("D4000")
    elif idx==1:
        ax.hist(np.log10(all_DBalmerLyman),bins=50,facecolor="b")
        ax.set_xlabel("Balmer/Lyman")
    elif idx==2:
        ax.scatter(all_D4000,np.log10(all_DBalmerLyman),marker='o',alpha=0.5,facecolor="b")
        ax.set_xlabel("D4000")
        ax.set_ylabel("Balmer/Lyman")
#ax.set_title("D4000")
#ax.set_xlabel("D4000")
plt.suptitle("Indicator for red color")
plt.tight_layout()

In [None]:
#sns.color_palette("hls", 100)
#sns.set(color_codes=True)
#sns.set_style("whitegrid")
fig = plt.figure(figsize=(10,5))

ax = fig.add_subplot()

ax.set_xscale('log')
ax.set_yscale('log')

for idx,d in enumerate(all_paramdicts):
    if idx%50 == 0:
        wlsall,spec_rest,spec_rest_att = ssp_spectrum_fromparam(d,z_obs)
        ax.plot(wlsall,1e-2*spec_rest/wlsall**2,alpha=0.5)
ax.set_ylabel("DSPS SED $F_\\lambda$")
ax.set_xlim(1e2,1e5)
ax.set_ylim(1e-20,1e-14)
plot_hydrogen_lines(ax)

In [None]:
xmin = np.log10(100.)
xmax = np.log10(1e5)

#ymin = np.log10(y.min())
ymin = -13.
#ymax = np.log10(y.max())
ymax = -5.

xbins = np.logspace(xmin, xmax, 200) # <- make a range from 10**xmin to 10**xmax
ybins = np.logspace(ymin, ymax, 200) # <- make a range from 10**ymin to 10**ymax



all_wlsall = jnp.array(np.empty(shape = 0))
all_spec = jnp.array(np.empty(shape = 0))
for idx,d in enumerate(all_paramdicts):
    if idx%10 == 0:
        wlsall,spec_rest,spec_rest_att = ssp_spectrum_fromparam(d,z_obs)
        yinterp = interp1d(xbins,wlsall,spec_rest)
        
        all_wlsall = jnp.concatenate((all_wlsall,xbins))
        all_spec = jnp.concatenate((all_spec, yinterp))


In [None]:
x = all_wlsall
y  = all_spec

fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(10, 5))
ax.scatter(x, y,color='r',alpha=0.01)

ax.set_xscale("log")               # <- Activate log scale on X axis
ax.set_yscale("log")               # <- Activate log scale on Y axis

ax.set_xlim(xmin=xbins[0])
ax.set_xlim(xmax=xbins[-1])
ax.set_ylim(ymin=ybins[0])
ax.set_ylim(ymax=ybins[-1])
plot_hydrogen_lines(ax)

In [None]:
x = all_wlsall
y = all_spec

counts, _, _ = np.histogram2d(x, y, bins=(xbins, ybins))

fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(10, 5))

pcm = ax.pcolormesh(xbins, ybins, np.log10(counts),cmap='RdBu_r')
plt.colorbar(pcm)

ax.set_xscale("log")               # <- Activate log scale on X axis
ax.set_yscale("log")               # <- Activate log scale on Y axis

ax.set_xlim(xmin=xbins[0])
ax.set_xlim(xmax=xbins[-1])
ax.set_ylim(ymin=ybins[0])
ax.set_ylim(ymax=ybins[-1])
plot_hydrogen_lines(ax)