# Generate quickly a large number of parameters thus spectra, compoting D4000 and plot correlation of D4000 with varying parameter



- Author Sylvie Dagoret-Campagne
- Afflilation : IJCLab/IN2P3/CNRS
- Organisation : LSST-DESC
- creation date : 2023-12-09
- last update : 2023-12-10


| computer | processor | kernel              |    date     |
| --- | --- | --- | --- |
| CC       | CPU       | conda_jax0325_py310 | 2023-11-10  |
| macbookpro | CPU | conda_jaxcpu_dsps_py310 | 2023-11-10  | 


libraries 
=========

jax
---

- jaxlib-0.3.25+cuda11.cudnn82
- jaxopt
- optax
- corner
- arviz
- numpyro
- graphviz

sps
---

- fsps
- prospect
- dsps
- diffstar
- diffmah



(conda_jax0325_py310) 
`/pbs/throng/lsst/users/dagoret/desc/StellarPopulationSynthesis>pip list | grep` 

| lib | version |
|--- | --- | 
|jax  |                         0.4.20 |
|jaxlib |                       0.4.20 |
|jaxopt  |                      0.8.2 |



## examples

- jaxcosmo : https://github.com/DifferentiableUniverseInitiative/jax-cosmo-paper/blob/master/notebooks/VectorizedNumPyro.ipynb
- on atmosphere : https://github.com/sylvielsstfr/FitDiffAtmo/blob/main/docs/notebooks/fitdiffatmo/test_numpyro_orderedict_diffatmemul_5params_P_pwv_oz_tau_beta.ipynb

## Import

### import external packages

In [None]:
import h5py
import pandas as pd
import numpy as np
import os
import re
import pickle 
import matplotlib as mpl
import matplotlib.pyplot as plt
%matplotlib inline
import matplotlib.colors as colors
import matplotlib.cm as cmx

from mpl_toolkits.axes_grid1.inset_locator import inset_axes, zoomed_inset_axes

#import seaborn as sns
#sns.set_theme(style='white')
import collections
from collections import OrderedDict
import re
import matplotlib.gridspec as gridspec
from sklearn.gaussian_process import GaussianProcessRegressor, kernels
import itertools

In [None]:
import jax
import jax.numpy as jnp
from jax import vmap
import jaxopt
import optax
jax.config.update("jax_enable_x64", True)
from interpax import interp1d

from jax.lax import fori_loop
from jax.lax import select,cond
from jax.lax import concatenate

In [None]:
import numpyro
from numpyro import optim
from numpyro.diagnostics import print_summary
import numpyro.distributions as dist
from numpyro.distributions import constraints
from numpyro.infer import MCMC, HMC, NUTS, SVI, Trace_ELBO
from numpyro.infer.autoguide import AutoBNAFNormal, AutoMultivariateNormal
from numpyro.infer.reparam import NeuTraReparam
from numpyro.handlers import seed, trace, condition

import corner
import arviz as az
import arviz.labels as azl

In [None]:
from dsps.cosmology import age_at_z, DEFAULT_COSMOLOGY

### import internal packages

In [None]:
from fors2tostellarpopsynthesis.parameters  import SSPParametersFit,paramslist_to_dict

In [None]:
from fors2tostellarpopsynthesis.fitters.fitter_jaxopt import (SSP_DATA,mean_spectrum,mean_mags,mean_sfr,ssp_spectrum_fromparam)

In [None]:
from fors2tostellarpopsynthesis.fitters.fitter_numpyro import(PARAM_SIMLAW_NODUST,PARAM_SIMLAW_WITHDUST,
                            PARAM_NAMES,PARAM_VAL,PARAM_MIN,PARAM_MAX,PARAM_SIGMA)

from fors2tostellarpopsynthesis.fitters.fitter_numpyro import(galaxymodel_nodust_av,galaxymodel_nodust,galaxymodel_withdust_av,galaxymodel_withdust)

In [None]:
from fors2tostellarpopsynthesis.fitters.fitter_util import plot_params_kde,calc_ratio

