# Loop over gaia to Find the nearest neighbourg of a gaia spectrum from magnitude Gaia and Calspec and Pickles spectra from hdf5 file

- author Sylvie Dagoret-Campagne
- affiliation IJCLab
- creation date : 2024/10/06
- update : 2024/10/07 : 

The goal is to find the neared Pickle SED to a Gaia Spectrum. The nearest Neighbourg is done after renormalizing Pickle Spectra to match Z band magnitude to that of Gaia. Then the Nearest Neighbourg is done by Matching the magnitudes in G,R,I (ad Z) by definition.

In [None]:
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
import matplotlib as mpl
import matplotlib.colors as colors
import matplotlib.cm as cmx
from mpl_toolkits.axes_grid1 import make_axes_locatable
from matplotlib.colors import LogNorm
from matplotlib.gridspec import GridSpec
import pandas as pd

import matplotlib.ticker                         # here's where the formatter is
import os,sys
import re
import pandas as pd

from astropy.io import fits
from astropy import units as u
from astropy import constants as c

plt.rcParams["figure.figsize"] = (8,6)
plt.rcParams["axes.labelsize"] = 'xx-large'
plt.rcParams['axes.titlesize'] = 'xx-large'
plt.rcParams['xtick.labelsize']= 'xx-large'
plt.rcParams['ytick.labelsize']= 'xx-large'

from scipy.interpolate import interp1d
from sklearn.neighbors import NearestNeighbors
from sklearn.neighbors import KDTree, BallTree
import h5py

In [None]:
pd.set_option('display.max_columns', 500)

In [None]:
machine_name = os.uname().nodename
path_rubinsimphot = "repos/repos_w_2024_38/rubinsimphot/src"
if 'sdf' in machine_name:
    #machine_name_usdf = 'sdfrome001'
    print("Set environment for USDF")
    newpythonpath = os.path.join(os.getenv("HOME"),path_rubinsimphot)
    sys.path.append(newpythonpath)
elif 'dagoret-nb' in machine_name:
    print("Set environment for USDF Rubin Science Platform")
    newpythonpath = os.path.join(os.getenv("HOME"),path_rubinsimphot)
    sys.path.append(newpythonpath)    
elif 'mac' in machine_name:
    print("Be sure to run this notebook in conda environment named conda_py310")
else:
    print(f"Your current machine name is {machine_name}. Check your python environment")

In [None]:
# reference flux in Jy
F0 = ((0.*u.ABmag).to(u.Jy)).value
F0

## Config 

In [None]:
FLAG_PLOTSPECTRA = False
FLAG_PLOTATMOSPHERETRANSM = False
FLAG_PLOTDETECTORTRANSM = False
FLAG_PLOTTOTALTRANSM = False
FLAG_PLOTPICKLESSED = False

In [None]:
input_path = "data_gaiacalspecspectra"
input_file_h5  = 'GAIACALSPECspectra.hdf5'
input_fullfile_h5 = os.path.join(input_path,input_file_h5)

## Read file spectra

In [None]:
hf =  h5py.File(input_fullfile_h5, 'r') 
list_of_keys = list(hf.keys())

In [None]:
list_of_keys

In [None]:
# pick one key    
key_sel =  list_of_keys[0]
# pick one group
group = hf.get(key_sel)  

In [None]:
#pickup all attribute names
all_subgroup_keys = []
for k in group.attrs.keys():
    all_subgroup_keys.append(k)

In [None]:
all_subgroup_keys

In [None]:
def GetColumnHfData(hff,list_of_keys,nameval):
    """
    Extract hff atttribute 
    
    parameters
      hff           : descriptor of h5 file
      list_of_keys : list of exposures
      nameval      : name of the attribute
      
    return
           the array of values in the order of 
    """
    

    all_data = []
    for key in list_of_keys:
        group=hff.get(key)
        val=group.attrs[nameval]
        all_data.append(val)
    return all_data

In [None]:
# create info
df_info = pd.DataFrame()
for key in all_subgroup_keys:
    arr=GetColumnHfData(hf, list_of_keys ,key)
    df_info[key] = arr

In [None]:
df_info

In [None]:
NSPEC = len(df_info)

## Extract the spectra

In [None]:
all_dfg = [] 
all_dfc = [] 
    
