# Fit DSPS SED models with Fors2 Spectra with JAXOPT and Photometry all combined


- author Sylvie Dagoret-Campagne
- affiliation : IJCLab/IN2P3/CNRS
- CC: kernel conda_jax0235_py310
- creation date : 2023-11-16
- update : 2023-11-16 : add interpax interpolation and fitted sfr, and filter object and pytree
- update : 2023-11-18 : really finish combined fit
- update : 2023-11-21 : correct bug on UV_bump


| computer | processor | kernel              |    date     |
| --- | --- | --- | --- |
| CC       | CPU       | conda_jax0325_py310 | 2023-11-10  |
| macbookpro | CPU | conda_jaxcpu_dsps_py310 | 2023-11-10  | 


libraries 
=========

jax
---

- jaxlib-0.3.25+cuda11.cudnn82
- jaxopt
- optax
- corner
- arviz
- numpyro
- interpax

sps
---

- fsps
- prospect
- dsps
- diffstar
- diffmah
- sedpy or astro-sedpy

plot
----

- matplotlib
- seaborn

(conda_jax0325_py310) 
`/pbs/throng/lsst/users/dagoret/desc/StellarPopulationSynthesis>pip list | grep` 

| lib | version |
|--- | --- | 
|jax  |                         0.4.20 |
|jaxlib |                       0.4.20 |
|jaxopt  |                      0.8.2 |



https://github.com/sylvielsstfr/Fors2ToStellarPopSynthesis/blob/main/examples/examples_jaxtutos/jaxtuto_jec2022/JAX-Optim-regression-piecewise.ipynb

In [None]:
import jax
import jax.numpy as jnp
jax.devices()

In [None]:
from sedpy import observate

In [None]:

import numpy as np
import scipy as sc

import jax
import jax.numpy as jnp
import jax.scipy as jsc

from jax import grad, jit, vmap
from jax import jacfwd, jacrev, hessian

import jaxopt
import optax

jax.config.update("jax_enable_x64", True)
import corner
import arviz as az

import copy
from interpax import interp1d

In [None]:
import itertools

In [None]:
import h5py
import pandas as pd
import numpy as np
import os
from astropy.io import fits
from astropy.table import Table
import matplotlib as mpl
import matplotlib.pyplot as plt
%matplotlib inline
import matplotlib.colors as colors
import matplotlib.cm as cmx
from collections import OrderedDict
import re
import matplotlib.gridspec as gridspec
from sklearn.gaussian_process import GaussianProcessRegressor, kernels

In [None]:
plt.rcParams["figure.figsize"] = (12,6)
plt.rcParams["axes.labelsize"] = 'xx-large'
plt.rcParams['axes.titlesize'] = 'xx-large'
plt.rcParams['xtick.labelsize']= 'xx-large'
plt.rcParams['ytick.labelsize']= 'xx-large'

In [None]:
from diffstar.defaults import DEFAULT_MAH_PARAMS
from diffstar.defaults import DEFAULT_MS_PARAMS
from diffstar.defaults import DEFAULT_Q_PARAMS

In [None]:
from diffstar import sfh_singlegal
from dsps.cosmology import age_at_z, DEFAULT_COSMOLOGY

In [None]:
from dsps import load_ssp_templates

In [None]:
from dsps import calc_rest_sed_sfh_table_lognormal_mdf
from dsps import calc_rest_sed_sfh_table_met_table

In [None]:
from dsps import calc_rest_mag, calc_obs_mag

In [None]:
from dsps.dust.att_curves import  sbl18_k_lambda, RV_C00,_frac_transmission_from_k_lambda

### Very important add the lib to load the data

In [None]:
import sys
sys.path.append("../../lib")

In [None]:
from fit_params_fors2 import U_FNU,U_FL,ConvertFlambda_to_Fnu,flux_norm,ordered_keys,Fors2DataAcess

In [None]:
lambda_red = 6231
lambda_width = 50
lambda_sel_min = lambda_red-lambda_width /2.
lambda_sel_max = lambda_red+lambda_width /2.

# Filters

### Build Tables for filters

In [None]:
class FilterInfo():
    def __init__(self):
        self.filters_galex = np.array(["galex_FUV","galex_NUV"])
        self.filters_sdss = np.array(["sdss_u0","sdss_g0","sdss_r0","sdss_i0"])
        self.filters_vircam = np.array(["vista_vircam_Z","vista_vircam_Y","vista_vircam_J","vista_vircam_H","vista_vircam_Ks"])

        # Galex filters
        self.all_filt_galex = []
        for filtname in self.filters_galex:
            filt = observate.Filter(filtname)
            self.all_filt_galex.append(filt)   
        self.N_galex = len(self.all_filt_galex)
        # colors for Galex
        cmap = mpl.cm.PuBu
        cNorm = colors.Normalize(vmin=0, vmax=self.N_galex)
        scalarMap = cmx.ScalarMappable(norm=cNorm, cmap=cmap)
        self.all_colors_galex = scalarMap.to_rgba(np.arange(self.N_galex+1), alpha=1)

        # SDSS filters (for KIDS survey)
        self.all_filt_sdss = []
        for filtname in self.filters_sdss:
            filt = observate.Filter(filtname)
            self.all_filt_sdss.append(filt)  
        self.N_sdss = len(self.all_filt_sdss)
        # colors for SDSS
        cmap = mpl.cm.Reds
        cNorm = colors.Normalize(vmin=0, vmax=self.N_sdss)
        scalarMap = cmx.ScalarMappable(norm=cNorm, cmap=cmap)
        self.all_colors_sdss = scalarMap.to_rgba(np.arange(self.N_sdss+1), alpha=1)

        # VIRCAM
        self.all_filt_vircam = []
        for filtname in self.filters_vircam:
            filt = observate.Filter(filtname)
            self.all_filt_vircam.append(filt) 
        self.N_vircam = len(self.all_filt_vircam)
        # colors for Vircam
        cmap = mpl.cm.Wistia
        cNorm = colors.Normalize(vmin=0, vmax=self.N_vircam)
        scalarMap = cmx.ScalarMappable(norm=cNorm, cmap=cmap)
        self.all_colors_vircam = scalarMap.to_rgba(np.arange(self.N_vircam+1), alpha=1)
        


        self.filters_indexlist = []
        self.filters_surveylist = []
        self.filters_namelist = []
        self.filters_transmissionlist = []
        self.filters_transmissionnormlist = []
        self.filters_colorlist = []
        
        filter_count = 0
        
        for index in range(self.N_galex):
            self.filters_indexlist.append(filter_count)
            self.filters_surveylist.append("galex") 
            self.filters_namelist.append(self.filters_galex[index])
            self.filters_transmissionlist.append(self.all_filt_galex[index])
            self.filters_transmissionnormlist.append(100.0)
            self.filters_colorlist.append(self.all_colors_galex[index+1])
            filter_count+= 1
            
        for index in range(self.N_sdss):
            self.filters_indexlist.append(filter_count)
            self.filters_surveylist.append("sdss") 
            self.filters_namelist.append(self.filters_sdss[index])
            self.filters_transmissionlist.append(self.all_filt_sdss[index])
            self.filters_transmissionnormlist.append(1.0)
            self.filters_colorlist.append(self.all_colors_sdss[index+1])
            filter_count+= 1
            
        for index in range(self.N_vircam):
            self.filters_indexlist.append(filter_count)
            self.filters_surveylist.append("vircam") 
            self.filters_namelist.append(self.filters_vircam[index])
            self.filters_transmissionlist.append(self.all_filt_vircam[index])
            if index==0:
                self.filters_transmissionnormlist.append(100.0)
            else:
                self.filters_transmissionnormlist.append(1.0)
            self.filters_colorlist.append(self.all_colors_vircam[index+1])
            filter_count+= 1

    def get_pytree(self):
        """
        return a dict of a tuple of a dict
        """
        the_dict = {}
        
        for index in self.filters_indexlist:
            the_subdict = {}
            the_filt = self.filters_transmissionlist[index]
            the_norm = self.filters_transmissionnormlist[index]

            the_name = self.filters_namelist[index]
            the_wlmean = the_filt.wave_mean
            the_wls = the_filt.wavelength
            the_transm =the_filt.transmission/the_norm
            
            the_subdict["name"] = the_name 
            the_subdict["wlmean"] = the_wlmean
            the_subdict["wls"] = jnp.array(the_wls)
            the_subdict["transm"] = jnp.array(the_transm)
            the_dict[index] = the_subdict
            
        return the_dict
            
    def get_2lists(self):
        """
        return a list of a pair of lists
        """
        the_list1 = []
        the_list2 = []
        
        for index in self.filters_indexlist:
           
            the_filt = self.filters_transmissionlist[index]
            the_norm = self.filters_transmissionnormlist[index]

            the_name = self.filters_namelist[index]
            the_wlmean = the_filt.wave_mean
            the_wls = the_filt.wavelength
            the_transm =the_filt.transmission/the_norm

            the_list1.append(the_wls) 
            the_list2.append(the_transm) 
            
        return the_list1,the_list2    

    def get_3lists(self):
        """
        return a list of a  of 3 lists
        """
        the_list1 = []
        the_list2 = []
        the_list3 = []
        
        for index in self.filters_indexlist:
            the_name = self.filters_namelist[index]
            
            the_filt = self.filters_transmissionlist[index]
            the_norm = self.filters_transmissionnormlist[index]

            the_name = self.filters_namelist[index]
            the_wlmean = the_filt.wave_mean
            the_wls = the_filt.wavelength
            the_transm =the_filt.transmission/the_norm

            the_list1.append(the_wls) 
            the_list2.append(the_transm) 
            the_list3.append(the_name)
            
        return the_list1,the_list2,the_list3    

    def plot_transmissions(self,ax = None):

        if ax == None:
            fig,ax = plt.subplots(1,1,figsize=(12,6))
        
        for index in self.filters_indexlist:
            the_name = self.filters_namelist[index]
            the_filt = self.filters_transmissionlist[index]
            the_norm = self.filters_transmissionnormlist[index]
            the_wlmean = the_filt.wave_mean
            the_color = self.filters_colorlist[index]
            the_transmission =the_filt.transmission/the_norm
            ax.plot(the_filt.wavelength,the_transmission,color=the_color)   

            if index%2 ==0:
                ax.text(the_wlmean, 0.7, the_name,horizontalalignment='center',verticalalignment='center',color=the_color,fontweight="bold") 
            else:
                ax.text(the_wlmean, 0.75, the_name,horizontalalignment='center',verticalalignment='center',color=the_color,fontweight="bold") 
        

        ax.grid()
        ax.set_title("Transmission")
        ax.set_xlabel("$\lambda (\AA)$")
        ax.set_xlim(0.,25000.)

    def dump(self):
        print("filters_indexlist   : \t ", self.filters_indexlist)
        print("filters_surveylist  : \t ", self.filters_surveylist)
        print("filters__namelist   : \t ", self.filters_namelist)
                