In [None]:
from fors2tostellarpopsynthesis.fors2starlightio import flux_norm

## Configuration

In [None]:
Lyman_lines = [1220., 1030. ,973.,950., 938., 930.]
Balmer_lines = [6562.791,4861.351,4340.4721,4101.740,3970.072,3889.0641,3835.3971]
Paschen_lines = [8750., 12820., 10938.0,10050., 9546.2, 9229.7,9015.3, 8862.89,8750.46,8665.02]
Brackett_lines = [40522.79, 26258.71, 21661.178, 19440., 18179.21]
Pfund_lines = [ 74599.0, 46537.8, 37405.76 , 32969.8, 30400.]
all_Hydrogen_lines = [ Lyman_lines, Balmer_lines, Paschen_lines, Brackett_lines, Pfund_lines]
Color_lines = ["purple", "blue", "green", "red","grey"]
Balmer_thres = 3645.6
Lyman_thres = 911.267
Paschen_thres = 8200.
Brackett_thres = 14580.
Pfund_lines = 22800.
all_Hydrogen_thres = [Lyman_thres , Balmer_thres, Paschen_thres, Brackett_thres, Pfund_lines]

- D4000
Hereafter the 4000 angstrom break is defined as the ratio between the average flux density in ergs−1cm−2Hz−1 between 4050 and 4250 angstrom and that between 3750 and 3950 angstrom (Bruzual 1983).

In [None]:
D4000_red = [4050.,4250] 
D4000_blue = [3750.,3950.]

W_BALMER = [Balmer_thres, Balmer_lines[0]]
W_LYMAN = [Lyman_thres, Lyman_lines[0]]

In [None]:
def plot_hydrogen_lines(ax):
    nth = len(all_Hydrogen_thres)
    for idx,group_lines in enumerate(all_Hydrogen_lines):
        # select on Lyman and Balmer
        if idx<2:
            color = Color_lines[idx]
            for wl_line in group_lines:
                ax.axvline(wl_line,color=color)
            if idx< nth:
                ax.axvline(all_Hydrogen_thres[idx],color=color,linestyle=":")
    ax.axvspan(W_LYMAN[0],W_LYMAN[1],facecolor='green', alpha=0.2)
    ax.axvspan(W_BALMER[0],W_BALMER[1],facecolor='yellow', alpha=0.2)

In [None]:
wl0 = 3645.6

In [None]:
#sns.color_palette("hls", 100)

### matplotlib configuration

In [None]:
plt.rcParams["figure.figsize"] = (12,12)
plt.rcParams["axes.labelsize"] = 'xx-large'
plt.rcParams['axes.titlesize'] = 'xx-large'
plt.rcParams['xtick.labelsize']= 'xx-large'
plt.rcParams['ytick.labelsize']= 'xx-large'
plt.rcParams['legend.fontsize']=  8
plt.rcParams['font.size'] = 8

## Fit parameters

In [None]:
p = SSPParametersFit()

In [None]:
PARAM_NAMES = p.PARAM_NAMES_FLAT

In [None]:
Z_MIN = 0.1
Z_MAX = 5.0
NZ = 100
all_redshifts = np.logspace(np.log10(Z_MIN),np.log10(Z_MAX),NZ)
NZ = len(all_redshifts)

In [None]:
bwr_map = plt.get_cmap('bwr')
reversed_map = bwr_map.reversed() 
#cNorm = colors.Normalize(vmin=np.log10(all_fractions).min(), vmax=np.log10(all_fractions).max())
#scalarMap = cmx.ScalarMappable(norm=cNorm, cmap=reversed_map)
#all_colors = scalarMap.to_rgba(np.log10(all_fractions), alpha=1)
cNorm = colors.Normalize(vmin=np.log10(Z_MIN), vmax=np.log10(Z_MAX))
scalarMap = cmx.ScalarMappable(norm=cNorm, cmap=bwr_map)
all_colors = scalarMap.to_rgba(np.log10(all_redshifts), alpha=1)

