In [None]:
import sys
sys.path.insert(0, '..')

import os.path as op

import pandas as pd

import matplotlib
import matplotlib.pyplot as plt
matplotlib.rc('font', **{'family': 'serif', 'serif': ['Computer Modern']})
matplotlib.rc('text', usetex=True)
matplotlib.rcParams.update({'font.size': 16})

import numpy as np

from astropy import units as u
from astropy.coordinates import SkyCoord
from astropy.io import fits
from astropy.table import Table

import pickle

In [None]:
#g PAUS catalog
sel_dfs = []
region_list = ['W1', 'W2', 'W3']
LFs_dir = '/home/alberto/almacen/PAUS_data/Lya_LFs'
nb_list = [[0, 3], [2, 5], [4, 7], [6, 9], [8, 11], [10, 13], [12, 15], [14, 18]]
for [nb1, nb2] in nb_list:
    for jj, region in enumerate(region_list):
        with open(f'{LFs_dir}/Lya_LF_nb{nb1}-{nb2}_{region}/selection.pkl', 'rb') as file:
            this_dict = pickle.load(file)
            # Separate the L_lya err into two columns
            this_dict['L_lya_corr_err_up'] = this_dict['L_lya_corr_err'][1]
            this_dict['L_lya_corr_err_down'] = this_dict['L_lya_corr_err'][0]
            del this_dict['L_lya_corr_err']

            this_df = pd.DataFrame(this_dict)
            this_df.insert(1, 'field', [region] * len(this_df))
        
        sel_dfs.append(this_df)
selection = pd.concat(sel_dfs)
selection['ref_id'] = selection['ref_id'].astype(int)
selection = selection.drop_duplicates(subset='ref_id')

nice_lya = np.asarray(selection['nice_lya'])
mask_high_NB = selection['lya_NB'] > 18
nice_lya[mask_high_NB] = (selection['nice_color'] & selection['nice_ml'])[mask_high_NB]

# Allow all high-z !!!!!
# nice_lya[mask_high_NB] = True
selection_raw = selection.copy()

selection = selection[nice_lya]

# Sort by redshift
selection = selection.sort_values(by='z_NB', ascending=True)

selection = selection.reset_index(drop=True)
print('PAUS cat loaded.')
selection_0 = selection.copy()

# SDSS catalog
path_to_cat = '/home/alberto/almacen/SDSS_spectra_fits/DR18/spAll-v5_13_2.fits'
sdss_cat = fits.open(path_to_cat)[1].data
print('SDSS cat loaded.')

# HETDEX catalog
path_to_cat = '/home/alberto/almacen/HETDEX_catalogs/hetdex_source_catalog_1'
version = 'v3.2'
hetdex_cat = pd.DataFrame(
    dict(Table.read(op.join(path_to_cat, f'hetdex_sc1_{version}.ecsv')))
    )
print('HETDEX cat loaded.')

# DESI catalog
path_to_cat = '/home/alberto/almacen/DESI/zall-pix-fuji.fits'
desi_cat = fits.open(path_to_cat)[1].data
print('DESI cat loaded.')
selection.head()

In [None]:
print(f'{len(selection)=}')

fig, ax = plt.subplots(figsize=(6, 4))

ax.hist(selection['lya_NB'], np.arange(-0.5, 40.5, 1))

ax.set_ylabel('N')
ax.set_xlabel('lya_NB')

plt.show()

fig, ax = plt.subplots(figsize=(6, 4))

ax.hist(selection['z_NB'], 50)

ax.set_ylabel('N')
ax.set_xlabel('lya_NB')

plt.show()

In [None]:
# Do the cross-matches
coords_paus = SkyCoord(ra=np.asarray(selection['RA']) * u.deg,
                       dec=np.asarray(selection['DEC']) * u.deg)
coords_sdss = SkyCoord(ra=sdss_cat['RA'] * u.deg,
                       dec=sdss_cat['DEC'] * u.deg)
coords_hetdex = SkyCoord(ra=np.asarray(hetdex_cat['RA']) * u.deg,
                         dec=np.asarray(hetdex_cat['DEC']) * u.deg)
