# Fit Single Fors2 Spectrum and Photometry with DSPS

Implement this fit using this `fors2tostellarpopsynthesis`package

- Author Sylvie Dagoret-Campagne
- Afflilation : IJCLab/IN2P3/CNRS
- Organisation : LSST-DESC
- creation date : 2023-11-21
- last update : 2023-11-21


## Import

In [None]:
import h5py
import pandas as pd
import numpy as np
import os
import re
from astropy.io import fits
from astropy.table import Table
import matplotlib as mpl
import matplotlib.pyplot as plt
%matplotlib inline
import matplotlib.colors as colors
import matplotlib.cm as cmx
import collections
from collections import OrderedDict
import re
import matplotlib.gridspec as gridspec
from sklearn.gaussian_process import GaussianProcessRegressor, kernels

In [None]:
import jax
import jax.numpy as jnp
import jaxopt
import optax
jax.config.update("jax_enable_x64", True)

In [None]:
plt.rcParams["figure.figsize"] = (12,6)
plt.rcParams["axes.labelsize"] = 'xx-large'
plt.rcParams['axes.titlesize'] = 'xx-large'
plt.rcParams['xtick.labelsize']= 'xx-large'
plt.rcParams['ytick.labelsize']= 'xx-large'
plt.rcParams['legend.fontsize']=  20

### Filters

In [None]:
from fors2tostellarpopsynthesis.filters import FilterInfo

### Fors2 and Starlight

In [None]:
from fors2tostellarpopsynthesis.fors2starlightio import Fors2DataAcess, SLDataAcess,convert_flux_torestframe,gpr

### fitter jaxopt

In [None]:
from fors2tostellarpopsynthesis.fitters.fitter_jaxopt import (lik_spec,lik_mag,lik_comb,
get_infos_spec,get_infos_mag,get_infos_comb)

## Check filters

In [None]:
ps = FilterInfo()
ps.plot_transmissions()

## Fors2 data

In [None]:
fors2 = Fors2DataAcess()

In [None]:
fors2.plot_allspectra()

In [None]:
fors2_tags = fors2.get_list_of_groupkeys()
len(fors2_tags)

In [None]:
list_of_fors2_attributes = fors2.get_list_subgroup_keys()
print(list_of_fors2_attributes)

## StarLight data

In [None]:
sl = SLDataAcess()

In [None]:
sl.plot_allspectra()

In [None]:
sl_tags = sl.get_list_of_groupkeys()
len(sl_tags)

## Select Spectrum

In [None]:
#selected_spectrum_number = 411
selected_spectrum_number = 560
selected_spectrum_tag = f"SPEC{selected_spectrum_number}"

In [None]:
fors2_attr =fors2.getattribdata_fromgroup(selected_spectrum_tag)
z_obs = fors2_attr['redshift']

### Get magnitude data

In [None]:
data_mags, data_magserr = fors2.get_photmagnitudes(selected_spectrum_tag)

### Get Fors2 spectrum in rest-frame

In [None]:
spec_obs = fors2.getspectrumcleanedemissionlines_fromgroup(selected_spectrum_tag)

In [None]:
spec_obs

In [None]:
Xs = spec_obs['wl']
Ys = spec_obs['fnu']
EYs = spec_obs['bg']

In [None]:
flmin = spec_obs['fnu']-spec_obs['bg']
flmax = spec_obs['fnu']+spec_obs['bg']

In [None]:
_,ax = plt.subplots(1,1,figsize=(6,4))
ax.plot(spec_obs['wl'],spec_obs['fnu'],'-b')
ax.fill_between(spec_obs['wl'],flmin,flmax, facecolor='lightgrey', edgecolors="None")

In [None]:
Xspec_data, Yspec_data = convert_flux_torestframe(Xs,Ys,z_obs)
EYspec_data = EYs*(1+z_obs)

In [None]:
fors2.plot_spectro_photom_rescaling(selected_spectrum_tag)

## Parameters

In [None]:
from fors2tostellarpopsynthesis.parameters import SSPParametersFit

In [None]:
p = SSPParametersFit()

In [None]:
print(p)

In [None]:
p.DICT_PARAMS_true

In [None]:
init_params = p.INIT_PARAMS
params_min = p.PARAMS_MIN
params_max = p.PARAMS_MAX

# Select filters

In [None]:
print(ps.filters_indexlist) 
print(ps.filters_surveylist)
print(ps.filters_namelist)

### Choose the index of the filters

In [None]:
index_selected_filters = np.arange(1,11)
index_selected_filters

In [None]:
XF = ps.get_2lists()
NF = len(XF[0])

In [None]:
list_wls_f_sel = []
list_trans_f_sel = []

list_name_f_sel = []
list_wlmean_f_sel = []

for index in index_selected_filters:
    list_wls_f_sel.append(XF[0][index])
    list_trans_f_sel.append(XF[1][index])
    the_filt = ps.filters_transmissionlist[index]
    the_wlmean = the_filt.wave_mean
    list_wlmean_f_sel.append(the_wlmean)
    list_name_f_sel.append(ps.filters_namelist[index])
    
list_wlmean_f_sel = jnp.array(list_wlmean_f_sel) 

In [None]:
Xf_sel = (list_wls_f_sel,list_trans_f_sel)

In [None]:
#mags_predictions = jax.tree_map(lambda x,y : calc_obs_mag(ssp_data.ssp_wave, sed_attenuated,x,y,z_obs, *DEFAULT_COSMOLOGY),list_wls_f_sel,list_trans_f_sel)

In [None]:
data_selected_mags =  jnp.array(data_mags[index_selected_filters])
data_selected_magserr = jnp.array(data_magserr[index_selected_filters])

In [None]:
data_selected_mags

## Fits

## Fit with magnitudes only

In [None]:

lbfgsb = jaxopt.ScipyBoundedMinimize(fun=lik_mag, method="L-BFGS-B")

res_m = lbfgsb.run(init_params, bounds=(params_min ,params_max ), xf = Xf_sel, mags_measured = data_selected_mags, sigma_mag_obs = data_selected_magserr,z_obs=z_obs)
params_m,fun_min_m,jacob_min_m,inv_hessian_min_m = get_infos_mag(res_m, lik_mag,  xf = Xf_sel, mgs = data_selected_mags, mgse = data_selected_magserr,z_obs=z_obs)
print("params:",params_m,"\nfun@min:",fun_min_m,"\njacob@min:",jacob_min_m)
#      ,\n invH@min:",inv_hessian_min_m)

## Fit with Spectrum only

In [None]:
lbfgsb = jaxopt.ScipyBoundedMinimize(fun=lik_spec, method="L-BFGS-B")

res_s = lbfgsb.run(init_params, bounds=(params_min ,params_max ), wls=Xspec_data, F=Yspec_data, sigma_obs = EYspec_data,z_obs=z_obs)
params_s,fun_min_s,jacob_min_s,inv_hessian_min_s = get_infos_spec(res_s, lik_spec, wls=Xspec_data, F=Yspec_data,eF=EYspec_data,z_obs=z_obs)
print("params:",params_s,"\nfun@min:",fun_min_s,"\njacob@min:",jacob_min_s)
#,     "\n invH@min:",inv_hessian_min_s)