idx=0
for key in list_of_keys :
        
    group = hf.get(key)
    dfg = pd.DataFrame()
    dfc = pd.DataFrame()

    dfg["WAVELENGTH"] = np.array(group.get("gWAVELENGTH"))
    dfg["FLUX"] = np.array(group.get("gFLUX")) 
    dfg["STATERROR"] = np.array(group.get("gSTATERROR")) 
    dfg["SYSERROR"] = np.array(group.get("gSYSERROR")) 
    dfc["WAVELENGTH"] = np.array(group.get("cWAVELENGTH"))
    dfc["FLUX"] = np.array(group.get("cFLUX")) 
    dfc["STATERROR"] = np.array(group.get("cSTATERROR")) 
    dfc["SYSERROR"] = np.array(group.get("cSYSERROR")) 
 
 
    all_dfg.append(dfg)
    all_dfc.append(dfc)
        
    idx+=1

## Check by plot

In [None]:
index = 0
row = df_info.iloc[index]
hdname = row["HD_name"]
gaianame = row["GAIA_ED3_Name"]
tag = f"{hdname}_{gaianame}"
dfg = all_dfg[index]
dfc = all_dfc[index]

In [None]:
def plotspec(tag,dfc,dfg):
    """
    tag : name of the star to appear un title
    dfc : dataframe for calspec
    dfg : dataframe for gaia
    """
    
    fig, ax = plt.subplots(1,1,figsize=(8,5))
    leg = ax.get_legend()
    title = "calspec-gaia : " + tag
    dfc.plot(x="WAVELENGTH",y="FLUX",ax=ax,marker='.',color='b',legend=leg,label="calspec")
    dfg.plot(x="WAVELENGTH",y="FLUX",ax=ax,marker='.',color='r',legend=leg,label="gaia")
    ax.set_xlim(300.,1100.)
    wl = dfc.WAVELENGTH
    index_sel = np.where(np.logical_and(wl>300.,wl<1100.))[0]
    fl = dfc.FLUX[index_sel]
    flmax = np.max(fl)*1.2
    ax.set_ylim(0.,flmax)    

    ax.legend()
    ax.set_xlabel("$\\lambda$ (nm)")
    ax.set_ylabel("Flux erg/cm$^2$/s/nm ")
    ax.set_title(title)
    plt.show()

In [None]:
if FLAG_PLOTSPECTRA: 
    plotspec(tag,dfc,dfg)

## Plot all spectra

In [None]:
if FLAG_PLOTSPECTRA: 
    for index in range(NSPEC):
        # info
        row = df_info.iloc[index]

        # the title
        hdname = row["HD_name"]
        gaianame = row["GAIA_ED3_Name"]
        tag = f"{hdname}_{gaianame}"

        # the spectra
        dfg = all_dfg[index]
        dfc = all_dfc[index]

        plotspec(tag,dfc,dfg)
    

## Atmospheric emulator

In [None]:
from importlib.metadata import version
the_ver = version('getObsAtmo')
print(f"Version of getObsAtmo : {the_ver}")

In [None]:
from getObsAtmo import ObsAtmo
emul = ObsAtmo("AUXTEL")

## Process transmission

In [None]:
import sys
sys.path.append('../lib')
#import libAtmosphericFit

In [None]:
# This package encapsulate the calculation on calibration used in this nb
from libPhotometricCorrections import *

In [None]:
def set_photometric_parameters(exptime, nexp, readnoise=None):
    # readnoise = None will use the default (8.8 e/pixel). Readnoise should be in electrons/pixel.
    photParams = PhotometricParameters(exptime=exptime, nexp=nexp, readnoise=readnoise)
    return photParams

In [None]:
def scale_sed(ref_mag, ref_filter, sed):
    fluxNorm = sed.calc_flux_norm(ref_mag, lsst_std[ref_filter])
    sed.multiply_flux_norm(fluxNorm)
    return sed

## library rubin_sim defining LSST parameters, namely for photometric calculations

In [None]:
from rubinsimphot.phot_utils import Bandpass, Sed
from rubinsimphot.data import get_data_dir

### Config of atmosphere

In [None]:
am0 =1.20    # airmass
pwv0 = 3.0  # Precipitable water vapor vertical column depth in mm
oz0 = 300.  # Ozone vertical column depth in Dobson Unit (DU)
ncomp=1     # Number of aerosol components
tau0= 0.0 # Vertical Aerosol depth (VAOD) 
beta0 = 1.2 # Aerosol Angstrom exponent
pc = PhotometricCorrections(am0,pwv0,oz0,tau0,beta0)

In [None]:
if FLAG_PLOTATMOSPHERETRANSM:
    fig, axs = plt.subplots(1,1,figsize=(6,4))
    axs.plot(pc.WL,pc.atm_std,'k-')
    axs.set_xlabel("$\\lambda$ (nm)")
    axs.set_title("Standard atmosphere transmission")
    plt.show()

