In [None]:
import numpy as np

import matplotlib.pyplot as plt
import matplotlib
matplotlib.rcParams.update({'font.size': 12})

from my_functions import *
from load_jpas_catalogs import load_minijpas_jnep

from scipy.stats import binned_statistic

In [None]:
w_central = central_wavelength()
nb_fwhm_Arr = nb_fwhm(range(60))

In [None]:
pm_flx, pm_err, tile_id, pmra_sn, pmdec_sn, parallax_sn, starprob, _,\
        spCl, zsp, _, _, _, N_minijpas, x_im, y_im,\
                ra, dec = load_minijpas_jnep(['minijpas', 'jnep'])
number = load_minijpas_jnep(['minijpas', 'jnep'], selection=True)[5]
mag = flux_to_mag(pm_flx[-2], w_central[-2])
mask_pm = mask_proper_motion(parallax_sn, pmra_sn, pmdec_sn)

In [None]:
N_sources = len(tile_id)
is_minijpas_source = np.ones(N_sources).astype(bool)
is_minijpas_source[N_minijpas:] = False

In [None]:
w_lya = 1215.67 # A

z_nb_Arr = w_central[:-4] / w_lya - 1
print(N_sources)

In [None]:
from minijpas_LF_and_puricomp import nb_or_3fm_cont

ew0_cut = 30
ew_oth = 100
cont_est_m = 'nb'

cont_est_lya, cont_err_lya, cont_est_other, cont_err_other =\
    nb_or_3fm_cont(pm_flx, pm_err, cont_est_m)

# Lya search
line = is_there_line(pm_flx, pm_err, cont_est_lya,
                        cont_err_lya, ew0_cut)
lya_lines, lya_cont_lines, _ = identify_lines(
    line, pm_flx, cont_est_lya, first=True, return_line_width=True
)
lya_lines = np.array(lya_lines)

# Other lines
line_other = is_there_line(pm_flx, pm_err, cont_est_other, cont_err_other,
                            ew_oth, obs=True, sigma=5)
other_lines = identify_lines(line_other, pm_flx, cont_est_other)

In [None]:
mag_min = 17
mag_max = 24

nb_min, nb_max = 1, 20

mag_cut = (mag > mag_min) & (mag < mag_max)

z_Arr = np.zeros(N_sources)
z_Arr[np.where(np.array(lya_lines) != -1)] =\
    z_NB(np.array(lya_cont_lines)[np.where(np.array(lya_lines) != -1)])

snr = np.empty(N_sources)
for src in range(N_sources):
    l = lya_lines[src]
    snr[src] = pm_flx[l, src] / pm_err[l, src]

# Drop bad NB image rows
bad_NB_image = np.array([4380, 30395, 30513, 30977, 40306, 43721, 11771, 2583])
mask_bad_NB = np.ones(N_sources).astype(bool)
for bad_nb_src in bad_NB_image:
    mask_bad_NB[bad_nb_src] = False

mask_snr = (snr > 6)
lya_lines_mask = (lya_lines >= nb_min) & (lya_lines <= nb_max)
mask = (lya_lines_mask & mag_cut & mask_snr & mask_bad_NB & mask_pm)

nice_lya_raw, c_mask, ml_mask = nice_lya_select(
    lya_lines, other_lines, pm_flx, pm_err, cont_est_lya, z_Arr,
    return_color_mask=True
)
nice_lya_raw = lya_lines_mask & nice_lya_raw & mask_bad_NB & mag_cut
nice_lya = nice_lya_raw & mask & c_mask & ml_mask
sum(nice_lya)

In [None]:
sum(nice_lya_raw)

In [None]:
EW_nb_Arr, EW_nb_e, L_Arr, L_e_Arr, flambda, flambda_e = EW_L_NB(
    pm_flx, pm_err, cont_est_lya, cont_err_lya, z_Arr, lya_lines, N_nb=0
)

In [None]:
L_binning = np.load('npy/L_nb_err_binning.npy')
L_Lbin_err = np.load('npy/L_nb_err.npy')
median_L = np.load('npy/L_bias.npy')

# Apply bin err
L_binning_position = binned_statistic(
        10 ** L_Arr, None, 'count', bins=L_binning
).binnumber
L_binning_position[L_binning_position > len(L_binning) - 2] = len(L_binning) - 2
L_e_Arr = L_Lbin_err[L_binning_position]

L_bin_c = [L_binning[i : i + 1].sum() * 0.5 for i in range(len(L_binning) - 1)]

# Correct L_Arr with the median
L_Arr =  np.log10(10 ** L_Arr - np.interp(10 ** L_Arr, L_bin_c, median_L))

In [None]:
def nanomaggie_to_flux(nmagg, wavelength):
    mAB = -2.5 * np.log10(nmagg * 1e-9)
    flx = mag_to_flux(mAB, wavelength)
    return flx

tile_id, number = load_minijpas_jnep(selection=True)[4:6]


In [None]:
from visual_inspection import load_sdss_xmatch
sdss_xm_num, sdss_xm_tid, sdss_xm_spObjID, f_zsp, xm_zsp = load_sdss_xmatch() 

