In [None]:
import numpy as np

import matplotlib.pyplot as plt
import matplotlib
matplotlib.rcParams.update({'font.size': 12})

from my_functions import *
from load_jpas_catalogs import load_minijpas_jnep

from scipy.stats import binned_statistic

In [None]:
w_central = central_wavelength()
nb_fwhm_Arr = nb_fwhm(range(60))

In [None]:
pm_flx, pm_err, tile_id, pmra_sn, pmdec_sn, parallax_sn, starprob, starlhood,\
    spCl, zsp, photoz, photoz_chi_best, photoz_odds, N_minijpas, _, _ = load_minijpas_jnep()

N_sources = pm_flx.shape[1]
is_minijpas_source = np.ones(N_sources).astype(bool)
is_minijpas_source[N_minijpas:] = False

In [None]:
mag = flux_to_mag(pm_flx[-2], w_central[-2])

In [None]:
w_lya = 1215.67 # A

z_nb_Arr = w_central[:-4] / w_lya - 1
print(N_sources)

In [None]:
pm_mask = mask_proper_motion(parallax_sn, pmra_sn, pmdec_sn)
mag_mask = (mag > 17) & (mag < 24)

mask = pm_mask & mag_mask

In [None]:
print(f'miniJPAS: {count_true(is_minijpas_source & mask)}')
print(f'J-NEP: {count_true(~is_minijpas_source & mask)}')
print(f'Total: {count_true(mask)}')

In [None]:
ew0_cut = 30

# Lya search
cont_est_lya, cont_err_lya = estimate_continuum(pm_flx, pm_err, IGM_T_correct=True)
line = is_there_line(pm_flx, pm_err, cont_est_lya, cont_err_lya, ew0_cut, mask=mask)
lya_lines, lya_cont_lines, line_widths = identify_lines(
    line, pm_flx, pm_err, first=True, return_line_width=True
)
lya_lines = np.array(lya_lines)

# Other lines
cont_est_other, cont_err_other = estimate_continuum(pm_flx, pm_err, IGM_T_correct=False)
line_other = is_there_line(pm_flx, pm_err, cont_est_other, cont_err_other,
    400, obs=True, mask=mask)
other_lines = identify_lines(line_other, pm_flx, pm_err)

In [None]:
mag_min = 17
mag_max = 24

# Used later!!
L_min = 40
L_max = 50

mag_cut = (mag > mag_min) & (mag < mag_max)

z_Arr = np.zeros(N_sources)
z_Arr[np.where(np.array(lya_lines) != -1)] =\
    z_NB(np.array(lya_cont_lines)[np.where(np.array(lya_lines) != -1)])

##
# nb_min = 5
# nb_max = 15
nb_min = 5
nb_max = 20

nbs_to_consider = np.arange(nb_min, nb_max + 1)

nb_cut = (np.array(lya_lines) >= nb_min) & (np.array(lya_lines) <= nb_max)

z_min = (w_central[nb_min] - nb_fwhm_Arr[nb_min] * 0.5)/ w_lya - 1
z_max = (w_central[nb_max] + nb_fwhm_Arr[nb_max] * 0.5)/ w_lya - 1

z_cut = (z_min < z_Arr) & (z_Arr < z_max)

mask = z_cut & mag_cut

nice_lya = nice_lya_select(
    lya_lines, other_lines, pm_flx, pm_err, cont_est_lya, z_Arr, mask=mask
)

In [None]:
len(np.where(nice_lya)[0])

In [None]:
EW_nb_Arr, EW_nb_e, L_Arr, L_e_Arr, flambda, flambda_e = EW_L_NB(
    pm_flx, pm_err, cont_est_lya, cont_err_lya, z_Arr, lya_lines, N_nb=0
)

# ML_predict_mask = (mag < 23) & (L_Arr > 0)
# L_Arr[ML_predict_mask] = ML_predict_L(
#     pm_flx[:, ML_predict_mask], pm_err[:, ML_predict_mask],
#     z_Arr[ML_predict_mask], L_Arr[ML_predict_mask], 'RFmag15-23'
# )