In [None]:
ps = FilterInfo()

In [None]:
ps.plot_transmissions()
ps.dump()

# transform the FilterInfo object into a pytree

In [None]:
pt_filters = ps.get_pytree()

In [None]:
leaves = jax.tree_leaves(pt_filters)

In [None]:
#leaves

# Read Fors2 / Galex and Kids

In [None]:
input_file_h5  = '../../data/fors2sl/FORS2spectraGalexKidsPhotom.hdf5'

In [None]:
fors2 = Fors2DataAcess(input_file_h5)

In [None]:
list_of_keys = fors2.get_list_of_groupkeys()
list_of_attributes = fors2.get_list_subgroup_keys()

## Must sort spectra name

In [None]:
list_of_keys = np.array(list_of_keys)

In [None]:
list_of_keysnum = [ int(re.findall("SPEC(.*)",specname)[0]) for specname in  list_of_keys ]

In [None]:
sorted_indexes = np.argsort(list_of_keysnum)

In [None]:
list_of_keys = list_of_keys[sorted_indexes]

In [None]:
df_info = pd.DataFrame(columns=list_of_attributes)
all_df = []

### Read each spectrum Fors2 as (wl,fnu)

In [None]:
for idx,key in enumerate(list_of_keys):
    attrs = fors2.getattribdata_fromgroup(key)
    spectr = fors2.getspectrum_fromgroup(key)
    df_info.loc[idx] = [*attrs.values()] # hope the order of attributes is kept
    df = pd.DataFrame({"wl":spectr["wl"],"fnu":spectr["fnu"]})
    all_df.append(df)
    

In [None]:
df_info.reset_index(drop=True, inplace=True) 

In [None]:
df_info = df_info[ordered_keys]

In [None]:
df_info

# Select good match with galex

In [None]:
df_info.hist("asep_galex",bins=100,color="b")
plt.axvline(5,c="k")

In [None]:
df_info.hist("asep_kids",bins=100,color='b')

## Select  Those spectra having GALEX

In [None]:
df = df_info[df_info["asep_galex"] <= 5]

In [None]:
df.index

## Remove NaN

- remove those row with no FUV

In [None]:
#df = df.dropna()

## Remove rows with Rmag = 0

In [None]:
#df = df[df["Rmag"] > 0]

# Plot Spectra

https://en.wikipedia.org/wiki/Photometric_system

In [None]:
lambda_FUV = 1528.
lambda_NUV = 2271.
lambda_U = 3650.
lambda_B = 4450.
lambda_G = 4640.
lambda_R = 5580.
lambda_I = 8060.
lambda_Z = 9000.
lambda_Y = 10200.
lambda_J = 12200.
lambda_H = 16300.
lambda_K = 21900.
lambda_L = 34500.

WL = [lambda_FUV, lambda_NUV, lambda_B, lambda_G, lambda_R ,lambda_I, lambda_Z, lambda_Y, lambda_J, lambda_H, lambda_K ]
FilterTag = ['FUV','NUV','B','G','R','I','Z','Y','J','H','Ks']

In [None]:
def PlotFilterTag(ax,fluxlist):
    goodfl = fluxlist[np.isfinite(fluxlist)]
    ymin = np.mean(goodfl)
    dy=ymin/5
   
    for idx,flux in enumerate(fluxlist):
        if np.isfinite(flux):
            #ax.text(WL[idx],flux, FilterTag[idx],fontsize=10,ha='center', va='bottom')
                     
            fl = flux - dy
            if fl <0:
                fl += 2*dy
            ax.text(WL[idx],fl, FilterTag[idx],fontsize=12,color="g",weight='bold',ha='center', va='bottom')
            

In [None]:
df

In [None]:
df['index0'] = df.index
df = df.reset_index()

In [None]:
df['index0']

In [None]:
df.columns

# Plots

## Fit Gaussian process to remove abs lines

In [None]:
kernel = kernels.RBF(0.5, (8000, 10000.0))
gp = GaussianProcessRegressor(kernel=kernel ,random_state=0)

### Select Spectrum

In [None]:
#selected_spectrum_number = 411
selected_spectrum_number = 560
selected_spectrum_tag = f"SPEC{selected_spectrum_number}"
dict_normalisation_factor = {}
dict_normalisation_factor[411] = 160.
dict_normalisation_factor[560] = 12.

In [None]:
# loop on rows
props = dict(boxstyle='round', facecolor='wheat', alpha=0.5)
for idx,row in df.iterrows():
    
    print(idx," == ",row["name"],row["num"],row["index0"])
    idx0 = row["index0"] # index for the spectrum
    specname  = row["name"]
    specnum = row["num"]
    if specnum < selected_spectrum_number:
        continue
    if specnum > selected_spectrum_number:
        break
    
    
    mags = np.array([ row["fuv_mag"], row["nuv_mag"], row['MAG_GAAP_u'], row['MAG_GAAP_g'], row['MAG_GAAP_r'], row['MAG_GAAP_i'], row['MAG_GAAP_Z'], row['MAG_GAAP_Y'], 
            row['MAG_GAAP_J'], row['MAG_GAAP_H'],row['MAG_GAAP_Ks'] ])
    
    magserr = np.array([ row["fuv_magerr"], row["nuv_magerr"], row['MAGERR_GAAP_u'], row['MAGERR_GAAP_g'], row['MAGERR_GAAP_r'], row['MAGERR_GAAP_i'], row['MAGERR_GAAP_Z'], row['MAGERR_GAAP_Y'], 
            row['MAGERR_GAAP_J'], row['MAGERR_GAAP_H'],row['MAGERR_GAAP_Ks'] ])
    
    mfluxes = [ 10**(-0.4*m) for m in mags ]
    mfluxeserr = []
    
    for f,em in zip(mfluxes,magserr):
        ferr = 0.4*np.log(10)*em*f
        mfluxeserr.append(ferr)
        
    mfluxes = np.array(mfluxes)
    mfluxeserr = np.array(mfluxeserr)
    
    fluxes =  [ row["fuv_flux"], row["nuv_flux"], row['FLUX_GAAP_u'], row['FLUX_GAAP_g'], row['FLUX_GAAP_r'], row['FLUX_GAAP_i'], row['FLUX_GAAP_Z'], row['FLUX_GAAP_Y'], 
            row['FLUX_GAAP_J'], row['FLUX_GAAP_H'],row['FLUX_GAAP_Ks'] ]
    
    fluxeserr =  [ row["fuv_fluxerr"], row["nuv_fluxerr"], row['FLUXERR_GAAP_u'], row['FLUXERR_GAAP_g'], row['FLUXERR_GAAP_r'], row['FLUXERR_GAAP_i'], row['FLUXERR_GAAP_Z'], row['FLUX_GAAP_Y'], 
            row['FLUXERR_GAAP_J'], row['FLUXERR_GAAP_H'],row['FLUXERR_GAAP_Ks'] ]
    
    ##############################################################
    # The top image
    ######################################
    fig = plt.figure(figsize=(10,14),tight_layout=True)
    gs = gridspec.GridSpec(3, 1,height_ratios=[3, 1, 3])
    
    
    #top image : original flux
    ax= fig.add_subplot(gs[0, 0])
    
    z = row["redshift"]
    asep_fg = row['asep_galex']
    rmag = row["Rmag"]
    asep_fk = row['asep_kids']
    
  
    the_lines = row['lines']
    print("the_lines = ",the_lines)
    all_elements  = the_lines.decode().split(",")
    
    
    the_label_data = f" FORS2 : z = {z:.3f} , Rmag = {rmag:.1f} mag,  angular sep (arcsec) f-g : {asep_fg:.3f}, f-k : {asep_fk:.3f}" 
    the_wl = all_df[idx0]["wl"].values
    the_fnu = all_df[idx0]["fnu"].values*10**(-0.4*mags[4])    
    ax.plot(the_wl,the_fnu,'b-',label=the_label_data)
     
    X = the_wl
    Y = the_fnu
    
    gp.fit(X[:, None], Y)
    xfit = np.linspace(X.min(),X.max())
    yfit, yfit_err = gp.predict(xfit[:, None], return_std=True)
    ax.plot(xfit, yfit, '-', color='cyan')
    ax.fill_between(xfit, yfit -  yfit_err, yfit +  yfit_err, color='gray', alpha=0.3)
     
    ax2 = ax.twinx()
    ax2.errorbar(WL,mfluxes,yerr=mfluxeserr,fmt='o',color="r",ecolor="r",ms=5,label='Galex (UV) + Kids (optics) +Viking (IR)')
    PlotFilterTag(ax2,mfluxes)
    #ax2.errorbar(WL, fluxes, yerr=fluxes, xerr=None, fmt='o', color="g",ecolor="g")
    ax2.legend(loc="lower right")
    
    ax.set_xlabel("$\lambda  (\AA)$ ")
    title_data = f"FORS2 : {idx}): {idx0} name = {specname}"
    ax.set_title(title_data)
    ax.set_ylabel("flux (maggies)")
    ax.legend(loc="upper right")
    ax.grid()
    
    the_max1 = np.max(the_fnu)
    goodmags = mfluxes[np.isfinite(mfluxes)]
    the_max2 = np.max(goodmags)
    the_max = 1.1*max(the_max1,the_max2)

    ax.set_ylim(0,the_max)
    ax2.set_ylim(0,the_max)

    ax.set_xlim(1000.,25000.)
    ##################
    # Middle image : Fit a gaussian process and compute the residuals
    ##################
    ax3 = fig.add_subplot(gs[1, 0])
    DeltaY,DeltaEY = Y - gp.predict(X[:, None], return_std=True)
    ax3.plot(X,DeltaY,'b')
    ax3.set_xlim(1000.,25000.)
    ax3.grid()
    ax3.set_xlabel("$\lambda  (\AA)$ ")
    
    background = np.sqrt(np.median(DeltaY**2))
    indexes_toremove = np.where(np.abs(DeltaY)> 8 * background)[0]
    
    for index in indexes_toremove:
        ax3.axvline(X[index],color='k')
    
    #########################    
    # Bottom image : resuting image after removing emission lines
    #########################
    
    ax4 = fig.add_subplot(gs[2, 0])
    
    Xclean = np.delete(X,indexes_toremove)
    Yclean  = np.delete(Y,indexes_toremove)
    
    ax4.plot(Xclean, Yclean,'b-',label=the_label_data)
    ax4.set_ylim(0,the_max)
    ax4.set_xlim(1000.,25000.)
    ax4.grid()
    ax4.set_xlabel("$\lambda  (\AA)$ ")
    
    ax5 = ax4.twinx()
    ax5.errorbar(WL,mfluxes,yerr=mfluxeserr,fmt='o',color="r",ecolor="r",ms=5,label='Galex (UV) + Kids (optics) +Viking (IR)')
    ax5.legend(loc="lower right")
    ax5.set_ylim(0,the_max)
    PlotFilterTag(ax5,mfluxes)
    
    ax4.set_ylabel("flux (maggies)")
    ax4.legend(loc="upper right")
    
    spectr = fors2.getspectrumcleanedemissionlines_fromgroup(specname,gp)

    the_wl = all_df[idx0]["wl"].values
    the_fnu = all_df[idx0]["fnu"].values*10**(-0.4*mags[4])
    
    ax5.plot(the_wl,the_fnu,color="grey",lw=0.25,label="original")

    
    plt.show()