In [None]:
if FLAG_PLOTDETECTORTRANSM:
    fig, axs = plt.subplots(1,1,figsize=(6,4))
    # loop on filter
    for index,f in enumerate(filter_tagnames):
        axs.plot(pc.bandpass_inst[f].wavelen,pc.bandpass_inst[f].sb,color=filter_color[index]) 
        axs.fill_between(pc.bandpass_inst[f].wavelen,pc.bandpass_inst[f].sb,color=filter_color[index],alpha=0.2) 
        axs.axvline(FILTERWL[index,2],color=filter_color[index],linestyle="-.") 
    axs.set_xlabel("$\\lambda$ (nm)")
    axs.set_title("Instrument throughput (rubin-obs)")
    plt.show()

In [None]:
if FLAG_PLOTTOTALTRANSM:
    fig, axs = plt.subplots(1,1,figsize=(6,4))
    # loop on filter
    for index,f in enumerate(filter_tagnames):
        axs.plot(pc.bandpass_total_std[f].wavelen,pc.bandpass_total_std[f].sb,color=filter_color[index]) 
        axs.fill_between(pc.bandpass_total_std[f].wavelen,pc.bandpass_total_std[f].sb,color=filter_color[index],alpha=0.2) 
        axs.axvline(FILTERWL[index,2],color=filter_color[index],linestyle="-.")    
    axs.set_xlabel("$\\lambda$ (nm)")
    axs.set_title("Total filter throughput (rubin-obs)")
    plt.show()


## Convert Gaia-Calspec sed into rubin-sim SED

In [None]:
all_sed_gaia = []
all_sed_calspec = []
all_sed_names = []
#zmag = 20.0
for index in np.arange(NSPEC):

    row = df_info.iloc[index]

    # the title
    hdname = row["HD_name"]
    gaianame = row["GAIA_ED3_Name"]
    tag = f"{hdname}_{gaianame}"

    # the spectra
    dfg = all_dfg[index]
    dfc = all_dfc[index]
    spectype = tag
    
    #wavelen (nm)
    #flambda (ergs/cm^2/s/nm)
    #set_sed
    the_sed_c = Sed(wavelen=dfc.WAVELENGTH.values, flambda=dfc.FLUX.values, name="calspec_" + spectype)
    the_sed_c.set_sed(wavelen=dfc.WAVELENGTH.values, flambda=dfc.FLUX.values, name="calspec_" + spectype) 
    the_sed_g = Sed(wavelen=dfg.WAVELENGTH.values, flambda=dfg.FLUX.values, name="gaiaspec_" + spectype)
    the_sed_g.set_sed(wavelen=dfg.WAVELENGTH.values, flambda=dfg.FLUX.values, name="gaiaspec_" + spectype)
    #flux_norm = the_sed.calc_flux_norm(zmag, pc.bandpass_total_std['r'])
    #the_sed.multiply_flux_norm(flux_norm)
    all_sed_gaia.append(the_sed_g) 
    all_sed_calspec.append(the_sed_c) 
    all_sed_names.append(tag)

In [None]:
#the_sed_g.get_sed_fnu()

## Compute magnitudes 

In [None]:
# container for all magnitudes and magnitudes differences
all_mags_std_gaia = []
all_mags_std_calspec = []

# loop on spectra
for index in np.arange(NSPEC):
    mags_std_gaia = {}
    mags_std_calspec = {}
    the_sed_c = all_sed_calspec[index] 
    the_sed_g = all_sed_gaia[index] 

    # loop on filters
    for index2,f in enumerate(filter_tagnames) :

        # extrapolate the gaia sed in filter range
        WLMIN_f = pc.bandpass_total_std[f].wavelen.min()
        WLMAX_f = pc.bandpass_total_std[f].wavelen.max()
        WL = np.arange(WLMIN_f-1.,WLMAX_f+1.,1.)

        # extrapolate the gaia
        finterp = interp1d(the_sed_g.wavelen, the_sed_g.flambda, kind = 'nearest',fill_value="extrapolate")
        the_sed_g_extrapolated = Sed(wavelen=WL, flambda= finterp(WL), name=the_sed_g.name)
        # compute magnitude
        mags_std_gaia[f] = the_sed_g_extrapolated .calc_mag(pc.bandpass_total_std[f])

        # extrapolate the calspec
        finterp = interp1d(the_sed_c.wavelen, the_sed_c.flambda, kind = 'nearest',fill_value="extrapolate")
        the_sed_c_extrapolated = Sed(wavelen=WL, flambda= finterp(WL), name=the_sed_c.name)

        mags_std_calspec[f] = the_sed_c_extrapolated.calc_mag(pc.bandpass_total_std[f])

    
    dfmag_c = pd.DataFrame(mags_std_calspec, index=[the_sed_c.name]).T
    dfmag_g = pd.DataFrame(mags_std_gaia, index=[the_sed_g.name]).T
    #dfmag_cg  =pd.concat([dfmag_c,dfmag_g],axis=1)
    #all_mags_std_gaiacalspec.append(dfmag_cg.T)
    all_mags_std_gaia.append(dfmag_g.T) 
    all_mags_std_calspec.append(dfmag_c.T)
    

