In [None]:
import os
import sys
sys.path.insert(0, '..')

import matplotlib
import matplotlib.pyplot as plt
# matplotlib.rc('font', **{'family': 'serif', 'serif': ['Computer Modern']})
# matplotlib.rc('text', usetex=True)
matplotlib.rcParams.update({'font.size': 16})

import numpy as np

from astropy.io import fits

from load_paus_cat import load_paus_cat

import pandas as pd

In [None]:
vi_cat = fits.open('/home/alberto/almacen/PAUS_data/catalogs/PAUS_LAE_selection_visual_insp_AT.fits')[1].data

In [None]:
paus_tcurves_dir = '/home/alberto/almacen/PAUS_data/OUT_FILTERS'
tcurves_file_list = os.listdir(paus_tcurves_dir)
tcurves_file_list.sort()

paus_fil_names = []

for name in tcurves_file_list:
    if name[4] == 'D':
        this_name = f'NB{name[6:9]}'
    else:
        this_name = name[-5]

    paus_fil_names.append(this_name)

paus_fil_names = paus_fil_names[6:]

In [None]:
# Initialize output data dict
to_stack_data = {}
for nbi in range(40):
    to_stack_data[paus_fil_names[nbi]] = []
    to_stack_data[f'{paus_fil_names[nbi]}_error'] = []
to_stack_data['z'] = []
to_stack_lya_NB = []
to_stack_L_lya = []


for field_name in ['W1', 'W2', 'W3']:
    path_to_cat = [f'/home/alberto/almacen/PAUS_data/catalogs/PAUS_3arcsec_{field_name}_extinction_corrected.pq']
    cat = load_paus_cat(path_to_cat)

    # Get IDs from LAEs visually selected
    LAE_vi_IDs = vi_cat['ref_id'][vi_cat['is_LAE_VI'] & (vi_cat['field'] == field_name)]
    redshifts  = vi_cat['z_NB'][vi_cat['is_LAE_VI'] & (vi_cat['field'] == field_name)]
    lya_NB  = vi_cat['lya_NB'][vi_cat['is_LAE_VI'] & (vi_cat['field'] == field_name)]
    L_lya  = vi_cat['L_lya_corr'][vi_cat['is_LAE_VI'] & (vi_cat['field'] == field_name)]
    print(len(LAE_vi_IDs))

    where_LAEs_in_cat = np.empty_like(LAE_vi_IDs).astype(int)
    for i, thisid in enumerate(LAE_vi_IDs):
        where_LAEs_in_cat[i] = np.where(thisid == cat['ref_id'])[0][0]

    # Data to save to a .csv to be read by stonp
    # 40 Narrow-bands
    for nbi in range(40):
        to_stack_data[paus_fil_names[nbi]].append(cat['flx'][nbi, where_LAEs_in_cat])
        to_stack_data[f'{paus_fil_names[nbi]}_error'].append(cat['err'][nbi, where_LAEs_in_cat])
    to_stack_data['z'].append(redshifts)
    to_stack_lya_NB.append(lya_NB)
    to_stack_L_lya.append(L_lya)


for nbi in range(40):
    to_stack_data[paus_fil_names[nbi]] = np.concatenate(to_stack_data[paus_fil_names[nbi]])
    to_stack_data[f'{paus_fil_names[nbi]}_error'] = np.concatenate(to_stack_data[f'{paus_fil_names[nbi]}_error'])
to_stack_data['z'] = np.concatenate(to_stack_data['z'])
to_stack_lya_NB = np.concatenate(to_stack_lya_NB)
to_stack_L_lya = np.concatenate(to_stack_L_lya)

In [None]:
# Save it to csv
# nb_list = [[0, 2], [2, 4], [4, 6], [6, 8],
#            [8, 10], [10, 12], [12, 14], [14, 16],
#            [16, 18]]
nb_list = [[0, 3], [4, 7], [8, 11], [12, 15], [16, 19]]

for [nb1, nb2] in nb_list:
    this_mask = (to_stack_lya_NB >= nb1) & (to_stack_lya_NB <= nb2) & (to_stack_L_lya > 44)
    print(sum(this_mask))
    pd.DataFrame(to_stack_data)[this_mask].to_csv(f'fluxes_to_stack_NB{nb1}_{nb2}.csv')

In [None]:
import stonp
wl_list = []
stacked_seds_list = []

for [nb1, nb2] in nb_list:
    st = stonp.Stacker()

    st.load_catalog(f'fluxes_to_stack_NB{nb1}_{nb2}.csv', z_label='z')
    st.to_rest_frame()
    st.stack()

    st.save_stack(f'stacked_seds_NB{nb1}_{nb2}', overwrite=True)
    # st.plot()

    wl_list.append(st.wl_grid)
    stacked_seds_list.append(st.stacked_seds)

In [None]:
fig, ax = plt.subplots(figsize=(12, 6))

iii = 0
for wl, seds in zip(wl_list, stacked_seds_list):
    this_norm = np.mean(seds[0][(wl > 130) & (wl < 137)])
    ax.plot(wl, seds[0] / this_norm, label=f'NB{nb_list[iii][0]}-{nb_list[iii][1]}')
    iii += 1

ax.set_xlabel('Wavelength [\AA]')
ax.set_ylabel('Flux [A. U.]')
# ax.set_xlim(110, 140)
ax.legend()

plt.show()

In [None]:
fig, ax = plt.subplots(figsize=(12, 6))

wl0 = wl_list[0]
this_norm = np.mean(stacked_seds_list[0][0][(wl0 > 130) & (wl0 < 137)])
seds0 = stacked_seds_list[0][0] / this_norm

iii = 0
for wl, seds in zip(wl_list, stacked_seds_list):
    iii += 1
    this_norm = np.mean(seds[0][(wl > 130) & (wl < 137)])
    ax.plot(wl, seds[0] / this_norm - np.interp(wl, wl0, seds0), label=iii)

ax.axhline(0, c='k', zorder=-999)

ax.legend()

plt.show()