In [None]:
mags

In [None]:
magserr 

In [None]:
background 

In [None]:
z

# Model of a galaxy

### Load SED templates

In [None]:

#ssp_data = load_ssp_templates(fn='tempdata_v2.h5')
ssp_data = load_ssp_templates(fn='../examples_dsps_diffstar_diffmah/tempdata.h5')

print(ssp_data._fields)

print('ssp_data : ssp_lgmet.shape = {}'.format(ssp_data.ssp_lgmet.shape))
print('ssp_data : ssp_lg_age_gyr.shape = {}'.format(ssp_data.ssp_lg_age_gyr.shape))
print('ssp_data : ssp_wave.shape = {}'.format(ssp_data.ssp_wave.shape))
print('ssp_data :ssp_flux.shape = {}'.format(ssp_data.ssp_flux.shape))

### calculate age distribution

In [None]:
today_gyr = 13.8 
tarr = np.linspace(0.1, today_gyr, 100)

In [None]:
sfh_gal = sfh_singlegal(tarr, DEFAULT_MAH_PARAMS, DEFAULT_MS_PARAMS, DEFAULT_Q_PARAMS)

In [None]:
fig, ax = plt.subplots(1, 1)
ylim = ax.set_ylim(1e-3, 50)
yscale = ax.set_yscale('log')

__=ax.plot(tarr, sfh_gal, '--', color='k',label='sfh_gal')
ax.set_title("Simulated Star Formation History (SFH)")
xlabel = ax.set_xlabel(r'${\rm cosmic\ time\ [Gyr]}$')
ylabel = ax.set_ylabel(r'${\rm SFR\ [M_{\odot}/yr]}$')
ax.grid()

In [None]:
fig, ax = plt.subplots(1, 1)
ylim = ax.set_ylim(0, 1)
#yscale = ax.set_yscale('log')

__=ax.plot(tarr, sfh_gal, '--', color='k',label='sfh_gal')
ax.set_title("Simulated Star Formation History (SFH)")
xlabel = ax.set_xlabel(r'${\rm cosmic\ time\ [Gyr]}$')
ylabel = ax.set_ylabel(r'${\rm SFR\ [M_{\odot}/yr]}$')
ax.grid()

In [None]:
gal_t_table = tarr

In [None]:
z_obs = z
t_obs = age_at_z(z_obs, *DEFAULT_COSMOLOGY) # age of the universe in Gyr at z_obs
t_obs = t_obs[0] # age_at_z function returns an array, but SED functions accept a float for this argument

In [None]:
sfh_gal = jnp.where(tarr<t_obs, sfh_gal, 0)

In [None]:
fig, ax = plt.subplots(1, 1)
ylim = ax.set_ylim(0, 1)
#yscale = ax.set_yscale('log')

__=ax.plot(tarr, sfh_gal, '--', color='k',label='sfh_gal')
ax.set_title("Simulated Star Formation History (SFH)")
xlabel = ax.set_xlabel(r'${\rm cosmic\ time\ [Gyr]}$')
ylabel = ax.set_ylabel(r'${\rm SFR\ [M_{\odot}/yr]}$')
ax.grid()

In [None]:
gal_sfr_table = sfh_gal
# metallicity
gal_lgmet = -2.0 # log10(Z)
gal_lgmet_scatter = 0.2 # lognormal scatter in the metallicity distribution function


In [None]:
# calculate first SED with only one metallicity
print("gal_lgmet = ", gal_lgmet)

sed_info = calc_rest_sed_sfh_table_lognormal_mdf(
    gal_t_table, gal_sfr_table, gal_lgmet, gal_lgmet_scatter,
    ssp_data.ssp_lgmet, ssp_data.ssp_lg_age_gyr, ssp_data.ssp_flux, t_obs)

In [None]:
Av= 1.0
uv_bump_ampl = 3.0
plaw_slope = -0.25
wave_spec_micron = ssp_data.ssp_wave/10_000

In [None]:
k = sbl18_k_lambda(wave_spec_micron,uv_bump_ampl,plaw_slope)
dsps_flux_ratio = _frac_transmission_from_k_lambda(k,Av)

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(12, 5))
# fig.tight_layout(pad=3.0)

label_dust = f" Av = {Av}, $\delta$ = {plaw_slope}, uv-bump = {uv_bump_ampl}"
ax.plot(ssp_data.ssp_wave ,dsps_flux_ratio,'b-',label=label_dust)

xlim = ax.set_xlim(900, 1e4)
ylim = ax.set_ylim(1e-4, 2)
__=ax.loglog()

title = ax.set_title(r'${\rm attenuation\ curve\ validation}$')
xlabel = ax.set_xlabel(r'$\lambda\ [\AA]$')
ylabel= ax.set_ylabel(r'$D(\lambda)\equiv F_{\rm dust}/F_{\rm no-dust}$')
ax.legend()
ax.grid()

In [None]:
sed_attenuated = dsps_flux_ratio * sed_info.rest_sed

In [None]:
FLAG_RESCALE_FORPLOT = False

In [None]:
fig, ax = plt.subplots(1, 1)
__=ax.loglog()
__=ax.plot(ssp_data.ssp_wave, sed_info.rest_sed,'b:',lw=1,label="no dust")
__=ax.plot(ssp_data.ssp_wave, sed_attenuated,'r-',lw=1,label="with dust")
__=ax.set_ylim(1e-8,1e-5)
ax.set_xlabel("$\lambda (\\AA)$")
ax.set_ylabel("$L_\\nu(\lambda)$")

if FLAG_RESCALE_FORPLOT : 
    X = the_wl/(1+z_obs)
    Y = the_fnu*(1+z)*dict_normalisation_factor[selected_spectrum_number]
else:
    X = the_wl/(1+z_obs)
    Y = the_fnu*(1+z)
 
ax.plot(X,Y,'k-',lw=1,label=the_label_data)
ax.legend()

title = "Comparison of SED $L_\\nu$ with SFH and dust with " + title_data
ax.set_title(title)
ax.grid()

### Extract quantities for the experimental spectra in Fnu

In [None]:
Xspec_data = copy.deepcopy(X)
Yspec_data = copy.deepcopy(Y)
EYspec_data = background

### Check the calculation of magnitudes inside the filters

In [None]:
for index,filtname in enumerate(ps.filters_galex):
    the_filt = ps.all_filt_galex[index]
    obs_mag = calc_obs_mag(ssp_data.ssp_wave, sed_attenuated, the_filt.wavelength, the_filt.transmission,
                      z_obs, *DEFAULT_COSMOLOGY)
    print(ps.filters_galex[index]," : " ,f"{obs_mag:.2f} mag")
    
for index,filtname in enumerate(ps.filters_sdss):
    the_filt = ps.all_filt_sdss[index]
    obs_mag = calc_obs_mag(ssp_data.ssp_wave, sed_attenuated, the_filt.wavelength, the_filt.transmission,
                      z_obs, *DEFAULT_COSMOLOGY)
    print(ps.filters_sdss[index]," : " ,f"{obs_mag:.2f} mag")      

for index,filtname in enumerate(ps.filters_vircam):
    the_filt = ps.all_filt_vircam[index]
    obs_mag = calc_obs_mag(ssp_data.ssp_wave, sed_attenuated, the_filt.wavelength, the_filt.transmission,
                      z_obs, *DEFAULT_COSMOLOGY)
    print(ps.filters_vircam[index]," : " ,f"{obs_mag:.2f} mag")      

# Start Optimisation in JaxOpt

### Parameters