In [None]:
#bwr_map

In [None]:
D4000MIN = 1.0
D4000MAX = 2.0
BLMIN = 0.8
BLMAX = 4.

In [None]:
fig = plt.figure(figsize=(10,5))

ax = fig.add_subplot()

ax.set_xscale('log')
ax.set_yscale('log')

all_d4000 = np.zeros(NZ)
all_dBL = np.zeros(NZ)
for idx,z_obs in enumerate(all_redshifts): 
    wlsall,spec_rest,spec_rest_att = ssp_spectrum_fromparam(p.DICT_PARAMS_true,z_obs)
    

    #norm_y_nodust = flux_norm(x,y_nodust,wlcenter=wl0)
    #norm_y_dust = flux_norm(x,y_dust,wlcenter=wl0)

    all_d4000[idx] = calc_ratio(wlsall,spec_rest)
    all_dBL[idx] = calc_ratio(wlsall,spec_rest,W_LYMAN ,W_BALMER )  
    label = f"z={z_obs:.2f}"
    ax.plot(wlsall,spec_rest,alpha=0.5,lw=2,color=all_colors[idx],label=label)
ax.set_ylabel("DSPS SED $F_\\nu$")
ax.set_xlabel("$\\lambda (\AA)$")
ax.set_xlim(1e2,1e5)
plot_hydrogen_lines(ax)
title = f"$F_\\nu(\lambda)$ by varying redshift"
ax.set_title(title)
#ax.legend(loc="upper right")
cbar=fig.colorbar(scalarMap , ax=ax)
cbar.ax.set_ylabel('$\log_{10} (z)$')

left, bottom, width, height = [0.16, 0.66, 0.15, 0.2]
ax2 = fig.add_axes([left, bottom, width, height])
ax2.hist(all_d4000,bins=20,range=(D4000MIN,D4000MAX),facecolor="g",alpha=0.2)
ax2.set_xlabel("D4000",fontsize=8)

left, bottom, width, height = [0.16, 0.2, 0.15, 0.2]
ax3 = fig.add_axes([left, bottom, width, height])
ax3.hist(np.log10(all_dBL),bins=20,range=(BLMIN,BLMAX),facecolor="g",alpha=0.2)
ax3.set_xlabel("Balmer/Lyman",fontsize=8)

In [None]:
fig = plt.figure(figsize=(10,5))

ax = fig.add_subplot()

ax.set_xscale('log')
ax.set_yscale('log')

inset_ax = inset_axes(ax,
                    width="30%", # width = 30% of parent_bbox
                    height="30%", # height : 1 inch
                    loc=4,borderpad=3)
inset_ax.set_title("Star Formation History",fontsize=10)
inset_ax.set_xlabel(r'${\rm cosmic\ time\ [Gyr]}$',fontsize=8)
inset_ax.set_ylabel(r'${\rm SFR\ [M_{\odot}/yr]}$',fontsize=8)

all_d4000 = np.zeros(NZ)
all_dBL = np.zeros(NZ)
for idx,z_obs in enumerate(all_redshifts): 
    wlsall,spec_rest,spec_rest_att = ssp_spectrum_fromparam(p.DICT_PARAMS_true,z_obs)
    tarr,sfh_gal = mean_sfr(p.DICT_PARAMS_true,z_obs)

    x = wlsall
    y_nodust = spec_rest
    y_dust = spec_rest_att 
    norm_y_nodust = flux_norm(x,y_nodust,wlcenter=wl0)
    norm_y_dust = flux_norm(x,y_dust,wlcenter=wl0)

    all_d4000[idx] = calc_ratio(wlsall,spec_rest)
    all_dBL[idx] = calc_ratio(wlsall,spec_rest,W_LYMAN ,W_BALMER )  

    y_nodust/= norm_y_nodust
    y_dust/= norm_y_dust
    
    label = f"z={z_obs:.2f}"
    ax.plot(x,y_nodust,alpha=0.5,lw=2,color=all_colors[idx],label=label)
    inset_ax.plot(tarr,sfh_gal,color=all_colors[idx]) 


