In [None]:
import os
import sys
sys.path.insert(0, '..')

import matplotlib
import matplotlib.pyplot as plt
# matplotlib.rc('font', **{'family': 'serif', 'serif': ['Computer Modern']})
# matplotlib.rc('text', usetex=True)
matplotlib.rcParams.update({'font.size': 16})

import numpy as np

from astropy.io import fits
from astropy.cosmology import Planck18 as cosmo
import astropy.units as u

from load_paus_cat import load_paus_cat
from paus_utils import w_central, z_NB

import pandas as pd

In [None]:
# vi_cat = fits.open('/home/alberto/almacen/PAUS_data/catalogs/PAUS_LAE_selection_visual_insp_AT.fits')[1].data
vi_cat = fits.open('/home/alberto/almacen/PAUS_data/catalogs/LAE_selection_VI_hiZ.fits')[1].data
# vi_cat = pd.read_csv('/home/alberto/almacen/PAUS_data/catalogs/LAE_selection.csv')

In [None]:
paus_tcurves_dir = '/home/alberto/almacen/PAUS_data/OUT_FILTERS'
tcurves_file_list = os.listdir(paus_tcurves_dir)
tcurves_file_list.sort()

paus_fil_names = []

for name in tcurves_file_list:
    if name[4] == 'D':
        this_name = f'NB{name[6:9]}'
    else:
        this_name = name[-5]

    paus_fil_names.append(this_name)

paus_fil_names = paus_fil_names[6:]

In [None]:
# Initialize output data dict
to_stack_data = {}
for nbi in range(40):
    to_stack_data[paus_fil_names[nbi]] = []
    to_stack_data[f'{paus_fil_names[nbi]}_error'] = []
to_stack_data['z'] = []
to_stack_lya_NB = []
to_stack_L_lya = []


for field_name in ['W1', 'W2', 'W3']:
    path_to_cat = [f'/home/alberto/almacen/PAUS_data/catalogs/PAUS_3arcsec_{field_name}_extinction_corrected.pq']
    cat = load_paus_cat(path_to_cat)

    # Get IDs from LAEs visually selected
    # mask = ~vi_cat['is_junk_VI'] & (vi_cat['field'] == field_name)
    mask = vi_cat['is_hiZ_LAE'] & (vi_cat['field'] == field_name)


    LAE_vi_IDs = np.array(vi_cat['ref_id'][mask])

    lya_NB  = np.array(vi_cat['lya_NB'][mask])
    lya_NB[vi_cat['lya_NB_VI'][mask] > 0] = vi_cat['lya_NB_VI'][mask][vi_cat['lya_NB_VI'][mask] > 0]


    # redshifts  = np.array(vi_cat['z_NB'][mask])
    redshifts = z_NB(lya_NB)
    L_lya  = np.array(vi_cat['L_lya_corr'][mask])
    print(len(LAE_vi_IDs))

    where_LAEs_in_cat = np.empty_like(LAE_vi_IDs).astype(int)
    for i, thisid in enumerate(LAE_vi_IDs):
        where_LAEs_in_cat[i] = np.where(thisid == cat['ref_id'])[0][0]

    # Compute normalization for each source, in the 160-180 nm band (rest-frame)
    norm = np.empty_like(L_lya)
    norm_wl_min, norm_wl_max = 1300, 1370

    for src in range(len(L_lya)):
        mask_norm_band = (w_central > norm_wl_min * (1 + redshifts[src])) & (w_central < norm_wl_max * (1 + redshifts[src]))
        mask_norm_band[-6:] = False
        dl = cosmo.luminosity_distance(redshifts[src]).to(u.cm).value
        norm[src] = np.average(cat['flx'][mask_norm_band, where_LAEs_in_cat[src]] * 1450 * (1 + redshifts[src]) * dl**2,
                               weights=(cat['err'][mask_norm_band, where_LAEs_in_cat[src]] * 1450 * (1 + redshifts[src]) * dl**2)**-2)

    # Data to save to a .csv to be read by stonp
    # 40 Narrow-bands
    for nbi in range(40):
        to_stack_data[paus_fil_names[nbi]].append(cat['flx'][nbi, where_LAEs_in_cat] / norm)
        to_stack_data[f'{paus_fil_names[nbi]}_error'].append(cat['err'][nbi, where_LAEs_in_cat] / norm)
    to_stack_data['z'].append(redshifts)
    to_stack_lya_NB.append(lya_NB)
    to_stack_L_lya.append(L_lya)