coords_desi = SkyCoord(ra=desi_cat['TARGET_RA'] * u.deg,
                       dec=desi_cat['TARGET_DEC'] * u.deg)

xm_id_sdss, ang_dist_sdss, _= coords_paus.match_to_catalog_sky(coords_sdss)
xm_id_hetdex, ang_dist_hetdex, _= coords_paus.match_to_catalog_sky(coords_hetdex)
xm_id_desi, ang_dist_desi, _= coords_paus.match_to_catalog_sky(coords_desi)

# Objects with 1 arcsec of separation
mask_dist_sdss = (ang_dist_sdss <= 1.5 * u.arcsec) & (sdss_cat['ZWARNING'][xm_id_sdss] == 0)
mask_dist_desi = (ang_dist_desi <= 1.5 * u.arcsec) & (desi_cat['ZWARN'][xm_id_desi] == 0)
mask_dist_hetdex = (ang_dist_hetdex <= 1.5 * u.arcsec) & (hetdex_cat['z_hetdex_conf'][xm_id_hetdex] > 0.9)

mask_dist_sdss = np.asarray(mask_dist_sdss)
mask_dist_desi = np.asarray(mask_dist_desi)
mask_dist_hetdex = np.asarray(mask_dist_hetdex)

In [None]:
print(sum(mask_dist_desi))
print(sum(mask_dist_sdss))
print(sum(mask_dist_hetdex))

In [None]:
from paus_utils import z_NB

fig, ax = plt.subplots(figsize=(6, 3.25))


ax.plot(sdss_cat['Z'][xm_id_sdss][mask_dist_sdss],
        selection['z_NB'][mask_dist_sdss],
        ls='', marker='o', ms=6, mew=1.5,
        mec='g', mfc='none', zorder=11, label='SDSS')
ax.plot(np.asarray(hetdex_cat['z_hetdex'])[xm_id_hetdex][mask_dist_hetdex],
        selection['z_NB'][mask_dist_hetdex],
        ls='', marker='^', ms=6, mew=1.5,
        mec='peru', mfc='none', zorder=13, label='HETDEX')
ax.plot(np.asarray(desi_cat['Z'])[xm_id_desi][mask_dist_desi],
        selection['z_NB'][mask_dist_desi],
        ls='', marker='x', ms=6, mew=1.5,
        mec='deepskyblue', mfc='none', zorder=12, label='DESI')

# lines
zsp_xx = np.linspace(0, 5, 50)
w_lya = 1215.67
w_CIV = 1549.48
w_CIII = 1908.73
w_MgII = 2799.12
w_OII = 3727
w_OIII = (4960 + 5008) * 0.5
w_Hbeta = 4862
ax.plot(zsp_xx, zsp_xx, ls='-', c='r', zorder=-99)
ax.plot((zsp_xx + 1) * w_lya / w_CIV - 1, zsp_xx, 
        ls='--', c='dimgray', zorder=-99)
ax.plot((zsp_xx + 1) * w_lya / w_CIII - 1, zsp_xx, 
        ls='--', c='dimgray', zorder=-99)
ax.plot((zsp_xx + 1) * w_lya / w_MgII - 1, zsp_xx, 
        ls='--', c='dimgray', zorder=-99)
ax.plot((zsp_xx + 1) * w_lya / w_OII - 1, zsp_xx, 
        ls='--', c='dimgray', zorder=-99)
ax.plot((zsp_xx + 1) * w_lya / w_OIII - 1, zsp_xx, 
        ls='--', c='dimgray', zorder=-99)
ax.plot((zsp_xx + 1) * w_lya / w_Hbeta - 1, zsp_xx, 
        ls='--', c='dimgray', zorder=-99)

ax.plot((zsp_xx + 1) * w_lya / 1025 - 1, zsp_xx, 
        ls='--', c='dimgray', zorder=-99)

text_plot = [[0.29, 3.9, '[OII]'],
             [-0.07, 4.35, '[OIII]'],
             [0.37, 4.3, r'H$\beta$'],
             [0.8, 3.9, 'MgII'],
             [1.8, 3.9, 'CIII]'],
             [2.6, 3.9, 'CIV'],
             [3.5, 3.9, r'\bf Ly$\mathbf{\alpha}$'],
             [4.0, 3.5, r'Ly$\beta$']]
