# Fit DSPS SED models with Fors2 Spectra with JAXOPT


- author Sylvie Dagoret-Campagne
- affiliation : IJCLab/IN2P3/CNRS
- CC: kernel conda_jax0235_py310
- creation date : 2023-11-10
- update : 2023-11-14 : add interpax interpolation and fitted sfr
- update : 2023-11-21 : correct bug on UV-BUMP


| computer | processor | kernel              |    date     |
| --- | --- | --- | --- |
| CC       | CPU       | conda_jax0325_py310 | 2023-11-10  |
| macbookpro | CPU | conda_jaxcpu_dsps_py310 | 2023-11-10  | 


libraries 
=========

jax
---

- jaxlib-0.3.25+cuda11.cudnn82
- jaxopt
- optax
- corner
- arviz
- numpyro
- intexpax

sps
---

- fsps
- prospect
- dsps
- diffstar
- diffmah
- sedpy or astro-sedpy

plot
----

- matplotlib
- seaborn



(conda_jax0325_py310) 
`/pbs/throng/lsst/users/dagoret/desc/StellarPopulationSynthesis>pip list | grep` 

| lib | version |
|--- | --- | 
|jax  |                         0.4.20 |
|jaxlib |                       0.4.20 |
|jaxopt  |                      0.8.2 |



https://github.com/sylvielsstfr/Fors2ToStellarPopSynthesis/blob/main/examples/examples_jaxtutos/jaxtuto_jec2022/JAX-Optim-regression-piecewise.ipynb

In [None]:
import jax
import jax.numpy as jnp
jax.devices()

In [None]:

import numpy as np
import scipy as sc

import jax
import jax.numpy as jnp
import jax.scipy as jsc

from jax import grad, jit, vmap
from jax import jacfwd, jacrev, hessian

import jaxopt
import optax

jax.config.update("jax_enable_x64", True)
import corner
import arviz as az

import copy

from interpax import interp1d

In [None]:
import itertools

In [None]:
import h5py
import pandas as pd
import numpy as np
import os
from astropy.io import fits
from astropy.table import Table
import matplotlib.pyplot as plt
%matplotlib inline
from collections import OrderedDict
import re
import matplotlib.gridspec as gridspec
from sklearn.gaussian_process import GaussianProcessRegressor, kernels

In [None]:
plt.rcParams["figure.figsize"] = (12,6)
plt.rcParams["axes.labelsize"] = 'xx-large'
plt.rcParams['axes.titlesize'] = 'xx-large'
plt.rcParams['xtick.labelsize']= 'xx-large'
plt.rcParams['ytick.labelsize']= 'xx-large'

In [None]:
from diffstar.defaults import DEFAULT_MAH_PARAMS
from diffstar.defaults import DEFAULT_MS_PARAMS
from diffstar.defaults import DEFAULT_Q_PARAMS

In [None]:
from diffstar import sfh_singlegal
from dsps.cosmology import age_at_z, DEFAULT_COSMOLOGY

In [None]:
from dsps import load_ssp_templates

In [None]:
from dsps import calc_rest_sed_sfh_table_lognormal_mdf
from dsps import calc_rest_sed_sfh_table_met_table

In [None]:
from dsps.dust.att_curves import  sbl18_k_lambda, RV_C00,_frac_transmission_from_k_lambda

### Very important add the lib to load the data

In [None]:
import sys
sys.path.append("../../lib")

In [None]:
from fit_params_fors2 import U_FNU,U_FL,ConvertFlambda_to_Fnu,flux_norm,ordered_keys,Fors2DataAcess

In [None]:
lambda_red = 6231
lambda_width = 50
lambda_sel_min = lambda_red-lambda_width /2.
lambda_sel_max = lambda_red+lambda_width /2.

# Read Fors2 / Galex and Kids

In [None]:
input_file_h5  = '../../data/fors2sl/FORS2spectraGalexKidsPhotom.hdf5'

In [None]:
fors2 = Fors2DataAcess(input_file_h5)

In [None]:
list_of_keys = fors2.get_list_of_groupkeys()
list_of_attributes = fors2.get_list_subgroup_keys()

## Must sort spectra name

In [None]:
list_of_keys = np.array(list_of_keys)

In [None]:
list_of_keysnum = [ int(re.findall("SPEC(.*)",specname)[0]) for specname in  list_of_keys ]

In [None]:
sorted_indexes = np.argsort(list_of_keysnum)

In [None]:
list_of_keys = list_of_keys[sorted_indexes]

In [None]:
df_info = pd.DataFrame(columns=list_of_attributes)
all_df = []

### Read each spectrum Fors2 as (wl,fnu)

In [None]:
for idx,key in enumerate(list_of_keys):
    attrs = fors2.getattribdata_fromgroup(key)
    spectr = fors2.getspectrum_fromgroup(key)
    df_info.loc[idx] = [*attrs.values()] # hope the order of attributes is kept
    df = pd.DataFrame({"wl":spectr["wl"],"fnu":spectr["fnu"]})
    all_df.append(df)
    