# ML_predict_mask = (mag > 23) & (L_Arr > 0)
# L_Arr[ML_predict_mask] = ML_predict_L(
#     pm_flx[:, ML_predict_mask], pm_err[:, ML_predict_mask],
#     z_Arr[ML_predict_mask], L_Arr[ML_predict_mask], 'RFmag23-23.5'
# )

In [None]:
L_binning = np.load('npy/L_nb_err_binning.npy')
L_Lbin_err = np.load('npy/L_nb_err.npy')
median_L = np.load('npy/L_bias.npy')

# Apply bin err
L_binning_position = binned_statistic(
        10 ** L_Arr, None, 'count', bins=L_binning
).binnumber
L_binning_position[L_binning_position > len(L_binning) - 2] = len(L_binning) - 2
L_e_Arr = L_Lbin_err[L_binning_position]

L_bin_c = [L_binning[i : i + 1].sum() * 0.5 for i in range(len(L_binning) - 1)]

# Correct L_Arr with the median
L_Arr =  np.log10(10 ** L_Arr - np.interp(10 ** L_Arr, L_bin_c, median_L))

In [None]:
fig, ax = plt.subplots(figsize=(7, 6))

ax.hist(starprob[nice_lya], 40)

ax.set_xlabel('p(star)', fontsize=15)
ax.set_ylabel('N', fontsize=15)

plt.show()

In [None]:
nice_z = (np.abs(z_Arr - zsp) < 0.2)

print('{} candidates'.format(count_true(nice_lya)))
print('{} QSO ({} w/ right z)'.format(count_true(spCl[nice_lya] == 'QSO'), count_true((spCl[nice_lya] == 'QSO') & nice_z[nice_lya])))
print('{} GALAXY ({} w/ right z)'.format(count_true(spCl[nice_lya] == 'GALAXY'), count_true((spCl[nice_lya] == 'GALAXY') & nice_z[nice_lya])))
print('{} No SDSS counterpart'.format(count_true(spCl[nice_lya].astype(str) == 'nan')))

fig, ax = plt.subplots(figsize=(7, 3))

ax.scatter(zsp[nice_lya], z_Arr[nice_lya], c='k')

ax.set_xlabel('SDSS z$_\mathrm{spec}$', fontsize=15)
ax.set_ylabel('NB z', fontsize=15)

plt.show()

In [None]:
fig, ax = plt.subplots(figsize=(10, 8))

ax.scatter(photoz, photoz_chi_best)
ax.scatter(photoz[nice_lya], photoz_chi_best[nice_lya])
where = (np.abs(photoz - zsp) < 0.1)
ax.scatter(photoz[where], photoz_chi_best[where])
where = (np.abs(photoz - zsp) > 0.3)
ax.scatter(photoz[where], photoz_chi_best[where])

ax.set_xlim(-0.05, 1.55)
ax.set_ylim(2e0, 1e6)
ax.set_yscale('log')
ax.set_ylabel(r'$\chi^{2}$', fontsize=15)
ax.set_xlabel(r'$z_\mathrm{phot}$', fontsize=15)

plt.show()

In [None]:
fig, ax = plt.subplots(figsize=(7, 6))

ax.scatter(photoz_odds, photoz_chi_best)
where = (np.abs(photoz - zsp) < 0.1)
ax.scatter(photoz_odds[where], photoz_chi_best[where])
where = (np.abs(photoz - zsp) > 0.3)
ax.scatter(photoz_odds[where], photoz_chi_best[where])

ax.set_ylim(1, 1e6)
ax.set_xlim(3e-2, 1.5)
ax.set_yscale('log')
ax.set_xscale('log')

plt.show()

In [None]:
fig, ax = plt.subplots(figsize=(7, 6))

ax.hist(np.log10(photoz_chi_best), np.linspace(-1, 4), log=True, density=True)
ax.hist(np.log10(photoz_chi_best[nice_lya]), np.linspace(0, 4), log=True, density=True, alpha=0.6)
plt.show()

In [None]:
count_true((starlhood < 0.1) & (starprob < 0.1))