In [None]:
def paramslist_to_dict(params_list,param_names):
    """
    Convert the list of parameters into a dictionnary
    :param params_list: list of params values
    :type params_list: float in an array

    :param param_names: list of parameter names
    :type params_names: strings in an array

    :return: dictionnary of parameters
    :rtype: dictionnary
    """
    
    Nparams = len(params_list)
    list_of_tuples = list(zip(param_names,params_list))
    print(list_of_tuples)
    dict_params = OrderedDict(list_of_tuples )
    return dict_params
            

#### MAH parameters
model for the mass assembly history of individual and populations of dark matter halos

In [None]:
DEFAULT_MAH_PARAMS

In [None]:
DEFAULT_MAH_PARAMS_MIN = DEFAULT_MAH_PARAMS + np.array([-3., -0.01, -1.5,-0.5])
DEFAULT_MAH_PARAMS_MAX = DEFAULT_MAH_PARAMS + np.array([2., +0.01, +1.5,+0.5])

In [None]:
mah_paramnames = ["MAH_lgmO","MAH_logtc","MAH_early_index","MAH_late_index"]

#### MS parameters

In [None]:
DEFAULT_MS_PARAMS

In [None]:
DEFAULT_MS_PARAMS_MIN = DEFAULT_MS_PARAMS - 0.25*np.ones((5)) 
DEFAULT_MS_PARAMS_MAX = DEFAULT_MS_PARAMS + 0.25*np.ones((5)) 

In [None]:
ms_paramnames = ["MS_lgmcrit", "MS_lgy_at_mcrit", "MS_indx_lo", "MS_indx_hi", "MS_tau_dep"]

#### Q parameters
Quenching parameters

In [None]:
DEFAULT_Q_PARAMS

In [None]:
DEFAULT_Q_PARAMS_MIN = DEFAULT_Q_PARAMS - 0.1*np.ones((4,))
DEFAULT_Q_PARAMS_MAX = DEFAULT_Q_PARAMS + 0.1*np.ones((4,))

In [None]:
q_paramnames = ["Q_lg_qt", "Q_qlglgdt", "Q_lg_drop", "Q_lg_rejuv"]

#### Dust parameters

In [None]:
Av= 1
uv_bump_ampl = 2.0
plaw_slope = -0.25

In [None]:
DEFAULT_DUST_PARAMS = [Av, uv_bump_ampl, plaw_slope]

In [None]:
DEFAULT_DUST_PARAMS_MIN = DEFAULT_DUST_PARAMS + np.array([-1.,-1.,-0.1])
DEFAULT_DUST_PARAMS_MAX = DEFAULT_DUST_PARAMS + np.array([2.,1.,0.25])

In [None]:
dust_paramnames = ["Av", "uv_bump", "plaw_slope"]

#### Scaling parameter

In [None]:
scaleF = 1.0

In [None]:
DEFAULT_SCALEF_PARAMS = np.array([scaleF])
#DEFAULT_SCALEF_PARAMS_MIN = DEFAULT_SCALEF_PARAMS + np.array([0.000001])
#DEFAULT_SCALEF_PARAMS_MAX = DEFAULT_SCALEF_PARAMS + np.array([1000000.])

DEFAULT_SCALEF_PARAMS_MIN =  np.array([1.])
DEFAULT_SCALEF_PARAMS_MAX = np.array([1.])

In [None]:
scale_paramnames = ["scaleF"]

#### Combine parameters

In [None]:
defaults_params = [DEFAULT_MAH_PARAMS,DEFAULT_MS_PARAMS,DEFAULT_Q_PARAMS,DEFAULT_DUST_PARAMS, DEFAULT_SCALEF_PARAMS ]

params_min = np.concatenate(([DEFAULT_MAH_PARAMS_MIN,DEFAULT_MS_PARAMS_MIN,DEFAULT_Q_PARAMS_MIN,DEFAULT_DUST_PARAMS_MIN,DEFAULT_SCALEF_PARAMS_MIN]))
params_max = np.concatenate(([DEFAULT_MAH_PARAMS_MAX,DEFAULT_MS_PARAMS_MAX,DEFAULT_Q_PARAMS_MAX,DEFAULT_DUST_PARAMS_MAX,DEFAULT_SCALEF_PARAMS_MAX]))
init_params = np.concatenate(defaults_params)
init_params = jnp.array(init_params)

param_names = [mah_paramnames,ms_paramnames,q_paramnames,dust_paramnames,scale_paramnames]
param_scales = [3, 0.25,0.1,2.]

In [None]:
dict_param_mah_true = OrderedDict([(mah_paramnames[0],DEFAULT_MAH_PARAMS[0]),
                                         (mah_paramnames[1],DEFAULT_MAH_PARAMS[1]),
                                         (mah_paramnames[2],DEFAULT_MAH_PARAMS[2]),
                                         (mah_paramnames[3],DEFAULT_MAH_PARAMS[3])
                                         ])
dict_param_mah_true_selected = OrderedDict([(mah_paramnames[0],DEFAULT_MAH_PARAMS[0]),
                                         ])
dict_param_ms_true = OrderedDict([(ms_paramnames[0],DEFAULT_MS_PARAMS[0]),
                                         (ms_paramnames[1],DEFAULT_MS_PARAMS[1]),
                                         (ms_paramnames[2],DEFAULT_MS_PARAMS[2]),
                                         (ms_paramnames[3],DEFAULT_MS_PARAMS[3]),
                                         (ms_paramnames[4],DEFAULT_MS_PARAMS[4])])
dict_param_q_true = OrderedDict([(q_paramnames[0],DEFAULT_Q_PARAMS[0]),
                                         (q_paramnames[1],DEFAULT_Q_PARAMS[1]),
                                         (q_paramnames[2],DEFAULT_Q_PARAMS[2]),
                                         (q_paramnames[3],DEFAULT_Q_PARAMS[3])])

dict_param_dust_true = OrderedDict([(dust_paramnames[0],DEFAULT_DUST_PARAMS[0]),
                                         (dust_paramnames[1],DEFAULT_DUST_PARAMS[1]),
                                         (dust_paramnames[2],DEFAULT_DUST_PARAMS[2])])
dict_param_dust_true_selected = OrderedDict([(dust_paramnames[0],DEFAULT_DUST_PARAMS[0])])

dict_param_scalef_true = OrderedDict([(scale_paramnames[0],DEFAULT_SCALEF_PARAMS[0]) ])

In [None]:
dict_param_dust_true

In [None]:
params_true = dict_param_mah_true
params_true.update(dict_param_ms_true)
params_true.update(dict_param_q_true)
params_true.update(dict_param_dust_true)
params_true.update(dict_param_scalef_true)

In [None]:
params_true

In [None]:
param_names

In [None]:
param_names_flat = list(itertools.chain(*param_names))
param_names_flat

In [None]:
@jit
def mean_spectrum(wls, params):
    """ Model of spectrum 
    
    :param wls: wavelengths of the spectrum in rest frame
    :type wls: float
    
    :return: the spectrum
    :rtype: float
    
    """
    
    # decode the parameters
    MAH_lgmO = params["MAH_lgmO"]
    MAH_logtc = params["MAH_logtc"]
    MAH_early_index = params["MAH_early_index"]
    MAH_late_index = params["MAH_late_index"]
    list_param_mah = [MAH_lgmO,MAH_logtc,MAH_early_index,MAH_late_index]
    
    MS_lgmcrit = params["MS_lgmcrit"]
    MS_lgy_at_mcrit = params["MS_lgy_at_mcrit"]
    MS_indx_lo = params["MS_indx_lo"]
    MS_indx_hi = params["MS_indx_hi"]
    MS_tau_dep = params["MS_tau_dep"]
    list_param_ms = [MS_lgmcrit,MS_lgy_at_mcrit,MS_indx_lo,MS_indx_hi,MS_tau_dep]
    
    Q_lg_qt = params["Q_lg_qt"]
    Q_qlglgdt = params["Q_qlglgdt"]
    Q_lg_drop = params["Q_lg_drop"]
    Q_lg_rejuv = params["Q_lg_rejuv"]
    list_param_q = [Q_lg_qt, Q_qlglgdt,Q_lg_drop,Q_lg_rejuv]
    
    Av = params["Av"]
    uv_bump = params["uv_bump"]
    plaw_slope = params["plaw_slope"]
    list_param_dust = [Av,uv_bump,plaw_slope]
    
    # compute SFR
    tarr = np.linspace(0.1, today_gyr, 100)
    sfh_gal = sfh_singlegal(
    tarr, list_param_mah , list_param_ms, list_param_q)
    
    # metallicity
    gal_lgmet = -2.0 # log10(Z)
    gal_lgmet_scatter = 0.2 # lognormal scatter in the metallicity distribution function

    # need age of universe when the light was emitted 
    t_obs = age_at_z(z_obs, *DEFAULT_COSMOLOGY) # age of the universe in Gyr at z_obs
    t_obs = t_obs[0] # age_at_z function returns an array, but SED functions accept a float for this argument
    
    # clear sfh in future
    sfh_gal = jnp.where(tarr<t_obs, sfh_gal, 0)
    
    # compute the SED_info object
    gal_t_table = tarr
    gal_sfr_table = sfh_gal
    sed_info = calc_rest_sed_sfh_table_lognormal_mdf(
    gal_t_table, gal_sfr_table, gal_lgmet, gal_lgmet_scatter,
    ssp_data.ssp_lgmet, ssp_data.ssp_lg_age_gyr, ssp_data.ssp_flux, t_obs)
    
    # compute dust attenuation
    wave_spec_micron = ssp_data.ssp_wave/10_000
    k = sbl18_k_lambda(wave_spec_micron,uv_bump,plaw_slope)
    dsps_flux_ratio = _frac_transmission_from_k_lambda(k,Av)
    
    sed_attenuated = dsps_flux_ratio * sed_info.rest_sed

    # interpolate with interpax which is differentiable
    #Fobs = jnp.interp(wls, ssp_data.ssp_wave, sed_attenuated)

    
    Fobs = interp1d(wls, ssp_data.ssp_wave, sed_attenuated,method='cubic')
   
    return Fobs
    

