In [None]:
import numpy as np

import jax.numpy as jnp
import jax.scipy as jsc
from jax import jit, vmap
import jax

import matplotlib.pyplot as plt
import matplotlib.colors as colors
import matplotlib.cm as cmx
import matplotlib.offsetbox
%matplotlib inline
# to enlarge the sizes
params = {'legend.fontsize': 'large',
          'figure.figsize': (12, 8),
         'axes.labelsize': 'xx-large',
         'axes.titlesize':'xx-large',
         'xtick.labelsize':'xx-large',
         'ytick.labelsize':'xx-large'}
plt.rcParams.update(params)
props = dict(boxstyle='round', edgecolor="w", facecolor="w", alpha=0.5)

import h5py

import copy
import sys, os

from interpax import interp1d

jax.config.update("jax_enable_x64", True)

from dsps import (calc_obs_mag, calc_rest_mag,
                  calc_rest_sed_sfh_table_lognormal_mdf,
                  calc_rest_sed_sfh_table_met_table, load_ssp_templates)
from dsps.cosmology import DEFAULT_COSMOLOGY, age_at_z
from dsps.dust.att_curves import (RV_C00, _frac_transmission_from_k_lambda,
                                  sbl18_k_lambda)

from fors2tostellarpopsynthesis.parameters import (SSPParametersFit,\
                                                   SSPParametersFitAgeDepMet,\
                                                   SSPParametersFit_AgeDepMet_Q,\
                                                   paramslist_to_dict)

In [None]:
data_dir = os.path.abspath(os.path.join('../../../', 'src', 'fors2tostellarpopsynthesis', 'fitters', 'data'))
fn_data_miles = 'tempdata.h5'
fn_data_c3k = 'test_fspsData_v3_2_C3K.h5'
fn_data_basel = 'test_fspsData_v3_2_BASEL.h5'
path_data_miles = os.path.join(data_dir, fn_data_miles)
path_data_c3k = os.path.join(data_dir, fn_data_c3k)
path_data_basel = os.path.join(data_dir, fn_data_basel)
SSP_DATA = {"MILES":load_ssp_templates(fn=path_data_miles),\
            "C3K":load_ssp_templates(fn=path_data_c3k),\
            "BASEL":load_ssp_templates(fn=path_data_basel)\
           }

In [None]:
SSP_DATA["MILES"]

In [None]:
for _fl in SSP_DATA["MILES"].ssp_flux[:,-1]:
    plt.plot(SSP_DATA["MILES"].ssp_wave, _fl)
plt.xscale('log')
plt.yscale('log')

In [None]:
miles_cmp = plt.get_cmap('Blues')
miles_cNorm = colors.Normalize(vmin=min(SSP_DATA["MILES"].ssp_lg_age_gyr-0.2), vmax=max(SSP_DATA["MILES"].ssp_lg_age_gyr))
miles_scalarMap = cmx.ScalarMappable(norm=miles_cNorm, cmap=miles_cmp)
miles_colors = miles_scalarMap.to_rgba(SSP_DATA["MILES"].ssp_lg_age_gyr, alpha=1)

c3k_cmp = plt.get_cmap('Oranges')
c3k_cNorm = colors.Normalize(vmin=min(SSP_DATA["C3K"].ssp_lg_age_gyr-0.2), vmax=max(SSP_DATA["C3K"].ssp_lg_age_gyr))
c3k_scalarMap = cmx.ScalarMappable(norm=c3k_cNorm, cmap=c3k_cmp)
c3k_colors = c3k_scalarMap.to_rgba(SSP_DATA["C3K"].ssp_lg_age_gyr, alpha=1)

basel_cmp = plt.get_cmap('Greens')
basel_cNorm = colors.Normalize(vmin=min(SSP_DATA["BASEL"].ssp_lg_age_gyr-0.2), vmax=max(SSP_DATA["BASEL"].ssp_lg_age_gyr))
basel_scalarMap = cmx.ScalarMappable(norm=basel_cNorm, cmap=basel_cmp)
basel_colors = basel_scalarMap.to_rgba(SSP_DATA["BASEL"].ssp_lg_age_gyr, alpha=1)