In [None]:
df_info.reset_index(drop=True, inplace=True) 

In [None]:
df_info = df_info[ordered_keys]

In [None]:
df_info

# Select good match with galex

In [None]:
df_info.hist("asep_galex",bins=100,color="b")
plt.axvline(5,c="k")

In [None]:
df_info.hist("asep_kids",bins=100,color='b')

## Select  Those spectra having GALEX

In [None]:
df = df_info[df_info["asep_galex"] <= 5]

In [None]:
df.index

## Remove NaN

- remove those row with no FUV

In [None]:
#df = df.dropna()

## Remove rows with Rmag = 0

In [None]:
#df = df[df["Rmag"] > 0]

# Plot Spectra

https://en.wikipedia.org/wiki/Photometric_system

In [None]:
lambda_FUV = 1528.
lambda_NUV = 2271.
lambda_U = 3650.
lambda_B = 4450.
lambda_G = 4640.
lambda_R = 5580.
lambda_I = 8060.
lambda_Z = 9000.
lambda_Y = 10200.
lambda_J = 12200.
lambda_H = 16300.
lambda_K = 21900.
lambda_L = 34500.

WL = [lambda_FUV, lambda_NUV, lambda_B, lambda_G, lambda_R ,lambda_I, lambda_Z, lambda_Y, lambda_J, lambda_H, lambda_K ]
FilterTag = ['FUV','NUV','B','G','R','I','Z','Y','J','H','Ks']

In [None]:
def PlotFilterTag(ax,fluxlist):
    goodfl = fluxlist[np.isfinite(fluxlist)]
    ymin = np.mean(goodfl)
    dy=ymin/5
   
    for idx,flux in enumerate(fluxlist):
        if np.isfinite(flux):
            #ax.text(WL[idx],flux, FilterTag[idx],fontsize=10,ha='center', va='bottom')
                     
            fl = flux - dy
            if fl <0:
                fl += 2*dy
            ax.text(WL[idx],fl, FilterTag[idx],fontsize=12,color="g",weight='bold',ha='center', va='bottom')
            

In [None]:
df

In [None]:
df['index0'] = df.index
df = df.reset_index()

In [None]:
df['index0']

In [None]:
df.columns

# Plots

# Fit Gaussian process to remove abs lines

In [None]:
kernel = kernels.RBF(0.5, (8000, 10000.0))
gp = GaussianProcessRegressor(kernel=kernel ,random_state=0)

### Select the spectrum

In [None]:
#selected_spectrum_number = 411
selected_spectrum_number = 560

dict_normalisation_factor = {}
dict_normalisation_factor[411] = 160.
dict_normalisation_factor[560] = 12.

