In [None]:
import numpy as np
import matplotlib.pyplot as plt
from astropy.table import Table
from my_utilities import *
import csv

from scipy.stats import binned_statistic

In [None]:
dirname = '/home/alberto/almacen/Stack_catalogs'
cat_name = f'cat_Unique_ALL_FIELDS_Pband_gri.Table.fits'
    
cat = Table.read(f'{dirname}/{cat_name}')

In [None]:
pm_flx = cat['FLUX_APER_3_0'].T * 1e-19
pm_err = cat['FLUX_ERR_APER_3_0'].T * 1e-19
mag = cat['MAG_APER_3_0'].T[-2]

In [None]:
def central_wavelength():
    data_tab = Table.read('../LAEs/fits/FILTERs_table.fits', format='fits')
    w_central = data_tab['wavelength']

    return np.array(w_central)


def nb_fwhm(nb_ind, give_fwhm=True):
    '''
    Returns the FWHM of a filter in tcurves if give_fwhm is True. If it is False, the
    function returns a tuple with (w_central - fwhm/2, w_central + fwhm/2)
    '''
    data_tab = Table.read('../LAEs/fits/FILTERs_table.fits', format='fits')
    w_central = data_tab['wavelength'][nb_ind]
    fwhm = data_tab['width'][nb_ind]

    if give_fwhm == False:
        return w_central + fwhm / 2, w_central - fwhm / 2
    if give_fwhm == True:
        return fwhm

def load_filter_tags():
    filepath = '../LAEs/JPAS_Transmission_Curves_20170316/minijpas.Filter.csv'
    filters_tags = []

    with open(filepath, mode='r') as csvfile:
        rdlns = csv.reader(csvfile, delimiter=',')

        next(rdlns, None)
        next(rdlns, None)

        for line in rdlns:
            filters_tags.append(line[1])

    filters_tags[0] = 'J0348'

    return filters_tags


w_central = central_wavelength()
nb_fwhm_Arr = nb_fwhm(range(60))
w_lya = 1215.67
filter_tags = load_filter_tags()

def estimate_continuum(NB_flx, NB_err, N_nb=7, IGM_T_correct=True,
                       only_right=False, N_nb_min=0, N_nb_max=47):
    '''
    Returns a matrix with the continuum estimate at any NB in all sources.
    '''
    NB_flx = NB_flx[:56]
    NB_err = NB_err[:56]

    cont_est = np.ones(NB_flx.shape) * 99.
    cont_err = np.ones(NB_flx.shape) * 99.
    w_central = central_wavelength()

    for nb_idx in range(1, NB_flx.shape[0]):
        # Limits on where to make the estimation
        if nb_idx < N_nb_min:
            continue
        if nb_idx > N_nb_max:
            break

        if (nb_idx < N_nb) or only_right:
            if IGM_T_correct:
                IGM_T = IGM_TRANSMISSION(
                    np.array(w_central[: nb_idx - 1])
                ).reshape(-1, 1)
            else:
                IGM_T = 1.

            # Stack filters at both sides or only at the right of the central one
            if not only_right:
                NBs_to_avg = np.vstack((
                    NB_flx[: nb_idx - 1] / IGM_T,
                    NB_flx[nb_idx + 2: nb_idx + N_nb + 1]
                ))
                NBs_errs = np.vstack((
                    NB_err[: nb_idx - 1] / IGM_T,
                    NB_err[nb_idx + 2: nb_idx + N_nb + 1]
                ))
            if only_right:
                NBs_to_avg = NB_flx[nb_idx + 2: nb_idx + N_nb + 1]
                NBs_errs = NB_err[nb_idx + 2: nb_idx + N_nb + 1]

        if (N_nb <= nb_idx < (NB_flx.shape[0] - 6)) and not only_right:
            if IGM_T_correct:
                IGM_T = IGM_TRANSMISSION(
                    np.array(w_central[nb_idx - N_nb: nb_idx - 1])
                ).reshape(-1, 1)
            else:
                IGM_T = 1.
            NBs_to_avg = np.vstack((
                NB_flx[nb_idx - N_nb: nb_idx - 1] / IGM_T,
                NB_flx[nb_idx + 2: nb_idx + N_nb + 1]
            ))
            NBs_errs = np.vstack((
                NB_err[nb_idx - N_nb: nb_idx - 1] / IGM_T,
                NB_err[nb_idx + 2: nb_idx + N_nb + 1]
            ))

        if nb_idx >= (NB_flx.shape[0] - 6):
            if IGM_T_correct:
                IGM_T = IGM_TRANSMISSION(
                    np.array(w_central[nb_idx - N_nb: nb_idx - 1])
                ).reshape(-1, 1)
            else:
                IGM_T = 1.
            NBs_to_avg = np.vstack((
                NB_flx[nb_idx - N_nb: nb_idx - 1] / IGM_T,
                NB_flx[nb_idx + 2:]
            ))
            NBs_errs = np.vstack((
                NB_err[nb_idx - N_nb: nb_idx - 1] / IGM_T,
                NB_err[nb_idx + 2:]
            ))

        # Weights for the average
        w = NBs_errs ** -2

        cont_est[nb_idx] = np.average(NBs_to_avg, weights=w, axis=0)
        cont_err[nb_idx] = np.sum(NBs_errs ** -2, axis=0) ** -0.5

    return cont_est, cont_err