ax.set_ylabel("DSPS SED $F_\\nu$")
ax.set_xlabel("$\\lambda (\AA)$")
ax.set_xlim(1e2,1e5)
plot_hydrogen_lines(ax)
title = f"Recaled $F_\\nu(\lambda)$ by varying redshift"
ax.set_title(title)
#ax.legend(loc="upper right")
cbar=fig.colorbar(scalarMap , ax=ax)
cbar.ax.set_ylabel('$\log_{10} (z)$')

left, bottom, width, height = [0.16, 0.66, 0.15, 0.2]
ax2 = fig.add_axes([left, bottom, width, height])
ax2.hist(all_d4000,bins=20,range=(D4000MIN,D4000MAX),facecolor="g",alpha=0.2)
ax2.set_xlabel("D4000",fontsize=8)

#left, bottom, width, height = [0.16, 0.15, 0.15, 0.2]
#ax3 = fig.add_axes([left, bottom, width, height])
#ax3.hist(np.log10(all_dBL),bins=20,range=(BLMIN,BLMAX),facecolor="g",alpha=0.2)
#ax3.set_xlabel("Balmer/Lyman",fontsize=8)

# boundaries zmin/zmax
zmax = 10.
wlsall,spec_rest,spec_rest_att = ssp_spectrum_fromparam(p.DICT_PARAMS_true,zmax)
d4000_zmax = calc_ratio(wlsall,spec_rest)
dBL_zmax = calc_ratio(wlsall,spec_rest,W_LYMAN ,W_BALMER )  
x = wlsall
y_nodust = spec_rest
y_dust = spec_rest_att 

norm_y_nodust = flux_norm(x,y_nodust,wlcenter=wl0)
norm_y_dust = flux_norm(x,y_dust,wlcenter=wl0)

y_nodust/= norm_y_nodust
y_dust/= norm_y_dust
label = f"z={zmax}"
#ax.plot(x,y_nodust,alpha=0.5,lw=2,color="k",label=label)

# boundaries next
zmin = 0.
wlsall,spec_rest,spec_rest_att = ssp_spectrum_fromparam(p.DICT_PARAMS_true,zmin)
d4000_zmin = calc_ratio(wlsall,spec_rest)
dBL_zmin = calc_ratio(wlsall,spec_rest,W_LYMAN ,W_BALMER )  
x = wlsall
y_nodust = spec_rest
y_dust = spec_rest_att 

norm_y_nodust = flux_norm(x,y_nodust,wlcenter=wl0)
norm_y_dust = flux_norm(x,y_dust,wlcenter=wl0)
y_nodust/= norm_y_nodust
y_dust/= norm_y_dust
label = f"z={zmin}"
#ax.plot(x,y_nodust,alpha=0.5,lw=2,color="k",label=label)




In [None]:
#_, (ax1,ax2) = plt.subplots(2, 1,figsize=(10,10))

In [None]:
fig = plt.figure(figsize=(10,4))

ax = fig.add_subplot(121)
ax.scatter(all_redshifts,all_d4000,color=all_colors)
ax.set_xlabel("redshift")
ax.set_ylabel("D4000")
ax.grid()

ax = fig.add_subplot(122)
ax.scatter(all_redshifts,np.log10(all_dBL),color=all_colors)
ax.set_xlabel("redshift")
ax.set_ylabel("log10(Balmer/Lyman)")
ax.grid()
plt.suptitle("dependence of color D4000 wrt redshift",fontsize=16)

In [None]:
D4000MIN = 1.0
D4000MAX = 2.0
BLMIN = 0.8
BLMAX = 4.

### Set No Dust

In [None]:
p.INIT_PARAMS = p.INIT_PARAMS.at[-4].set(0.)
p.INIT_PARAMS = p.INIT_PARAMS.at[-3].set(0.)
p.INIT_PARAMS = p.INIT_PARAMS.at[-2].set(0.)
p.INIT_PARAMS = p.INIT_PARAMS.at[-1].set(1.)