In [None]:
for Zid, lg_met in enumerate(SSP_DATA["MILES"].ssp_lgmet):
    f, a = plt.subplots(1, 1)
    for ageId, _fl in enumerate(SSP_DATA["MILES"].ssp_flux[Zid, :]):
        lab = ""
        if ageId==SSP_DATA["MILES"].ssp_flux.shape[1]-1: lab = "MILES"
        _sel = (SSP_DATA["MILES"].ssp_wave>1000.)*(SSP_DATA["MILES"].ssp_wave<11000.)
        _sel_norm = (SSP_DATA["MILES"].ssp_wave>3900.)*(SSP_DATA["MILES"].ssp_wave<4100.)
        _norm = np.trapz(_fl[_sel_norm], x=SSP_DATA["MILES"].ssp_wave[_sel_norm])
        a.plot(SSP_DATA["MILES"].ssp_wave[_sel], _fl[_sel], c=miles_colors[ageId], label=lab, alpha=0.3)
    
    for ageId, _fl in enumerate(SSP_DATA["BASEL"].ssp_flux[Zid, :]):
        lab = ""
        if ageId==SSP_DATA["BASEL"].ssp_flux.shape[1]-1: lab = "BASEL"
        _sel = (SSP_DATA["BASEL"].ssp_wave>1000.)*(SSP_DATA["BASEL"].ssp_wave<11000.)
        _sel_norm = (SSP_DATA["BASEL"].ssp_wave>3900.)*(SSP_DATA["BASEL"].ssp_wave<4100.)
        _norm = np.trapz(_fl[_sel_norm], x=SSP_DATA["BASEL"].ssp_wave[_sel_norm])
        a.plot(SSP_DATA["BASEL"].ssp_wave[_sel], _fl[_sel], c=basel_colors[ageId], label=lab, alpha=0.3)
    
    for ageId, _fl in enumerate(SSP_DATA["C3K"].ssp_flux[Zid, :]):
        lab = ""
        if ageId==SSP_DATA["C3K"].ssp_flux.shape[1]-1: lab = "C3K"
        _sel = (SSP_DATA["C3K"].ssp_wave>1000.)*(SSP_DATA["C3K"].ssp_wave<11000.)
        _sel_norm = (SSP_DATA["C3K"].ssp_wave>3900.)*(SSP_DATA["C3K"].ssp_wave<4100.)
        _norm = np.trapz(_fl[_sel_norm], x=SSP_DATA["C3K"].ssp_wave[_sel_norm])
        a.plot(SSP_DATA["C3K"].ssp_wave[_sel], _fl[_sel], c=c3k_colors[ageId], label=lab, alpha=0.3)
        
    a.set_yscale('log')
    a.set_xscale('log')
    a.legend()
    plt.colorbar(miles_scalarMap, location='left', label='Log-age[GYr]', ax=a)

In [None]:
fn_data_padova_miles = 'fspsData_v3_2_PADOVA_MILES.h5'
path_data_padova_miles = os.path.join(data_dir, fn_data_padova_miles)
SSP_DATA_ISOCHRONES = {"MIST+MILES":load_ssp_templates(fn=path_data_miles),\
                       "PADOVA+MILES":load_ssp_templates(fn=path_data_padova_miles)\
                      }
padova_miles_cmp = plt.get_cmap('Reds')
padova_miles_cNorm = colors.Normalize(vmin=min(SSP_DATA_ISOCHRONES["PADOVA+MILES"].ssp_lg_age_gyr-0.2),\
                                      vmax=max(SSP_DATA_ISOCHRONES["PADOVA+MILES"].ssp_lg_age_gyr))
padova_miles_scalarMap = cmx.ScalarMappable(norm=padova_miles_cNorm, cmap=padova_miles_cmp)
padova_miles_colors = padova_miles_scalarMap.to_rgba(SSP_DATA_ISOCHRONES["PADOVA+MILES"].ssp_lg_age_gyr, alpha=1)

In [None]:
for Zid, lg_met in enumerate(SSP_DATA_ISOCHRONES["MIST+MILES"].ssp_lgmet):
    f, a = plt.subplots(1, 1)
    for ageId, _fl in enumerate(SSP_DATA_ISOCHRONES["MIST+MILES"].ssp_flux[Zid, :]):
        lab = ""
        if ageId==SSP_DATA_ISOCHRONES["MIST+MILES"].ssp_flux.shape[1]-1: lab = "MIST+MILES"
        _sel = (SSP_DATA_ISOCHRONES["MIST+MILES"].ssp_wave>1000.)*(SSP_DATA_ISOCHRONES["MIST+MILES"].ssp_wave<11000.)
        _sel_norm = (SSP_DATA_ISOCHRONES["MIST+MILES"].ssp_wave>3900.)*(SSP_DATA_ISOCHRONES["MIST+MILES"].ssp_wave<4100.)
        _norm = np.trapz(_fl[_sel_norm], x=SSP_DATA_ISOCHRONES["MIST+MILES"].ssp_wave[_sel_norm])
        a.plot(SSP_DATA_ISOCHRONES["MIST+MILES"].ssp_wave[_sel], _fl[_sel],\
               c=miles_colors[ageId], label=lab, alpha=0.3)
    
    for ageId, _fl in enumerate(SSP_DATA_ISOCHRONES["PADOVA+MILES"].ssp_flux[Zid, :]):
        lab = ""
        if ageId==SSP_DATA_ISOCHRONES["PADOVA+MILES"].ssp_flux.shape[1]-1: lab = "PADOVA+MILES"
        _sel = (SSP_DATA_ISOCHRONES["PADOVA+MILES"].ssp_wave>1000.)*(SSP_DATA_ISOCHRONES["PADOVA+MILES"].ssp_wave<11000.)
        _sel_norm = (SSP_DATA_ISOCHRONES["PADOVA+MILES"].ssp_wave>3900.)*(SSP_DATA_ISOCHRONES["PADOVA+MILES"].ssp_wave<4100.)
        _norm = np.trapz(_fl[_sel_norm], x=SSP_DATA_ISOCHRONES["PADOVA+MILES"].ssp_wave[_sel_norm])
        a.plot(SSP_DATA_ISOCHRONES["PADOVA+MILES"].ssp_wave[_sel], _fl[_sel],\
               c=padova_miles_colors[ageId], label=lab, alpha=0.3)
        
    a.set_yscale('log')
    a.set_xscale('log')
    a.legend()
    plt.colorbar(miles_scalarMap, location='left', label='Log-age[GYr]', ax=a)