def nb_or_3fm_cont(pm_flx, pm_err):
    est_lya, err_lya = estimate_continuum(
        pm_flx, pm_err, IGM_T_correct=True)
    est_oth, err_oth = estimate_continuum(
        pm_flx, pm_err, IGM_T_correct=False)
    return est_lya, err_lya, est_oth, err_oth

def is_there_line(pm_flx, pm_err, cont_est, cont_err, ew0min,
                  mask=True, obs=False, sigma=3):
    w_central = central_wavelength()[:-4]
    fwhm_Arr = nb_fwhm(range(56)).reshape(-1, 1)

    if not obs:
        z_nb_Arr = (w_central / 1215.67 - 1).reshape(-1, 1)
        ew_Arr = ew0min * (1 + z_nb_Arr)
    if obs:
        ew_Arr = ew0min

    line = (
        # 3-sigma flux excess
        (
            pm_flx[:-4] - cont_est > sigma * \
            (pm_err[:-4]**2 + cont_err**2) ** 0.5
        )
        # EW0 min threshold
        & (
            pm_flx[:-4] / cont_est > 1 + ew_Arr / fwhm_Arr
        )
        & (
            pm_flx[:-4] > cont_est
        )
        # Masks
        & (
            mask
        )
        # Check that cont_est is ok
        & (
            np.isfinite(cont_est)
        )
    )
    return line

def nb_fwhm(nb_ind, give_fwhm=True):
    '''
    Returns the FWHM of a filter in tcurves if give_fwhm is True. If it is False, the
    function returns a tuple with (w_central - fwhm/2, w_central + fwhm/2)
    '''
    data_tab = Table.read('../LAEs/fits/FILTERs_table.fits', format='fits')
    w_central = data_tab['wavelength'][nb_ind]
    fwhm = data_tab['width'][nb_ind]

    if give_fwhm == False:
        return w_central + fwhm / 2, w_central - fwhm / 2
    if give_fwhm == True:
        return fwhm

