In [None]:
import numpy as np
import matplotlib.pyplot as plt
from my_functions import *
import seaborn as sns

import astropy.units as u
import astropy.constants as c

from LumFunc_miniJPAS import LumFunc_hist

In [None]:
w_central = central_wavelength()
nb_fwhm_Arr = nb_fwhm(range(60))

In [None]:
## Load everything
filename = '/home/alberto/cosmos/photo-z/data/mock_cat_4.fits'

mock = Table.read(filename, format='fits').to_pandas()

# Units of the mock fluxes are mJy
convert_factor = (
    1. * u.erg * u.cm**-2 * u.s**-1 * u.AA**-1
    * c.c**-1 * w_central**2 * u.AA**2
    ).to(u.mJy) ** -1
convert_factor = convert_factor.value.reshape(-1, 1)

pm_flx = np.roll(mock.to_numpy()[:, 2 : 60 + 2].T, 1)
pm_err = np.roll(mock.to_numpy()[:, 60 + 2 : 120 + 2].T, 1)

swap_idx = np.array([*range(56)] + [59, 58, 57, 56])
pm_flx = pm_flx[swap_idx]
pm_err = pm_err[swap_idx]

# pm_flx += np.random.normal(size=(pm_err.shape)) * pm_err

pm_flx *= convert_factor
pm_err *= convert_factor

zspec = mock['redshift']

mag = flux_to_mag(pm_flx, w_central.reshape(-1, 1))
mag[np.isnan(mag)] = 99.

In [None]:
w_lya = 1215.67 # A
N_sources = pm_flx.shape[1]
N_sources

In [None]:
fig, ax = plt.subplots(figsize=(8, 7))
ax.hist(mag[-2], bins=np.linspace(20, 28, 50))
ax.set_xlabel('rSDSS')
plt.show()

# plt.hist(zspec, bins=np.linspace(1, 8, 50))
# plt.xlabel('z spec')
# plt.show()

In [None]:
# Lya search
cont_est_lya, cont_err_lya = estimate_continuum(pm_flx, pm_err, IGM_T_correct=False, N_nb=7)

# Other lines
cont_est_other, cont_err_other = estimate_continuum(pm_flx, pm_err, IGM_T_correct=False, N_nb=7)

In [None]:
ew0lya_min = 0
ew0lya_max = 70
ew0lya_step = 8
ew0oth_min = 0
ew0oth_max = 400
ew0oth_step = 11

other_select_list = []
for ew0min in np.linspace(ew0oth_min, ew0oth_max, ew0oth_step):
    print(ew0min)
    line_other = is_there_line(pm_flx, pm_err, cont_est_other, cont_err_other, ew0min,
        obs=True)
    other_lines = identify_lines(line_other, pm_flx, pm_err, first=False)

    other_select_list.append(other_lines)

lya_select_list = []
lya_z_nb = []
for ew0min in np.linspace(ew0lya_min, ew0lya_max, ew0lya_step):
    print(ew0min)
    line = is_there_line(pm_flx, pm_err, cont_est_lya, cont_err_lya, ew0min)
    lya_lines, lya_cont_lines = identify_lines(line, pm_flx, pm_err, first=True)
    z_nb_Arr = np.ones(N_sources) * -1 
    z_nb_Arr[np.where(np.array(lya_lines) != -1)] =\
        z_NB(np.array(lya_cont_lines)[np.where(np.array(lya_lines) != -1)])
    
    lya_select_list.append(lya_lines)
    lya_z_nb.append(z_nb_Arr)

In [None]:
mag_min = 17
mag_max = 24

mag_cut = (mag[-2] > mag_min) & (mag[-2] < mag_max)

z_Arr = np.zeros(N_sources)
z_Arr[np.where(np.array(lya_lines) != -1)] =\
    z_NB(np.array(lya_cont_lines)[np.where(np.array(lya_lines) != -1)])

##
nb_min = 5
nb_max = 20

nbs_to_consider = np.arange(nb_min, nb_max + 1)

nb_cut = (np.array(lya_lines) >= nb_min) & (np.array(lya_lines) <= nb_max)

z_min = (w_central[nb_min] - nb_fwhm_Arr[nb_min] * 0.5)/ w_lya - 1
z_max = (w_central[nb_max] + nb_fwhm_Arr[nb_max] * 0.5)/ w_lya - 1

z_cut = (z_min < z_Arr) & (z_Arr < z_max)

mask = z_cut & mag_cut


In [None]:
select_grid = np.zeros((ew0lya_step, ew0oth_step))
rightz_grid = np.zeros((ew0lya_step, ew0oth_step))

for i in range(ew0lya_step):
    print(i)
    for j in range(ew0oth_step):
        line_len_Arr = np.array([len(l) for l in other_select_list[j]])
        nice_lya = nice_lya_select(
            lya_select_list[i], other_select_list[j], pm_flx, pm_err, cont_est_lya,
            lya_z_nb[i], mask=mask
        )
        nice_z = np.abs(lya_z_nb[i] - zspec) < 0.12
        selected = nice_lya

        select_grid[i, j] = len(np.where(selected)[0])
        rightz_grid[i, j] = len(np.where(selected & nice_z)[0])

N_target = 1
completeness = rightz_grid / N_target
purity = np.zeros(completeness.shape)