In [None]:
df_maggaia = pd.concat(all_mags_std_gaia)
df_maggaia

In [None]:
all_sed_gaia_renorm = []
for index in np.arange(NSPEC):
    list_of_mags = df_maggaia.iloc[index]
    zmag = list_of_mags["z"]
    the_sed_g = all_sed_gaia[index] 

    # extrapolate the gaia sed in filter range
    WLMIN_f = pc.bandpass_total_std["z"].wavelen.min()
    WLMAX_f = pc.bandpass_total_std["z"].wavelen.max()
    WL_f = np.arange(WLMIN_f-1.,WLMAX_f+1.,1.)

    # extrapolate the gaia
    finterp = interp1d(the_sed_g.wavelen, the_sed_g.flambda, kind = 'nearest',fill_value="extrapolate")
    the_sed_g_extrapolated = Sed(wavelen=WL_f, flambda= finterp(WL_f), name=the_sed_g.name)

    flux_norm = the_sed_g_extrapolated .calc_flux_norm(zmag, pc.bandpass_total_std['z'])
    the_sed_g_extrapolated.multiply_flux_norm(flux_norm)
    all_sed_gaia_renorm.append(the_sed_g_extrapolated) 

In [None]:
df_magcalspec = pd.concat(all_mags_std_calspec)
df_magcalspec

In [None]:
all_sed_calspec_renorm = []

for index in np.arange(NSPEC):
    list_of_mags = df_magcalspec.iloc[index]
    zmag = list_of_mags["z"]
    the_sed_c = all_sed_calspec[index] 

    # extrapolate the gaia sed in filter range
    WLMIN_f = pc.bandpass_total_std["z"].wavelen.min()
    WLMAX_f = pc.bandpass_total_std["z"].wavelen.max()
    WL_f = np.arange(WLMIN_f-1.,WLMAX_f+1.,1.)

    # extrapolate the gaia
    finterp = interp1d(the_sed_c.wavelen, the_sed_c.flambda, kind = 'nearest',fill_value="extrapolate")
    the_sed_c_extrapolated = Sed(wavelen=WL_f, flambda= finterp(WL_f), name=the_sed_c.name)

    flux_norm = the_sed_c_extrapolated .calc_flux_norm(zmag, pc.bandpass_total_std['z'])
    the_sed_c_extrapolated.multiply_flux_norm(flux_norm)
    all_sed_calspec_renorm.append(the_sed_c_extrapolated) 


In [None]:
# substraction not working
df_maggaia.sub(df_magcalspec)
df_maggaia.subtract(df_magcalspec)

## Convert Gaia-Calspec sed into rubin-sim SED

## access to pickle model

In [None]:
# Find the throughputs directory 
#fdir = os.getenv('RUBIN_SIM_DATA_DIR')
fdir = get_data_dir()
if fdir is None:  #environment variable not set
    fdir = os.path.join(os.getenv('HOME'), 'rubin_sim_data')

In [None]:
seddir = os.path.join(fdir, 'pysynphot', 'pickles')
seddir_uvk = os.path.join(seddir,"dat_uvk")
file_ref = os.path.join(seddir_uvk, "pickles_uk.fits")
hdul = fits.open(file_ref)
df_pickle = pd.DataFrame(hdul[1].data)
NSED = len(df_pickle)

In [None]:
# sed colors
jet = plt.get_cmap('jet')
cNorm = mpl.colors.Normalize(vmin=0, vmax=NSED)
scalarMap = cmx.ScalarMappable(norm=cNorm, cmap=jet)
all_colors = scalarMap.to_rgba(np.arange(NSED), alpha=1)

In [None]:
fig, ax = plt.subplots(figsize=(18, 0.6), layout='constrained')
cmap = mpl.cm.jet
norm = mpl.colors.Normalize(vmin=0, vmax=NSED)
cbar = fig.colorbar(mpl.cm.ScalarMappable(norm=norm, cmap=cmap),cax=ax, orientation='horizontal', label='spectral type')

# change the number of labels
labels = [item.get_text() for item in cbar.ax.get_xticklabels()]
cbar.ax.locator_params(axis='x', nbins=4*len(labels)) 
labels = [item.get_text() for item in cbar.ax.get_xticklabels()]