def identify_lines(line_Arr, qso_flx, cont_flx, nb_min=0, first=False,
                   return_line_width=False):
    '''
    Returns a list of N lists with the index positions of the lines.

    Input: 
    line_Arr: Bool array of 3sigma detections in sources. Dim N_filters x N_sources
    qso_flx:  Flambda data
    nb_min
    '''
    N_fil, N_src = line_Arr.shape
    line_list = []
    line_len_list = []
    line_cont_list = []

    for src in range(N_src):
        fil = 0
        this_src_lines = []  # The list of lines
        this_cont_lines = []  # The list of continuum indices of lines
        this_src_line_flx = []  # The list of lengths of this src lines

        while fil < N_fil:
            this_line = []  # The list of contiguous indices of this line
            while ~line_Arr[fil, src]:
                fil += 1
                if fil == N_fil - 1:
                    break
            if fil == N_fil - 1:
                break
            while line_Arr[fil, src]:
                this_line.append(fil)
                fil += 1
                if fil == N_fil - 1:
                    break
            if fil == N_fil - 1:
                break

            aux = -len(this_line) + nb_min + fil

            if first:  # If first=True, append continuum index to list
                this_cont_lines.append(
                    np.average(
                        np.array(this_line),
                        weights=qso_flx[np.array(this_line), src] ** 2
                    )
                )
            # Append index of the max flux of this line to the list
            this_src_lines.append(
                np.argmax(qso_flx[np.array(this_line) + nb_min, src]) + aux
            )
            this_src_line_flx.append(
                qso_flx[np.array(this_line) + nb_min, src].sum())

        if first:  # If first=True,
            try:
                # idx = np.argmax(
                #     qso_flx[np.array(this_src_lines), src]
                #     - cont_flx[np.array(this_src_lines), src]
                # )
                idx = np.argmax(
                    np.array(this_src_line_flx)
                    - cont_flx[np.array(this_src_lines), src]
                )

                line_list.append(this_src_lines[idx])
                line_len_list.append(this_src_lines)
                line_cont_list.append(this_cont_lines[idx])
            except:
                line_list.append(-1)
                line_len_list.append([-1])
                line_cont_list.append(-1)

        if not first:
            line_list.append(this_src_lines)

    if first:
        if return_line_width:
            return line_list, line_cont_list, line_len_list
        else:
            return line_list, line_cont_list
    return line_list

def z_NB(cont_line_pos):
    '''
    Computes the Lya z for a continuum NB index.
    '''
    w_central = central_wavelength()

    # Convert to numpy arr
    cont_line_pos = np.atleast_1d(cont_line_pos)

    w1 = w_central[cont_line_pos.astype(int)]
    w2 = w_central[cont_line_pos.astype(int) + 1]

    w = (w2 - w1) * cont_line_pos % 1 + w1

    return w / 1215.67 - 1

def nice_lya_select(lya_lines, other_lines, pm_flx, pm_err, cont_est, z_Arr, mask=None):
    N_sources = len(lya_lines)
    w_central = central_wavelength()
    fwhm_Arr = nb_fwhm(range(56))
    nice_lya = np.zeros(N_sources).astype(bool)

    # Line rest-frame wavelengths (Angstroms)
    w_lyb = 1025.7220
    w_lya = 1215.67
    w_SiIV = 1397.61
    w_CIV = 1549.48
    w_CIII = 1908.73
    w_MgII = 2799.12

    i = flux_to_mag(pm_flx[-1], w_central[-1])
    r = flux_to_mag(pm_flx[-2], w_central[-2])
    g = flux_to_mag(pm_flx[-3], w_central[-3])
    gr = g - r
    ri = r - i
    # For z > 3
    # color_aux1 = (-1.5 * ri + 1.7 > gr)
    color_aux1 = (ri < 0.6) & (gr < 1.5)
    # For z < 3
    # color_aux2 = (-1.5 * ri + 2.5 > gr) & (ri < 1.)
    color_aux2 = (ri < 0.6) & (gr < 0.7)

    # color_aux1 = np.ones(g.shape).astype(bool)
    # color_aux2 = np.ones(g.shape).astype(bool)

    for src in np.where(np.array(lya_lines) != -1)[0]:
        # l_lya = lya_lines[src]
        z_src = z_Arr[src]

        w_obs_lya = (1 + z_src) * w_lya
        w_obs_lyb = (1 + z_src) * w_lyb
        w_obs_SiIV = (1 + z_src) * w_SiIV
        w_obs_CIV = (1 + z_src) * w_CIV
        w_obs_CIII = (1 + z_src) * w_CIII
        w_obs_MgII = (1 + z_src) * w_MgII

        this_nice = True

        # Check the Lyman limit
        w_central_0 = w_central / (1 + z_src)
        w_central_0[0] = 99999
        w_central_0[-4] = 99999
        Lybreak_flx_Arr = pm_flx[w_central_0 < 912, src]
        Lybreak_err_Arr = pm_err[w_central_0 < 912, src]
        if len(Lybreak_flx_Arr) != 0:
            Lybreak_flx = np.average(
                Lybreak_flx_Arr, weights=Lybreak_err_Arr ** -2)
            Lybreak_err = np.sum(Lybreak_err_Arr ** -2) ** -0.5

            if Lybreak_flx - pm_flx[-3, src] > 3 * Lybreak_err:
                this_nice = False

        for l in other_lines[src]:
            # Ignore very red and very blue NBs
            if (l > 50) | (l < 1):
                continue

            w_obs_l = w_central[l]
            fwhm = fwhm_Arr[l]

            good_l = (
                (np.abs(w_obs_l - w_obs_lya) < fwhm * 1.5)
                | (np.abs(w_obs_l - w_obs_lyb) < fwhm * 1.5)
                | (np.abs(w_obs_l - w_obs_SiIV) < fwhm * 1.5)
                | (np.abs(w_obs_l - w_obs_CIV) < fwhm * 1.5)
                | (np.abs(w_obs_l - w_obs_CIII) < fwhm * 1.5)
                | (np.abs(w_obs_l - w_obs_MgII) < fwhm * 1.5)
                | (w_obs_l > w_obs_MgII + fwhm)
            )

            if ~good_l:
                this_nice = False
                break

        if not this_nice:
            continue
        elif len(other_lines[src]) > 1:
            pass
        else:
            if z_src > 3.:
                good_colors = color_aux2[src]
            else:
                good_colors = color_aux1[src]
            if ~good_colors:
                this_nice = False

        if this_nice:
            nice_lya[src] = True

    if mask is None:
        return nice_lya
    else:
        return nice_lya & mask