In [None]:
mean_spectrum(Xspec_data,params_true)

In [None]:
@jit
def mean_mags(X, params):
    """ Model of photometry

    Compute the magnitudes in Filters
    
    :param X: List of to be used (Galex, sdss, vircam)
    :type X: a list of tuples of two arrays (one array with wavelength and one array of corresponding transmission)

    :param params: model parameters
    :type params: Dictionnary of parameters

    
    :return: aray the magnitude for the SED spectrum model
    :rtype: float
    
    """
    
    # decode the parameters
    MAH_lgmO = params["MAH_lgmO"]
    MAH_logtc = params["MAH_logtc"]
    MAH_early_index = params["MAH_early_index"]
    MAH_late_index = params["MAH_late_index"]
    list_param_mah = [MAH_lgmO,MAH_logtc,MAH_early_index,MAH_late_index]
    
    MS_lgmcrit = params["MS_lgmcrit"]
    MS_lgy_at_mcrit = params["MS_lgy_at_mcrit"]
    MS_indx_lo = params["MS_indx_lo"]
    MS_indx_hi = params["MS_indx_hi"]
    MS_tau_dep = params["MS_tau_dep"]
    list_param_ms = [MS_lgmcrit,MS_lgy_at_mcrit,MS_indx_lo,MS_indx_hi,MS_tau_dep]
    
    Q_lg_qt = params["Q_lg_qt"]
    Q_qlglgdt = params["Q_qlglgdt"]
    Q_lg_drop = params["Q_lg_drop"]
    Q_lg_rejuv = params["Q_lg_rejuv"]
    list_param_q = [Q_lg_qt, Q_qlglgdt,Q_lg_drop,Q_lg_rejuv]
    
    Av = params["Av"]
    uv_bump = params["uv_bump"]
    plaw_slope = params["plaw_slope"]
    list_param_dust = [Av,uv_bump,plaw_slope]
    
    # compute SFR
    tarr = np.linspace(0.1, today_gyr, 100)
    sfh_gal = sfh_singlegal(
    tarr, list_param_mah , list_param_ms, list_param_q)
    
    # metallicity
    gal_lgmet = -2.0 # log10(Z)
    gal_lgmet_scatter = 0.2 # lognormal scatter in the metallicity distribution function

    # need age of universe when the light was emitted 
    t_obs = age_at_z(z_obs, *DEFAULT_COSMOLOGY) # age of the universe in Gyr at z_obs
    t_obs = t_obs[0] # age_at_z function returns an array, but SED functions accept a float for this argument
    
    # clear sfh in future
    sfh_gal = jnp.where(tarr<t_obs, sfh_gal, 0)
    
    # compute SED
    gal_t_table = tarr
    gal_sfr_table = sfh_gal

    # create the sed object
    sed_info = calc_rest_sed_sfh_table_lognormal_mdf(
    gal_t_table, gal_sfr_table, gal_lgmet, gal_lgmet_scatter,
    ssp_data.ssp_lgmet, ssp_data.ssp_lg_age_gyr, ssp_data.ssp_flux, t_obs)
    
    # compute dust attenuation
    wave_spec_micron = ssp_data.ssp_wave/10_000
    k = sbl18_k_lambda(wave_spec_micron,uv_bump,plaw_slope)
    dsps_flux_ratio = _frac_transmission_from_k_lambda(k,Av)
    
    sed_attenuated = dsps_flux_ratio * sed_info.rest_sed

    # calculate magnitudes in observation frame    
    mags_predictions = []

    #decode the two lists
    list_wls_filters = X[0]
    list_transm_filters = X[1]

    #def vect_obs_mag(x,y):
    #    obs_mag = calc_obs_mag(ssp_data.ssp_wave, sed_attenuated,x,y,z_obs, *DEFAULT_COSMOLOGY)

    mags_predictions = jax.tree_map(lambda x,y : calc_obs_mag(ssp_data.ssp_wave, sed_attenuated,x,y,z_obs, *DEFAULT_COSMOLOGY),
                                    list_wls_filters,
                                    list_transm_filters)
    mags_predictions = jnp.array(mags_predictions)
    
    return mags_predictions
    

### Select the observed magnitudes thus the filters

In [None]:
print(ps.filters_indexlist) 
print(ps.filters_surveylist)
print(ps.filters_namelist)

In [None]:
index_selected_filters = np.arange(1,11)
index_selected_filters

In [None]:
X = ps.get_2lists()

In [None]:
NF = len(X[0])

In [None]:
list_wls_f_sel = []
list_trans_f_sel = []

list_name_f_sel = []
list_wlmean_f_sel = []

for index in index_selected_filters:
    list_wls_f_sel.append(X[0][index])
    list_trans_f_sel.append(X[1][index])
    the_filt = ps.filters_transmissionlist[index]
    the_wlmean = the_filt.wave_mean
    list_wlmean_f_sel.append(the_wlmean)
    list_name_f_sel.append(ps.filters_namelist[index])
    
list_wlmean_f_sel = jnp.array(list_wlmean_f_sel)    

In [None]:
print(list_name_f_sel)
print(list_wlmean_f_sel)

In [None]:
Xf_sel = (list_wls_f_sel,list_trans_f_sel)

In [None]:
mags_predictions = jax.tree_map(lambda x,y : calc_obs_mag(ssp_data.ssp_wave, sed_attenuated,x,y,z_obs, *DEFAULT_COSMOLOGY),list_wls_f_sel,list_trans_f_sel)

In [None]:
mags_predictions

In [None]:
predicted_mags = mean_mags(Xf_sel,params_true)
predicted_mags 

In [None]:
data_selected_mags = mags[index_selected_filters]
data_selected_mags

In [None]:
data_selected_magserr = magserr[index_selected_filters]
data_selected_magserr

In [None]:
@jit
def mean_sfr(params):
    """ Model of the SFR
    
    :param params: Fitted parameter dictionnary
    :type params: float as a dictionnary
    
    :return: array of the star formation rate
    :rtype: float
    
    """
    
    # decode the parameters
    MAH_lgmO = params["MAH_lgmO"]
    MAH_logtc = params["MAH_logtc"]
    MAH_early_index = params["MAH_early_index"]
    MAH_late_index = params["MAH_late_index"]
    list_param_mah = [MAH_lgmO,MAH_logtc,MAH_early_index,MAH_late_index]
    
    MS_lgmcrit = params["MS_lgmcrit"]
    MS_lgy_at_mcrit = params["MS_lgy_at_mcrit"]
    MS_indx_lo = params["MS_indx_lo"]
    MS_indx_hi = params["MS_indx_hi"]
    MS_tau_dep = params["MS_tau_dep"]
    list_param_ms = [MS_lgmcrit,MS_lgy_at_mcrit,MS_indx_lo,MS_indx_hi,MS_tau_dep]
    
    Q_lg_qt = params["Q_lg_qt"]
    Q_qlglgdt = params["Q_qlglgdt"]
    Q_lg_drop = params["Q_lg_drop"]
    Q_lg_rejuv = params["Q_lg_rejuv"]
    list_param_q = [Q_lg_qt, Q_qlglgdt,Q_lg_drop,Q_lg_rejuv]
    
    Av = params["Av"]
    uv_bump = params["uv_bump"]
    plaw_slope = params["plaw_slope"]
    list_param_dust = [Av,uv_bump,plaw_slope]
    
    
    # compute SFR
    tarr = np.linspace(0.1, today_gyr, 100)
    sfh_gal = sfh_singlegal(
    tarr, list_param_mah , list_param_ms, list_param_q)
    
    # metallicity
    gal_lgmet = -2.0 # log10(Z)
    gal_lgmet_scatter = 0.2 # lognormal scatter in the metallicity distribution function

    # need age of universe when the light was emitted 
    t_obs = age_at_z(z_obs, *DEFAULT_COSMOLOGY) # age of the universe in Gyr at z_obs
    t_obs = t_obs[0] # age_at_z function returns an array, but SED functions accept a float for this argument
    
    # clear sfh in future
    sfh_gal = jnp.where(tarr<t_obs, sfh_gal, 0)
    
    
    return tarr,sfh_gal
    

In [None]:
@jit
def lik_spec(p,wls,F, sigma_obs):
    """
    neg loglikelihood(parameters,x,y,sigmas) for the spectrum
    """
    
    params = {"MAH_lgmO":p[0], 
              "MAH_logtc":p[1], 
              "MAH_early_index":p[2], 
              "MAH_late_index": p[3],
              
              "MS_lgmcrit":p[4],
              "MS_lgy_at_mcrit":p[5],
              "MS_indx_lo":p[6],
              "MS_indx_hi":p[7],
              "MS_tau_dep":p[8],
              
              "Q_lg_qt":p[9],
              "Q_qlglgdt":p[10],
              "Q_lg_drop":p[11],
              "Q_lg_rejuv":p[12],
              
              "Av":p[13],
              "uv_bump":p[14],
              "plaw_slope":p[15],
              "scaleF":p[16]   
             }
    scaleF =  params["scaleF"]
    # residuals
    resid = mean_spectrum(wls, params) - F*scaleF
    
    return 0.5*jnp.sum((resid/(sigma_obs*jnp.sqrt(scaleF)))** 2) 


In [None]:
@jit
def lik_mag(p,xf,mags_measured, sigma_mag_obs):
    """
    neg loglikelihood(parameters,x,y,sigmas) for the photometry
    """

    params = {"MAH_lgmO":p[0], 
              "MAH_logtc":p[1], 
              "MAH_early_index":p[2], 
              "MAH_late_index": p[3],
              
              "MS_lgmcrit":p[4],
              "MS_lgy_at_mcrit":p[5],
              "MS_indx_lo":p[6],
              "MS_indx_hi":p[7],
              "MS_tau_dep":p[8],
              
              "Q_lg_qt":p[9],
              "Q_qlglgdt":p[10],
              "Q_lg_drop":p[11],
              "Q_lg_rejuv":p[12],
              
              "Av":p[13],
              "uv_bump":p[14],
              "plaw_slope":p[15],
              "scaleF":p[16]   
             }
    scaleF =  params["scaleF"]

    all_mags_redictions = mean_mags(xf, params)
    resid = mags_measured - all_mags_redictions
    
    return 0.5*jnp.sum((resid/sigma_mag_obs)** 2) 