# rename the labels
Nlabels = len(labels)
for ilab in range(0,Nlabels-1):
    the_label = int(labels[ilab])
    labels[ilab] = df_pickle.loc[the_label,"SPTYPE"]
cbar.ax.set_xticklabels(labels,rotation=45);
plt.show()


In [None]:
if FLAG_PLOTPICKLESSED: 
    fig,(ax,ax2) = plt.subplots(1,2,figsize=(16,6))
    for index in np.arange(NSED):
        filename = df_pickle.loc[index,"FILENAME"].strip()+".fits"
        fullfilename = os.path.join(seddir_uvk,filename) 
        hdul = fits.open(fullfilename)
        dff = pd.DataFrame(hdul[1].data)
        #if index <= 103 or index >= 105:
        if index < 100 or index > 104:    
            ax.plot(dff.WAVELENGTH,dff.FLUX,color=all_colors[index])
    ax.set_yscale('log')
    ax.set_xlim(3000.,11000)
    ax.set_ylim(1e-11,1e-6)
    ax.grid()
    ax.set_title("Pickles $F_\lambda$")
    ax.set_ylabel("$F_\lambda$")
    ax.set_xlabel("$\lambda \, (\AA$)")
    #plt.colorbar(ax.collections[0],ax=ax,orientation = 'horizontal',shrink=0.8)
    #ax.collections[-1].colorbar

    for index in np.arange(NSED):
        filename = df_pickle.loc[index,"FILENAME"].strip()+".fits"
        fullfilename = os.path.join(seddir_uvk,filename) 
        hdul = fits.open(fullfilename)
        dff = pd.DataFrame(hdul[1].data)
        label= f"{index}, {filename}"
        #if index <= 103 or index >= 105:
        if index >= 99 and index <= 104: 
            if index == 99:
                ax2.plot(dff.WAVELENGTH,dff.FLUX,color="k",label=label)
            elif index == 104:
                ax2.plot(dff.WAVELENGTH,dff.FLUX,color="purple",label=label)
            else:
                ax2.plot(dff.WAVELENGTH,dff.FLUX,color=all_colors[index],ls=":",label=label)
    ax2.set_yscale('log')
    ax2.set_xlim(3000.,11000)
    ax2.set_ylim(1e-11,1e-6)
    ax2.grid()
    ax2.set_title("Pickles $F_\lambda$")
    ax2.set_ylabel("$F_\lambda$")
    ax2.set_xlabel("$\lambda \, (\AA$)")
    ax2.legend()
    plt.show()

In [None]:
fig,ax = plt.subplots(1,1,figsize=(10,6))
for index in np.arange(NSED):
    filename = df_pickle.loc[index,"FILENAME"].strip()+".fits"
    fullfilename = os.path.join(seddir_uvk,filename) 
    hdul = fits.open(fullfilename)
    dff = pd.DataFrame(hdul[1].data)
    #if index <= 103 or index >= 105:
    if index < 100 or index > 104:    
        ax.plot(dff.WAVELENGTH,dff.FLUX,color=all_colors[index])
ax.set_yscale('log')
ax.set_xlim(3000.,11000)
ax.set_ylim(1e-11,1e-7)
ax.grid()
ax.set_title("Pickles $F_\lambda$")
ax.set_ylabel("$F_\lambda$")
ax.set_xlabel("$\lambda \, (\AA$)")

#cbar= fig.colorbar(cm.ScalarMappable(norm=norm, cmap="jet"), ax=ax)
cbar =fig.colorbar(mpl.cm.ScalarMappable(norm=norm, cmap=cmap),ax=ax, orientation='vertical', label='spectral type')

#cbar = plt.colorbar(ax.collections[0],ax=ax)
#cbar.ax.set_ylabel('spectral type', rotation=90)

labels = [item.get_text() for item in cbar.ax.get_yticklabels()]
cbar.ax.locator_params(axis='y', nbins=2*len(labels)) 
labels = [item.get_text() for item in cbar.ax.get_yticklabels()]

Nlabels = len(labels)
for ilab in range(0,Nlabels-1):
    the_label = int(labels[ilab])
    labels[ilab] = df_pickle.loc[the_label,"SPTYPE"]
cbar.ax.set_yticklabels(labels,rotation=0);
plt.tight_layout()
plt.show()

## Select the gaia-calspec

In [None]:
index = 2
row = df_info.iloc[index]
hdname = row["HD_name"]
gaianame = row["GAIA_ED3_Name"]
tag = f"{hdname}_{gaianame}"
dfg = all_dfg[index]
dfc = all_dfc[index]
tag