In [None]:
# loop on rows
props = dict(boxstyle='round', facecolor='wheat', alpha=0.5)
for idx,row in df.iterrows():
    
    print(idx," == ",row["name"],row["num"],row["index0"])
    idx0 = row["index0"] # index for the spectrum
    specname  = row["name"]
    specnum = row["num"]
    if specnum < selected_spectrum_number:
        continue
    if specnum > selected_spectrum_number:
        break
    
    
    mags = [ row["fuv_mag"], row["nuv_mag"], row['MAG_GAAP_u'], row['MAG_GAAP_g'], row['MAG_GAAP_r'], row['MAG_GAAP_i'], row['MAG_GAAP_Z'], row['MAG_GAAP_Y'], 
            row['MAG_GAAP_J'], row['MAG_GAAP_H'],row['MAG_GAAP_Ks'] ]
    
    magserr = [ row["fuv_magerr"], row["nuv_magerr"], row['MAGERR_GAAP_u'], row['MAGERR_GAAP_g'], row['MAGERR_GAAP_r'], row['MAGERR_GAAP_i'], row['MAGERR_GAAP_Z'], row['MAGERR_GAAP_Y'], 
            row['MAGERR_GAAP_J'], row['MAGERR_GAAP_H'],row['MAGERR_GAAP_Ks'] ]
    
    mfluxes = [ 10**(-0.4*m) for m in mags ]
    mfluxeserr = []
    
    for f,em in zip(mfluxes,magserr):
        ferr = 0.4*np.log(10)*em*f
        mfluxeserr.append(ferr)
        
    mfluxes = np.array(mfluxes)
    mfluxeserr = np.array(mfluxeserr)
    
    fluxes =  [ row["fuv_flux"], row["nuv_flux"], row['FLUX_GAAP_u'], row['FLUX_GAAP_g'], row['FLUX_GAAP_r'], row['FLUX_GAAP_i'], row['FLUX_GAAP_Z'], row['FLUX_GAAP_Y'], 
            row['FLUX_GAAP_J'], row['FLUX_GAAP_H'],row['FLUX_GAAP_Ks'] ]
    
    fluxeserr =  [ row["fuv_fluxerr"], row["nuv_fluxerr"], row['FLUXERR_GAAP_u'], row['FLUXERR_GAAP_g'], row['FLUXERR_GAAP_r'], row['FLUXERR_GAAP_i'], row['FLUXERR_GAAP_Z'], row['FLUX_GAAP_Y'], 
            row['FLUXERR_GAAP_J'], row['FLUXERR_GAAP_H'],row['FLUXERR_GAAP_Ks'] ]
    
    ##############################################################
    # The top image
    ######################################
    fig = plt.figure(figsize=(10,14),tight_layout=True)
    gs = gridspec.GridSpec(3, 1,height_ratios=[3, 1, 3])
    
    
    #top image : original flux
    ax= fig.add_subplot(gs[0, 0])
    
    z = row["redshift"]
    asep_fg = row['asep_galex']
    rmag = row["Rmag"]
    asep_fk = row['asep_kids']
    
  
    the_lines = row['lines']
    print("the_lines = ",the_lines)
    all_elements  = the_lines.decode().split(",")
    
    
    the_label_data = f" FORS2 : z = {z:.3f} , Rmag = {rmag:.1f} mag,  angular sep (arcsec) f-g : {asep_fg:.3f}, f-k : {asep_fk:.3f}" 
    the_wl = all_df[idx0]["wl"].values
    the_fnu = all_df[idx0]["fnu"].values*10**(-0.4*mags[4])    
    ax.plot(the_wl,the_fnu,'b-',label=the_label_data)
     
    X = the_wl
    Y = the_fnu
    gp.fit(X[:, None], Y)
    xfit = np.linspace(X.min(),X.max())
    yfit, yfit_err = gp.predict(xfit[:, None], return_std=True)
    ax.plot(xfit, yfit, '-', color='cyan')
    ax.fill_between(xfit, yfit -  yfit_err, yfit +  yfit_err, color='gray', alpha=0.3)
     
    ax2 = ax.twinx()
    ax2.errorbar(WL,mfluxes,yerr=mfluxeserr,fmt='o',color="r",ecolor="r",ms=5,label='Galex (UV) + Kids (optics) +Viking (IR)')
    PlotFilterTag(ax2,mfluxes)
    #ax2.errorbar(WL, fluxes, yerr=fluxes, xerr=None, fmt='o', color="g",ecolor="g")
    ax2.legend(loc="lower right")
    
    ax.set_xlabel("$\lambda  (\AA)$ ")
    title_data = f"FORS2 : {idx}): {idx0} name = {specname}"
    ax.set_title(title_data)
    ax.set_ylabel("flux (maggies)")
    ax.legend(loc="upper right")
    ax.grid()
    
    the_max1 = np.max(the_fnu)
    goodmags = mfluxes[np.isfinite(mfluxes)]
    the_max2 = np.max(goodmags)
    the_max = 1.1*max(the_max1,the_max2)

    ax.set_ylim(0,the_max)
    ax2.set_ylim(0,the_max)

    ax.set_xlim(1000.,25000.)
    ##################
    # Middle image : Fit a gaussian process and compute the residuals
    ##################
    ax3 = fig.add_subplot(gs[1, 0])
    DeltaY,DeltaEY = Y - gp.predict(X[:, None], return_std=True)
    ax3.plot(X,DeltaY,'b')
    ax3.set_xlim(1000.,25000.)
    ax3.grid()
    ax3.set_xlabel("$\lambda  (\AA)$ ")
    
    background = np.sqrt(np.median(DeltaY**2))
    indexes_toremove = np.where(np.abs(DeltaY)> 8 * background)[0]
    
    for index in indexes_toremove:
        ax3.axvline(X[index],color='k')
    
    #########################    
    # Bottom image : resuting image after removing emission lines
    #########################
    
    ax4 = fig.add_subplot(gs[2, 0])
    
    Xclean = np.delete(X,indexes_toremove)
    Yclean  = np.delete(Y,indexes_toremove)
    
    ax4.plot(Xclean, Yclean,'b-',label=the_label_data)
    ax4.set_ylim(0,the_max)
    ax4.set_xlim(1000.,25000.)
    ax4.grid()
    ax4.set_xlabel("$\lambda  (\AA)$ ")
    
    ax5 = ax4.twinx()
    ax5.errorbar(WL,mfluxes,yerr=mfluxeserr,fmt='o',color="r",ecolor="r",ms=5,label='Galex (UV) + Kids (optics) +Viking (IR)')
    ax5.legend(loc="lower right")
    ax5.set_ylim(0,the_max)
    PlotFilterTag(ax5,mfluxes)
    
    ax4.set_ylabel("flux (maggies)")
    ax4.legend(loc="upper right")
    
    spectr = fors2.getspectrumcleanedemissionlines_fromgroup(specname,gp)

    the_wl = all_df[idx0]["wl"].values
    the_fnu = all_df[idx0]["fnu"].values*10**(-0.4*mags[4])
    
    ax5.plot(the_wl,the_fnu,color="grey",lw=0.25,label="original")

    
    plt.show()