for text in text_plot:
        ax.text(text[0], text[1], text[2], fontsize=11,
                zorder=99999)

ax.set_xlim(-0.1, 4.5)
ax.set_ylim(2.5, 4.5)

ax.set_ylabel(r'$z_\mathrm{phot}$', fontsize=15)
ax.set_xlabel(r'$z_\mathrm{spec}$', fontsize=15)
ax.tick_params(direction='in', which='both', labelsize=13)
ax.yaxis.set_ticks_position('both')
ax.xaxis.set_ticks_position('both')

ax.legend(fontsize=11, framealpha=1)

fig.savefig('../figures/redshift_confusion_lines.pdf', bbox_inches='tight', pad_inches=0.1,
            facecolor='w')
plt.show(block=False)

fig, [ax, axh] = plt.subplots(1, 2, figsize=(6, 3.5), width_ratios=[1, 0.25])

ax.plot(selection['z_NB'][mask_dist_sdss],
        selection['z_NB'][mask_dist_sdss] - sdss_cat['Z'][xm_id_sdss][mask_dist_sdss],
        ls='', marker='o', ms=6, mew=1.5,
        mec='g', mfc='none', zorder=11, label='SDSS')
ax.plot(selection['z_NB'][mask_dist_hetdex],
        selection['z_NB'][mask_dist_hetdex] - np.asarray(hetdex_cat['z_hetdex'])[xm_id_hetdex][mask_dist_hetdex],
        ls='', marker='^', ms=6, mew=1.5,
        mec='peru', mfc='none', zorder=13, label='HETDEX')
ax.plot(selection['z_NB'][mask_dist_desi],
        selection['z_NB'][mask_dist_desi] - np.asarray(desi_cat['Z'])[xm_id_desi][mask_dist_desi],
        ls='', marker='x', ms=6, mew=1.5,
        mec='deepskyblue', mfc='none', zorder=12, label='DESI')

# NB redshifts
NB_redshift = z_NB(selection['lya_NB'])
ax.plot(NB_redshift[mask_dist_sdss],
        NB_redshift[mask_dist_sdss] - sdss_cat['Z'][xm_id_sdss][mask_dist_sdss],
        ls='', marker='.', ms=6,
        c='k', zorder=-99, alpha=0.4)
ax.plot(NB_redshift[mask_dist_hetdex],
        NB_redshift[mask_dist_hetdex] - np.asarray(hetdex_cat['z_hetdex'])[xm_id_hetdex][mask_dist_hetdex],
        ls='', marker='.', ms=6,
        c='k', zorder=-99, alpha=0.4)
ax.plot(NB_redshift[mask_dist_desi],
        NB_redshift[mask_dist_desi] - np.asarray(desi_cat['Z'])[xm_id_desi][mask_dist_desi],
        ls='', marker='.', ms=6,
        c='k', zorder=-99, alpha=0.4,
        label=r'$z_{\rm NB}$')

ax.errorbar(4.4, 0.09, yerr=0.053, mec='k',
            mfc='w', fmt='o', capsize=2, mew=1.5,
            ms=6, ecolor='k')
ax.text(4.325, 0.086, 'NB width =', fontsize=13,
        horizontalalignment='right')

ax.axhline(0, c='k', ls='--', zorder=-9999)

ax.set_xlim(2.5, 4.5)
ax.set_ylim(-0.1, 0.15)

ax.set_xlabel(r'$z_{\rm phot}$')
ax.set_ylabel(r'$z_{\rm phot} - z_{\rm spec}$')

ax.tick_params(direction='in', which='both', labelsize=13)
ax.yaxis.set_ticks_position('both')
ax.xaxis.set_ticks_position('both')

ax.legend(fontsize=11, ncol=2)

###############

bins = np.linspace(-0.1, 0.15, 25)

axh.hist(selection['z_NB'][mask_dist_sdss] - sdss_cat['Z'][xm_id_sdss][mask_dist_sdss],
         bins, orientation='horizontal', histtype='step', density=True,
         lw=2, color='g')