In [None]:
# increase the range of the DSPS parameters
FLAG_INCREASE_RANGE_MAH = True
if FLAG_INCREASE_RANGE_MAH:
    # MAH_lgmO
    p.PARAMS_MIN = p.PARAMS_MIN.at[0].set(8)
    p.PARAMS_MAX = p.PARAMS_MAX.at[1].set(15)
    # MAH_logtc
    p.PARAMS_MIN = p.PARAMS_MIN.at[1].set(0.01)
    p.PARAMS_MAX = p.PARAMS_MAX.at[1].set(0.15)
    # MAH_early_index
    p.PARAMS_MIN = p.PARAMS_MIN.at[2].set(0.1)
    p.PARAMS_MAX = p.PARAMS_MAX.at[2].set(10.)
    # MAH_late_index
    p.PARAMS_MIN = p.PARAMS_MIN.at[3].set(0.1)
    p.PARAMS_MAX = p.PARAMS_MAX.at[3].set(10.)

FLAG_INCREASE_RANGE_MS = True
if FLAG_INCREASE_RANGE_MS:
    # MS_lgmcrit  12
    p.PARAMS_MIN = p.PARAMS_MIN.at[4].set(9.)
    p.PARAMS_MAX = p.PARAMS_MAX.at[4].set(13.)
    # MS_lgy_at_mcrit : -1
    p.PARAMS_MIN = p.PARAMS_MIN.at[5].set(-2.)
    p.PARAMS_MAX = p.PARAMS_MAX.at[5].set(-0.7)
    #MS_indx_lo : 1
    p.PARAMS_MIN = p.PARAMS_MIN.at[6].set(0.7)
    p.PARAMS_MAX = p.PARAMS_MAX.at[6].set(2.)
    #MS_indx_hi : -1
    p.PARAMS_MIN = p.PARAMS_MIN.at[7].set(-2.)
    p.PARAMS_MAX = p.PARAMS_MAX.at[7].set(-0.7)
    #MS_tau_dep : 2
    p.PARAMS_MIN = p.PARAMS_MIN.at[8].set(0.7)
    p.PARAMS_MAX = p.PARAMS_MAX.at[8].set(3.)

FLAG_INCREASE_RANGE_Q = True
if FLAG_INCREASE_RANGE_Q:
    #'Q_lg_qt', 1.0),
    p.PARAMS_MIN = p.PARAMS_MIN.at[9].set(0.5)
    p.PARAMS_MAX = p.PARAMS_MAX.at[9].set(2.)
    #('Q_qlglgdt', -0.50725),
    p.PARAMS_MIN = p.PARAMS_MIN.at[10].set(-2.)
    p.PARAMS_MAX = p.PARAMS_MAX.at[10].set(-0.2)
    # ('Q_lg_drop', -1.01773),
    p.PARAMS_MIN = p.PARAMS_MIN.at[11].set(-2.)
    p.PARAMS_MAX = p.PARAMS_MAX.at[11].set(-0.5)
    #('Q_lg_rejuv', -0.212307),
    p.PARAMS_MIN = p.PARAMS_MIN.at[12].set(-2.)
    p.PARAMS_MAX = p.PARAMS_MAX.at[12].set(-0.1)


In [None]:
p.DICT_PARAMS_true

In [None]:
wlsall,spec_rest,spec_rest_att = ssp_spectrum_fromparam(p.DICT_PARAMS_true,0)

## Read MCMC without dust

In [None]:
PARAM_NAMES

In [None]:
PARAM_NODUST = np.array(['MAH_lgmO', 'MAH_logtc', 'MAH_early_index', 'MAH_late_index',
       'MS_lgmcrit', 'MS_lgy_at_mcrit', 'MS_indx_lo', 'MS_indx_hi',
       'MS_tau_dep', 'Q_lg_qt', 'Q_qlglgdt', 'Q_lg_drop', 'Q_lg_rejuv'])