In [None]:
background 

In [None]:
z

# Model of a galaxy

### Load SED templates

In [None]:

#ssp_data = load_ssp_templates(fn='tempdata_v2.h5')
ssp_data = load_ssp_templates(fn='../examples_dsps_diffstar_diffmah/tempdata.h5')

print(ssp_data._fields)

print('ssp_data : ssp_lgmet.shape = {}'.format(ssp_data.ssp_lgmet.shape))
print('ssp_data : ssp_lg_age_gyr.shape = {}'.format(ssp_data.ssp_lg_age_gyr.shape))
print('ssp_data : ssp_wave.shape = {}'.format(ssp_data.ssp_wave.shape))
print('ssp_data :ssp_flux.shape = {}'.format(ssp_data.ssp_flux.shape))

### calculate age distribution

In [None]:
today_gyr = 13.8 
tarr = np.linspace(0.1, today_gyr, 100)

In [None]:
sfh_gal = sfh_singlegal(tarr, DEFAULT_MAH_PARAMS, DEFAULT_MS_PARAMS, DEFAULT_Q_PARAMS)

In [None]:
fig, ax = plt.subplots(1, 1)
ylim = ax.set_ylim(1e-3, 50)
yscale = ax.set_yscale('log')

__=ax.plot(tarr, sfh_gal, '--', color='k',label='sfh_gal')
ax.set_title("Simulated Star Formation History (SFH)")
xlabel = ax.set_xlabel(r'${\rm cosmic\ time\ [Gyr]}$')
ylabel = ax.set_ylabel(r'${\rm SFR\ [M_{\odot}/yr]}$')
ax.grid()

In [None]:
fig, ax = plt.subplots(1, 1)
ylim = ax.set_ylim(0, 1)
#yscale = ax.set_yscale('log')

__=ax.plot(tarr, sfh_gal, '--', color='k',label='sfh_gal')
ax.set_title("Simulated Star Formation History (SFH)")
xlabel = ax.set_xlabel(r'${\rm cosmic\ time\ [Gyr]}$')
ylabel = ax.set_ylabel(r'${\rm SFR\ [M_{\odot}/yr]}$')
ax.grid()

In [None]:
gal_t_table = tarr

In [None]:
z_obs = z
t_obs = age_at_z(z_obs, *DEFAULT_COSMOLOGY) # age of the universe in Gyr at z_obs
t_obs = t_obs[0] # age_at_z function returns an array, but SED functions accept a float for this argument

In [None]:
sfh_gal = jnp.where(tarr<t_obs, sfh_gal, 0)

In [None]:
fig, ax = plt.subplots(1, 1)
ylim = ax.set_ylim(0, 1)
#yscale = ax.set_yscale('log')

__=ax.plot(tarr, sfh_gal, '--', color='k',label='sfh_gal')
ax.set_title("Simulated Star Formation History (SFH)")
xlabel = ax.set_xlabel(r'${\rm cosmic\ time\ [Gyr]}$')
ylabel = ax.set_ylabel(r'${\rm SFR\ [M_{\odot}/yr]}$')
ax.grid()

In [None]:
gal_sfr_table = sfh_gal
# metallicity
gal_lgmet = -2.0 # log10(Z)
gal_lgmet_scatter = 0.2 # lognormal scatter in the metallicity distribution function


In [None]:
# calculate first SED with only one metallicity
print("gal_lgmet = ", gal_lgmet)

sed_info = calc_rest_sed_sfh_table_lognormal_mdf(
    gal_t_table, gal_sfr_table, gal_lgmet, gal_lgmet_scatter,
    ssp_data.ssp_lgmet, ssp_data.ssp_lg_age_gyr, ssp_data.ssp_flux, t_obs)

In [None]:
Av= 1.0
uv_bump_ampl = 3.0
plaw_slope = -0.25
wave_spec_micron = ssp_data.ssp_wave/10_000

In [None]:
k = sbl18_k_lambda(wave_spec_micron,uv_bump_ampl,plaw_slope)
dsps_flux_ratio = _frac_transmission_from_k_lambda(k,Av)

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(12, 5))
# fig.tight_layout(pad=3.0)

label_dust = f" Av = {Av}, $\delta$ = {plaw_slope}, uv-bump = {uv_bump_ampl}"
ax.plot(ssp_data.ssp_wave ,dsps_flux_ratio,'b-',label=label_dust)

xlim = ax.set_xlim(900, 1e4)
ylim = ax.set_ylim(1e-4, 2)
__=ax.loglog()