In [None]:
@jit
def lik_comb(p,xc,datac, sigmac, weight= 0.5):
    """
    neg loglikelihood(parameters,xc,yc,sigmasc) combining the spectroscopy and the photometry

    Xc = [Xspec_data, Xf_sel]
    Yc = [Yspec_data, mags_measured ]
    EYc = [EYspec_data, data_selected_magserr]

    weight must be between 0 and 1
    """

    resid_spec = lik_spec(p,xc[0],datac[0], sigmac[0])  
    resid_phot = lik_mag(p,xc[1],datac[1], sigmac[1])
    
    return weight*resid_spec + (1-weight)*resid_phot


In [None]:
def get_infos_spec(res, model, wls,F, eF):
    params    = res.params
    fun_min   = model(params,wls,F,eF)
    jacob_min =jax.jacfwd(model)(params, wls,F,eF)
    #covariance matrix of parameters
    inv_hessian_min =jax.scipy.linalg.inv(jax.hessian(model)(params, wls,F,eF))
    return params,fun_min,jacob_min,inv_hessian_min


In [None]:
def get_infos_mag(res, model, xf, mgs, mgse):
    params    = res.params
    fun_min   = model(params,xf,mgs,mgse)
    jacob_min =jax.jacfwd(model)(params, xf, mgs, mgse)
    #covariance matrix of parameters
    inv_hessian_min =jax.scipy.linalg.inv(jax.hessian(model)(params, xf, mgs , mgse))
    return params,fun_min,jacob_min,inv_hessian_min


In [None]:
def get_infos_comb(res, model, xc, datac, sigmac,weight):
    params    = res.params
    fun_min   = model(params,xc,datac,sigmac,weight=weight)
    jacob_min =jax.jacfwd(model)(params, xc,datac,sigmac,weight=weight)
    #covariance matrix of parameters
    inv_hessian_min =jax.scipy.linalg.inv(jax.hessian(model)(params,xc,datac,sigmac,weight=weight))
    return params,fun_min,jacob_min,inv_hessian_min


In [None]:
init_params

#### Fit magnitudes only

In [None]:
#lik_mag(p,mag_indexes,mags_measured, sigma_mag_obs = 0.01 

lbfgsb = jaxopt.ScipyBoundedMinimize(fun=lik_mag, method="L-BFGS-B")

res_m = lbfgsb.run(init_params, bounds=(params_min ,params_max ), xf = Xf_sel, mags_measured = data_selected_mags, sigma_mag_obs = data_selected_magserr)
params_m,fun_min_m,jacob_min_m,inv_hessian_min_m = get_infos_mag(res_m, lik_mag,  xf = Xf_sel, mgs = data_selected_mags, mgse = data_selected_magserr)
print("params:",params_m,"\nfun@min:",fun_min_m,"\njacob@min:",jacob_min_m)
#      ,\n invH@min:",inv_hessian_min_m)


#### Fit spectrum only

In [None]:
Xspec_data

In [None]:
Yspec_data

In [None]:
EYspec_data

In [None]:
lbfgsb = jaxopt.ScipyBoundedMinimize(fun=lik_spec, method="L-BFGS-B")

res_s = lbfgsb.run(init_params, bounds=(params_min ,params_max ), wls=Xspec_data, F=Yspec_data, sigma_obs = EYspec_data)
params_s,fun_min_s,jacob_min_s,inv_hessian_min_s = get_infos_spec(res_s, lik_spec, wls=Xspec_data, F=Yspec_data,eF=EYspec_data)
print("params:",params_s,"\nfun@min:",fun_min_s,"\njacob@min:",jacob_min_s)
#,     "\n invH@min:",inv_hessian_min_s)


### Comparing results of various fits

In [None]:
dict_params_fitted_s = paramslist_to_dict(params_s,param_names_flat)
dict_params_fitted_m = paramslist_to_dict(params_m,param_names_flat)

In [None]:
dict_params_fitted_nodust_s = copy.deepcopy(dict_params_fitted_s)
dict_params_fitted_nodust_s["Av"] = 0
dict_params_fitted_nodust_m = copy.deepcopy(dict_params_fitted_m)
dict_params_fitted_nodust_m["Av"] = 0

In [None]:
Y_fit_s = mean_spectrum(ssp_data.ssp_wave, dict_params_fitted_s)
Y_fit_nodust_s = mean_spectrum(ssp_data.ssp_wave, dict_params_fitted_nodust_s)
Y_fit_m = mean_spectrum(ssp_data.ssp_wave, dict_params_fitted_m)
Y_fit_nodust_m = mean_spectrum(ssp_data.ssp_wave, dict_params_fitted_nodust_m)

### $L_\nu$

#### Scaling flux from magnitude

In [None]:
def lik_mag_flux(p,wls,flux_mags_measured, flux_sigma_mag_measured,xref,yref):
    pred_flux = jnp.interp(wls,xref,yref)
    resid = flux_mags_measured*p[0]-pred_flux 
    return 0.5*jnp.sum((resid/flux_sigma_mag_measured)** 2) 

In [None]:
data_selected_mags_flux = jnp.array([ jnp.power(10,-0.4*m) for m in data_selected_mags ])
data_selected_mags_flux_err = data_selected_magserr * data_selected_mags_flux
def plot_magnitudes_data(ax):
    ax.errorbar(list_wlmean_f_sel , data_selected_mags_flux, yerr=data_selected_mags_flux_err,
                marker='o', color="green",ecolor="green",markersize=8,lw=2)

In [None]:
lbfgsb = jaxopt.ScipyBoundedMinimize(fun=lik_mag_flux, method="L-BFGS-B")
p0 = [1.0]
pmin = [0.00000001]
pmax = [100000000.]
res_p = lbfgsb.run(p0, bounds=(pmin,pmax), wls=list_wlmean_f_sel, 
                   flux_mags_measured=data_selected_mags_flux, 
                   flux_sigma_mag_measured = data_selected_mags_flux_err,
                  xref = jnp.array(ssp_data.ssp_wave),
                  yref = Y_fit_m)

In [None]:
print("scaling factor", res_p.params)

In [None]:
data_selected_mags_flux *= res_p.params[0]
data_selected_mags_flux_err *= res_p.params[0]

In [None]:
fig, ax = plt.subplots(1, 1)
__=ax.loglog()
__=ax.set_xlim(1e3,1e6)
__=ax.set_ylim(1e-10,1e-3)
ax.set_xlabel("$\lambda (\\AA)$")
ax.set_ylabel("$L_\\nu(\lambda)$")

ax.plot(ssp_data.ssp_wave,Y_fit_s,'-',color='green',lw=1,label="fitted spectrum model with dust")
ax.plot(ssp_data.ssp_wave,Y_fit_nodust_s,'-',color='red',lw=1,label="fitted spectrum model No dust")

ax.plot(ssp_data.ssp_wave,Y_fit_m,'-.',color='green',lw=1,label="fitted mags model with dust")
ax.plot(ssp_data.ssp_wave,Y_fit_nodust_m,'-.',color='red',lw=1,label="fitted mags spectrum model No dust")

ax.plot(Xspec_data,Yspec_data,'b-',lw=3,label=the_label_data)

ax.errorbar(list_wlmean_f_sel , data_selected_mags_flux, yerr=data_selected_mags_flux_err,
                marker='o', color="black",ecolor="black",markersize=8,lw=2,label="mag data")

title = "Comparison of SED $L_\\nu$ with SFH and dust with " + title_data
ax.set_title(title)
ax.legend(loc="lower right")
ax.grid()

## Flux Ratio to determine the proper scaling

In [None]:
ratio_SEDs_fit = Y_fit_s/Y_fit_m
indexes_sel = jnp.where(jnp.logical_and(ssp_data.ssp_wave >Xspec_data.min(), ssp_data.ssp_wave >Xspec_data.max()))
ratio_sel = ratio_SEDs_fit[indexes_sel]
specdata_rescaling_factor = np.median(ratio_sel)

fig, ax = plt.subplots(1, 1)
__=ax.loglog()
__=ax.set_xlim(1e3,1e6)
__=ax.set_ylim(1e-3,1e-2)
ax.set_xlabel("$\lambda (\\AA)$")
ax.set_ylabel("$L_\\nu(\lambda)$")

ax.plot(ssp_data.ssp_wave,ratio_SEDs_fit,'-',color='r',lw=1,label="spectroscopic flux/ phtometric flux")
ax.axvspan(Xspec_data.min(), Xspec_data.max(), alpha=0.5, color='orange')

# These are in unitless percentages of the figure size. (0,0 is bottom left)
left, bottom, width, height = [0.55, 0.5, 0.3, 0.3]
ax2 = fig.add_axes([left, bottom, width, height])
ax2.hist(ratio_sel,bins=50,facecolor="blue",alpha=0.6)
ax2.axvline(specdata_rescaling_factor,color = "k")
ax2.set_title("ratio of SED")

title = "ratio of SED $L_\\nu$ spectroscopy/photometry " + title_data
ax.set_title(title)
ax.legend(loc="lower right")
ax.grid()

### $L_\lambda$

In [None]:
YL_fit_s = Y_fit_s*3e18/(ssp_data.ssp_wave)**2
YL_fit_nodust_s = Y_fit_nodust_s*3e18/(ssp_data.ssp_wave)**2

YL_fit_m = Y_fit_m*3e18/(ssp_data.ssp_wave)**2
YL_fit_nodust_m = Y_fit_nodust_m*3e18/(ssp_data.ssp_wave)**2

YL = Yspec_data*3e18/Xspec_data**2
YM = data_selected_mags_flux*3e18/list_wlmean_f_sel**2
EYM = data_selected_mags_flux_err*3e18/list_wlmean_f_sel**2

