# Fit Single Fors2 Spectrum and Photometry with DSPS - Shorst version

Implement this fit using this `fors2tostellarpopsynthesis`package

- Author Sylvie Dagoret-Campagne
- Afflilation : IJCLab/IN2P3/CNRS
- Organisation : LSST-DESC
- creation date : 2023-11-23
- last update : 2023-11-23



Most functions are inside the package.

## Import

In [None]:
import h5py
import pandas as pd
import numpy as np
import os
import re
import pickle 
import matplotlib as mpl
import matplotlib.pyplot as plt
%matplotlib inline
import matplotlib.colors as colors
import matplotlib.cm as cmx
import collections
from collections import OrderedDict
import re
import matplotlib.gridspec as gridspec
from sklearn.gaussian_process import GaussianProcessRegressor, kernels

In [None]:
import jax
import jax.numpy as jnp
from jax import vmap
import jaxopt
import optax
jax.config.update("jax_enable_x64", True)
from interpax import interp1d

In [None]:
plt.rcParams["figure.figsize"] = (12,6)
plt.rcParams["axes.labelsize"] = 'xx-large'
plt.rcParams['axes.titlesize'] = 'xx-large'
plt.rcParams['xtick.labelsize']= 'xx-large'
plt.rcParams['ytick.labelsize']= 'xx-large'
plt.rcParams['legend.fontsize']=  16

### Filters

In [None]:
from fors2tostellarpopsynthesis.filters import FilterInfo

### Fors2 and Starlight

In [None]:
from fors2tostellarpopsynthesis.fors2starlightio import Fors2DataAcess, SLDataAcess,convert_flux_torestframe,gpr

### fitter jaxopt

In [None]:
from fors2tostellarpopsynthesis.fitters.fitter_jaxopt import (lik_spec,lik_mag,lik_comb,
get_infos_spec,get_infos_mag,get_infos_comb)

from fors2tostellarpopsynthesis.fitters.fitter_jaxopt import (SSP_DATA,mean_spectrum,mean_mags,mean_sfr,ssp_spectrum_fromparam)

In [None]:
from fors2tostellarpopsynthesis.fitters.fitter_util import (plot_fit_ssp_photometry,
                          plot_fit_ssp_spectrophotometry,
                          plot_fit_ssp_spectrophotometry_sl,
                          plot_fit_ssp_spectroscopy, plot_SFH,
                          rescale_photometry, rescale_spectroscopy,
                          rescale_starlight_inrangefors2)


In [None]:
#from dsps.cosmology import DEFAULT_COSMOLOGY, age_at_z

## Check filters

In [None]:
ps = FilterInfo()
ps.plot_transmissions()

## Fors2 data

In [None]:
fors2 = Fors2DataAcess()

In [None]:
fors2.plot_allspectra()

In [None]:
fors2_tags = fors2.get_list_of_groupkeys()
len(fors2_tags)

In [None]:
list_of_fors2_attributes = fors2.get_list_subgroup_keys()
print(list_of_fors2_attributes)

## StarLight data

In [None]:
sl = SLDataAcess()

In [None]:
sl.plot_allspectra()

In [None]:
sl_tags = sl.get_list_of_groupkeys()
len(sl_tags)

## Select Spectrum

In [None]:
selected_spectrum_number = 411
#selected_spectrum_number = 560
selected_spectrum_tag = f"SPEC{selected_spectrum_number}"

In [None]:
fors2_attr =fors2.getattribdata_fromgroup(selected_spectrum_tag)
z_obs = fors2_attr['redshift']

### Get magnitude data

In [None]:
data_mags, data_magserr = fors2.get_photmagnitudes(selected_spectrum_tag)

### Get Fors2 spectrum in rest-frame

In [None]:
spec_obs = fors2.getspectrumcleanedemissionlines_fromgroup(selected_spectrum_tag)

In [None]:
spec_obs

In [None]:
Xs = spec_obs['wl']
Ys = spec_obs['fnu']
EYs = spec_obs['bg']
EYs_med = spec_obs['bg_med']
flmin_obs = spec_obs['fnu']-spec_obs['bg']
flmax_obs = spec_obs['fnu']+spec_obs['bg']

#### Convert to restframe

In [None]:
Xspec_data, Yspec_data = convert_flux_torestframe(Xs,Ys,z_obs)
EYspec_data = EYs*(1+z_obs)
EYspec_data_med = EYs_med*(1+z_obs) 

In [None]:
flmin_rest = Yspec_data - EYspec_data
flmax_rest= Yspec_data + EYspec_data

In [None]:
title_spec = selected_spectrum_tag + f" z= {z_obs:.2f}"
_,axs = plt.subplots(2,1,figsize=(10,6))

ax1=axs[0]
ax1.plot(Xs,Ys,'-b',label="obs frame")
ax1.fill_between(Xs,flmin_obs,flmax_obs, facecolor='lightgrey', edgecolors="None")

