In [None]:
import numpy as np
from paus_utils import *
from astropy.table import Table
from load_paus_cat import paus_flux_units

from jpasLAEs.utils import mag_to_flux, flux_to_mag

from LAE_selection_method import *

In [None]:
import seaborn as sns

region_name = 'W3'
nb1, nb2 = 0, 2

corr_dir = '/home/alberto/almacen/PAUS_data/LF_corrections'
r_bins = np.linspace(17, 24, 200 + 1)
L_bins = np.linspace(40, 47, 200 + 1)
puri2d = np.load(f'{corr_dir}/puri2D_W3_nb{nb1}-{nb2}.npy')

import pickle

LF_name = f'Lya_LF_nb{nb1}-{nb2}_{region_name}'
pathname = f'/home/alberto/almacen/PAUS_data/Lya_LFs/{LF_name}'
with open(f'{pathname}/selection.pkl', 'rb') as f:
    selection = pickle.load(f)

def L_to_bins(L_Arr):
    return np.interp(L_Arr, L_bins, np.arange(len(L_bins)))
def r_to_bins(mag):
    return np.interp(mag, r_bins, np.arange(len(r_bins)))

fig, ax = plt.subplots(figsize=(7, 7))


selec_L = selection['L_lya_corr']
selec_L_err = selection['L_lya_corr_err']
L_Arr_b = L_to_bins(selec_L)
mag_b = r_to_bins(selection['r_mag'])
L_err_Arr_b = (L_to_bins(selec_L + selec_L_err)
                - L_to_bins(selec_L - selec_L_err)) * 0.5
ax.errorbar(L_Arr_b, mag_b, fmt='s', xerr=L_err_Arr_b,
        color='k', capsize=3, linestyle='', ms=2, alpha=0.5)

sns.heatmap(puri2d.T, vmin=0, vmax=1, ax=ax,
            cmap='Spectral', rasterized=True)

yticks = np.arange(len(r_bins))[::20]
ax.set_yticks(yticks)
ax.set_yticklabels(f'{s:0.1f}' for s in r_bins[yticks])

xticks = np.arange(len(L_bins))[::20]
ax.set_xticks(yticks)
ax.set_xticklabels(f'{s:0.1f}' for s in L_bins[xticks])

ax.set(xlim=(80, 170), ylim=(200, 15))

plt.show()

In [None]:
# tab = Table.read('/home/alberto/almacen/PAUS_data/catalogs/EDR_1_v2.fits').to_pandas()

# cat = {} # Initialize catlalog dict

# cat['ref_id'] = np.array(tab['ref_id'])
# cat['photoz'] = np.array(tab['zb'])
# cat['photoz_odds'] = np.array(tab['odds'])
# cat['ra'] = np.array(tab['ra'])
# cat['dec'] = np.array(tab['dec'])
# cat['type'] = np.array(tab['type'])
# cat['zspec'] = np.array(tab['zspec'])
# cat['r_mag'] = np.array(tab['rmag'])

# # Flux units have to be converted to erg s^-1 A^-1
# cat['flx'] = paus_flux_units(tab.to_numpy()[:, 7 : 7 + 40],
#                                 w_central[:-6]).T
# cat['err'] = paus_flux_units(tab.to_numpy()[:, 7 + 40 : 7 + 40 + 40],
#                                 w_central[:-6]).T

# N_sources = len(cat['ref_id'])

# # Add provisional BBs
# cat['flx'] = np.vstack([cat['flx'],
#                         mag_to_flux(tab.to_numpy()[:, -5:].T, w_central[-6:-1].reshape(-1,1)),
#                         np.zeros(N_sources)])
# cat['err'] = np.vstack([cat['err'],
#                         np.zeros((6, N_sources))])

In [None]:
# from load_paus_mocks import load_gal_mock

# def load_this_cat():
#     lc_path = '/home/alberto/almacen/PAUS_data/catalogs/LightCone_mock.fits'
#     cat = load_gal_mock(lc_path)

#     nominal_errs = mag_to_flux(23, w_central) / 3
#     cat['err'] = np.ones_like(cat['flx_0']) * nominal_errs.reshape(-1, 1)
#     cat['flx'] = cat['flx_0'] #+ cat['err'] * np.random.normal(size=cat['flx_0'].shape)
#     # TODO: add_errors function
#     # mock['flx'], mock['err'] = add_errors(mock['flx_0'], field_name)

