In [None]:
import sys
sys.path.insert(0, '..')

from astropy.table import Table
from astropy import units as u
from astropy.coordinates import SkyCoord

import os.path as op

import numpy as np
import pandas as pd

from load_paus_cat import load_paus_cat
from jpasLAEs.utils import flux_to_mag
from paus_utils import w_central, plot_PAUS_source

import matplotlib
import matplotlib.pyplot as plt
matplotlib.rc('font', **{'family': 'serif', 'serif': ['Computer Modern']})
matplotlib.rc('text', usetex=True)
matplotlib.rcParams.update({'font.size': 16})

In [None]:
# Load HETDEX
path_to_cat = '/home/alberto/almacen/HETDEX_catalogs/hetdex_source_catalog_1'
version = 'v3.2'

hetdex_table = pd.DataFrame(
    dict(Table.read(op.join(path_to_cat, f'hetdex_sc1_{version}.ecsv')))
    )

# Load PAUS catalog
field_name = 'W3'
path_to_cat = [f'/home/alberto/almacen/PAUS_data/catalogs/PAUS_3arcsec_{field_name}.csv']
cat = load_paus_cat(path_to_cat)

mask_NB_number = (cat['NB_number'] > 39)
cat['flx'] = cat['flx'][:, mask_NB_number]
cat['err'] = cat['err'][:, mask_NB_number]
cat['NB_mask'] = cat['NB_mask'][:, mask_NB_number]
for key in cat.keys():
    if key in ['flx', 'err', 'NB_mask', 'area']:
        continue
    cat[key] = cat[key][mask_NB_number]

stack_nb_ids = np.arange(12, 16 + 1)
synth_BB_flx = np.average(cat['flx'][stack_nb_ids],
                          weights=cat['err'][stack_nb_ids] ** -2,
                          axis=0)
cat['synth_r_mag'] = flux_to_mag(synth_BB_flx, w_central[-4])

N_sources = len(cat['ref_id'])

In [None]:
hetdex_table.keys()

In [None]:
# Xmatch
coords_paus = SkyCoord(ra=np.array(cat['RA']) * u.deg,
                       dec=np.array(cat['DEC']) * u.deg)
coords_hetdex = SkyCoord(ra=np.array(hetdex_table['RA']) * u.deg,
                         dec=np.array(hetdex_table['DEC']) * u.deg)
                    
xm_id, ang_dist, _= coords_paus.match_to_catalog_sky(coords_hetdex)

# Objects with 1 arcsec of separation
mask_dist = (ang_dist <= 1 * u.arcsec)

cat['z_HETDEX'] = np.array(hetdex_table['z_hetdex'])[xm_id]
cat['z_HETDEX'][~mask_dist] = np.ones(sum(~mask_dist)) * -1
cat['HETDEX_class'] = np.array(hetdex_table['source_type'])[xm_id]
cat['HETDEX_class'][~mask_dist] = ''

In [None]:
sel_id_Arr = cat['ref_id'][np.where((cat['HETDEX_class'] == 'lae'))]
print(len(sel_id_Arr))

for j, refid in enumerate(sel_id_Arr):
    if j >= 100:
        break
    src = np.where(refid == cat['ref_id'])[0][0]

    fig, ax = plt.subplots(figsize=(8, 3.5))

    cat['flx'][-1, src] = 0
    cat['err'][-1, src] = 0
    plot_PAUS_source(cat['flx'][:, src], cat['err'][:, src],
                     ax=ax, plot_BBs=True, set_ylim=False)

    ax.axvline(1215.67 * (cat['z_HETDEX'][src] + 1))

    plt.show()