title = ax.set_title(r'${\rm attenuation\ curve\ validation}$')
xlabel = ax.set_xlabel(r'$\lambda\ [\AA]$')
ylabel= ax.set_ylabel(r'$D(\lambda)\equiv F_{\rm dust}/F_{\rm no-dust}$')
ax.legend()
ax.grid()

In [None]:
sed_attenuated = dsps_flux_ratio * sed_info.rest_sed

In [None]:
FLAG_RESCALE_FORPLOT = False

In [None]:
fig, ax = plt.subplots(1, 1)
__=ax.loglog()
__=ax.plot(ssp_data.ssp_wave, sed_info.rest_sed,'b:',lw=1,label="no dust")
__=ax.plot(ssp_data.ssp_wave, sed_attenuated,'r-',lw=1,label="with dust")
__=ax.set_ylim(1e-7,1e-5)
ax.set_xlabel("$\lambda (\\AA)$")
ax.set_ylabel("$L_\\nu(\lambda)$")

if FLAG_RESCALE_FORPLOT : 
    X = the_wl/(1+z_obs)
    Y = the_fnu*(1+z)*dict_normalisation_factor[selected_spectrum_number]
else:
    X = the_wl/(1+z_obs)
    Y = the_fnu*(1+z)
 
ax.plot(X,Y,'k-',lw=1,label=the_label_data)
ax.legend()

title = "Comparison of SED $L_\\nu$ with SFH and dust with " + title_data
ax.set_title(title)
ax.grid()

In [None]:
X = the_wl/(1+z_obs)
Y = the_fnu*(1+z)

# Start Optimisation

### Parameters

In [None]:
def paramslist_to_dict(params_list,param_names):
    """
    Convert the list of parameters into a dictionnary
    """
    
    Nparams = len(params_list)
    list_of_tuples = list(zip(param_names,params_list))
    print(list_of_tuples)
    dict_params = OrderedDict(list_of_tuples )
    return dict_params
            

#### MAH parameters

In [None]:
DEFAULT_MAH_PARAMS

In [None]:
DEFAULT_MAH_PARAMS_MIN = DEFAULT_MAH_PARAMS + np.array([-3., -0.01, -1.5,-0.5])
DEFAULT_MAH_PARAMS_MAX = DEFAULT_MAH_PARAMS + np.array([2., +0.01, +1.5,+0.5])

In [None]:
mah_paramnames = ["MAH_lgmO","MAH_logtc","MAH_early_index","MAH_late_index"]

#### MS parameters

In [None]:
DEFAULT_MS_PARAMS

In [None]:
DEFAULT_MS_PARAMS_MIN = DEFAULT_MS_PARAMS - 0.25*np.ones((5)) 
DEFAULT_MS_PARAMS_MAX = DEFAULT_MS_PARAMS + 0.25*np.ones((5)) 

In [None]:
ms_paramnames = ["MS_lgmcrit", "MS_lgy_at_mcrit", "MS_indx_lo", "MS_indx_hi", "MS_tau_dep"]

#### Q parameters

In [None]:
DEFAULT_Q_PARAMS

In [None]:
DEFAULT_Q_PARAMS_MIN = DEFAULT_Q_PARAMS - 0.1*np.ones((4,))
DEFAULT_Q_PARAMS_MAX = DEFAULT_Q_PARAMS + 0.1*np.ones((4,))

In [None]:
q_paramnames = ["Q_lg_qt", "Q_qlglgdt", "Q_lg_drop", "Q_lg_rejuv"]

#### Dust parameters

In [None]:
Av= 1
uv_bump_ampl = 2.0
plaw_slope = -0.25

In [None]:
DEFAULT_DUST_PARAMS = [Av, uv_bump_ampl, plaw_slope]

In [None]:
DEFAULT_DUST_PARAMS_MIN = DEFAULT_DUST_PARAMS + np.array([-1.,-1.,-0.1])
DEFAULT_DUST_PARAMS_MAX = DEFAULT_DUST_PARAMS + np.array([2.,1.,0.25])

In [None]:
dust_paramnames = ["Av", "uv_bump", "plaw_slope"]

#### Combine parameters

In [None]:
defaults_params = [DEFAULT_MAH_PARAMS,DEFAULT_MS_PARAMS,DEFAULT_Q_PARAMS,DEFAULT_DUST_PARAMS]

params_min = np.concatenate(([DEFAULT_MAH_PARAMS_MIN,DEFAULT_MS_PARAMS_MIN,DEFAULT_Q_PARAMS_MIN,DEFAULT_DUST_PARAMS_MIN]))
params_max = np.concatenate(([DEFAULT_MAH_PARAMS_MAX,DEFAULT_MS_PARAMS_MAX,DEFAULT_Q_PARAMS_MAX,DEFAULT_DUST_PARAMS_MAX]))
init_params = np.concatenate(defaults_params)
init_params = jnp.array(init_params)

