# test creating parameters true values

In [None]:

import h5py
import pandas as pd
import numpy as np
import os
import re
import pickle 
import collections
from collections import OrderedDict
import copy

import jax
import jax.numpy as jnp
from jax import vmap
import jaxopt
import optax

from interpax import interp1d

from jax.lax import fori_loop
from jax.lax import select,cond
from jax.lax import concatenate

import numpyro
from numpyro import optim
from numpyro.diagnostics import print_summary
import numpyro.distributions as dist
from numpyro.distributions import constraints
from numpyro.infer import MCMC, HMC, NUTS, SVI, Trace_ELBO
from numpyro.infer.autoguide import AutoBNAFNormal, AutoMultivariateNormal
from numpyro.infer.reparam import NeuTraReparam
from numpyro.handlers import seed, trace, condition

from fors2tostellarpopsynthesis.parameters  import SSPParametersFit,paramslist_to_dict
from fors2tostellarpopsynthesis.fitters.fitter_jaxopt import (SSP_DATA,mean_spectrum,mean_mags,mean_sfr,ssp_spectrum_fromparam)
from fors2tostellarpopsynthesis.fitters.fitter_numpyro import(PARAM_SIMLAW_NODUST,PARAM_SIMLAW_WITHDUST,
                            PARAM_NAMES,PARAM_VAL,PARAM_MIN,PARAM_MAX,PARAM_SIGMA)

from fors2tostellarpopsynthesis.fitters.fitter_numpyro import(galaxymodel_nodust_av,galaxymodel_nodust,galaxymodel_withdust_av,galaxymodel_withdust)




In [None]:
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
# initialisation
jax.config.update("jax_enable_x64", True)


# observation
z_obs = 0.5

# initialisation of parameters
p = SSPParametersFit()


In [None]:
# select parameter true values and change it wrt default value
dict_sel_params_true = copy.deepcopy(p.DICT_PARAMS_true)
dict_sel_params_true['MAH_lgmO'] = 10.0
dict_sel_params_true['MAH_logtc'] = 0.8
dict_sel_params_true['MAH_early_index'] = 3.0
dict_sel_params_true['MAH_late_index'] = 0.5
dict_sel_params_true['AV'] = 0.5
dict_sel_params_true['UV_BUMP'] = 2.5
dict_sel_params_true['PLAW_SLOPE'] = -0.1

list_sel_params_true = list(dict_sel_params_true.values())

In [None]:
list_sel_params_true

In [None]:
dict_sel_params_true

In [None]:
df_params = pd.DataFrame({"name":PARAM_NAMES,"min": PARAM_MIN,"val": PARAM_VAL,"max": PARAM_MAX,"sig":PARAM_SIGMA,"true":list_sel_params_true})

In [None]:
df_params  = df_params.round(decimals=3)
df_params

In [None]:
# generate spectrum from true selected values
wlsall,spec_rest_noatt,spec_rest_att = ssp_spectrum_fromparam(dict_sel_params_true,z_obs)


In [None]:
fig = plt.figure(figsize=(10,4))
ax = fig.add_subplot()
ax.set_xscale('log')
ax.set_yscale('log')
ax.plot(wlsall,spec_rest_noatt,linewidth=2,color="k" ,label="no dust")
ax.plot(wlsall,spec_rest_att, linewidth=2 ,color='r', label="with dust")
ax.set_xlabel("$\lambda$ (nm)")
ax.set_ylabel("DSPS SED true")
ax.legend()
ymin = 1e-11
ymax = jnp.max(spec_rest_noatt)*2.
ymin = ymax/1e6
ax.set_ylim(ymin,ymax)
ax.set_xlim(1e2,1e6)
ax.grid();