In [None]:
PARAM_NODUST_DICT_NAMES = OrderedDict()
for name in PARAM_NODUST:
    PARAM_NODUST_DICT_NAMES[name] = f"{name}"   

## Choose a parameter

In [None]:
Z_OBS = 0.1
T_OBS = age_at_z(Z_OBS, *DEFAULT_COSMOLOGY) # age of the universe in Gyr at z_obs, we cant see for t> t_obs
T_OBS = T_OBS[0]

In [None]:
selected_param_name = 'MAH_lgmO'

In [None]:
NSIM = 500

In [None]:
selected_param_idx = np.where(PARAM_NODUST == selected_param_name)[0][0]

In [None]:
key = jax.random.PRNGKey(42)

In [None]:
new_key, subkey = jax.random.split(key)

### Simulate the selected parameter values

In [None]:
simulated_val = jax.random.uniform(subkey,shape=(NSIM,),minval = p.PARAMS_MIN[selected_param_idx],maxval = p.PARAMS_MAX[selected_param_idx])

### Extract other parameter values and stack them vertically

In [None]:
other_param_val_left = p.INIT_PARAMS[:selected_param_idx]
other_param_val_right = p.INIT_PARAMS[selected_param_idx+1:]
other_param_val = jnp.hstack((other_param_val_left,other_param_val_right))
other_param_val_arr = jnp.tile(other_param_val,NSIM).reshape(NSIM,-1)

### Insert a new column with the simulated values

In [None]:
#jax.numpy.hstack((simulated_val,other_param_val_arr))
new_param_sim_arr = jax.numpy.insert(other_param_val_arr, selected_param_idx , simulated_val, axis = 1)

In [None]:
#jax.tree_map(lambda x: x*2, list_of_lists)) 

In [None]:
# Working and elegant, need absolutely a list of ...
list_of_dicts_params = jax.tree_map( lambda x: paramslist_to_dict(x,PARAM_NAMES),list(new_param_sim_arr))

In [None]:
# comprehension working but not elegant
#list_of_dicts_params = [ paramslist_to_dict(new_param_sim_arr[idx,:],PARAM_NAMES) for idx in  range(new_param_sim_arr.shape[0])]

In [None]:
wlsall,spec_rest,spec_rest_att = ssp_spectrum_fromparam(list_of_dicts_params[1],Z_OBS)

In [None]:
bwr_map = plt.get_cmap('bwr')
reversed_map = bwr_map.reversed() 
#cNorm = colors.Normalize(vmin=np.log10(all_fractions).min(), vmax=np.log10(all_fractions).max())
#scalarMap = cmx.ScalarMappable(norm=cNorm, cmap=reversed_map)
#all_colors = scalarMap.to_rgba(np.log10(all_fractions), alpha=1)
cNorm = colors.Normalize(vmin=D4000MIN, vmax=D4000MAX)
scalarMap = cmx.ScalarMappable(norm=cNorm, cmap=bwr_map)
#all_colors = scalarMap.to_rgba(np.log10(all_redshifts), alpha=1)

## Plot renormalized Spectra and SFR

In [None]:
fig = plt.figure(figsize=(10,5))

ax = fig.add_subplot()

ax.set_xscale('log')
ax.set_yscale('log')

left, bottom, width, height = [0.45, 0.2, 0.3, 0.3]
ax4 = fig.add_axes([left, bottom, width, height])
ax4.set_title("Star Formation History",fontsize=10)
ax4.set_xlabel(r'${\rm cosmic\ time\ [Gyr]}$',fontsize=8)
ax4.set_ylabel(r'${\rm SFR\ [M_{\odot}/yr]}$',fontsize=8)
ax4.axvspan(T_OBS,13.8,color="grey",alpha=0.5)