ax1.plot(Xspec_data,Yspec_data ,'-r',label="rest frame")
ax1.fill_between(Xspec_data,flmin_rest,flmax_rest, facecolor='lightgrey', edgecolors="None")
ax1.legend()
ax1.set_title(title_spec)
ax1.set_xlabel("$\lambda (\\AA)$")
ax1.grid()

ax2=axs[1]

ax2.plot(Xs,EYs,'-b',label="obs frame")
ax2.axhline(EYs_med,color="b")
ax2.plot(Xspec_data,EYspec_data ,'-r',label="rest frame")
ax2.axhline(EYspec_data_med,color="r")
ax2.set_xlabel("$\lambda (\\AA)$")
ax2.grid()

In [None]:
fors2.plot_spectro_photom_rescaling(selected_spectrum_tag)

## Parameters

In [None]:
from fors2tostellarpopsynthesis.parameters import SSPParametersFit,paramslist_to_dict

In [None]:
p = SSPParametersFit()

In [None]:
print(p)

In [None]:
p.DICT_PARAMS_true

In [None]:
init_params = p.INIT_PARAMS
params_min = p.PARAMS_MIN
params_max = p.PARAMS_MAX

# Select filters

In [None]:
print(ps.filters_indexlist) 
print(ps.filters_surveylist)
print(ps.filters_namelist)

### Choose the index of the filters

In [None]:
#data_mags, data_magserr

In [None]:
np.argwhere(~np.isnan(data_mags)).flatten()

In [None]:
np.argwhere(~np.isnan(data_magserr)).flatten()

In [None]:
NoNaN_mags = np.intersect1d(np.argwhere(~np.isnan(data_mags)).flatten(),np.argwhere(~np.isnan(data_magserr)).flatten())
NoNaN_mags

In [None]:
index_selected_filters = NoNaN_mags
index_selected_filters

In [None]:
XF = ps.get_2lists()
NF = len(XF[0])

In [None]:
list_wls_f_sel = []
list_trans_f_sel = []

list_name_f_sel = []
list_wlmean_f_sel = []

for index in index_selected_filters:
    list_wls_f_sel.append(XF[0][index])
    list_trans_f_sel.append(XF[1][index])
    the_filt = ps.filters_transmissionlist[index]
    the_wlmean = the_filt.wave_mean
    list_wlmean_f_sel.append(the_wlmean)
    list_name_f_sel.append(ps.filters_namelist[index])
    
list_wlmean_f_sel = jnp.array(list_wlmean_f_sel) 

In [None]:
Xf_sel = (list_wls_f_sel,list_trans_f_sel)

In [None]:
#mags_predictions = jax.tree_map(lambda x,y : calc_obs_mag(ssp_data.ssp_wave, sed_attenuated,x,y,z_obs, *DEFAULT_COSMOLOGY),list_wls_f_sel,list_trans_f_sel)

In [None]:
data_selected_mags =  jnp.array(data_mags[index_selected_filters])
data_selected_magserr = jnp.array(data_magserr[index_selected_filters])

In [None]:
data_selected_mags

In [None]:
data_selected_magserr

## Fits

### Fit with magnitudes only

The magnitudes associated to the Fors2 spectrum allow to set the SED-Mass scale, thus the
flux scale. The following code do the fit on photometry by calling the jaxopt optimisation. 

In [None]:

lbfgsb = jaxopt.ScipyBoundedMinimize(fun=lik_mag, method="L-BFGS-B")

res_m = lbfgsb.run(init_params, bounds=(params_min ,params_max ), xf = Xf_sel, mags_measured = data_selected_mags, sigma_mag_obs = data_selected_magserr,z_obs=z_obs)
params_m,fun_min_m,jacob_min_m,inv_hessian_min_m = get_infos_mag(res_m, lik_mag,  xf = Xf_sel, mgs = data_selected_mags, mgse = data_selected_magserr,z_obs=z_obs)
print("params:",params_m,"\nfun@min:",fun_min_m,"\njacob@min:",jacob_min_m)
#      ,\n invH@min:",inv_hessian_min_m)

In [None]:
# Convert fitted parameters into a dictionnary
dict_params_m = paramslist_to_dict( params_m,p.PARAM_NAMES_FLAT) 

In [None]:
#mfluxes = vmap(lambda x : jnp.power(10.,-0.4*x), in_axes=0)(data_selected_mags)
#emfluxes = vmap(lambda x,y : jnp.power(10.,-0.4*x)*y)(data_selected_mags, data_selected_magserr)

In [None]:
xphot_rest,yphot_rest,eyphot_rest,factor = rescale_photometry(dict_params_m,list_wlmean_f_sel,data_selected_mags,data_selected_magserr,z_obs)

In [None]:
plot_fit_ssp_photometry(dict_params_m,list_wlmean_f_sel,data_selected_mags,data_selected_magserr,z_obs, subtit = title_spec ,ax=None)

### Fit with Spectrum only

The Fors2 spectra aren't calibrated. Those have to be rescaled in amplitude
on the SED model fitted first with photometry

#### rescale spectroscopic data