thisdiff = selection['z_NB'][mask_dist_sdss] - sdss_cat['Z'][xm_id_sdss][mask_dist_sdss]
print(np.mean(thisdiff[np.abs(thisdiff < 0.15)]))
print(np.var(thisdiff[np.abs(thisdiff < 0.15)]))
print()

axh.hist(selection['z_NB'][mask_dist_desi] - np.asarray(desi_cat['Z'])[xm_id_desi][mask_dist_desi],
         bins, orientation='horizontal', histtype='step', density=True,
         lw=2, color='deepskyblue')
thisdiff = selection['z_NB'][mask_dist_desi] - np.asarray(desi_cat['Z'])[xm_id_desi][mask_dist_desi]
print(np.mean(thisdiff[np.abs(thisdiff < 0.15)]))
print(np.var(thisdiff[np.abs(thisdiff < 0.15)]))
print()

axh.hist(np.concatenate([NB_redshift[mask_dist_sdss] - sdss_cat['Z'][xm_id_sdss][mask_dist_sdss],
         NB_redshift[mask_dist_desi] - np.asarray(desi_cat['Z'])[xm_id_desi][mask_dist_desi]]),
         bins, orientation='horizontal', alpha=0.3, density=True,
         lw=2, color='dimgray', zorder=-99)

thisdiff = np.concatenate([NB_redshift[mask_dist_sdss] - sdss_cat['Z'][xm_id_sdss][mask_dist_sdss],
         NB_redshift[mask_dist_desi] - np.asarray(desi_cat['Z'])[xm_id_desi][mask_dist_desi]])
print(np.mean(thisdiff[np.abs(thisdiff < 0.15)]))
print(np.var(thisdiff[np.abs(thisdiff < 0.15)]))


axh.tick_params(direction='in', which='both', labelsize=13)
axh.yaxis.set_ticks_position('both')
axh.xaxis.set_ticks_position('both')
axh.set_yticks([])
axh.set_xticks([])

axh.tick_params(labelleft=False)
axh.tick_params(labelbottom=False)    
axh.set_ylim(-0.1, 0.15)

axh.axhline(0, c='k', ls='--', zorder=-9999)


fig.subplots_adjust(wspace=0.04)

fig.savefig('../figures/zphot_zspec_redshift.pdf', bbox_inches='tight', pad_inches=0.1,
            facecolor='w')
plt.show()

In [None]:
# Define global z_spec
z_spec = np.ones_like(mask_dist_desi).astype(float) * -1

z_spec_sdss = np.ones_like(mask_dist_desi).astype(float) * -1
z_spec_desi = np.ones_like(mask_dist_desi).astype(float) * -1
z_spec_hetdex = np.ones_like(mask_dist_desi).astype(float) * -1

z_spec[mask_dist_hetdex] = hetdex_cat['z_hetdex'][xm_id_hetdex][mask_dist_hetdex]
z_spec_hetdex[mask_dist_hetdex] = hetdex_cat['z_hetdex'][xm_id_hetdex][mask_dist_hetdex]

z_spec[mask_dist_desi] = desi_cat['Z'][xm_id_desi][mask_dist_desi]
z_spec_desi[mask_dist_desi] = desi_cat['Z'][xm_id_desi][mask_dist_desi]

z_spec[mask_dist_sdss] = sdss_cat['Z'][xm_id_sdss][mask_dist_sdss]
z_spec_sdss[mask_dist_sdss] = sdss_cat['Z'][xm_id_sdss][mask_dist_sdss]
selection['z_spec'] = z_spec

In [None]:
save_sel_to = '/home/alberto/almacen/PAUS_data/catalogs/LAE_selection.csv'
selection_0.to_csv(save_sel_to)

sel_to_visual_insp = selection.copy()
plate = np.zeros(len(mask_dist_sdss))
mjd = np.zeros(len(mask_dist_sdss))
fiber = np.zeros(len(mask_dist_sdss))
plate[mask_dist_sdss] = sdss_cat['PLATE'][xm_id_sdss][mask_dist_sdss]
mjd[mask_dist_sdss] = sdss_cat['MJD'][xm_id_sdss][mask_dist_sdss]
fiber[mask_dist_sdss] = sdss_cat['FIBERID'][xm_id_sdss][mask_dist_sdss]