all_d4000 = np.zeros(NSIM)
all_dBL = np.zeros(NSIM)
for idx in range(NSIM):
    wlsall,spec_rest,spec_rest_att = ssp_spectrum_fromparam(list_of_dicts_params[idx],Z_OBS)
    tarr,sfh_gal = mean_sfr(list_of_dicts_params[idx],Z_OBS)
    
    d4000 = calc_ratio(wlsall,spec_rest)
    all_d4000[idx] = d4000
    dBL = calc_ratio(wlsall,spec_rest,W_LYMAN ,W_BALMER ) 
    all_dBL[idx] =  dBL

    x = wlsall
    y_nodust = spec_rest
    y_dust = spec_rest_att 
    norm_y_nodust = flux_norm(x,y_nodust,wlcenter=wl0)
    norm_y_dust = flux_norm(x,y_dust,wlcenter=wl0)
    y_nodust/= norm_y_nodust
    y_dust/= norm_y_dust
    
    col = scalarMap.to_rgba(d4000, alpha=1)
    
    ax.plot(x,y_nodust,alpha=0.5,lw=2,color=col)
    ax4.plot(tarr,sfh_gal,color=col)

cbar=fig.colorbar(scalarMap , ax=ax)
cbar.ax.set_ylabel('D4000')

ax.set_ylabel("DSPS SED $F_\\nu$")
ax.set_xlabel("$\\lambda (\AA)$")
ax.set_xlim(1e2,1e5)
plot_hydrogen_lines(ax)
title = f"$F_\\nu(\lambda)$ by varying parameter {selected_param_name} at redshift {Z_OBS:.2f}"
ax.set_title(title)

left, bottom, width, height = [0.15, 0.66, 0.15, 0.2]
ax2 = fig.add_axes([left, bottom, width, height])
ax2.hist(all_d4000,bins=20,range=(D4000MIN,D4000MAX),facecolor="g",alpha=0.2)
ax2.set_xlabel("D4000",fontsize=8)

left, bottom, width, height = [0.15, 0.2, 0.15, 0.2]
ax3 = fig.add_axes([left, bottom, width, height])
ax3.hist(np.log10(all_dBL),bins=20,range=(BLMIN,BLMAX),facecolor="g",alpha=0.2)
ax3.set_xlabel("Balmer/Lyman",fontsize=8)

## Loop on parameters

In [None]:
Z_OBS = 0.1
NSIM = 100
T_OBS = age_at_z(Z_OBS, *DEFAULT_COSMOLOGY) # age of the universe in Gyr at z_obs, we cant see for t> t_obs
T_OBS = T_OBS[0]