In [None]:
np.random.seed(26)
for src in np.random.choice(np.where((mag[-2] < 24))[0], 5):
    fig = plt.figure(figsize=(10, 6))
    ax = plt.gca()
    ax = plot_JPAS_source(pm_flx[:, src], pm_err[:, src])
    # ax.plot(mock['w_Arr'], mock['SEDs'][src])
    plt.show()
    print('rSDSS = {0:0.2f}'.format(mag[-2, src]))

In [None]:
N_sources

In [None]:
fig = plt.figure(figsize=(8, 8))

width = 0.5
height = 0.5
spacing = 0.06
cbar_width = 0.05

# Define axes
ax00 = fig.add_axes([0, height + 1.5 * spacing, width, height])
ax01 = fig.add_axes([width + spacing, height + 1.5 * spacing, width, height], sharey=ax00)
ax10 = fig.add_axes([0, 0, width, height], sharex=ax00)
ax11 = fig.add_axes([width + spacing, 0, width, height], sharex=ax01, sharey=ax10)
axcbar0 = fig.add_axes([2 * width + 1.5 * spacing, height + 1.5 * spacing, cbar_width, height])
axcbar1 = fig.add_axes([2 * width + 1.5 * spacing, 0, cbar_width, height])

# Plot stuff in the rectangles
vmax = np.max([np.max(rightz_grid), np.max(select_grid)])
cmap = 'Spectral'

sns.heatmap(rightz_grid, ax=ax00, vmin=0, vmax=vmax, cbar_ax=axcbar0, cmap=cmap)
sns.heatmap(select_grid, ax=ax01, vmin=0, vmax=vmax, cbar_ax=axcbar0, cmap=cmap)

sns.heatmap(purity, ax=ax10, vmin=0, vmax=1, cbar_ax=axcbar1, cmap=cmap)
sns.heatmap(completeness, ax=ax11, vmin=0, vmax=1, cbar=False, cmap=cmap)

ax00.invert_yaxis()
ax10.invert_yaxis()

# Axes ticks
xticks = range(ew0oth_step)
yticks = range(ew0lya_step)
xtick_labels = ['{0:0.0f}'.format(n) for n in np.linspace(ew0oth_min, ew0oth_max, ew0oth_step)]
ytick_labels = ['{0:0.0f}'.format(n) for n in np.linspace(ew0lya_min, ew0lya_max, ew0lya_step)]

ax00.set_yticks(yticks)
ax00.set_yticklabels(ytick_labels)
ax00.set_xticks(xticks)
ax00.set_xticklabels(xtick_labels, rotation='vertical')

ax10.set_yticks(yticks)
ax10.set_yticklabels(ytick_labels)
ax10.set_xticks(xticks)
ax10.set_xticklabels(xtick_labels, rotation='vertical')

ax11.set_xticks(xticks)
ax11.set_xticklabels(xtick_labels, rotation='vertical')
ax11.set_yticks(yticks)
ax11.set_yticklabels(ytick_labels)

ax01.set_xticks(xticks)
ax01.set_xticklabels(xtick_labels, rotation='vertical')
ax01.set_yticks(yticks)
ax01.set_yticklabels(ytick_labels)

# Axes labels
ylabel = r'Ly$\alpha$ EW$_0$ ($\AA$)'
xlabel = 'Other lines EW$_\mathrm{obs}$ ($\AA$)'
ax00.set_ylabel(ylabel, fontsize=12)
ax10.set_ylabel(ylabel, fontsize=12)
ax10.set_xlabel(xlabel, fontsize=12)
ax11.set_xlabel(xlabel, fontsize=12)

# Set titles
ax00.set_title('Selected w/ correct z', fontsize=15)
ax01.set_title('All selected', fontsize=15)
ax10.set_title('Purity', fontsize=15)
ax11.set_title('Completeness', fontsize=15)

# plt.savefig('output/puri-comp_magcut-' + str(mag_cut) + '.pdf', dpi=600,
#     bbox_inches='tight')
# plt.show()

print('N_target = {}'.format(N_target))

In [None]:
i = 2
j = 10

mag_min = 17
mag_max = 24

mag_cut = (mag[-2] > mag_min) & (mag[-2] < mag_max)

z_Arr = np.zeros(N_sources)
z_Arr[np.where(np.array(lya_lines) != -1)] =\
    z_NB(np.array(lya_cont_lines)[np.where(np.array(lya_lines) != -1)])

##
nb_min = 5
nb_max = 20

nbs_to_consider = np.arange(nb_min, nb_max + 1)

nb_cut = (np.array(lya_lines) >= nb_min) & (np.array(lya_lines) <= nb_max)

z_min = (w_central[nb_min] - nb_fwhm_Arr[nb_min] * 0.5)/ w_lya - 1
z_max = (w_central[nb_max] + nb_fwhm_Arr[nb_max] * 0.5)/ w_lya - 1

z_cut = (z_min < z_Arr) & (z_Arr < z_max)

mask = z_cut & mag_cut

nice_lya = nice_lya_select(
    lya_select_list[i], other_select_list[j], pm_flx, pm_err, cont_est_lya, z_Arr,
    mask=mask
)
nice_z = np.abs(lya_z_nb[i] - zspec) < 0.12

selected = nice_lya
print(len(np.where(selected)[0]))