nice_xm = np.zeros_like(xm_zsp).astype(bool)
for i in range(len(sdss_xm_num)):
    try:
        mj_src = np.where((number == sdss_xm_num[i]) & (tile_id == sdss_xm_tid[i]))[0][0]
    except:
        continue
    nice_xm[i] = nice_lya[mj_src]

Lya_fts = pd.read_csv('csv/Lya_fts_DR16_v2.csv')

L_lya =  np.zeros_like(f_zsp)
EW_lya =  np.zeros_like(f_zsp)
EW_lya_err =  np.ones_like(f_zsp) * 99
for i, this_spObjID in enumerate(sdss_xm_spObjID):
    if this_spObjID == 0:
        continue
    # Disgregate SpObjID in mjd, tile, fiber
    spObj_binary = np.binary_repr(this_spObjID)
    plate = int(spObj_binary[::-1][50:64][::-1], 2)
    mjd = int(spObj_binary[::-1][24:38][::-1], 2) + 50000
    fiber = int(spObj_binary[::-1][38:50][::-1], 2)

    wh_in_fts = np.where((plate == Lya_fts['plate'])
                         & (mjd == Lya_fts['mjd'])
                         & (fiber == Lya_fts['fiberid']))[0]
    if len(wh_in_fts) == 0:
        continue

    F_line = np.array(Lya_fts['LyaF'])[wh_in_fts] * 1e-17
    z = Lya_fts['Lya_z'].to_numpy().flatten()[wh_in_fts]
    dL = cosmo.luminosity_distance(z).to(u.cm).value
    L = np.log10(F_line * 4*np.pi * dL ** 2)

    L_lya[i] = L
    EW_lya[i] = Lya_fts['LyaEW'][wh_in_fts]
    EW_lya_err[i] = Lya_fts['LyaEW_err'][wh_in_fts]

EW_lya_err = np.abs(EW_lya_err)

In [None]:
import os.path as op
from astropy.table import Table

# Load HETDEX
path_to_cat = '/home/alberto/almacen/HETDEX_catalogs/hetdex_source_catalog_1'
path_to_agn = '/home/alberto/almacen/HETDEX_catalogs/agn_catalog_v1.0'
version = 'v3.2'

source_table = Table.read(op.join(path_to_cat, f'hetdex_sc1_{version}.ecsv'))
det_table = Table.read(op.join(path_to_cat, 'hetdex_sc1_detinfo_{}.ecsv'.format(version)))
xm_hetdex_id = np.load('npy/hetdex_crossmatch_ids.npy')

fname = f'{path_to_agn}/hetdex_agn.fits'
agn = Table.read(fname, format='fits', hdu=1)

In [None]:
z_hetdex = np.ones(N_sources) * -1
z_hetdex_conf = np.ones(N_sources) * -1
L_lya_hetdex = np.ones(N_sources) * -1
EW_lya_hetdex = np.ones(N_sources) * -9999999999999
EW_lya_hetdex_err = np.ones(N_sources) * 9999
type_hetdex = np.zeros(N_sources).astype(str)
for src in range(N_sources):
    if xm_hetdex_id[src] > 0:
        wh = np.where(xm_hetdex_id[src] == source_table['source_id'])[0][0]
        wh_det = np.where(xm_hetdex_id[src] == det_table['source_id'])[0][0]
        wh_agn = np.where(
            (source_table['RA'][wh] == agn['ra'])
            & (source_table['DEC'][wh] == agn['dec'])
        )[0]

        z_hetdex[src] = source_table['z_hetdex'][wh]
        z_hetdex_conf[src] = source_table['z_hetdex_conf'][wh]
        type_hetdex[src] = source_table['source_type'][wh]
        if len(wh_agn) > 0:
            wh_agn = wh_agn[0]
            F_lya = agn['flux_LyA'][wh_agn] * 1e-17
            dL = cosmo.luminosity_distance(z_hetdex[src]).to(u.cm).value
            L_lya_hetdex[src] = np.log10(F_lya * 4*np.pi * dL ** 2)
            print(L_lya_hetdex[src])
        else:
            L_lya_hetdex[src] = np.log10(source_table['lum_lya'][wh])
            if type_hetdex[src] == 'lae':
                EW_lya_hetdex[src] = det_table['flux'][wh_det] / det_table['continuum'][wh_det]
                EW_lya_hetdex_err[src] = (
                    (det_table['flux_err'][wh_det] / det_table['continuum'][wh_det]) ** 2
                    + (det_table['flux'][wh_det] * det_table['continuum'][wh_det]**-2
                    * det_table['continuum_err'][wh_det]) ** 2
                ) ** 0.5

In [None]:
sum(
    (z_hetdex > 2)
    & (z_hetdex_conf > 0)
    & (type_hetdex == 'agn')
    & nice_lya
    # & (L_lya_hetdex > 43.8)
)

In [None]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.plot(source_table['RA'], source_table['DEC'], marker='.', ls='', markersize=4)
ax.plot(agn['ra'], agn['dec'], marker='.', ls='', markersize=4)