In [None]:
fig, ax = plt.subplots(1, 1)
__=ax.loglog()
__=ax.set_xlim(1e3,1e6)
__=ax.set_ylim(1e0,1e7)
ax.set_xlabel("$\lambda (\\AA)$")
ax.set_ylabel("$L_\lambda(\lambda)$")

ax.plot(ssp_data.ssp_wave,YL_fit_nodust_s,'-',color='green',lw=1,label="fitted spectrum model with dust")
ax.plot(ssp_data.ssp_wave,YL_fit_s,'-',color='red',lw=1,label="fitted spectrum model No dust")

ax.plot(ssp_data.ssp_wave,YL_fit_nodust_m,'-.',color='green',lw=1,label="fitted mag model with dust")
ax.plot(ssp_data.ssp_wave,YL_fit_m,'-.',color='red',lw=1,label="fitted mag model No dust")


ax.plot(Xspec_data,YL,'b-',lw=3,label=the_label_data)
ax.errorbar(list_wlmean_f_sel , YM, yerr=EYM,
                marker='o', color="black",ecolor="black",markersize=8,lw=2,label="mag data")

title = "Comparison of SED $L_\\lambda$ with SFH and dust with " + title_data
ax.set_title(title)
ax.legend(loc="lower right")
ax.grid()

In [None]:
fig, ax = plt.subplots(1, 1)
__=ax.loglog()
__=ax.set_xlim(1e3,1e5)
__=ax.set_ylim(1e1,1e7)
ax.set_xlabel("$\lambda (\\AA)$")
ax.set_ylabel("$L_\lambda(\lambda)$")

ax.plot(ssp_data.ssp_wave,YL_fit_nodust_s,'-',color='green',lw=1,label="fitted spectrum model with dust")
ax.plot(ssp_data.ssp_wave,YL_fit_s,'-',color='red',lw=1,label="fitted spectrum model No dust")
ax.plot(ssp_data.ssp_wave,YL_fit_nodust_m,'-.',color='green',lw=1,label="fitted mag model with dust")
ax.plot(ssp_data.ssp_wave,YL_fit_m,'-.',color='red',lw=1,label="fitted spectrum mag No dust")

ax.plot(Xspec_data,YL,'b-',lw=3,label=the_label_data)
ax.errorbar(list_wlmean_f_sel , YM, yerr=EYM,
                marker='o', color="black",ecolor="black",markersize=8,lw=2,label="mag data")


title = "Comparison of SED $L_\\lambda$ with SFH and dust with " + title_data
ax.set_title(title)
ax.legend(loc="lower right")
ax.grid()

### Check the SFH


In [None]:
tarr_fit_s,sfr_fit_s = mean_sfr(dict_params_fitted_s)
tarr_fit_m,sfr_fit_m = mean_sfr(dict_params_fitted_m)

In [None]:
sfr_max_s = sfr_fit_s.max()*10.
sfr_min_s = sfr_max_s/1e4

sfr_max_m = sfr_fit_m.max()*10.
sfr_min_m = sfr_max_m/1e4

sfr_max = max(sfr_max_s,sfr_max_m)
fig, ax = plt.subplots(1, 1)


ylim = ax.set_ylim(1e-5, sfr_max)
yscale = ax.set_yscale('log')

__=ax.plot(tarr_fit_s, sfr_fit_s, '-', color='r',label='SFH fitted with Spectrum')
__=ax.plot(tarr_fit_m, sfr_fit_m, '-', color='b',label='SFH fitted with Magnitudes')

ax.set_title("Fitted Star Formation History (SFH) " + title_data)
xlabel = ax.set_xlabel(r'${\rm cosmic\ time\ [Gyr]}$')
ylabel = ax.set_ylabel(r'${\rm SFR\ [M_{\odot}/yr]}$')
ax.grid()
ax.legend()

In [None]:
sfr_max_s = sfr_fit_s.max()*1.1
sfr_min_s = sfr_max_s/1e4

sfr_max_m = sfr_fit_m.max()*1.1
sfr_min_m = sfr_max_m/1e4

sfr_max = max(sfr_max_s,sfr_max_m)

fig, ax = plt.subplots(1, 1)
ylim = ax.set_ylim(0, sfr_max)
#yscale = ax.set_yscale('log')

__=ax.plot(tarr_fit_s, sfr_fit_s, '-', color='r',label='SFH fitted with Spectrum')
__=ax.plot(tarr_fit_m, sfr_fit_m, '-', color='b',label='SFH fitted with Magnitudes')
ax.set_title("Fitted Star Formation History (SFH) " + title_data)
xlabel = ax.set_xlabel(r'${\rm cosmic\ time\ [Gyr]}$')
ylabel = ax.set_ylabel(r'${\rm SFR\ [M_{\odot}/yr]}$')
ax.grid()

# Combine fits

## Rescale spectroscopic data

In [None]:
Yspec_data /= specdata_rescaling_factor
EYspec_data /= specdata_rescaling_factor 

### Combine spectroscopic data with photometric data

In [None]:
Xc = [Xspec_data, Xf_sel]
Yc = [Yspec_data,  data_selected_mags ]
EYc = [EYspec_data, data_selected_magserr]
weight_spec = 0.5

In [None]:
init_params

In [None]:
params_min[-1] = 0.5
params_min

In [None]:
params_max[-1] = 1.5
params_max

In [None]:
lbfgsb = jaxopt.ScipyBoundedMinimize(fun=lik_comb, method="L-BFGS-B")

res_c = lbfgsb.run(init_params, bounds=(params_min ,params_max ), xc=Xc, datac=Yc,sigmac=EYc,weight=weight_spec)
params_c,fun_min_c,jacob_min_c,inv_hessian_min_c = get_infos_comb(res_c, lik_comb, xc=Xc, datac=Yc,sigmac=EYc,weight=weight_spec)
print("params:",params_c,"\nfun@min:",fun_min_c,"\njacob@min:",jacob_min_c)
#      ,"\n invH@min:",inv_hessian_min_c)


In [None]:
params_cm,fun_min_cm,jacob_min_cm,inv_hessian_min_cm  = get_infos_mag(res_c, lik_mag,  xf = Xf_sel, mgs = data_selected_mags, mgse = data_selected_magserr)
print("params:",params_cm,"\nfun@min:",fun_min_cm,"\njacob@min:",jacob_min_cm)

In [None]:
params_cs,fun_min_cs,jacob_min_cs,inv_hessian_min_cs = get_infos_spec(res_c, lik_spec, wls=Xspec_data, F=Yspec_data,eF=EYspec_data)
print("params:",params_cs,"\nfun@min:",fun_min_cs,"\njacob@min:",jacob_min_cs)

In [None]:
dict_params_fitted_c = paramslist_to_dict(params_c,param_names_flat)
dict_params_fitted_nodust_c = copy.deepcopy(dict_params_fitted_c)
dict_params_fitted_nodust_c["Av"] = 0


In [None]:
Y_fit_c = mean_spectrum(ssp_data.ssp_wave, dict_params_fitted_c)
Y_fit_nodust_c = mean_spectrum(ssp_data.ssp_wave, dict_params_fitted_nodust_c)

In [None]:
fig, ax = plt.subplots(1, 1)
__=ax.loglog()
__=ax.set_xlim(1e3,1e6)
__=ax.set_ylim(1e-10,1e-3)
ax.set_xlabel("$\lambda (\\AA)$")
ax.set_ylabel("$L_\\nu(\lambda)$")

ax.plot(ssp_data.ssp_wave,Y_fit_c,'-',color='green',lw=1,label="fitted spectrum model with dust")
ax.plot(ssp_data.ssp_wave,Y_fit_nodust_c,'-',color='red',lw=1,label="fitted spectrum model No dust")

ax.plot(Xspec_data,Yspec_data,'b-',lw=3,label=the_label_data)

ax.errorbar(list_wlmean_f_sel , data_selected_mags_flux, yerr=data_selected_mags_flux_err,
                marker='o', color="k",ecolor="black",markersize=8,lw=2,label="mag data",alpha=0.4)

title = "Combined spectro-photom fit of SED $L_\\nu$ with SFH and dust with " + title_data
ax.set_title(title)
ax.legend(loc="lower right")
ax.grid()

In [None]:
YL_fit_c = Y_fit_c*3e18/(ssp_data.ssp_wave)**2
YL_fit_nodust_c = Y_fit_nodust_c*3e18/(ssp_data.ssp_wave)**2

YL = Yspec_data*3e18/Xspec_data**2
YM = data_selected_mags_flux*3e18/list_wlmean_f_sel**2
EYM = data_selected_mags_flux_err*3e18/list_wlmean_f_sel**2

In [None]:
fig, ax = plt.subplots(1, 1)
__=ax.loglog()
__=ax.set_xlim(1e3,1e6)
__=ax.set_ylim(1e0,1e7)
ax.set_xlabel("$\lambda (\\AA)$")
ax.set_ylabel("$L_\lambda(\lambda)$")

ax.plot(ssp_data.ssp_wave,YL_fit_nodust_c,'-',color='green',lw=1,label="fitted spectrum model with dust")
ax.plot(ssp_data.ssp_wave,YL_fit_c,'-',color='red',lw=1,label="fitted spectrum model No dust")


ax.plot(Xspec_data,YL,'b-',lw=3,label=the_label_data)
ax.errorbar(list_wlmean_f_sel , YM, yerr=EYM,
                marker='o', color="black",ecolor="black",markersize=8,lw=2,label="mag data")

title = "Comparison of SED $L_\\lambda$ with SFH and dust with " + title_data
ax.set_title(title)
ax.legend(loc="lower right")
ax.grid()

## Star Formation history

In [None]:
tarr_fit_c,sfr_fit_c = mean_sfr(dict_params_fitted_c)

In [None]:
sfr_max_c = sfr_fit_c.max()*10.
sfr_min_c = sfr_max_c/1e4