def EW_L_NB(pm_flx, pm_err, cont_flx, cont_err, z_Arr, lya_lines, F_bias=None,
            nice_lya=None, N_nb=0):
    '''
    Returns the EW0 and the luminosity from a NB selection given by lya_lines
    '''

    w_central = central_wavelength()

    N_sources = pm_flx.shape[1]
    nb_fwhm_Arr = np.array(nb_fwhm(range(56)))

    if nice_lya is None:
        nice_lya = np.ones(N_sources).astype(bool)

    EW_nb_Arr = np.zeros(N_sources)
    EW_nb_e = np.zeros(N_sources)
    L_Arr = np.zeros(N_sources)
    L_e_Arr = np.zeros(N_sources)
    cont = np.zeros(N_sources)
    cont_e = np.zeros(N_sources)
    flambda = np.zeros(N_sources)
    flambda_e = np.zeros(N_sources)

    for src in np.where(nice_lya)[0]:
        l = lya_lines[src]
        if l == -1:
            continue

        cont[src] = cont_flx[l, src]
        cont_e[src] = cont_err[l, src]

        # Let's integrate the NB flux over the transmission curves to obtain Flambda
        l_start = np.max([0, l - N_nb])

        lw = np.arange(l_start, l + N_nb + 1)

        IGM_T_Arr = np.ones(len(lw))
        IGM_T_Arr[: l -
                  l_start] = IGM_TRANSMISSION(w_central[lw[: l - l_start]])
        IGM_T_Arr[l -
                  l_start] = (IGM_TRANSMISSION(w_central[lw[l - l_start]]) + 1) * 0.5

        pm_flx[l_start: l + N_nb + 1, src] /= IGM_T_Arr
        pm_flx[l_start: l + N_nb + 1, src][pm_flx[l_start: l +
                                                  N_nb + 1, src] < cont[src]] = cont[src]

        intersec = 0.
        for i in range(lw[0], lw[-1]):
            intersec_dlambda = (
                (nb_fwhm_Arr[i] + nb_fwhm_Arr[i + 1]) * 0.5
                - (w_central[i + 1] - w_central[i])
            )
            intersec += np.min(
                [(pm_flx[i, src]) * intersec_dlambda,
                 (pm_flx[i + 1, src]) * intersec_dlambda]
            )

        flambda_cont = cont[src] * (
            w_central[l + N_nb] + nb_fwhm_Arr[l + N_nb] * 0.5
            - (w_central[l_start] - nb_fwhm_Arr[l_start] * 0.5)
        )

        flambda[src] = np.sum(
            (pm_flx[lw[0]: lw[-1] + 1, src]) * nb_fwhm_Arr[lw[0]: lw[-1] + 1]
        ) - intersec - flambda_cont
        flambda_e[src] = (
            np.sum(
                (pm_err[lw[0]: lw[-1] + 1, src] *
                 nb_fwhm_Arr[lw[0]: lw[-1] + 1]) ** 2
            )
            + (flambda_cont / cont[src] * cont_e[src]) ** 2
        ) ** 0.5

    if F_bias is not None:
        flambda /= F_bias[np.array(lya_lines)]

    EW_nb_Arr = flambda / cont / (1 + z_Arr)
    EW_nb_e = flambda_e / cont / (1 + z_Arr)

    def LumDist(z): return cosmo.luminosity_distance(z).to(u.cm).value
    def Redshift(w): return w / 1215.67 - 1
    dL = LumDist(z_Arr)
    dL_e = (
        LumDist(
            Redshift(
                w_central[lya_lines] + 0.5 * nb_fwhm_Arr[lya_lines]
            )
        )
        - LumDist(
            Redshift(
                w_central[lya_lines]
            )
        )
    )

    L_Arr = np.log10(flambda * 4*np.pi * dL ** 2)
    L_e_Arr = (
        (dL ** 2 * flambda_e) ** 2
        + (2*dL * dL_e * flambda) ** 2
    ) ** 0.5 * 4*np.pi

    return EW_nb_Arr, EW_nb_e, L_Arr, L_e_Arr, flambda, flambda_e