#     # Compute r_mag
#     cat['r_mag'] = flux_to_mag(cat['flx'][-4], w_central[-4])
#     # tab = Table.read(lc_path).to_pandas().to_numpy()
#     # mock_size = len(tab)

#     # cat_fraction = 0.5 # Load only a fraction of the mock
#     # np.random.seed(1312)
#     # sel = np.random.randint(0, mock_size, size=int(mock_size * cat_fraction))
    
#     # cat = {}

#     # cat['flx'] = mag_to_flux(tab[sel, 11 : 11 + 40],
#     #                                 w_central[:-6]).T

#     # # Add provisional BBs
#     # cat['flx'] = np.vstack([cat['flx'],
#     #                         mag_to_flux(tab[sel, -5:].T,
#     #                                     w_central[-6:-1].reshape(-1,1)),
#     #                         np.zeros(len(sel))])

#     # cat['r_mag'] = flux_to_mag(cat['flx'][-4], w_central[-4])
#     # cat['i_mag'] = flux_to_mag(cat['flx'][-3], w_central[-3])
#     # mag_mask = np.array(cat['r_mag'] < 24) & np.array(cat['i_mag'] < 23)
#     # cat['flx'] = cat['flx'][:, mag_mask]
#     # cat['r_mag'] = cat['r_mag'][mag_mask]

#     # nominal_errs = mag_to_flux(23, w_central) / 3
#     # cat['err'] = np.ones_like(cat['flx']) * nominal_errs.reshape(-1, 1)

#     # # Apply errs
#     # cat['flx'] += cat['err'] * np.random.normal(size=cat['flx'].shape)

#     # cat['zspec'] = np.array(tab[:, 4])[sel][mag_mask]

#     return cat

# cat = load_this_cat()

# N_sources = cat['flx'].shape[1]
# print(f'N_sources = {N_sources}')

In [None]:
# # Load QSO mock
# from load_paus_mocks import add_errors, load_qso_mock

# source_cats_dir = '/home/alberto/almacen/Source_cats'
# mock_path = f'{source_cats_dir}/QSO_PAUS_LAES_2'
# cat = load_qso_mock(mock_path)

# field_name = 'W3'
# cat['flx'], cat['err'] = add_errors(cat['flx_0'], field_name, True)

# # Compute r_mag
# cat['r_mag'] = flux_to_mag(cat['flx'][-4], w_central[-4])

# N_sources = len(cat['r_mag'])
# print(f'N_sources = {N_sources}')

# mock = cat

In [None]:
from load_paus_cat import load_paus_cat

field_name = 'W3'
path_to_cat = [f'/home/alberto/almacen/PAUS_data/catalogs/PAUS_3arcsec_{field_name}.csv']
cat = load_paus_cat(path_to_cat)

mask_NB_number = (cat['NB_number'] > 39)
cat['flx'] = cat['flx'][:, mask_NB_number]
cat['err'] = cat['err'][:, mask_NB_number]
cat['NB_mask'] = cat['NB_mask'][:, mask_NB_number]
for key in cat.keys():
    if key in ['flx', 'err', 'NB_mask', 'area']:
        continue
    cat[key] = cat[key][mask_NB_number]

stack_nb_ids = np.arange(12, 16 + 1)
synth_BB_flx = np.average(cat['flx'][stack_nb_ids],
                          weights=cat['err'][stack_nb_ids] ** -2,
                          axis=0)
cat['synth_r_mag'] = flux_to_mag(synth_BB_flx, w_central[-4])

N_sources = len(cat['ref_id'])

In [None]:
# import matplotlib.colors as colors

# rows = np.arange(16)
# cols = [8, 16, 24, 32, 40]

# mask_NB_number = (cat['NB_number'] > 39)

# nfilter_mat = np.zeros((len(rows), len(cols)))
# for i, n_row in enumerate(rows):
#     for j, n_col in enumerate(cols):
#         fil_0 = n_row
#         fil_delta = np.min([39, fil_0 + n_col])

#         if fil_delta == 39 and fil_0 == 0:
#             extra_mask = mask_NB_number
#         else:
#             extra_mask = ~mask_NB_number

#         N_src = sum(np.all(np.isfinite(cat['flx'][fil_0 : fil_0 + fil_delta]),
#                            axis=0) & extra_mask)
#         nfilter_mat[i, j] = N_src

# fig, ax = plt.subplots(figsize=(8, 10))


