In [None]:
from my_functions import *
from load_mocks import ensemble_mock
from minijpas_LF_and_puricomp import purity_or_completeness_plot

import numpy as np
import matplotlib.pyplot as plt

In [None]:
w_central = central_wavelength()
nb_fwhm_Arr = nb_fwhm(range(60))
w_lya = 1215.67
filter_tags = load_filter_tags()
gal_factor = 12.57

In [None]:
qso_name = 'QSO_double_train_minijpas_0'
sf_name = 'LAE_12.5deg_z2-4.25_train_minijpas_0'
gal_name = 'GAL_100000_minijpas_0'

pm_flx, pm_err, zspec, EW_lya, L_lya, is_qso, is_sf, is_gal, _, _ =\
    ensemble_mock(qso_name, gal_name, sf_name)

mag = flux_to_mag(pm_flx[-2], w_central[-2])
mag[np.isnan(mag)] = 99.

N_sources = pm_flx.shape[1]
N_sources

In [None]:
cont_est_lya, cont_err_lya = estimate_continuum(pm_flx, pm_err, IGM_T_correct=True)
cont_est_other, cont_err_other = estimate_continuum(pm_flx, pm_err, IGM_T_correct=False)

In [None]:
ew0_lya_Arr = np.linspace(0, 100, 11)
ew_oth_Arr = np.linspace(0, 500, 11)
x, y = np.meshgrid(ew0_lya_Arr, ew_oth_Arr)

In [None]:
mag_min = 17
mag_max = 24

nb_min = 5
nb_max = 15
# nb_min = 16
# nb_max = 23

# Used later!!
L_min = 40
L_max = 50

z_min = (w_central[nb_min] - nb_fwhm_Arr[nb_min] * 0.5) / w_lya - 1
z_max = (w_central[nb_max] + nb_fwhm_Arr[nb_max] * 0.5) / w_lya - 1
print(f'z interval: ({z_min:0.2f}, {z_max:0.2f})')

In [None]:
puri = np.zeros((len(ew0_lya_Arr) * len(ew_oth_Arr)))
comp = np.copy(puri)

def puricomp(ew0_lya, ew_oth):
    line = is_there_line(pm_flx, pm_err, cont_est_lya, cont_err_lya, ew0_lya)
    lya_lines, lya_cont_lines, _ = identify_lines(
        line, pm_flx, cont_est_lya, first=True, return_line_width=True
    )
    lya_lines = np.array(lya_lines)

    line_other = is_there_line(pm_flx, pm_err, cont_est_other, cont_err_other,
        ew_oth, obs=True)
    other_lines = identify_lines(line_other, cont_est_other, pm_err)

    # Compute z
    z_Arr = np.zeros(N_sources)
    z_Arr[np.where(np.array(lya_lines) != -1)] =\
        z_NB(np.array(lya_cont_lines)[np.where(np.array(lya_lines) != -1)])

    nice_z = np.abs(z_Arr - zspec) < 0.16

    nice_lya = nice_lya_select(
        lya_lines, other_lines, pm_flx, pm_err, cont_est_lya, z_Arr
    )

    good = count_true(
        nice_lya
        & nice_z
        & (zspec >= z_min) & (zspec <= z_max)
        & (mag >= mag_min) & (mag <= mag_max)
    )
    bad = count_true(
        nice_lya
        & ~nice_z
        & (mag >= mag_min) & (mag <= mag_max)
        & (z_Arr >= z_min) & (z_Arr <= z_max)
        & (is_qso | is_sf)
    )
    bad_gal = count_true(
        nice_lya
        & ~nice_z
        & (mag >= mag_min) & (mag <= mag_max)
        & (z_Arr >= z_min) & (z_Arr <= z_max)
        & is_gal
    ) * gal_factor
    all = count_true(
        (EW_lya > ew0_lya)
        & (zspec >= z_min) & (zspec <= z_max)
        & (mag >= mag_min) & (mag <= mag_max)
    )

    comp = good / all
    puri = good / (good + bad + bad_gal)

    return puri, comp

puricomp_func = np.vectorize(puricomp)

In [None]:
M = puricomp_func(x, y)
M