In [None]:
import sys
sys.path.insert(0, '..')

import os.path as op

import pandas as pd
import numpy as np

import matplotlib
import matplotlib.pyplot as plt
matplotlib.rc('font', **{'family': 'serif', 'serif': ['Computer Modern']})
matplotlib.rc('text', usetex=True)
matplotlib.rcParams.update({'font.size': 16})

from load_paus_cat import load_paus_cat

from astropy.coordinates import SkyCoord

from paus_utils import w_central
from jpasLAEs.utils import flux_to_mag

from astropy.io import fits
from astropy import units as u
from astropy.table import Table

In [None]:
selection = pd.read_csv('/home/alberto/almacen/PAUS_data/catalogs/LAE_selection.csv')

sel_coords = SkyCoord(selection['RA'], selection['DEC'], unit='deg')

In [None]:
vi_cat = fits.open('/home/alberto/almacen/PAUS_data/catalogs/PAUS_LAE_selection_visual_insp_AT.fits')
vi_ref_ID = vi_cat[1].data['ref_id']
vi_is_LAE = vi_cat[1].data['is_LAE_VI']
vi_field = vi_cat[1].data['field']

selection['is_LAE_VI'] = np.zeros_like(selection['nice_lya']).astype(bool)

for field_name in ['W1', 'W2', 'W3']:
    this_vi_ref_ID = vi_ref_ID[vi_field == field_name]
    this_vi_is_LAE = vi_is_LAE[vi_field == field_name]
    LAE_IDs = this_vi_ref_ID[this_vi_is_LAE]
    for thisid in LAE_IDs:
        if thisid in np.array(selection['ref_id']):
            selection['is_LAE_VI'][(selection['ref_id'] == thisid) & (selection['field'] == field_name)] = True

In [None]:
u_flux = np.empty(len(selection))
g_flux = np.empty(len(selection))
r_flux = np.empty(len(selection))
i_flux = np.empty(len(selection))
z_flux = np.empty(len(selection))

for field_name in ['W1', 'W2', 'W3']:
    path_to_cat = [f'/home/alberto/almacen/PAUS_data/catalogs/PAUS_3arcsec_{field_name}_extinction_corrected.pq']
    cat = load_paus_cat(path_to_cat)

    this_cat_coords = SkyCoord(cat['RA'], cat['DEC'], unit='deg')

    this_ids, this_dist, _ = sel_coords.match_to_catalog_sky(this_cat_coords)

    mask_dist = this_dist.value < (3 / 3600)

    u_flux[mask_dist] = cat['flx'][-6][this_ids][mask_dist]
    g_flux[mask_dist] = cat['flx'][-5][this_ids][mask_dist]
    r_flux[mask_dist] = cat['flx'][-4][this_ids][mask_dist]
    i_flux[mask_dist] = cat['flx'][-3][this_ids][mask_dist]
    z_flux[mask_dist] = cat['flx'][-2][this_ids][mask_dist]

In [None]:
# SDSS catalog
path_to_cat = '/home/alberto/almacen/SDSS_spectra_fits/DR18/spAll-v5_13_2.fits'
sdss_cat = fits.open(path_to_cat)[1].data
print('SDSS cat loaded.')

# HETDEX catalog
path_to_cat = '/home/alberto/almacen/HETDEX_catalogs/hetdex_source_catalog_1'
version = 'v3.2'
hetdex_cat = pd.DataFrame(
    dict(Table.read(op.join(path_to_cat, f'hetdex_sc1_{version}.ecsv')))
    )
print('HETDEX cat loaded.')

# DESI catalog
path_to_cat = '/home/alberto/almacen/DESI/zall-pix-fuji.fits'
desi_cat = fits.open(path_to_cat)[1].data
print('DESI cat loaded.')

In [None]:
# Do the cross-matches
coords_paus = SkyCoord(ra=np.asarray(selection['RA']) * u.deg,
                       dec=np.asarray(selection['DEC']) * u.deg)
coords_sdss = SkyCoord(ra=sdss_cat['RA'] * u.deg,
                       dec=sdss_cat['DEC'] * u.deg)
coords_hetdex = SkyCoord(ra=np.asarray(hetdex_cat['RA']) * u.deg,
                         dec=np.asarray(hetdex_cat['DEC']) * u.deg)
coords_desi = SkyCoord(ra=desi_cat['TARGET_RA'] * u.deg,
                       dec=desi_cat['TARGET_DEC'] * u.deg)

xm_id_sdss, ang_dist_sdss, _= coords_paus.match_to_catalog_sky(coords_sdss)
xm_id_hetdex, ang_dist_hetdex, _= coords_paus.match_to_catalog_sky(coords_hetdex)
xm_id_desi, ang_dist_desi, _= coords_paus.match_to_catalog_sky(coords_desi)

# Objects with 1 arcsec of separation
mask_dist_sdss = (ang_dist_sdss <= 1.5 * u.arcsec) & (sdss_cat['ZWARNING'][xm_id_sdss] == 0)
mask_dist_desi = (ang_dist_desi <= 1.5 * u.arcsec) & (desi_cat['ZWARN'][xm_id_desi] == 0)
mask_dist_hetdex = (ang_dist_hetdex <= 1.5 * u.arcsec) & (hetdex_cat['z_hetdex_conf'][xm_id_hetdex] > 0.9)

mask_dist_sdss = np.asarray(mask_dist_sdss)
mask_dist_desi = np.asarray(mask_dist_desi)
mask_dist_hetdex = np.asarray(mask_dist_hetdex)

mask_spec = mask_dist_sdss | mask_dist_desi | mask_dist_hetdex

In [None]:
u_mag = flux_to_mag(u_flux, w_central[-6])
g_mag = flux_to_mag(g_flux, w_central[-5])
r_mag = flux_to_mag(r_flux, w_central[-4])
i_mag = flux_to_mag(i_flux, w_central[-3])
z_mag = flux_to_mag(z_flux, w_central[-2])

In [None]:
fig, ax = plt.subplots(figsize=(12, 12))

cmap = plt.get_cmap('rainbow')

scatter = ax.scatter(g_mag - r_mag, r_mag - i_mag,
           c=(selection['L_lya_corr'] - 43)/ (45 - 43), cmap=cmap)
# ax.scatter((g_mag - r_mag)[selection['is_LAE_VI']], (r_mag - i_mag)[selection['is_LAE_VI']])
# ax.scatter((g_mag - r_mag)[mask_spec], (r_mag - i_mag)[mask_spec])

plt.colorbar(scatter)


ax.set_ylim(-1, 3)
ax.set_xlim(-1, 3)

plt.show()

fig, ax = plt.subplots()

ax.scatter(r_mag - i_mag, i_mag - z_mag)
# ax.scatter((r_mag - i_mag)[selection['is_LAE_VI']], (i_mag - z_mag)[selection['is_LAE_VI']])
# ax.scatter((r_mag - i_mag)[mask_spec], (i_mag - z_mag)[mask_spec])

ax.set_ylim(-1, 3)
ax.set_xlim(-1, 3)

plt.show()

In [None]:
plt.scatter(r_mag, selection['L_lya_corr'])
plt.scatter(r_mag[selection['is_LAE_VI']], selection['L_lya_corr'][selection['is_LAE_VI']])

plt.xlim(18, 25)