In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from my_functions import *

%matplotlib inline

In [None]:
w_central = central_wavelength()
nb_fwhm_Arr = nb_fwhm(range(60))
w_lya = 1215.67
filter_tags = load_filter_tags()

In [None]:
star_area = 100000 / 2190
print(f'Star area = {star_area:0.1f} deg2')

In [None]:
def load_STAR_prior_mock():
    ## Load QSO catalog
    filename = ('/home/alberto/cosmos/JPAS_mocks_sep2021/'
        'JPAS_mocks_classification_01sep_model11/Fluxes/Qso_jpas_mock_flam_train.cat')

    my_filter_order = np.arange(60)
    my_filter_order[[-4, -3, -2, -1]] = np.array([1, 12, 28, 43])
    my_filter_order[1:-4] += 1
    my_filter_order[12:-4] += 1
    my_filter_order[28:-4] += 1
    my_filter_order[43:-4] += 1

    pm_flx = pd.read_csv(
        filename, sep=' ', usecols=range(2, 2 + 60)
    ).to_numpy().T[my_filter_order]
    pm_err = pd.read_csv(
        filename, sep=' ', usecols=range(2 + 60, 2 + 60 + 60)
    ).to_numpy().T[my_filter_order]
    zspec = pd.read_csv(filename, sep=' ', usecols=[127]).to_numpy().reshape(-1, )

    return pm_flx, pm_err

In [None]:
pm_flx, pm_err = load_STAR_prior_mock()
mag = flux_to_mag(pm_flx[-2], w_central[-2])
mag[np.isnan(mag)] = 99.
N_sources = pm_flx.shape[1]
print(pm_flx.shape)

In [None]:
ew0_cut = 30
ew_other = 400

# Lya search
cont_est_lya, cont_err_lya = estimate_continuum(pm_flx, pm_err, IGM_T_correct=True)
line = is_there_line(pm_flx, pm_err, cont_est_lya, cont_err_lya, ew0_cut)
lya_lines, lya_cont_lines, line_widths = identify_lines(
    line, pm_flx, cont_est_lya, first=True, return_line_width=True
)
lya_lines = np.array(lya_lines)

# Other lines
cont_est_other, cont_err_other = estimate_continuum(pm_flx, pm_err, IGM_T_correct=False)
line_other = is_there_line(pm_flx, pm_err, cont_est_other, cont_err_other,
    ew_other, obs=True)
other_lines = identify_lines(line_other, cont_est_other, pm_err)

# Compute z
z_Arr = np.zeros(N_sources)
z_Arr[np.where(np.array(lya_lines) != -1)] =\
    z_NB(np.array(lya_cont_lines)[np.where(np.array(lya_lines) != -1)])

# %xdel cont_est_other
%xdel cont_err_other

mag_min = 17
mag_max = 24

nb_min = 1
nb_max = 25

nbs_to_consider = np.arange(nb_min, nb_max + 1)

nb_cut = (np.array(lya_lines) >= nb_min) & (np.array(lya_lines) <= nb_max)

z_min = (w_central[nb_min] - nb_fwhm_Arr[nb_min] * 0.5) / w_lya - 1
z_max = (w_central[nb_max] + nb_fwhm_Arr[nb_max] * 0.5) / w_lya - 1
print(f'z interval: ({z_min:0.2f}, {z_max:0.2f})')

z_cut = (z_min < z_Arr) & (z_Arr < z_max)
mag_cut = (mag > mag_min) & (mag < mag_max)

snr = np.empty(N_sources)
for src in range(N_sources):
    l = lya_lines[src]
    snr[src] = pm_flx[l, src] / pm_err[l, src]

nice_lya_mask = z_cut & mag_cut & (snr > 6)
nice_lya = nice_lya_select(
    lya_lines, other_lines, pm_flx, pm_err, cont_est_lya, z_Arr, mask=nice_lya_mask
)

In [None]:
count_true(nice_lya) / star_area

In [None]:
selected = np.random.permutation(
    np.where(
        nice_lya
    )[0]
)
print(count_true(selected))

# qso_lines = [1025.7220, 1397.61, 1549.48, 1908.73, 2799.12]
# Actually gal lines
# qso_lines = [4861, 5007, 3727, 6549, 6564, 6585]
# This is the peak of the gal contaminant distribution
qso_lines = [3200]
pm_flx[0, :] = 99999
pm_flx[-4, :] = 99999
pm_err[0, :] = 0
pm_err[-4, :] = 0

for i, src in enumerate(selected):
    if i == 10: break
    print(src)
    lya_obs_w = w_central[lya_lines[src]]
    other_lines_w = [w_central[i] for i in other_lines[src]]

    fig = plt.figure(figsize=(10, 5))
    ax = plot_JPAS_source(pm_flx[:, src], pm_err[:, src], e17scale=True, set_ylim=False)

    ax.axvline(lya_obs_w, linestyle='--', color='r', label='Retrieved Lya line')

    ax.plot(w_central[:56], cont_est_lya[:, src] * 1e17)
    # ax.plot(w_central[:56], cont_est_other[:, src], ls='--')

    # ax.set_title(f'zspec={zspec[src]:0.3f}, z_Arr={z_Arr[src]:0.3f}')
    ax.legend()

    ax.set_ylim((-1, 5))
    ax.set_xlim((3000, 10000))

    plt.show()

In [None]:
fig, ax = plt.subplots(figsize=(5, 4))

ax.hist(mag[nice_lya], np.linspace(16, 25, 20))

plt.show()