In [None]:
from astropy.io import fits
import numpy as np
from LAE_selection_method import Lya_L_estimation, estimate_continuum
from load_paus_cat import load_paus_cat
from paus_utils import z_NB
import matplotlib.pyplot as plt

In [None]:
hiz_cat = fits.open('/home/alberto/almacen/PAUS_data/catalogs/LAE_selection_VI_hiZ.fits')[1].data

In [None]:
L_lya_list = []
z_list = []
lya_NB_list = []

for field_name in ['W1', 'W2', 'W3']:
    path_to_cat = [f'/home/alberto/almacen/PAUS_data/catalogs/PAUS_3arcsec_{field_name}_extinction_corrected.pq']
    paus_cat = load_paus_cat(path_to_cat)

    mask = hiz_cat['is_hiz_LAE'] & (hiz_cat['field'] == field_name)

    LAE_vi_IDs = np.array(hiz_cat['ref_id'][mask])
    where_LAEs_in_cat = np.empty_like(LAE_vi_IDs).astype(int)
    for i, thisid in enumerate(LAE_vi_IDs):
        where_LAEs_in_cat[i] = np.where(thisid == paus_cat['ref_id'])[0][0]


    lya_NB = np.array(hiz_cat['lya_NB'])
    lya_NB[hiz_cat['lya_NB_VI'] > 0] = hiz_cat['lya_NB_VI'][hiz_cat['lya_NB_VI'] > 0]

    paus_cat['lya_NB'] = np.ones(paus_cat['flx'].shape[1]).astype(int) * -1
    paus_cat['nice_lya'] = np.zeros(paus_cat['flx'].shape[1]).astype(bool)
    paus_cat['nice_lya_0'] = np.zeros(paus_cat['flx'].shape[1]).astype(bool)
    paus_cat['z_NB'] = np.empty(paus_cat['flx'].shape[1])

    paus_cat['lya_NB'][where_LAEs_in_cat] = lya_NB[mask]
    paus_cat['nice_lya'][where_LAEs_in_cat] = True
    paus_cat['nice_lya_0'][where_LAEs_in_cat] = True
    paus_cat['z_NB'][where_LAEs_in_cat] = z_NB(lya_NB[mask])


    cont_est, cont_err = estimate_continuum(paus_cat['flx'], paus_cat['err'],
                                            IGM_T_correct=True, N_nb=6)

    paus_cat = Lya_L_estimation(paus_cat, cont_est, cont_err)


    # Save the stuff I'm interested in saving
    L_lya_list += list(paus_cat['L_lya'][paus_cat['nice_lya']])
    z_list += list(paus_cat['z_NB'][paus_cat['nice_lya']])
    lya_NB_list += list(paus_cat['lya_NB'][paus_cat['nice_lya']])

In [None]:
plt.hist(L_lya_list, 10)
plt.show()
plt.hist(z_list, 10)
plt.show()

In [None]:
from paus_utils import Lya_effective_volume

lya_NB_bins = [18.5, 24.5, 30.5]

vol1 = Lya_effective_volume(19, 24, 'W1') + Lya_effective_volume(19, 24, 'W2') + Lya_effective_volume(19, 24, 'W3')
vol2 = Lya_effective_volume(25, 30, 'W1') + Lya_effective_volume(25, 31, 'W2') + Lya_effective_volume(25, 30, 'W3')


mask1 = (np.array(lya_NB_list) >= 19) & (np.array(lya_NB_list) <= 24)
rho1 = np.sum(pow(10., np.array(L_lya_list))[mask1]) / vol1
mask2 = (np.array(lya_NB_list) >= 25) & (np.array(lya_NB_list) <= 30)
rho2 = np.sum(pow(10., np.array(L_lya_list))[mask1]) / vol2

print(rho1, rho2)