param_names = [mah_paramnames,ms_paramnames,q_paramnames,dust_paramnames]
param_scales = [3, 0.25,0.1,2.]

In [None]:
dict_param_mah_true = OrderedDict([(mah_paramnames[0],DEFAULT_MAH_PARAMS[0]),
                                         (mah_paramnames[1],DEFAULT_MAH_PARAMS[1]),
                                         (mah_paramnames[2],DEFAULT_MAH_PARAMS[2]),
                                         (mah_paramnames[3],DEFAULT_MAH_PARAMS[3])
                                         ])
dict_param_mah_true_selected = OrderedDict([(mah_paramnames[0],DEFAULT_MAH_PARAMS[0]),
                                         ])
dict_param_ms_true = OrderedDict([(ms_paramnames[0],DEFAULT_MS_PARAMS[0]),
                                         (ms_paramnames[1],DEFAULT_MS_PARAMS[1]),
                                         (ms_paramnames[2],DEFAULT_MS_PARAMS[2]),
                                         (ms_paramnames[3],DEFAULT_MS_PARAMS[3]),
                                         (ms_paramnames[4],DEFAULT_MS_PARAMS[4])])
dict_param_q_true = OrderedDict([(q_paramnames[0],DEFAULT_Q_PARAMS[0]),
                                         (q_paramnames[1],DEFAULT_Q_PARAMS[1]),
                                         (q_paramnames[2],DEFAULT_Q_PARAMS[2]),
                                         (q_paramnames[3],DEFAULT_Q_PARAMS[3])])

dict_param_dust_true = OrderedDict([(dust_paramnames[0],DEFAULT_DUST_PARAMS[0]),
                                         (dust_paramnames[1],DEFAULT_DUST_PARAMS[1]),
                                         (dust_paramnames[2],DEFAULT_DUST_PARAMS[2])])
dict_param_dust_true_selected = OrderedDict([(dust_paramnames[0],DEFAULT_DUST_PARAMS[0])])

In [None]:
dict_param_dust_true

In [None]:
params_true = dict_param_mah_true
params_true.update(dict_param_ms_true)
params_true.update(dict_param_q_true)
params_true.update(dict_param_dust_true)

In [None]:
params_true

In [None]:
param_names

In [None]:
param_names_flat = list(itertools.chain(*param_names))
param_names_flat

In [None]:
def mean_spectrum(wls, params):
    """ Model of spectrum 
    
    :param wls: wavelengths of the spectrum in rest frame
    :type wls: float
    
    :return: the spectrum
    :rtype:
    
    
    """
    
    # decode the parameters
    MAH_lgmO = params["MAH_lgmO"]
    MAH_logtc = params["MAH_logtc"]
    MAH_early_index = params["MAH_early_index"]
    MAH_late_index = params["MAH_late_index"]
    list_param_mah = [MAH_lgmO,MAH_logtc,MAH_early_index,MAH_late_index]
    
    MS_lgmcrit = params["MS_lgmcrit"]
    MS_lgy_at_mcrit = params["MS_lgy_at_mcrit"]
    MS_indx_lo = params["MS_indx_lo"]
    MS_indx_hi = params["MS_indx_hi"]
    MS_tau_dep = params["MS_tau_dep"]
    list_param_ms = [MS_lgmcrit,MS_lgy_at_mcrit,MS_indx_lo,MS_indx_hi,MS_tau_dep]
    
    Q_lg_qt = params["Q_lg_qt"]
    Q_qlglgdt = params["Q_qlglgdt"]
    Q_lg_drop = params["Q_lg_drop"]
    Q_lg_rejuv = params["Q_lg_rejuv"]
    list_param_q = [Q_lg_qt, Q_qlglgdt,Q_lg_drop,Q_lg_rejuv]
    
    Av = params["Av"]
    uv_bump = params["uv_bump"]
    plaw_slope = params["plaw_slope"]
    list_param_dust = [Av,uv_bump,plaw_slope]
    
    
    # compute SFR
    tarr = np.linspace(0.1, today_gyr, 100)
    sfh_gal = sfh_singlegal(
    tarr, list_param_mah , list_param_ms, list_param_q)
    
    # metallicity
    gal_lgmet = -2.0 # log10(Z)
    gal_lgmet_scatter = 0.2 # lognormal scatter in the metallicity distribution function

    # need age of universe when the light was emitted 
    t_obs = age_at_z(z_obs, *DEFAULT_COSMOLOGY) # age of the universe in Gyr at z_obs
    t_obs = t_obs[0] # age_at_z function returns an array, but SED functions accept a float for this argument
    
    # clear sfh in future
    sfh_gal = jnp.where(tarr<t_obs, sfh_gal, 0)
    
    # compute SED
    gal_t_table = tarr
    gal_sfr_table = sfh_gal
    sed_info = calc_rest_sed_sfh_table_lognormal_mdf(
    gal_t_table, gal_sfr_table, gal_lgmet, gal_lgmet_scatter,
    ssp_data.ssp_lgmet, ssp_data.ssp_lg_age_gyr, ssp_data.ssp_flux, t_obs)
    
    # compute dust attenuation
    wave_spec_micron = ssp_data.ssp_wave/10_000
    k = sbl18_k_lambda(wave_spec_micron,uv_bump,plaw_slope)
    dsps_flux_ratio = _frac_transmission_from_k_lambda(k,Av)
    
    sed_attenuated = dsps_flux_ratio * sed_info.rest_sed

    # interpolate with interpax which is differentiable
    #Fobs = jnp.interp(wls, ssp_data.ssp_wave, sed_attenuated)
    Fobs = interp1d(wls, ssp_data.ssp_wave, sed_attenuated,method='cubic')
    
    return Fobs
    