def plot_JPAS_source(flx, err, set_ylim=True, e17scale=False, fs=15):
    '''
    Generates a plot with the JPAS data.
    '''

    if e17scale:
        flx = flx * 1e17
        err = err * 1e17

    data_tab = Table.read('../LAEs/fits/FILTERs_table.fits', format='fits')
    cmap = data_tab['color_representation']
    w_central = data_tab['wavelength']
    # fwhm_Arr = data_tab['width']

    data_max = np.max(flx)
    data_min = np.min(flx)
    y_max = (data_max - data_min) * 2/3 + data_max
    y_min = data_min - (data_max - data_min) * 0.3

    ax = plt.gca()
    for i, w in enumerate(w_central[:-4]):
        ax.errorbar(w, flx[i], yerr=err[i],
                    marker='o', markeredgecolor='dimgray', markerfacecolor=cmap[i],
                    markersize=8, ecolor='dimgray', capsize=4, capthick=1, linestyle='',
                    zorder=-99)
    ax.errorbar(w_central[-4], flx[-4], yerr=err[-4], markeredgecolor='dimgray',
                fmt='s', markerfacecolor=cmap[-4], markersize=10,
                ecolor='dimgray', capsize=4, capthick=1)
    ax.errorbar(w_central[-3], flx[-3], yerr=err[-3], markeredgecolor='dimgray',
                fmt='s', markerfacecolor=cmap[-3], markersize=10,
                ecolor='dimgray', capsize=4, capthick=1)
    ax.errorbar(w_central[-2], flx[-2], yerr=err[-2], markeredgecolor='dimgray',
                fmt='s', markerfacecolor=cmap[-2], markersize=10,
                ecolor='dimgray', capsize=4, capthick=1)
    ax.errorbar(w_central[-1], flx[-1], yerr=err[-1], markeredgecolor='dimgray',
                fmt='s', markerfacecolor=cmap[-1], markersize=10,
                ecolor='dimgray', capsize=4, capthick=1)

    try:
        if set_ylim:
            ax.set_ylim((y_min, y_max))
    except:
        pass

    ax.set_xlabel('$\lambda\ (\AA)$', size=fs)
    if e17scale:
        ax.set_ylabel(
            r'$f_\lambda\cdot10^{17}$ (erg cm$^{-2}$ s$^{-1}$ $\AA^{-1}$)', size=fs)
    else:
        ax.set_ylabel(
            '$f_\lambda$ (erg cm$^{-2}$ s$^{-1}$ $\AA^{-1}$)', size=fs)

    return ax


In [None]:
ew0_cut = 30
ew_oth = 100
mag_min, mag_max = 16, 30
nb_min, nb_max = 1, 24

cont_est_lya, cont_err_lya, cont_est_other, cont_err_other =\
    nb_or_3fm_cont(pm_flx, pm_err)

# Lya search
line = is_there_line(pm_flx, pm_err, cont_est_lya,
                        cont_err_lya, ew0_cut)