In [None]:
the_sed_gaia = all_sed_gaia[index]
the_sed_calspec = all_sed_calspec[index]

the_sed_gaia_renorm = all_sed_gaia_renorm[index]
the_sed_calspec_renorm = all_sed_calspec_renorm[index]

the_sed_name = all_sed_names[index]
the_sed_name

In [None]:
plotspec(tag,dfc,dfg)

In [None]:
df_maggaia.head()

In [None]:
ser_mags = df_maggaia.loc["gaiaspec_" + tag]
ser_mags

## Convert all SED-pickles in rubin-sim format

In [None]:
def get_rubinsim_sed_pickles(magref,bandref,df_pickle):
    """
    Renormalise all pickes sed according the magnitute magref in band bandref.
    df_picke s the dataframe containing info to accesssto SED pickles from file
    """
    all_sed_pickles = []
    
    for index in np.arange(NSED):
        spectype = df_pickle.loc[index,"SPTYPE"].strip()
        filename = df_pickle.loc[index,"FILENAME"].strip()+".fits"
        fullfilename = os.path.join(seddir_uvk,filename) 
        hdul = fits.open(fullfilename)
        dff = pd.DataFrame(hdul[1].data)
        wl= dff.WAVELENGTH.values
        flux = dff.FLUX.values
        #wavelen (nm)
        #flambda (ergs/cm^2/s/nm)
        sed_label = f"{index}_{spectype}"
        the_sed = Sed(wavelen=wl/10., flambda=flux*10., name=sed_label)
        flux_norm = the_sed.calc_flux_norm(magref, pc.bandpass_total_std[bandref])
        the_sed.multiply_flux_norm(flux_norm)
        all_sed_pickles.append(the_sed) 
    return all_sed_pickles

In [None]:
zmag = ser_mags["z"]
all_sed_pickles = get_rubinsim_sed_pickles(zmag,'z',df_pickle)
NPICKLES = len(all_sed_pickles)

### Check the normalisation over 2 pickles

In [None]:
def PlotFlambdaFnuGaiaPickle(the_sed_gaia,all_sed_pickles):
    """
    """

    the_sed1 = all_sed_pickles[0]
    #flux_norm = the_sed1.calc_flux_norm(zmag, pc.bandpass_total_std['r'])
    #the_sed1.multiply_flux_norm(flux_norm)

    the_sed2 = all_sed_pickles[-1]
    #flux_norm = the_sed2.calc_flux_norm(zmag, pc.bandpass_total_std['r'])
    #the_sed2.multiply_flux_norm(flux_norm)

    fig,(ax1,ax2) = plt.subplots(1,2,figsize=(16,6))
    ax1.plot(the_sed1.wavelen,the_sed1.flambda,"b-",label=the_sed1.name)
    ax1.plot(the_sed2.wavelen,the_sed2.flambda,"r-",label=the_sed2.name)
    ax1.plot(the_sed_gaia.wavelen,the_sed_gaia.flambda,"k-",lw=3,label = "gaia_"+tag)
    ax1.set_yscale("log")
    ax1.legend()
    ax1.set_ylim(1e-13,1e-9)
    ax1.set_xlim(300.,2000.)
    ax1.set_title("Pickles $F_\lambda$")
    ax1.set_ylabel("$F_\lambda$")
    ax1.set_xlabel("$\lambda \, (nm)$")

    ax2.plot(the_sed1 .wavelen,the_sed1.fnu,"b-",label=the_sed1.name)
    ax2.plot(the_sed2.wavelen,the_sed2.fnu,"r-",label=the_sed2.name)

    ax2.plot(the_sed_gaia_renorm.wavelen,the_sed_gaia_renorm.fnu,"k-",lw=3,label = "gaia_"+tag)
    ax2.set_yscale("log")
    ax2.legend()
    ax2.set_ylim(1e-3,100.)
    ax2.set_xlim(300.,2000.)
    ax2.set_title("Pickles $F_\\nu$")
    ax2.set_ylabel("$F_\\nu$")
    ax2.set_xlabel("$\lambda \, (nm)$")

    ax3 = ax1.twinx()
    for ifilt,f in enumerate(filter_tagnames):
        ax3.fill_between(pc.bandpass_total_std[f].wavelen,pc.bandpass_total_std[f].sb,color=filter_color[ifilt],alpha=0.1) 
        ax3.set_yticks([])
    
    ax4 = ax2.twinx()
    for ifilt,f in enumerate(filter_tagnames):
        ax4.fill_between(pc.bandpass_total_std[f].wavelen,pc.bandpass_total_std[f].sb,color=filter_color[ifilt],alpha=0.1) 
        ax4.set_yticks([])

    plt.show()