In [None]:
for selected_param_name in PARAM_NODUST:
    
    selected_param_idx = np.where(PARAM_NODUST == selected_param_name)[0][0]
    key = jax.random.PRNGKey(42)
    new_key, subkey = jax.random.split(key)
    simulated_val = jax.random.uniform(subkey,shape=(NSIM,),minval = p.PARAMS_MIN[selected_param_idx],maxval = p.PARAMS_MAX[selected_param_idx])
    
    other_param_val_left = p.INIT_PARAMS[:selected_param_idx]
    other_param_val_right = p.INIT_PARAMS[selected_param_idx+1:]
    other_param_val = jnp.hstack((other_param_val_left,other_param_val_right))
    other_param_val_arr = jnp.tile(other_param_val,NSIM).reshape(NSIM,-1)

    new_param_sim_arr = jax.numpy.insert(other_param_val_arr, selected_param_idx , simulated_val, axis = 1)

    list_of_dicts_params = [ paramslist_to_dict(new_param_sim_arr[idx,:],PARAM_NAMES) for idx in  range(new_param_sim_arr.shape[0])]

    #-----------------------------------------------------------------------
    # Fig 1
    #-----------------------------------------------------------------------
    fig = plt.figure(figsize=(10,5))
    ax = fig.add_subplot()
    ax.set_xscale('log')
    ax.set_yscale('log')

    left, bottom, width, height = [0.42, 0.2, 0.3, 0.3]
    ax4 = fig.add_axes([left, bottom, width, height])
    ax4.set_title("Star Formation History",fontsize=10)
    ax4.set_xlabel(r'${\rm cosmic\ time\ [Gyr]}$',fontsize=8)
    ax4.set_ylabel(r'${\rm SFR\ [M_{\odot}/yr]}$',fontsize=8)
    ax4.axvspan(T_OBS,13.8,color="grey",alpha=0.5)


    all_d4000 = np.zeros(NSIM)
    all_dBL = np.zeros(NSIM)
    # loop on simulations
    for idx in range(NSIM):
        wlsall,spec_rest,spec_rest_att = ssp_spectrum_fromparam(list_of_dicts_params[idx],Z_OBS)
        tarr,sfh_gal = mean_sfr(list_of_dicts_params[idx],Z_OBS)
        
        d4000 = calc_ratio(wlsall,spec_rest)
        all_d4000[idx] = d4000
        dBL = calc_ratio(wlsall,spec_rest,W_LYMAN ,W_BALMER ) 
        all_dBL[idx] =  dBL

        x = wlsall
        y_nodust = spec_rest
        y_dust = spec_rest_att 
        norm_y_nodust = flux_norm(x,y_nodust,wlcenter=wl0)
        norm_y_dust = flux_norm(x,y_dust,wlcenter=wl0)
        y_nodust/= norm_y_nodust
        y_dust/= norm_y_dust
    
        col = scalarMap.to_rgba(d4000, alpha=1)
    
        ax.plot(x,y_nodust,alpha=0.5,lw=2,color=col)
        ax4.plot(tarr,sfh_gal,color=col)

    cbar=fig.colorbar(scalarMap , ax=ax)
    cbar.ax.set_ylabel('D4000')

    ax.set_ylabel("DSPS SED $F_\\nu$")
    ax.set_xlabel("$\\lambda (\AA)$")
    ax.set_xlim(1e2,1e5)
    plot_hydrogen_lines(ax)
    title = f"$F_\\nu(\lambda)$ by varying parameter {selected_param_name} at redshift {Z_OBS:.2f}"
    ax.set_title(title)

    left, bottom, width, height = [0.15, 0.66, 0.15, 0.2]
    ax2 = fig.add_axes([left, bottom, width, height])
    ax2.hist(all_d4000,bins=20,range=(D4000MIN,D4000MAX),facecolor="g",alpha=0.2)
    ax2.set_xlabel("D4000",fontsize=8)

    left, bottom, width, height = [0.15, 0.2, 0.15, 0.2]
    #ax3 = fig.add_axes([left, bottom, width, height])
    #ax3.hist(np.log10(all_dBL),bins=20,range=(BLMIN,BLMAX),facecolor="g",alpha=0.2)
    #ax3.set_xlabel("Balmer/Lyman",fontsize=8)
    plt.show()

    #-----------------------------------------------------------------------
    small_dict = OrderedDict()
    small_dict[selected_param_name]= simulated_val
    small_dict["d4000"] = all_d4000 

   
   
    nparams = len(small_dict)
    par_names = [ selected_param_name,"d4000"]

    try:
        fig = plt.figure(figsize=(5,5))
        ax = fig.add_subplot()

        ax=az.plot_pair(
        small_dict,
        kind="kde",
        marginal_kwargs={"plot_kwargs": {"lw":3, "c":"blue", "ls":"-"}},
        kde_kwargs={
            "hdi_probs": [0.3, 0.68, 0.9],  # Plot 30%, 68% and 90% HDI contours
            "contour_kwargs":{"colors":None, "cmap":"Blues", "linewidths":3,
                              "linestyles":"-"},
            "contourf_kwargs":{"alpha":0.5},
        },
        point_estimate_kwargs={"lw": 3, "c": "b"},ax=ax,
        #marginals=True, textsize=50, point_estimate='median',
        );
    except Exception as inst:
        print(type(inst))    # the exception type
        print(inst.args)     # arguments stored in .args
        print(inst)  
    
    

    