sel_to_visual_insp['plate'] = plate.astype(int)
sel_to_visual_insp['mjd'] = mjd.astype(int)
sel_to_visual_insp['fiber'] = fiber.astype(int)

save_sel_to = '/home/alberto/almacen/PAUS_data/catalogs/LAE_selection_vi.csv'
sel_to_visual_insp.to_csv(save_sel_to)

In [None]:
from jpasLAEs.utils import smooth_hist

fig, ax = plt.subplots(figsize=(6, 4))

mask_L = (selection['L_lya_corr'] > 40) & (selection['EW0_lya'] > 00)
sel_mask = (z_spec > 0) & mask_L

# Define nice z
nice_z = np.abs(selection['z_NB'] - z_spec) < 0.12

nice_mask = sel_mask & nice_z

nice_h_smooth, to_plot_c = smooth_hist(selection['z_NB'][nice_mask],
                                       2.7, 4.5, 0.05, 0.15)
sel_h_smooth, to_plot_c = smooth_hist(selection['z_NB'][sel_mask],
                                       2.7, 4.5, 0.05, 0.15)

sdss_p = nice_h_smooth / sel_h_smooth
p_err = ((nice_h_smooth ** 0.5 / sel_h_smooth) ** 2
         + (sdss_p * sel_h_smooth ** -0.5) ** 2) ** 0.5
ax.errorbar(to_plot_c, sdss_p, lw=2, label='L_lya$>$44, EW0$>$0')

ax.legend()

ax.set_ylim(0, 1.01)
ax.set_xlabel('Redshift')
ax.set_ylabel('Purity')

plt.show(block=False)

In [None]:
vi_sel = fits.open('/home/alberto/almacen/PAUS_data/catalogs/PAUS_LAE_selection_visual_insp_AT.fits')[1].data

selection['is_LAE_VI'] = np.zeros_like(selection['z_NB']).astype('bool')
selection['is_junk_VI'] = np.zeros_like(selection['z_NB']).astype('bool')

for iii, refid in enumerate(vi_sel['ref_id']):
    where_in_selection = np.where(refid == selection['ref_id'])
    if len(where_in_selection[0]) == 0:
        continue

    selection['is_LAE_VI'][where_in_selection] = vi_sel['is_LAE_VI'][iii]
    selection['is_junk_VI'][where_in_selection] = vi_sel['is_junk_VI'][iii]


In [None]:
selection['ref_id'][selection['is_LAE_VI'] & ~nice_z & (z_spec>0)]

In [None]:
mask_L = (selection['L_lya_corr'] > 40) & (selection['EW0_lya'] > 00)

all_mask = ~selection['is_junk_VI']
is_LAE_mask = selection['is_LAE_VI']

sel_mask = (z_spec > 0) & mask_L & all_mask

sel_sdss = (z_spec_sdss > 0) & mask_L & all_mask
sel_desi = (z_spec_desi > 0) & mask_L & all_mask

nice_mask = sel_mask & nice_z & all_mask
print(sum(nice_mask) / sum(sel_mask))

bins = np.linspace(43, 45.5, 20)
dbin = bins[1] - bins[0]

fig, ax = plt.subplots(figsize=(6, 4))

ax.hist(selection['L_lya_corr'][all_mask], histtype='step',
        lw=2, bins=bins, color='crimson', label='Full sample', ls='--',
        weights=np.ones(sum(all_mask)) / dbin / 35)
ax.hist(selection['L_lya_corr'][is_LAE_mask], histtype='step',
        lw=2, bins=bins, color='crimson', label='Golden sample',
        weights=np.ones(sum(is_LAE_mask)) / dbin / 35)
ax.hist(selection['L_lya_corr'][sel_mask],
        lw=2, bins=bins, color='teal', label='With spectrum', alpha=0.5,
        weights=np.ones(sum(sel_mask)) / dbin / 35)
ax.hist(selection['L_lya_corr'][nice_mask],
        lw=2, bins=bins, color='teal', label=r'$\vert z_{\rm phot}-z_{\rm spec}\vert <0.1$',
        weights=np.ones(sum(nice_mask)) / dbin / 35)