In [None]:
def mean_sfr(params):
    """ Model of the SFR
    
    :param params: Fitted parameter dictionnary
    :type params: float as a dictionnary
    
    :return: the spectrum
    :rtype:
    
    
    """
    
    # decode the parameters
    MAH_lgmO = params["MAH_lgmO"]
    MAH_logtc = params["MAH_logtc"]
    MAH_early_index = params["MAH_early_index"]
    MAH_late_index = params["MAH_late_index"]
    list_param_mah = [MAH_lgmO,MAH_logtc,MAH_early_index,MAH_late_index]
    
    MS_lgmcrit = params["MS_lgmcrit"]
    MS_lgy_at_mcrit = params["MS_lgy_at_mcrit"]
    MS_indx_lo = params["MS_indx_lo"]
    MS_indx_hi = params["MS_indx_hi"]
    MS_tau_dep = params["MS_tau_dep"]
    list_param_ms = [MS_lgmcrit,MS_lgy_at_mcrit,MS_indx_lo,MS_indx_hi,MS_tau_dep]
    
    Q_lg_qt = params["Q_lg_qt"]
    Q_qlglgdt = params["Q_qlglgdt"]
    Q_lg_drop = params["Q_lg_drop"]
    Q_lg_rejuv = params["Q_lg_rejuv"]
    list_param_q = [Q_lg_qt, Q_qlglgdt,Q_lg_drop,Q_lg_rejuv]
    
    Av = params["Av"]
    uv_bump = params["uv_bump"]
    plaw_slope = params["plaw_slope"]
    list_param_dust = [Av,uv_bump,plaw_slope]
    
    
    # compute SFR
    tarr = np.linspace(0.1, today_gyr, 100)
    sfh_gal = sfh_singlegal(
    tarr, list_param_mah , list_param_ms, list_param_q)
    
    # metallicity
    gal_lgmet = -2.0 # log10(Z)
    gal_lgmet_scatter = 0.2 # lognormal scatter in the metallicity distribution function

    # need age of universe when the light was emitted 
    t_obs = age_at_z(z_obs, *DEFAULT_COSMOLOGY) # age of the universe in Gyr at z_obs
    t_obs = t_obs[0] # age_at_z function returns an array, but SED functions accept a float for this argument
    
    # clear sfh in future
    sfh_gal = jnp.where(tarr<t_obs, sfh_gal, 0)
    
    
    return tarr,sfh_gal
    

In [None]:
def lik(p,wls,F, sigma_obs= background ):
    """
    neg loglikelihood(parameters,x,y,sigmas)
    """
    
    params = {"MAH_lgmO":p[0], 
              "MAH_logtc":p[1], 
              "MAH_early_index":p[2], 
              "MAH_late_index": p[3],
              
              "MS_lgmcrit":p[4],
              "MS_lgy_at_mcrit":p[5],
              "MS_indx_lo":p[6],
              "MS_indx_hi":p[7],
              "MS_tau_dep":p[8],
              
              "Q_lg_qt":p[9],
              "Q_qlglgdt":p[10],
              "Q_lg_drop":p[11],
              "Q_lg_rejuv":p[12],
              
              "Av":p[13],
              "uv_bump":p[14],
              "plaw_slope":p[15]}
    
    resid = mean_spectrum(wls, params) -F
    return 0.5*jnp.sum((resid/sigma_obs) ** 2) 


In [None]:
def get_infos(res, model, wls,F):
    params    = res.params
    fun_min   = model(params,wls,F)
    jacob_min =jax.jacfwd(model)(params, wls,F)
    #covariance matrix of parameters
    inv_hessian_min =jax.scipy.linalg.inv(jax.hessian(model)(params, wls,F))
    return params,fun_min,jacob_min,inv_hessian_min


In [None]:
lbfgsb = jaxopt.ScipyBoundedMinimize(fun=lik, method="L-BFGS-B")