lya_lines, lya_cont_lines, _ = identify_lines(
    line, pm_flx, cont_est_lya, first=True, return_line_width=True
)
lya_lines = np.array(lya_lines)

# Other lines
line_other = is_there_line(pm_flx, pm_err, cont_est_other, cont_err_other,
                            ew_oth, obs=True, sigma=5)
other_lines = identify_lines(line_other, pm_flx, cont_est_other)

N_sources = pm_flx.shape[1]

mag_cut = (mag > mag_min) & (mag < mag_max)

z_Arr = np.zeros(N_sources)
z_Arr[np.where(np.array(lya_lines) != -1)] =\
    z_NB(np.array(lya_cont_lines)[np.where(np.array(lya_lines) != -1)])

snr = np.empty(N_sources)
for src in range(N_sources):
    l = lya_lines[src]
    snr[src] = pm_flx[l, src] / pm_err[l, src]

mask = (lya_lines >= nb_min) & (lya_lines <= nb_max) & mag_cut & (snr > 6)
nice_lya = nice_lya_select(
    lya_lines, other_lines, pm_flx, pm_err, cont_est_lya, z_Arr, mask=mask
)

# Estimate Luminosity
_, EW_Arr_e, L_Arr, _, _, _ = EW_L_NB(
    pm_flx, pm_err, cont_est_lya, cont_err_lya, z_Arr, lya_lines, N_nb=0
)

L_Lbin_err = np.load('../LAEs/npy/L_nb_err.npy')
median_L = np.load('../LAEs/npy/L_bias.npy')
L_binning = np.load('../LAEs/npy/L_nb_err_binning.npy')
L_bin_c = [L_binning[i: i + 2].sum() * 0.5 for i in range(len(L_binning) - 1)]

# Correct L_Arr with the median
mask_median_L = (median_L < 10)
L_Arr = L_Arr - np.interp(L_Arr, np.log10(L_bin_c)
                            [mask_median_L], median_L[mask_median_L])

# Apply bin err
L_binning_position = binned_statistic(
    10 ** L_Arr, None, 'count', bins=L_binning
).binnumber
L_binning_position[L_binning_position > len(
    L_binning) - 2] = len(L_binning) - 2
L_e_Arr = L_Lbin_err[L_binning_position]

bins = np.log10(L_binning)

N_sources = pm_flx.shape[1]

# Compute EW_Arr
EW_Arr = np.empty(L_Arr.shape)
for src in range(N_sources):
    l = lya_lines[src]
    EW_Arr[src] = (pm_flx[l, src] / cont_est_lya[l, src] -
                    1) * nb_fwhm_Arr[l]
EW_Arr /= z_Arr

In [None]:
fig, ax = plt.subplots()

bins = np.linspace(41, 46, 30)
mask = nice_lya & (mag < 24)
ax.hist(L_Arr[mask], bins, histtype='step', label='r < 24')
mask = nice_lya & (mag > 24)
ax.hist(L_Arr[mask], bins, histtype='step', label='r > 24')

ax.legend()

plt.show()

fig, ax = plt.subplots()

bins = np.linspace(1e1, 1e3, 30)
mask = nice_lya & (mag < 24)
ax.hist(EW_Arr[mask], bins, histtype='step', label='r < 24')
mask = nice_lya & (mag > 24)
ax.hist(EW_Arr[mask], bins, histtype='step', label='r > 24')

ax.legend()

plt.show()

In [None]:
for src in np.where(nice_lya & (mag > 24))[0]:
    fig = plt.figure(figsize=(8, 6))
    ax = plot_JPAS_source(pm_flx[:, src], pm_err[:, src])

    lya_obs_w = w_central[lya_lines[src]]
    Ly_lim_w = 912 * (z_Arr[src] + 1)
    ax.axvline(lya_obs_w, linestyle='--', color='r', label='Retrieved Lya line', zorder=999)
    ax.axvline(Ly_lim_w, linestyle='--', color='dimgray', label='Ly lim', zorder=999)
    ax.axhline(0, c='k', zorder=-999)

    print(f'L_Lya = {L_Arr[src]:0.2f}, EW_Lya = {EW_Arr[src]:0.2f}')
    plt.show()