ax.hist(selection['L_lya_corr'][sel_sdss],
        lw=2, bins=bins, color='black', histtype='step', ls=':',
        label=r'SDSS DR16',
        weights=np.ones(sum(sel_sdss)) / dbin / 35)
ax.hist(selection['L_lya_corr'][sel_desi],
        lw=2, bins=bins, color='peru', histtype='step', ls=':',
        label=r'DESI EDR',
        weights=np.ones(sum(sel_desi)) / dbin / 35)

ax.legend(fontsize=11, frameon=False, ncol=2)
ax.set_yscale('log')
ax.set_xlim(43.25, 45.5)
ax.set_ylim(1e-1, 3e2)

ax.tick_params(direction='in', which='both')
ax.yaxis.set_ticks_position('both')
ax.xaxis.set_ticks_position('both')

ax.set_ylabel(r'\# [deg$^{-2}\Delta z^{-1}$]')
ax.set_xlabel(r'$\log_{10}(L_{\mathrm{Ly}\alpha}/\mathrm{erg\,s}^{-1})$')

fig.savefig('../figures/Purity_histograms_spec_xmatch.pdf', bbox_inches='tight', pad_inches=0.1,
            facecolor='w')
plt.show()

In [None]:
mask_L = (selection['L_lya_corr'] > 40) & (selection['EW0_lya'] > 00)

all_mask = ~selection['is_junk_VI'] & mask_L
is_LAE_mask = selection['is_LAE_VI'] & mask_L

sel_mask = (z_spec > 0) & mask_L & all_mask

sel_sdss = (z_spec_sdss > 0) & mask_L & all_mask
sel_desi = (z_spec_desi > 0) & mask_L & all_mask

nice_mask = sel_mask & nice_z & all_mask
print(sum(nice_mask) / sum(sel_mask))

bins = np.linspace(-29., -20, 22)
dbin = bins[1] - bins[0]

fig, ax = plt.subplots(figsize=(6, 4))

ax.hist(selection['M_UV'][all_mask], histtype='step',
        lw=2, bins=bins, color='crimson', label='Full sample', ls='--',
        weights=np.ones(sum(all_mask)) / dbin / 35)
ax.hist(selection['M_UV'][is_LAE_mask], histtype='step',
        lw=2, bins=bins, color='crimson', label='Golden sample',
        weights=np.ones(sum(is_LAE_mask)) / dbin / 35)
ax.hist(selection['M_UV'][sel_mask],
        lw=2, bins=bins, color='teal', label='With spectrum', alpha=0.5,
        weights=np.ones(sum(sel_mask)) / dbin / 35)
ax.hist(selection['M_UV'][nice_mask],
        lw=2, bins=bins, color='teal', label=r'$\vert z_{\rm phot}-z_{\rm spec}\vert <0.1$',
        weights=np.ones(sum(nice_mask)) / dbin / 35)

ax.hist(selection['M_UV'][sel_sdss],
        lw=2, bins=bins, color='black', histtype='step', ls=':',
        label=r'SDSS DR16',
        weights=np.ones(sum(sel_sdss)) / dbin / 35)
ax.hist(selection['M_UV'][sel_desi],
        lw=2, bins=bins, color='peru', histtype='step', ls=':',
        label=r'DESI EDR',
        weights=np.ones(sum(sel_desi)) / dbin / 35)


ax.legend(fontsize=11, frameon=False, ncol=2)
ax.set_yscale('log')
ax.set_xlim(-21, -29)
ax.set_ylim(1e-2, 5e2)

ax.tick_params(direction='in', which='both')
ax.yaxis.set_ticks_position('both')
ax.xaxis.set_ticks_position('both')

ax.set_ylabel(r'\# [deg$^{-2}\Delta z^{-1}$]')
ax.set_xlabel(r'$M_{\rm UV}$')

fig.savefig('../figures/Purity_histograms_spec_xmatch_MUV.pdf', bbox_inches='tight', pad_inches=0.1,
            facecolor='w')
plt.show()