In [None]:
Xspec_data_rest,Yspec_data_rest,EYspec_data_rest,factor = rescale_spectroscopy(dict_params_m,Xspec_data,Yspec_data,EYspec_data,z_obs)

#### fit spectroscopic data alone

In [None]:
lbfgsb = jaxopt.ScipyBoundedMinimize(fun=lik_spec, method="L-BFGS-B")
res_s = lbfgsb.run(init_params, bounds=(params_min ,params_max ), wls=Xspec_data_rest, F=Yspec_data_rest, sigma_obs = EYspec_data_rest,z_obs=z_obs)
params_s,fun_min_s,jacob_min_s,inv_hessian_min_s = get_infos_spec(res_s, lik_spec, wls=Xspec_data, F=Yspec_data,eF=EYspec_data,z_obs=z_obs)
print("params:",params_s,"\nfun@min:",fun_min_s,"\njacob@min:",jacob_min_s)
#,     "\n invH@min:",inv_hessian_min_s)


#### convert the fitted parameters on spectroscopic data into a dictionnary

In [None]:
# Convert fitted parameters into a dictionnary
dict_params_s = paramslist_to_dict( params_s,p.PARAM_NAMES_FLAT) 

#### plot the SED models and the spectroscopic data

In [None]:
plot_fit_ssp_spectroscopy(dict_params_s,Xspec_data_rest,Yspec_data_rest,EYspec_data_rest,z_obs,subtit = title_spec)

### Fit by combining Fors2 Spectrum and Photometry

- Combine Fors2 data and Photometric data.
- Both are properly rescaled and in rest frame

In [None]:
Xc = [Xspec_data_rest, Xf_sel]
Yc = [Yspec_data_rest,  data_selected_mags ]
EYc = [EYspec_data_rest, data_selected_magserr]
weight_spec = 0.5

#### Do the combined fit

In [None]:
lbfgsb = jaxopt.ScipyBoundedMinimize(fun=lik_comb, method="L-BFGS-B")

res_c = lbfgsb.run(init_params, bounds=(params_min ,params_max ), xc=Xc, datac=Yc,sigmac=EYc,z_obs=z_obs,weight=weight_spec)
params_c,fun_min_c,jacob_min_c,inv_hessian_min_c = get_infos_comb(res_c, lik_comb, xc=Xc, datac=Yc,sigmac=EYc,z_obs=z_obs,weight=weight_spec)
print("params:",params_c,"\nfun@min:",fun_min_c,"\njacob@min:",jacob_min_c)
#      ,"\n invH@min:",inv_hessian_min_c)


##### check the value of the chi2 for the photometric part

In [None]:
params_cm,fun_min_cm,jacob_min_cm,inv_hessian_min_cm  = get_infos_mag(res_c, lik_mag,  xf = Xf_sel, mgs = data_selected_mags, mgse = data_selected_magserr,z_obs=z_obs)
print("params:",params_cm,"\nfun@min:",fun_min_cm,"\njacob@min:",jacob_min_cm)

##### check the value of the chi2 for the spectroscopic part

In [None]:
params_cs,fun_min_cs,jacob_min_cs,inv_hessian_min_cs = get_infos_spec(res_c, lik_spec, wls=Xspec_data_rest, F=Yspec_data_rest,eF=EYspec_data_rest,z_obs=z_obs)
print("params:",params_cs,"\nfun@min:",fun_min_cs,"\njacob@min:",jacob_min_cs)

#### Convert fitted parameters into a dictionnary

In [None]:
dict_params_c = paramslist_to_dict( params_c,p.PARAM_NAMES_FLAT) 

#### Plot combined fit

In [None]:
plot_fit_ssp_spectrophotometry(dict_params_c ,Xspec_data_rest,Yspec_data_rest,EYspec_data_rest,xphot_rest,yphot_rest,eyphot_rest,z_obs=z_obs,subtit = title_spec )

## Add StarLight model for comparison

### Get StarLight spectrum

In [None]:
dict_sl = sl.getspectrum_fromgroup(selected_spectrum_tag)

### Rescale Starlight spectrum

In [None]:
w_sl ,fnu_sl , _ = rescale_starlight_inrangefors2(dict_sl["wl"],dict_sl["fnu"],Xspec_data_rest,Yspec_data_rest )

### Plot all data and models

In [None]:
plot_fit_ssp_spectrophotometry_sl(dict_params_c ,Xspec_data_rest,Yspec_data_rest,EYspec_data_rest,xphot_rest,yphot_rest,eyphot_rest,w_sl,fnu_sl,z_obs=z_obs,subtit = title_spec )

## Plot the SFH model

In [None]:
plot_SFH(dict_params_c,z_obs,subtit = title_spec , ax=None)

## save fitted data

In [None]:
filename_params = f"fitparams_{selected_spectrum_tag}.pickle"

In [None]:
with open(filename_params, 'wb') as f:
    pickle.dump(dict_params_c, f)

In [None]:
#with open(filename_params, 'rb') as f:
#    loaded_dict = pickle.load(f)