fig, ax = plt.subplots(1, 1)


ylim = ax.set_ylim(sfr_min_c, sfr_max_c)
yscale = ax.set_yscale('log')

__=ax.plot(tarr_fit_c, sfr_fit_c, '-', color='r',label='SFH fitted with Spectrum and photometry')


ax.set_title("Fitted Star Formation History (SFH) " + title_data)
xlabel = ax.set_xlabel(r'${\rm cosmic\ time\ [Gyr]}$')
ylabel = ax.set_ylabel(r'${\rm SFR\ [M_{\odot}/yr]}$')
ax.grid()
ax.legend()

# Comparison with StarLight

In [None]:
from astropy import units as u
from astropy import constants as const
U_FNU = u.def_unit(u'fnu',  u.erg / (u.cm**2 * u.s * u.Hz))
U_FL = u.def_unit(u'fl',  u.erg / (u.cm**2 * u.s * u.AA))

In [None]:
def ConvertFlambda_to_Fnu(wl, flambda):
    """
    Convert spectra density flambda to fnu.
    parameters:
        - Flambda : flux density in erg/s/cm2 /AA or W/cm2/AA
    return
         - Fnu : flux density in erg/s/cm2/Hz or W/cm2/Hz
    Compute Fnu = wl**2/c Flambda
    check the conversion units with astropy units and constants 
    
    """
    fnu = (flambda*U_FL*(wl*u.AA)**2/const.c).to(U_FNU)/(1*U_FNU)
    #fnu = (flambda* (u.erg / (u.cm**2 * u.s * u.AA)) *(wl*u.AA)**2/const.c).to( u.erg / (u.cm**2 * u.s * u.Hz))/(u.erg / (u.cm**2 * u.s * u.Hz))
    
    return fnu


In [None]:
class SLDataAcess(object):
    def __init__(self,filename):
        if os.path.isfile(filename):
            self.hf = h5py.File(filename, 'r')
            self.list_of_groupkeys = list(self.hf.keys())      
             # pick one key    
            key_sel =  self.list_of_groupkeys[0]
            # pick one group
            group = self.hf.get(key_sel)  
            #pickup all attribute names
            self.list_of_subgroup_keys = []
            for k in group.attrs.keys():
                self.list_of_subgroup_keys.append(k)
        else:
            self.hf = None
            self.list_of_groupkeys = []
            self.list_of_subgroup_keys = []
    def close_file(self):
        self.hf.close() 
        
    def get_list_of_groupkeys(self):
        return self.list_of_groupkeys 
    def get_list_subgroup_keys(self):
        return self.list_of_subgroup_keys
    def getattribdata_fromgroup(self,groupname):
        attr_dict = OrderedDict()
        if groupname in self.list_of_groupkeys:       
            group = self.hf.get(groupname)  
            for  nameval in self.list_of_subgroup_keys:
                attr_dict[nameval] = group.attrs[nameval]
        else:
            print(f'getattribdata_fromgroup : No group {groupname}')
        return attr_dict
    def getspectrum_fromgroup(self,groupname):
        spec_dict = {}
        if groupname in self.list_of_groupkeys:       
            group = self.hf.get(groupname)  
            wl = np.array(group.get("wl"))
            fl = np.array(group.get("fl")) 
            spec_dict["wl"] = wl
            spec_dict["fl"] = fl

            #convert to fnu
            fnu = ConvertFlambda_to_Fnu(wl, fl)
            fnorm = flux_norm(wl,fnu)
            spec_dict["fnu"] = fnu/fnorm
            
            
        else:
            print(f'getspectrum_fromgroup : No group {groupname}')
        return spec_dict
    
    #kernel = kernels.RBF(0.5, (8000, 10000.0))
    #gp = GaussianProcessRegre

### Open StarLight file spectra

In [None]:
#filename_StarLightSpectra = "../../data/fors2sl/SLspectra.hdf5"
filename_StarLightSpectra = "../../data/fors2sl/SLspectra_manyPoints.hdf5"

In [None]:
sl = SLDataAcess(filename_StarLightSpectra)

In [None]:
#sl.get_list_of_groupkeys()

### Attributes of this spectrum

In [None]:
sl.getattribdata_fromgroup(selected_spectrum_tag)

### Access to the file

In [None]:
FSL = sl.getspectrum_fromgroup(selected_spectrum_tag)

In [None]:
FSL

In [None]:
sl_wl = FSL['wl']
sl_fl = FSL['fl']

In [None]:
fig,ax = plt.subplots(1,1,figsize=(15,7))
ax.plot(sl_wl,sl_fl,color="b")
ax.set_yscale('log')
ax.set_ylim(1e-7,1e-2)
ax.set_xlim(0,20000)
ax.axvspan(Xspec_data.min(), Xspec_data.max(),alpha=0.5, color='orange')
the_title = "StarLight spectrum for " + title_data
ax.set_title(the_title)
ax.set_xlabel("$\lambda (\\AA)$")
ax.set_ylabel("$F_\lambda (\lambda)$")

### Renormalisation of the spectrum

### Interpolate SL spectrum over DSPS wl range

In [None]:
sl_sel_indexes = jnp.where(jnp.logical_and(sl_wl> Xspec_data.min(),sl_wl< Xspec_data.max()))[0]
ssp_sel_indexes = jnp.where(jnp.logical_and(ssp_data.ssp_wave> Xspec_data.min(),ssp_data.ssp_wave< Xspec_data.max()))[0]
sl_wl_cut = sl_wl[sl_sel_indexes]
sl_fl_cut = sl_fl[sl_sel_indexes]

In [None]:
plt.plot(sl_wl_cut,sl_fl_cut)

In [None]:
interpolated_sl = interp1d(ssp_data.ssp_wave[ssp_sel_indexes],sl_wl_cut ,sl_fl_cut,method='cubic')

In [None]:
interpolated_sl

In [None]:
ratio_sl_by_ssp = interpolated_sl/YL_fit_c[ssp_sel_indexes]
ratio_sl_by_ssp_median = np.median(ratio_sl_by_ssp)

In [None]:
plt.plot(ssp_data.ssp_wave[ssp_sel_indexes],ratio_sl_by_ssp)
plt.axhline(ratio_sl_by_ssp_median,color="r")

In [None]:
sl_fl /= ratio_sl_by_ssp_median 

In [None]:
fig,ax = plt.subplots(1,1,figsize=(15,7))
ax.plot(ssp_data.ssp_wave,YL_fit_c,color="b",label='DSPS')
ax.plot(sl_wl,sl_fl,color="r",label="StarLight")
ax.set_yscale('log')
ax.set_ylim(1e4,1e7)
ax.set_xlim(0,20000)
ax.axvspan(Xspec_data.min(), Xspec_data.max(),alpha=0.5, color='orange')
the_title = "StarLight & SSP Flambda " + title_data
ax.set_title(the_title)
ax.set_xlabel("$\lambda (\\AA)$")
ax.set_ylabel("$F_\lambda (\lambda)$")

In [None]:
def lik_renorm_flux(p,wls,flux, flux_sigma,xref,yref):
    pred_flux = jnp.interp(wls,xref,yref)
    resid = flux-pred_flux*p[0] 
    return 0.5*jnp.sum((resid/flux_sigma)** 2) 

In [None]:
lbfgsb = jaxopt.ScipyBoundedMinimize(fun=lik_renorm_flux, method="L-BFGS-B")
p0 = [1.0]
pmin = [0.5]
pmax = [1.5]
res_sl = lbfgsb.run(p0, bounds=(pmin,pmax), wls=Xspec_data, 
                   flux = YL, 
                   flux_sigma = 1.0,
                  xref = sl_wl,
                  yref = sl_fl)

In [None]:
print("scaling factor", res_sl.params)
sl_scaling_factor = res_sl.params[0]

In [None]:
fig, ax = plt.subplots(1, 1)
__=ax.loglog()
__=ax.set_xlim(1e3,1e6)
__=ax.set_ylim(1e0,1e7)
ax.set_xlabel("$\lambda (\\AA)$")
ax.set_ylabel("$L_\lambda(\lambda)$")

ax.plot(ssp_data.ssp_wave,YL_fit_nodust_c,'-',color='green',lw=1,label="fitted spectrum model with dust")
ax.plot(ssp_data.ssp_wave,YL_fit_c,'-',color='red',lw=1,label="fitted spectrum model No dust")


ax.plot(Xspec_data,YL,'b-',lw=3,label=the_label_data)
ax.errorbar(list_wlmean_f_sel , YM, yerr=EYM,
                marker='o', color="black",ecolor="black",markersize=8,lw=2,label="mag data")


ax.plot(sl_wl,sl_fl*sl_scaling_factor,"-",color="grey",label="StarLight",lw=0.5)

title = "Comparison of SED $L_\\lambda$ with SFH and dust with " + title_data
ax.set_title(title)
ax.legend(loc="lower right")
ax.grid()

In [None]:
fig, ax = plt.subplots(1, 1)
__=ax.loglog()
__=ax.set_xlim(1e3,5e4)
__=ax.set_ylim(1e4,1e7)
ax.set_xlabel("$\lambda (\\AA)$")
ax.set_ylabel("$L_\lambda(\lambda)$")

ax.plot(ssp_data.ssp_wave,YL_fit_nodust_c,'-',color='green',lw=1,label="fitted spectrum model with dust")
ax.plot(ssp_data.ssp_wave,YL_fit_c,'-',color='red',lw=1,label="fitted spectrum model No dust")


ax.plot(Xspec_data,YL,'b-',lw=3,label=the_label_data)
ax.errorbar(list_wlmean_f_sel , YM, yerr=EYM,
                marker='o', color="black",ecolor="black",markersize=8,lw=2,label="mag data")


ax.plot(sl_wl,sl_fl*sl_scaling_factor,"-",color="grey",label="StarLight",lw=0.5)

title = "Comparison of SED $L_\\lambda$ with SFH and dust with " + title_data
ax.set_title(title)
ax.legend(loc="lower right")
ax.grid()