In [None]:
PlotFlambdaFnuGaiaPickle(the_sed_gaia,all_sed_pickles)

## Compute pickles magnitudes

In [None]:
def get_rubinsim_mag_pickles(all_sed_pickles):
    """
    Compute the magnitude for all pickle sed in all_sed_pickles
    Return the magnitude pandas dataframe all_mags_std_pickles in filters u,g,r,i,z,y
    """

    NPICKLES = len(all_sed_pickles)
    
    # container for all magnitudes and magnitudes differences
    all_mags_std_pickles = []

    # loop on spectra
    for index in np.arange(NPICKLES):
        mags_std_pickles = {}
        the_sed = all_sed_pickles[index]

        # loop on filters
        for index2,f in enumerate(filter_tagnames) :
            mags_std_pickles[f] = the_sed.calc_mag(pc.bandpass_total_std[f])
   
        dfmag = pd.DataFrame(mags_std_pickles, index=[the_sed.name]).T
        all_mags_std_pickles.append(dfmag.T)

    df_mags_std_pickles = pd.concat(all_mags_std_pickles)
    return df_mags_std_pickles
   

In [None]:
df_mags_std_pickles = get_rubinsim_mag_pickles(all_sed_pickles)
df_mags_std_pickles

## Find Nearest Neighbourg

In [None]:
def FindNearestNeighbors(X_data,X_neigh):
    nbrs = NearestNeighbors(n_neighbors=1, algorithm='ball_tree').fit(X_neigh)
    distances, indices = nbrs.kneighbors(X_data)
    return indices,distances  

In [None]:
def FindNearestKDT(X_data,X_neigh):
    kdt = KDTree(X_neigh,leaf_size=30, metric='euclidean')
    distances,indices = kdt.query(X_data, k=1, return_distance=True)
    return indices,distances  

In [None]:
def FindNearestBallT(X_data,X_neigh):
    bt = BallTree(X_neigh, metric='euclidean')
    distances,indices = bt.query(X_data,k=1,return_distance=True)
    return indices,distances  

In [None]:
# select in G,R,I,Z
# exclude border U,Y
X_pkl = df_mags_std_pickles.values[:,1:-1]
X_data = [ser_mags.values[1:-1]]

In [None]:
indices1,distances1 = FindNearestNeighbors(X_data,X_pkl)
indices2,distances2 = FindNearestKDT(X_data,X_pkl)
indices3,distances3 = FindNearestBallT(X_data,X_pkl)

In [None]:
np.hstack([indices1,indices2,indices3])

In [None]:
indice = indices1[0][0]

In [None]:
the_sed_pickle = all_sed_pickles[indice]

In [None]:
def find_fminfmax(wl,flux,wlmin=300.,wlmax=1200.):
    """
    """
    indexes = np.where(np.logical_and(wl>=wlmin,wl<=wlmax))[0]
    flux_sel = flux[indexes]
    return flux_sel.min(),flux_sel.max()

In [None]:
def PlotFlambdaGaiaCalspecPickle(the_sed_gaia,the_sed_calspec,the_sed_pickle):
    """
    """
    fig,ax = plt.subplots(1,1,figsize=(8,6))

    ax.plot(the_sed_gaia.wavelen,the_sed_gaia.flambda,"b-",lw=2,label = "gaia_"+tag)
    ax.plot(the_sed_calspec.wavelen,the_sed_calspec.flambda,"g:",lw=2,label = "calspec_"+tag)
    ax.plot(the_sed_pickle.wavelen,the_sed_pickle.flambda,"r-",lw=1,label=the_sed_pickle.name)
    ax.set_yscale("log")
    ax.legend()

    fmin,fmax = find_fminfmax(the_sed_pickle.wavelen,the_sed_pickle.flambda)
    fmin /=2
    fmax *=2

    ax.set_ylim(fmin,fmax)
    ax.set_xlim(300.,1100.)
    ax.set_title("Pickles $F_\lambda$")
    ax.set_ylabel("$F_\lambda$")
    ax.set_xlabel("$\lambda \, (nm)$")

    ax2 = ax.twinx()
    for ifilt,f in enumerate(filter_tagnames):
        ax2.fill_between(pc.bandpass_total_std[f].wavelen,pc.bandpass_total_std[f].sb,color=filter_color[ifilt],alpha=0.1) 
        ax2.set_yticks([])
    
    plt.show()
    

In [None]:
PlotFlambdaGaiaCalspecPickle(the_sed_gaia,the_sed_calspec,the_sed_pickle)