# sns.heatmap(nfilter_mat, annot=True, ax=ax,
#             cbar=False, fmt='0.0f',
#             xticklabels=[f'+{c}' for c in cols],
#             yticklabels=rows,
#             norm=colors.LogNorm())

# ax.set_ylabel('First filter (f$_0$)')
# ax.set_xlabel('$\Delta f$')

# # fig.savefig('/home/alberto/Desktop/fil_table.pdf', facecolor='w')
# plt.show()

In [None]:
from PAUS_Lya_LF_corrections import L_lya_bias_apply

nb_min, nb_max = 0, 20
r_min, r_max = 17, 24

print(f'Searching for LAEs at {z_NB(nb_min):0.2f} < z < {z_NB(nb_max):0.2f}')

cat = select_LAEs(cat, nb_min, nb_max, r_min, r_max)
# cat = L_lya_bias_apply(cat, field_name, nb_min, nb_max)
cont_est, cont_err = estimate_continuum(cat['flx'], cat['err'],
                                        IGM_T_correct=True, N_nb=6)

In [None]:
# fig, ax = plt.subplots(figsize=(6, 6))

# # nb_min, nb_max = 6, 8

# z_min, z_max = z_NB(nb_min), z_NB(nb_max)

# # ax.scatter(cat['L_lya_spec'], cat['L_lya'])
# # ax.scatter(cat['L_lya_spec'], cat['L_lya_corr'])

# xx = [-100, 100]
# ax.plot(xx, xx, ls='--', c='r')

# ax.set(xlim=(43, 46), ylim=(43, 46))

# plt.show()

In [None]:
from Make_Lya_LF import puricomp2d_weights
nb_min, nb_max = 0, 2
corr_dir = '/home/alberto/almacen/PAUS_data/LF_corrections'
puri2d = np.load(f'{corr_dir}/puri2D_{field_name}_nb{nb_min}-{nb_max}.npy')
comp2d = np.load(f'{corr_dir}/comp2D_{field_name}_nb{nb_min}-{nb_max}.npy')
puricomp2d_L_bins = np.load(f'{corr_dir}/puricomp2D_L_bins.npy')
puricomp2d_r_bins = np.load(f'{corr_dir}/puricomp2D_r_bins.npy')
puri, comp = puricomp2d_weights(cat['r_mag'], cat['L_lya'], puri2d, comp2d,
                       puricomp2d_L_bins, puricomp2d_r_bins)

In [None]:
nb_list = [[0, 2], [2, 4], [4, 6], [6, 8],
        [8, 10], [10, 12], [12, 14], [14, 16], [16, 18]]

for nb_min, nb_max in nb_list:
    print(f'\nnbs: {nb_min}-{nb_max}')
    # Study color distributions
    import pickle
    from load_paus_mocks import load_mock_dict

    field_name = 'W3'
    savedir = '/home/alberto/almacen/PAUS_data/LF_corrections'

    # nb_min, nb_max = 4, 6

    color_dict = {
        'SFG': 'C1',
        'QSO_cont': 'dimgray',
        'QSO_LAEs_loL': 'C0',
        'QSO_LAEs_hiL': 'blue',
        'GAL': 'C2'
    }

    with open(f'{savedir}/mock_dict_{field_name}_nb{nb_min}-{nb_max}.pkl', 'rb') as f:
        mock_dict = pickle.load(f)

    # source_cats_dir = '/home/alberto/almacen/Source_cats'
    # mock_SFG_path = f'{source_cats_dir}/LAE_12.5deg_z2.55-5_PAUS_0'
    # mock_QSO_cont_path = f'{source_cats_dir}/QSO_PAUS_contaminants_2'
    # mock_QSO_LAEs_loL_path = f'{source_cats_dir}/QSO_PAUS_LAES_2'
    # mock_QSO_LAEs_hiL_path = f'{source_cats_dir}/QSO_PAUS_LAES_hiL_2'
    # mock_GAL_path = '/home/alberto/almacen/PAUS_data/catalogs/LightCone_mock.fits'
    # mock_dict = load_mock_dict(mock_SFG_path, mock_QSO_cont_path,
    #                                 mock_QSO_LAEs_loL_path, mock_QSO_LAEs_hiL_path,
    #                                 mock_GAL_path, gal_fraction=0.3)

    m_color_bins = np.linspace(-2, 5, 30)

    color_name = ['ug', 'gr', 'ri', 'iz'][::-1]

    for i, cname in enumerate(color_name):
        # fig, ax = plt.subplots(figsize=(6, 4))

        for mock_name, mock in mock_dict.items():
            nice_lya = mock['nice_lya'] & (np.abs(mock['zspec'] - z_NB(mock['lya_NB'])) < 0.2)
            m_color = (flux_to_mag(mock['flx'][-3 - i][nice_lya], w_central[-3 - i])
                    - flux_to_mag(mock['flx'][-2 - i][nice_lya], w_central[-2 - i]))

            # ax.hist(m_color, m_color_bins,
            #         color=color_dict[mock_name],
            #         histtype='step', density=True,
            #         lw=2)

            if mock_name == 'QSO_LAEs_hiL':
                perc = np.nanpercentile(m_color, [1, 99])
                print(f'{cname} = {perc}')
                for p in perc:
                    ax.axvline(p, ls='--', color='k')

        #     ax.set_title(cname)

        # plt.show()