res = lbfgsb.run(init_params, bounds=(params_min ,params_max ), wls=X, F=Y)
params,fun_min,jacob_min,inv_hessian_min = get_infos(res, lik, wls=X, F=Y)
print("params:",params,"\nfun@min:",fun_min,"\njacob@min:",jacob_min,
     "\n invH@min:",inv_hessian_min)


In [None]:
len(params)

In [None]:
dict_params_fitted = paramslist_to_dict(params,param_names_flat)

In [None]:
dict_params_fitted

In [None]:
dict_params_fitted_nodust = copy.deepcopy(dict_params_fitted)

In [None]:
dict_params_fitted_nodust["Av"] = 0

In [None]:
Y_fit = mean_spectrum(ssp_data.ssp_wave, dict_params_fitted)
Y_fit_nodust = mean_spectrum(ssp_data.ssp_wave, dict_params_fitted_nodust)

In [None]:
fig, ax = plt.subplots(1, 1)
__=ax.loglog()
__=ax.set_xlim(1e3,1e6)
__=ax.set_ylim(1e-10,1e-6)
ax.set_xlabel("$\lambda (\\AA)$")
ax.set_ylabel("$L_\\nu(\lambda)$")

ax.plot(ssp_data.ssp_wave,Y_fit,'-',color='green',lw=1,label="fitted model with dust")
ax.plot(ssp_data.ssp_wave,Y_fit_nodust,'-',color='red',lw=1,label="model No dust")
ax.plot(X,Y,'b-',lw=3,label=the_label_data)


title = "Comparison of SED $L_\\nu$ with SFH and dust with " + title_data
ax.set_title(title)
ax.legend()
ax.grid()

In [None]:
YL_fit = Y_fit*3e18/(ssp_data.ssp_wave)**2
YL_fit_nodust = Y_fit_nodust*3e18/(ssp_data.ssp_wave)**2
YL = Y*3e18/X**2

In [None]:
fig, ax = plt.subplots(1, 1)
__=ax.loglog()
__=ax.set_xlim(1e3,1e6)
__=ax.set_ylim(1e0,1e5)
ax.set_xlabel("$\lambda (\\AA)$")
ax.set_ylabel("$L_\lambda(\lambda)$")

ax.plot(ssp_data.ssp_wave,YL_fit_nodust,'-',color='green',lw=1,label="fitted model with dust")
ax.plot(ssp_data.ssp_wave,YL_fit,'-',color='red',lw=1,label="model No dust")
ax.plot(X,YL,'b-',lw=3,label=the_label_data)


title = "Comparison of SED $L_\\lambda$ with SFH and dust with " + title_data
ax.set_title(title)
ax.legend()
ax.grid()

In [None]:
fig, ax = plt.subplots(1, 1)
__=ax.loglog()
__=ax.set_xlim(1e3,1e5)
__=ax.set_ylim(1e1,1e5)
ax.set_xlabel("$\lambda (\\AA)$")
ax.set_ylabel("$L_\lambda(\lambda)$")

ax.plot(ssp_data.ssp_wave,YL_fit_nodust,'-',color='green',lw=1,label="fitted model with dust")
ax.plot(ssp_data.ssp_wave,YL_fit,'-',color='red',lw=1,label="model No dust")
ax.plot(X,YL,'b-',lw=3,label=the_label_data)


title = "Comparison of SED $L_\\lambda$ with SFH and dust with " + title_data
ax.set_title(title)
ax.legend()
ax.grid()

## Check the SFH

In [None]:
tarr_fit,sfr_fit = mean_sfr(dict_params_fitted)

In [None]:
sfr_max = sfr_fit.max()*10.
sfr_min = sfr_max/1e4
fig, ax = plt.subplots(1, 1)
ylim = ax.set_ylim(sfr_min, sfr_max)
yscale = ax.set_yscale('log')

__=ax.plot(tarr_fit, sfr_fit, '--', color='k',label='sfh_gal')
ax.set_title("Fitted Star Formation History (SFH) " + title_data)
xlabel = ax.set_xlabel(r'${\rm cosmic\ time\ [Gyr]}$')
ylabel = ax.set_ylabel(r'${\rm SFR\ [M_{\odot}/yr]}$')
ax.grid()

In [None]:
sfr_max = sfr_fit.max()*1.2
fig, ax = plt.subplots(1, 1)
ylim = ax.set_ylim(0, sfr_max)
#yscale = ax.set_yscale('log')

__=ax.plot(tarr_fit, sfr_fit, '--', color='k',label='sfh_gal')
ax.set_title("Fitted Star Formation History (SFH) " + title_data)
xlabel = ax.set_xlabel(r'${\rm cosmic\ time\ [Gyr]}$')
ylabel = ax.set_ylabel(r'${\rm SFR\ [M_{\odot}/yr]}$')
ax.grid()