for nbi in range(40):
    to_stack_data[paus_fil_names[nbi]] = np.concatenate(to_stack_data[paus_fil_names[nbi]])
    to_stack_data[f'{paus_fil_names[nbi]}_error'] = np.concatenate(to_stack_data[f'{paus_fil_names[nbi]}_error'])
to_stack_data['z'] = np.concatenate(to_stack_data['z'])
to_stack_lya_NB = np.concatenate(to_stack_lya_NB)
to_stack_L_lya = np.concatenate(to_stack_L_lya)

In [None]:
# Save it to csv
# nb_list = [[0, 2], [2, 4], [4, 6], [6, 8],
#            [8, 10], [10, 12], [12, 14], [14, 16],
#            [16, 18]]
# nb_list = [[3, 6], [6, 9], [9, 12], [12, 15], [15, 18]]
nb_list = [[0, 50]]

for [nb1, nb2] in nb_list:
    this_mask = (to_stack_lya_NB >= nb1) & (to_stack_lya_NB <= nb2) & (to_stack_L_lya > 40)
    print(sum(this_mask))
    pd.DataFrame(to_stack_data)[this_mask].to_csv(f'fluxes_to_stack_NB{nb1}_{nb2}.csv')

In [None]:
import stonp
import astropy.units as u

wl_list = []
stacked_seds_50 = []
stacked_seds_16 = []
stacked_seds_84 = []

for [nb1, nb2] in nb_list:
    st = stonp.Stacker()

    st.load_catalog(f'fluxes_to_stack_NB{nb1}_{nb2}.csv', z_label='z')
    # st.flux_units = u.dimensionless_unscaled
    st.to_rest_frame(flux_conversion='luminosity')
    st.stack(error_type='flux_error', percentile_q=50)
    stack_50 = st.stacked_seds
    st.stack(error_type='flux_error', percentile_q=16)
    stack_16 = st.stacked_seds
    st.stack(error_type='flux_error', percentile_q=84)
    stack_84 = st.stacked_seds

    # st.save_stack(f'stacked_seds_NB{nb1}_{nb2}', overwrite=True)
    # st.plot()

    wl_list.append(st.wl_grid)
    stacked_seds_50.append(stack_50)
    stacked_seds_16.append(stack_16)
    stacked_seds_84.append(stack_84)

In [None]:
## Let's first save a reference spec
# np.save('ref_lum.npy', stacked_seds_50[0][0])
# np.save('ref_wav.npy', wl_list[0])
ref_lum = np.load('ref_lum.npy')
ref_wav = np.load('ref_wav.npy')

fig, ax = plt.subplots(figsize=(12, 6))


ax.plot(ref_wav, ref_lum, c='C0')

z_corr_factor = 1.01

iii = 0
for wl, sed_50, sed_16, sed_84 in zip(wl_list, stacked_seds_50, stacked_seds_16, stacked_seds_84):
    wavelength = wl*z_corr_factor

    ax.plot(wavelength, sed_50[0],
            # label=f'NB{nb_list[iii][0]}-{nb_list[iii][1]}',
            label='$z=4.4$',
            c='C1')
    ax.fill_between(wavelength, sed_16[0], sed_84[0], alpha=0.3, zorder=-999,
                    color='C1', lw=0)
    iii += 1


ax.axvline(91.2, ls=':', c='k', zorder=-99)
ax.axvline(154.9, ls=':', c='k', zorder=-99)
ax.axvline(121.56, ls=':', c='k', zorder=-99)
ax.axvline(139.98, ls=':', c='k', zorder=-99)
ax.axvline(190.8734, ls=':', c='k', zorder=-99)

# ax.axvline(norm_wl_min / 10)
# ax.axvline(norm_wl_max / 10)
ax.axhline(0, c='k')

ax.set_xlabel(r'Wavelength [\AA]')
ax.set_ylabel(r'$L_\lambda$ [A. U.]')
# ax.set_xlim(85, 180)
ax.legend()

plt.show()