In [None]:
sdss_xm = pd.read_csv('/home/alberto/almacen/PAUS_data/catalogs/Xmatch_SDSS_W3.csv')

In [None]:
qso_lines = [1025.7220, 1397.61, 1549.48, 1908.73, 2799.12, 2326, 3727]
qso_lines_name = ['LyB', 'SiIV', 'CIV', 'CIII', 'MgII', 'CII', 'OII']

line_dict = {'halpha': 6564.61,
             'hbeta': 4862.68,
             'oii3727': 3727.10,
             'oiii4959': 4960.30,
             'oiii5007': 5008.24}

nice_lya = cat['nice_lya']
z_Arr = cat['z_NB']
lya_lines = cat['lya_NB']
other_lines = cat['other_lines_NBs']

for name, w in zip(qso_lines_name, qso_lines):
    line_dict[name] = w

selection = (nice_lya & (cat['NB_number'] > 39)
    # & (z_NB(nb_min) <= sdss_xm['z_CIV']) & (z_NB(nb_max) >= sdss_xm['z_CIV']) & (sdss_xm['L_lya'] > 44)
    # & (sdss_xm['z_CIV'] < z_NB(nb_min) - 0.5) & (cat['L_lya'] > 44) & (sdss_xm['z_CIV'] > 0)
    & (sdss_xm['z_best'][mask_NB_number] > 0)
    & (sdss_xm['z_best'][mask_NB_number] < 2.7)
    & (lya_lines >= 0) & (lya_lines <= 2)
    & (cat['synth_r_mag'] > 17)
    & (cat['L_lya'] > 44)
)

print(f'{sum(selection)=}')


for i, src in enumerate(np.random.permutation(np.where(selection)[0])):
    if i == 10:
        break

    print(f'r = {cat["r_mag"][src]:0.2f}, L_lya = {cat["L_lya"][src]:0.2f}, ',
          f'EW0_Lya = {cat["EW0_lya"][src]:0.2f}, '
          f'b_frac = {cat["bulge_fraction"][src]:0.2f}, '
          f'flattening = {cat["flattening"][src]:0.2f}'
          f'\nclass_pred = {cat["class_pred"][src]}'
          f'\nnice_lya={cat["nice_lya"][src]}, nice_color={cat["nice_color"][src]}, nice_lines={cat["nice_ml"][src]}'
          f'\nz_NB = {z_Arr[src]:0.2f}, ',
        #   f'\nzspec = {cat["zspec"][src]:0.2f}',
          f'\nz_best = {np.array(sdss_xm["z_best"])[mask_NB_number][src]:0.2f}',
          f'\nL_lya_spec = {np.array(sdss_xm["L_lya"])[mask_NB_number][src]:0.2f}',
          f'\nEW0_lya_spec = {np.array(sdss_xm["EW0_lya"])[mask_NB_number][src]:0.2f}',
          f'\npuri = {puri[src]:0.2f}, comp = {comp[src]:0.2f} ({puri[src]/comp[src]:0.2f})',)


    fig, ax = plt.subplots(figsize=(8, 3.5))

    cat['flx'][-1, src] = 0
    cat['err'][-1, src] = 0
    plot_PAUS_source(cat['flx'][:, src], cat['err'][:, src],
                     ax=ax, plot_BBs=True, set_ylim=False)

    ax.axvline(w_central[lya_lines[src]], ls='--', c='r')

    # # Show other detected lines
    # for l in other_lines[src]:
    #     ax.axvline(w_central[l], ls=':', c='dimgray', zorder=99)

    for name, w in zip(qso_lines_name, qso_lines):
        qso_obs_w = w * (1 + z_Arr[src])
        if (4400 > qso_obs_w) | (qso_obs_w > 8600):
            continue
        ax.axvline(qso_obs_w, linestyle=':', color='orange')
        ypos = cat['flx'][:40, src].min() * 1e17
        ax.text(qso_obs_w, ypos, name, color='k')
    # for name, w in line_dict.items():
    #     obs_w = w * (1 + z_Arr[src])
    #     if (4400 > obs_w) | (obs_w > 8600):
    #         continue
    #     ax.axvline(obs_w, linestyle=':', color='orange')
    #     ypos = cat['flx'][:40, src].min() * 1e17
    #     ax.text(obs_w, ypos, name, color='k')

    # # ax.plot(w_central[-4], 1e17 * synth_BB_flx[src],
    # #         ls='', marker='s', ms=10, c='r')

    # ax.plot(w_central[:30], cont_est[:30, src] * 1e17)
    
    plt.show()