In [None]:
def PlotFlambdaFnuGaiaCalspecPickle(the_sed_gaia,the_sed_calspec,the_sed_pickle):
    """
    """
    
    fig,(ax1,ax2) = plt.subplots(1,2,figsize=(18,6))
    ax1.plot(the_sed_pickle.wavelen,the_sed_pickle.flambda,"r-",label=the_sed_pickle.name)
    ax1.plot(the_sed_gaia.wavelen,the_sed_gaia.flambda,"b-",lw=2,label = "gaia_"+tag)
    ax1.set_yscale("log")
    ax1.legend()

    fmin,fmax = find_fminfmax(the_sed_pickle.wavelen,the_sed_pickle.flambda)
    fmin /=2
    fmax *=2

    ax1.set_ylim(fmin,fmax)
    ax1.set_xlim(300.,1200.)
    ax1.set_title("Pickles $F_\lambda$")
    ax1.set_ylabel("$F_\lambda$")
    ax1.set_xlabel("$\lambda \, (nm)$")


    wlp,fnup = the_sed_pickle.get_sed_fnu() 
    ax2.plot(the_sed_pickle.wavelen,the_sed_pickle.fnu,"r-",label=the_sed_pickle.name)
    # force the calculation of fnu
    #wlg,fnug = the_sed_g.get_sed_fnu() 
    #ax2.plot(wlg,fnug,"b-",lw=2,label = "gaia_"+tag)
    ax2.plot(the_sed_gaia_renorm.wavelen,the_sed_gaia_renorm.fnu,"b-",lw=2,label = "gaia_"+tag)

    fmin,fmax = find_fminfmax(the_sed_pickle.wavelen,the_sed_pickle.fnu)
    fmin /=2
    fmax *=2

    ax2.set_yscale("log")
    ax2.legend()
    ax2.set_ylim(fmin,fmax)
    ax2.set_xlim(300.,1200.)
    ax2.set_title("Pickles $F_\\nu$")
    ax2.set_ylabel("$F_\\nu$")
    ax2.set_xlabel("$\lambda \, (nm)$")

    ax3 = ax.twinx()
    for ifilt,f in enumerate(filter_tagnames):
        ax3.fill_between(pc.bandpass_total_std[f].wavelen,pc.bandpass_total_std[f].sb,color=filter_color[ifilt],alpha=0.1) 
        ax3.set_yticks([])
    
    ax4 = ax2.twinx()
    for ifilt,f in enumerate(filter_tagnames):
        ax4.fill_between(pc.bandpass_total_std[f].wavelen,pc.bandpass_total_std[f].sb,color=filter_color[ifilt],alpha=0.1) 
        ax4.set_yticks([])

    plt.show()

In [None]:
PlotFlambdaFnuGaiaCalspecPickle(the_sed_gaia,the_sed_calspec,the_sed_pickle)

## Loop on Gaia 

In [None]:
for index in np.arange(NSPEC):

    # get info on the current Gaia 
    row = df_info.iloc[index]
    hdname = row["HD_name"]
    gaianame = row["GAIA_ED3_Name"]
    tag = f"{hdname}_{gaianame}"
    dfg = all_dfg[index]
    dfc = all_dfc[index]
    the_sed_gaia = all_sed_gaia[index]
    the_sed_calspec = all_sed_calspec[index]
    the_sed_gaia_renorm = all_sed_gaia_renorm[index]
    the_sed_calspec_renorm = all_sed_calspec_renorm[index]
    the_sed_name = all_sed_names[index]

    # retrieve the magnitudes of the gaia
    ser_mags = df_maggaia.loc["gaiaspec_" + tag]
    zmag = ser_mags["z"]

    # get the sed pickles renormaliszed
    all_sed_pickles = get_rubinsim_sed_pickles(zmag,'z',df_pickle)
    NPICKLES = len(all_sed_pickles)


    # compare gaia wrt extremme blue red pickles
    #PlotFlambdaFnuGaiaPickle(the_sed_gaia,the_sed_pickle)

    # compute the magnitudes
    df_mags_std_pickles = get_rubinsim_mag_pickles(all_sed_pickles)

    # Nearest neighbourg
    X_pkl = df_mags_std_pickles.values[:,1:-1]
    X_data = [ser_mags.values[1:-1]]

    indices,distances = FindNearestNeighbors(X_data,X_pkl)
    indice = indices[0][0]
    the_sed_pickle = all_sed_pickles[indice]

    # Plot Flambda
    PlotFlambdaGaiaCalspecPickle(the_sed_gaia,the_sed_calspec,the_sed_pickle)
    