ax.set_ylim(51.67, 53.67)
ax.set_xlim(216.25, 213.12)
# ax.set_xlim(255.37 - 0.25, 255.37 + 0.25)
# ax.set_ylim(65.78 - 0.25, 65.78 + 0.25)

plt.show()

In [None]:
fig, ax = plt.subplots(figsize=(6, 5))

mask = (f_zsp == 0)

ax.errorbar(xm_zsp[nice_xm & mask], EW_lya[nice_xm & mask], ls='', fmt='o',
        markersize=7, c='g', yerr=EW_lya_err[nice_xm & mask])
ax.errorbar(xm_zsp[~nice_xm & mask], EW_lya[~nice_xm & mask], ls='', fmt='o',
        markersize=7, c='r', yerr=EW_lya_err[~nice_xm & mask])

ax.errorbar(z_hetdex[nice_lya], EW_lya_hetdex[nice_lya], ls='', fmt='^',
        markersize=7, c='g', yerr=EW_lya_hetdex_err[nice_lya])
ax.errorbar(z_hetdex[~nice_lya], EW_lya_hetdex[~nice_lya], ls='', fmt='^',
        markersize=7, c='r', yerr=EW_lya_hetdex_err[~nice_lya])

ax.set_ylim(-700, 700)
ax.set_xlim(1.8, 3.7)

plt.show()

In [None]:
from astropy.io import fits
# Load HETDEX spectra
hdu_hetdex_spec = fits.open(op.join(path_to_cat, f'hetdex_sc1_spec_{version}.fits'))

In [None]:
# Directory of the spectra .fits files
# fits_dir = '/home/alberto/almacen/SDSS_spectra_fits/miniJPAS_Xmatch'

# wh = (z_hetdex > 2) & (type_hetdex == 'lae')
# print(sum(wh))
# count = 0
# for src in np.where(wh)[0]:
#     print(f'z_HETDEX = {z_hetdex[src]:0.2f}, z_Arr = {z_Arr[src]:0.2f}, r = {mag[src]:0.2f}')
#     print(f'HETDEX L_lya = {L_lya_hetdex[src]:0.2f}, EW_lya = {EW_lya_hetdex[src]:0.2f}')
#     sdss_src = (number[src] == sdss_xm_num) & (tile_id[src] == sdss_xm_tid)
#     if np.any(sdss_src):
#         this_spObjID = sdss_xm_spObjID.to_numpy()[sdss_src][0]
#         zw = f_zsp[sdss_src].to_numpy().astype(int)[0]
#         if zw > 0 or L_lya[sdss_src] <= 0:
#             g_band = None
#             spec = None
#         else:
#             count += 1
#             # Disgregate SpObjID in mjd, tile, fiber
#             spObj_binary = np.binary_repr(this_spObjID)
#             plate = int(spObj_binary[::-1][50:64][::-1], 2)
#             mjd = int(spObj_binary[::-1][24:38][::-1], 2) + 50000
#             fiber = int(spObj_binary[::-1][38:50][::-1], 2)

#             spec_name = f'spec-{plate:04d}-{mjd:05d}-{fiber:04d}.fits'
#             print(spec_name)
#             print(f'ML = {ml_mask[src]}, Color = {c_mask[src]}, S/N = {mask_snr[src]} ({snr[src]:0.2f})')
#             print(f'L_lya = {L_lya[sdss_src][0]:0.2f}')
#             spec = Table.read(f'{fits_dir}/{spec_name}', hdu=1, format='fits')
#             g_band = Table.read(f'{fits_dir}/{spec_name}', hdu=2, format='fits')['SPECTROFLUX']
#             g_band = nanomaggie_to_flux(np.array(g_band)[0][1], 4750)
#     else:
#         g_band = None
#         spec = None

#     fig = plt.figure(figsize=(10, 3))

#     ax = plot_JPAS_source(pm_flx[:, src], pm_err[:, src], e17scale=True)

#     if g_band is not None and spec is not None:
#         # Normalizing factor:
#         norm = pm_flx[-3, src] / g_band
#         spec_flx = spec['FLUX'] * norm
#         spec_w = 10 ** spec['LOGLAM']

#         ax.plot(spec_w, spec_flx, c='dimgray', zorder=-99, alpha=0.7)

#     wh_hetdex = np.where(xm_hetdex_id[src] == source_table['source_id'])[0][0]
#     spec_hetdex = hdu_hetdex_spec['SPEC'].data[wh_hetdex]
#     spec_w_hetdex = hdu_hetdex_spec['WAVELENGTH'].data
#     g_band_hetdex = mag_to_flux(source_table['gmag'][wh_hetdex], w_central[-3])
#     norm = pm_flx[-3, src] / g_band_hetdex
#     spec_hetdex = spec_hetdex * norm
#     ax.plot(spec_w_hetdex, spec_hetdex, c='orange', zorder=-99, alpha=1.0)
#     ax.axvline(1215.67 * (1 + z_hetdex[src]), ls='--', c='r', zorder=-100)

#     # ax.set_xlim(3470, 5540)
#     ax.set_ylim(spec_hetdex.min() - 0.1, spec_hetdex.max() + 0.5)
    
#     plt.show()
# print(count)