In [None]:
# fig, ax = plt.subplots()

# ax.scatter(cat['r_mag'], cat['synth_r_mag'],
#            s=1, alpha=0.2)
# print(np.nanmedian(cat['r_mag'] - cat['synth_r_mag']))

# xx = [-100, 100]
# ax.plot(xx, xx, ls='--', c='k')

# ax.set_ylim(15, 26)
# ax.set_xlim(15, 26)

# plt.show()

In [None]:
# bins = np.linspace(0, 1, 1000)

# fig, ax = plt.subplots(figsize=(6, 4))

# mask = cat['nice_lya']
# ax.hist(cat['flattening'][mask], bins, histtype='step', density=True)
# mask = cat['nice_lya']\
#     & (z_NB(nb_min) <= sdss_xm['z_CIV'][mask_NB_number]) & (z_NB(nb_max) >= sdss_xm['z_CIV'][mask_NB_number])
# ax.hist(cat['flattening'][mask], bins, histtype='step', density=True,
#         label='good')

# ax.legend()

# ax.set_xlabel('Flattening')

# plt.show()

# fig, ax = plt.subplots(figsize=(6, 4))

# mask = cat['nice_lya']
# ax.hist(cat['bulge_fraction'][mask], bins, histtype='step', density=True)
# mask = cat['nice_lya']\
#     & (z_NB(nb_min) <= sdss_xm['z_CIV'][mask_NB_number]) & (z_NB(nb_max) >= sdss_xm['z_CIV'][mask_NB_number])
# ax.hist(cat['bulge_fraction'][mask], bins, histtype='step', density=True,
#         label='good')

# ax.legend()

# ax.set_xlabel('Bulge fraction')

# plt.show()

In [None]:
# fig, ax = plt.subplots(figsize=(6, 4))

# nb_min, nb_max = 0, 10

# nice_z = (z_NB(nb_min) <= sdss_xm['z_CIV']) & (z_NB(nb_max) >= sdss_xm['z_CIV'])
# nb_mask = (cat['lya_NB'] >= nb_min) & (cat['lya_NB'] <= nb_max)

# L_bins = np.linspace(42, 46, 30)
# ax.hist(cat['L_lya'][nice_lya & nice_z & nb_mask], L_bins,
#         histtype='step', label='nice z')
# ax.hist(cat['L_lya'][nice_lya & ~nice_z & (sdss_xm['z_CIV'] > 0) & nb_mask], L_bins,
#         histtype='step', label='bad z')

# ax.legend(fontsize=11)
# ax.set_xlabel('log L_Lya')

# plt.show()

In [None]:
from sklearn.metrics import confusion_matrix

z_SDSS = np.array(sdss_xm["z_best"])[mask_NB_number]
nice_SDSS_z = np.abs(z_SDSS - cat['z_NB']) < 0.2
class_SDSS = np.ones_like(nice_SDSS_z).astype(int)
class_SDSS[nice_SDSS_z] = 2

mask = (z_SDSS > 0) & (cat['class_pred'] > 0)
cm = confusion_matrix(class_SDSS[mask], cat['class_pred'][mask],
                      labels=[1, 2, 4, 5])

# Plot confusion matrix
label_names = ['Contaminants', 'LAEs', 'low-z Gal', '?']
sns.heatmap(cm, annot=True, cmap="Blues", fmt='0.0f',
            xticklabels=label_names, yticklabels=label_names,
            cbar=False)
plt.xlabel("Predicted Labels")
plt.ylabel("True Labels")
plt.show(block=False)