In [None]:
# Now with the raw selection
# Do the cross-matches
coords_paus = SkyCoord(ra=np.asarray(selection_raw['RA']) * u.deg,
                       dec=np.asarray(selection_raw['DEC']) * u.deg)
coords_sdss = SkyCoord(ra=sdss_cat['RA'] * u.deg,
                       dec=sdss_cat['DEC'] * u.deg)
coords_hetdex = SkyCoord(ra=np.asarray(hetdex_cat['RA']) * u.deg,
                         dec=np.asarray(hetdex_cat['DEC']) * u.deg)
coords_desi = SkyCoord(ra=desi_cat['TARGET_RA'] * u.deg,
                       dec=desi_cat['TARGET_DEC'] * u.deg)

xm_id_sdss, ang_dist_sdss, _= coords_paus.match_to_catalog_sky(coords_sdss)
xm_id_hetdex, ang_dist_hetdex, _= coords_paus.match_to_catalog_sky(coords_hetdex)
xm_id_desi, ang_dist_desi, _= coords_paus.match_to_catalog_sky(coords_desi)

# Objects with 1 arcsec of separation
mask_dist_sdss = (ang_dist_sdss <= 1.5 * u.arcsec) & (sdss_cat['ZWARNING'][xm_id_sdss] == 0)
mask_dist_desi = (ang_dist_desi <= 1.5 * u.arcsec) & (desi_cat['ZWARN'][xm_id_desi] == 0)
mask_dist_hetdex = (ang_dist_hetdex <= 1.5 * u.arcsec) & (hetdex_cat['z_hetdex_conf'][xm_id_hetdex] > 0.9)

mask_dist_sdss = np.asarray(mask_dist_sdss)
mask_dist_desi = np.asarray(mask_dist_desi)
mask_dist_hetdex = np.asarray(mask_dist_hetdex)

In [None]:
zspec = np.concatenate([sdss_cat['Z'][xm_id_sdss][mask_dist_sdss],
                       desi_cat['Z'][xm_id_desi][mask_dist_desi]])

this_class_pred = np.concatenate([selection_raw['class_pred'][mask_dist_sdss],
                                  selection_raw['class_pred'][mask_dist_desi]])

this_L_lya = np.concatenate([selection_raw['L_lya_corr'][mask_dist_sdss],
                             selection_raw['L_lya_corr'][mask_dist_desi]])

this_nice_ml = np.concatenate([selection_raw['nice_ml'][mask_dist_sdss],
                               selection_raw['nice_ml'][mask_dist_desi]])

this_class_star = np.concatenate([selection_raw['class_star'][mask_dist_sdss],
                                  selection_raw['class_star'][mask_dist_desi]])

this_mask = (this_L_lya > 40) & (this_class_star > 0.0) & this_nice_ml


zbins = np.linspace(0, 4.5, 50)

fig, ax = plt.subplots(figsize=(6, 4))

h_all, _, _ = ax.hist(zspec[this_mask], zbins, alpha=0.3, color='teal',
            label='All candidates w/spec')

h_2, _, _ = ax.hist(zspec[(this_class_pred == 2) & this_mask],
            zbins, color='teal',
            label='Classified as LAE')

h_1, _, _ = ax.hist(zspec[(this_class_pred == 1) & this_mask],
            zbins, histtype='step', lw=2, color='sienna',
            label='Classified as QSO cont.')

h_4, _, _ = ax.hist(zspec[(this_class_pred == 4) & this_mask],
            zbins, histtype='step', lw=2, color='limegreen',
            label='Classified as low-$z$ galaxy')

# mark the minimum Lya redshift
ax.axvline(2.7, lw=2, ls=':', c='r')

ax.set_xlabel('Spectroscopic redshift')
ax.set_ylabel('Number of objects')

ax.set_ylim(0, 150)
ax.set_xlim(0, 4.2)

ax.tick_params(direction='in', which='both')
ax.yaxis.set_ticks_position('both')
ax.xaxis.set_ticks_position('both')

ax.legend(fontsize=11, ncol=2, frameon=True, framealpha=1)

fig.savefig('../figures/ML_class_spectroscopic_hist.pdf', bbox_inches='tight', pad_inches=0.1,
    facecolor='w')
plt.show()