In [None]:
from astropy.io import fits
import matplotlib
import matplotlib.pyplot as plt
from matplotlib import patheffects
matplotlib.rc('font', **{'family': 'serif', 'serif': ['Computer Modern']})
matplotlib.rc('text', usetex=True)
import numpy as np
import pandas as pd
from LAEs.my_functions import *
from LAEs.load_jpas_catalogs import load_minijpas_jnep, load_sdss_xmatch
import os
import os.path as op
from astropy.table import Table
from visual_inspection import nanomaggie_to_flux

filter_labels = load_filter_tags()
w_central = central_wavelength()
fwhm_Arr = nb_fwhm(np.arange(60))

# Exposure times for NB and BB in seconds
bb_exp_time = 30
nb_exp_time = 120

######
selection = pd.read_csv('csv/selection.csv')
sel_x_im = selection['x_im']
sel_y_im = selection['y_im']
# puri = selection['puri']

print('Loading catalogs...')
pm_flx, pm_err, x_im, y_im, tile_id, number, starprob, spCl,\
        photoz, photoz_chi_best, photoz_odds, RA, DEC\
            = load_minijpas_jnep(selection=True)
N_sel = len(selection['src'])
N_sources = len(tile_id)

# Estimate the continuum to plot it
cont_est_lya, cont_err_lya = estimate_continuum(pm_flx, pm_err, IGM_T_correct=False)

sdss_xm_num, sdss_xm_tid, sdss_xm_spObjID = load_sdss_xmatch()[:3]

# Directory of the spectra .fits files
fits_dir = '/home/alberto/almacen/SDSS_spectra_fits/miniJPAS_Xmatch'

# HETDEX spectra
path_to_cat = '/home/alberto/almacen/HETDEX_catalogs/hetdex_source_catalog_1'
version = 'v3.2'
hdu_hetdex_spec = fits.open(op.join(path_to_cat, f'hetdex_sc1_spec_{version}.fits'))

path_to_agn = '/home/alberto/almacen/HETDEX_catalogs/agn_catalog_v1.0'

source_table = Table.read(op.join(path_to_cat, f'hetdex_sc1_{version}.ecsv'))
det_table = Table.read(op.join(path_to_cat, 'hetdex_sc1_detinfo_{}.ecsv'.format(version)))
xm_hetdex_id = np.load('npy/hetdex_crossmatch_ids.npy')

fname = f'{path_to_agn}/hetdex_agn.fits'
agn = Table.read(fname, format='fits', hdu=1)
z_hetdex = np.ones(N_sources) * -1
L_lya_hetdex = np.ones(N_sources) * -1
EW_lya_hetdex = np.ones(N_sources) * -9999999999999
EW_lya_hetdex_err = np.ones(N_sources) * 9999
type_hetdex = np.zeros(N_sources).astype(str)
for src in range(N_sources):
    if xm_hetdex_id[src] > 0:
        wh = np.where(xm_hetdex_id[src] == source_table['source_id'])[0][0]
        wh_det = np.where(xm_hetdex_id[src] == det_table['source_id'])[0][0]
        wh_agn = np.where(xm_hetdex_id[src] == agn['detectid_best'])[0]

        z_hetdex[src] = source_table['z_hetdex'][wh]
        type_hetdex[src] = source_table['source_type'][wh]
        if len(wh_agn) > 0:
            wh_agn = wh_agn[0]
            F_lya = agn['flux_LyA'] * 1e-17
            print(F_lya)
            dL = cosmo.luminosity_distance(z_hetdex[src]).to(u.cm).value
            L_lya_hetdex[src] = np.log10(F_lya * 4*np.pi * dL ** 2)
        else:
            L_lya_hetdex[src] = np.log10(source_table['lum_lya'][wh])
            if type_hetdex[src] == 'lae':
                EW_lya_hetdex[src] = det_table['flux'][wh_det] / det_table['continuum'][wh_det]
                EW_lya_hetdex_err[src] = (
                    (det_table['flux_err'][wh_det] / det_table['continuum'][wh_det]) ** 2
                    + (det_table['flux'][wh_det] * det_table['continuum'][wh_det]**-2
                    * det_table['continuum_err'][wh_det]) ** 2
                ) ** 0.5



In [None]:
import os.path as op
from astropy.io import fits

# HETDEX spectra
path_to_cat = '/home/alberto/almacen/HETDEX_catalogs/hetdex_source_catalog_1'
version = 'v3.2'
hdu_hetdex_spec = fits.open(op.join(path_to_cat, f'hetdex_sc1_spec_{version}.fits'))

In [None]:
hdu_hetdex_spec[1].data['RA'].shape

In [None]:
# List of sources to plot
src_to_plot = [
    13871, 402, 33296, 46964, 51714,  # Good LAEs
    7234, 7653, 24054, 3453, # Bad QSOs
    23401 # low-z Galaxies
]

n_sel_list = [np.where(selection['src'] == src)[0][0] for src in src_to_plot]

fig, axs = plt.subplots(5, 7, figsize=(12, 9),
                        width_ratios=[1.3, 0.4, 0.4, 0.1, 1.3, 0.4, 0.4])
                        # height_ratios=[1, 1, 1, 1, 1])
fig.subplots_adjust(hspace=0, wspace=0.05)
for i in range(5):
    axs[i, 3].set_visible(False)

for iii, n in enumerate(n_sel_list):
    # if iii == 2: break # TMP
    if iii < 5:
        ax0 = axs[iii, 0]
        ax1 = axs[iii, 1]
        ax2 = axs[iii, 2]
    elif iii >= 5:
        ax0 = axs[iii - 5, 4]
        ax1 = axs[iii - 5, 5]
        ax2 = axs[iii - 5, 6]

    for ax in [ax1, ax2]:
        newpos = np.array(ax.get_position())
        newpos[1] = newpos[1] - newpos[0]
        newpos[0, 1] -= 0.031
        ax.set_position(newpos.flatten())

    ax1.tick_params(axis='both', bottom=False, top=False,
                    labelbottom=False, labeltop=False,
                    right=False, left=False,
                    labelright=False, labelleft=False)
    ax2.tick_params(axis='both', bottom=False, top=False,
                    labelbottom=False, labeltop=False,
                    right=False, left=False,
                    labelright=False, labelleft=False)

    print(f'Plotting: {n + 1}')

    try:
    # Look for the source in the SDSS Xmatch
        where_mjj = np.where((sel_x_im[n] == x_im) & (sel_y_im[n] == y_im))[0][0]
        this_number = int(number[where_mjj])
        this_tile_id = int(tile_id[where_mjj])

        this_spObjID = sdss_xm_spObjID.to_numpy()[(this_number == sdss_xm_num)
                                        & (this_tile_id == sdss_xm_tid)][0]

        # Disgregate SpObjID in mjd, tile, fiber
        spObj_binary = np.binary_repr(this_spObjID)
        plate = int(spObj_binary[::-1][50:64][::-1], 2)
        mjd = int(spObj_binary[::-1][24:38][::-1], 2) + 50000
        fiber = int(spObj_binary[::-1][38:50][::-1], 2)

        spec_name = f'spec-{plate:04d}-{mjd:05d}-{fiber:04d}.fits'
        print(spec_name)
        spec_bool = True
        spec_sdss = Table.read(f'{fits_dir}/{spec_name}', hdu=1, format='fits')
        g_band_sdss = Table.read(f'{fits_dir}/{spec_name}', hdu=2, format='fits')['SPECTROFLUX']
        g_band_sdss = nanomaggie_to_flux(np.array(g_band_sdss)[0][1], 4750)

    except:
        spec_bool = False
        spec_sdss = None
        g_band_sdss = None

    # Look for the source in the HETDEX Xmatch
    src = selection['src'][n].astype(int)
    this_x_im = selection['x_im'][n].astype(int)
    this_y_im = selection['y_im'][n].astype(int)
    nb = selection['nb_sel'][n].astype(int)
    other_lines = selection['other_lines'][n]
    z_src = z_NB(nb)[0]
    NB_snr = pm_flx[nb, src] / pm_err[nb, src]


    if xm_hetdex_id[where_mjj] > 0:
        spec_bool = True

        wh_hetdex = np.where(xm_hetdex_id[where_mjj] == source_table['source_id'])[0][0]

        spec_hetdex = hdu_hetdex_spec['SPEC'].data[wh_hetdex]
        spec_w_hetdex = hdu_hetdex_spec['WAVELENGTH'].data
        g_band_hetdex = mag_to_flux(source_table['gmag'][wh_hetdex], w_central[-3])
    else:
        wh_hetdex = None
        spec_w_hetdex = None
        spec_hetdex = None
        g_band_hetdex = None

    oth_raw_list = other_lines[1:-1].split()
    if len(oth_raw_list) == 0:
        oth_list = []
    else:
        oth_list = [int(item[:-1]) for item in oth_raw_list[:-1]] + [int(oth_raw_list[-1])]

    # PLOT IT

    nb_sel = selection['nb_sel'][n]
    Lyb_str = r'Ly$\beta$'
    spec_fts = {
        # 'Ly lim': 912 * (1 + z_src),
        'CIV': 1549.48 * (1 + z_src),
        f'OVI\n+\n{Lyb_str}': 1033.82 * (1 + z_src),
        'CIII': 1908.73 * (1 + z_src),
        'MgII': 2799.12 * (1 + z_src),
        'CII': 2326.0 * (1 + z_src),
    }
    z_SDSS = selection['SDSS_zspec'][n]
    spec_fts_zspec_sdss = {
        'CIV': 1549.48 * (1 + z_SDSS),
        f'OVI\n+\n{Lyb_str}': 1033.82 * (1 + z_SDSS),
        'CIII': 1908.73 * (1 + z_SDSS),
        'MgII': 2799.12 * (1 + z_SDSS),
        'CII': 2326.0 * (1 + z_src),
    }
    spec_fts_zspec_hetdex = {
        'OII': (3727.092 + 3729.875) * 0.5 * (1 + z_hetdex[where_mjj]),
        # 'MgII': 2799.12 * (1 + z_hetdex[where_mjj]),
        # r'H$\beta$': 4862.68 * (1 + z_hetdex[where_mjj]),
        'OIII': (4960.295 + 5008.240) * 0.5 * (1 + z_hetdex[where_mjj]),
        r'H$\alpha$': 6564.61 * (1 + z_hetdex[where_mjj])
    }
    z_att = 0.54
    spec_fts_att = {
        'CIV': 1549.48 * (1 + z_att),
        f'OVI\n+\n{Lyb_str}': 1033.82 * (1 + z_att),
        'CIII': 1908.73 * (1 + z_att),
        'MgII': 2799.12 * (1 + z_att),
        'CII': 2326.0 * (1 + z_att),
        'OII': (3727.092 + 3729.875) * 0.5 * (1 + z_att),
        r'H$\beta$': 4862.68 * (1 + z_att),
        'OIII': (4960.295 + 5008.240) * 0.5 * (1 + z_att),
        r'H$\alpha$': 6564.61 * (1 + z_att)
    }

    if tile_id[where_mjj] == 2520:
        survey_name = 'jnep'
    else:
        survey_name = 'minijpas'
    filenamer = f'/home/alberto/almacen/images_fits/{survey_name}/{tile_id[where_mjj]:0.0f}-{59}.fits'
    filenamenb = f'/home/alberto/almacen/images_fits/{survey_name}/{tile_id[where_mjj]:0.0f}-{nb_sel + 1}.fits'

    if x_im[where_mjj] is not None and y_im[where_mjj] is not None:
        box_side = 16
        y_range = slice(this_x_im - box_side, this_x_im + box_side + 1)
        x_range = slice(this_y_im - box_side, this_y_im + box_side + 1)
        im_r = fits.open(filenamer)[1].data[x_range, y_range]
        im_nb = fits.open(filenamenb)[1].data[x_range, y_range]

        # Normalize by the bandwidth
        im_r = im_r / fwhm_Arr[-2] * bb_exp_time
        im_nb = im_nb / fwhm_Arr[nb_sel] * nb_exp_time


    plot_JPAS_source(pm_flx[:, src], pm_err[:, src],
                        e17scale=True, fs=15, mock_mode=False,
                        ax=ax0, bb_ms=8, nb_ms=6, BB_alpha=0.7)

    
    # Draw line on the selected NB
    text_h = ax0.get_ylim()[1] - 0.15 * (ax0.get_ylim()[1] - ax0.get_ylim()[0])
    text_h_number = ax0.get_ylim()[1] - 0.11 * (ax0.get_ylim()[1] - ax0.get_ylim()[0])
    alt_text_h = ax0.get_ylim()[1] - 0.33 * (ax0.get_ylim()[1] - ax0.get_ylim()[0])
    alt_text_h_2 = ax0.get_ylim()[1] - 0.25 * (ax0.get_ylim()[1] - ax0.get_ylim()[0])
    ax0.axvline(w_central[nb_sel], color='r', linestyle='--',
                zorder=9999, lw=1.5)

    # Plot a text with a identifyer
    ax0.text(3150, text_h_number, rf'\bf{iii + 1})',
             fontsize=12, c='r')

    if iii < 5:
        ax0.text(w_central[nb_sel] + 30, text_h,
                r'Ly$\alpha$', fontsize=8.5, color='k')
    # Draw line on other lines selected
    for nb in oth_list:
        print(nb)
        ax0.axvline(w_central[nb], ls='--', c='g', zorder=-90)
    # Draw other important features
    if iii > -1 and iii < 5:
        for jj, (name, w_value) in enumerate(spec_fts.items()):
            if jj == 1:
                this_text_h = alt_text_h
            elif jj == 4 or jj == 7:
                this_text_h = alt_text_h_2
            else:
                this_text_h = text_h
            if w_value < 3700 or w_value > 9000:
                continue
            ax0.axvline(w_value, color='dimgray', linestyle=':')
            ax0.text(w_value + 30, this_text_h, name,
                    color='k', fontsize=8.5, in_layout=True, zorder=999999)
    if iii > 4 and iii < 9:
        for jj, (name, w_value) in enumerate(spec_fts_zspec_sdss.items()):
            if jj == 1:
                this_text_h = alt_text_h
            elif jj == 4 or jj == 7:
                this_text_h = alt_text_h_2
            else:
                this_text_h = text_h
            if w_value < 3700 or w_value > 9000:
                continue
            ax0.axvline(w_value, color='dimgray', linestyle=':')
            ax0.text(w_value + 30, this_text_h, name,
                    color='k', fontsize=8.5, in_layout=True, zorder=999999)
    if iii == 9:
        for jj, (name, w_value) in enumerate(spec_fts_att.items()):
            if jj == 1:
                this_text_h = alt_text_h
            elif jj == 4 or jj == 7:
                this_text_h = alt_text_h_2
            else:
                this_text_h = text_h
            if w_value < 3700 or w_value > 9000:
                continue
            ax0.axvline(w_value, color='dimgray', linestyle=':')
            ax0.text(w_value + 30, this_text_h, name,
                    color='k', fontsize=8.5, in_layout=True, zorder=999999)

    #### Plot SDSS spectrum if available ####
    if g_band_sdss is not None and spec_sdss is not None:
        # Normalizing factor:
        norm = pm_flx[-3, src] / g_band_sdss
        spec_flx_sdss = spec_sdss['FLUX'] * norm
        spec_w_sdss = 10 ** spec_sdss['LOGLAM']

        # REBIN HETDEX spec
        spec_sdss = rebin_1d_arr(spec_flx_sdss, 5)
        spec_w_sdss = spec_w_sdss[::5][:len(spec_sdss)]

        ax0.plot(spec_w_sdss, spec_sdss,
                        c='dimgray', zorder=-99, alpha=0.7)
    if spec_hetdex is not None:
        norm = pm_flx[-3, src] / g_band_hetdex
        spec_hetdex = spec_hetdex * norm

        # REBIN HETDEX spec
        spec_hetdex = rebin_1d_arr(spec_hetdex, 5)
        spec_w_hetdex = spec_w_hetdex[::5][:len(spec_hetdex)]

        ax0.plot(spec_w_hetdex, spec_hetdex,
                        c='orange', zorder=-100, alpha=0.7)

    # Plot the continuum at lya pos
    ax0.plot(w_central[nb_sel], cont_est_lya[nb_sel, src] * 1e17,
            ls='', marker='s', c='k', ms=6, zorder=999)

    if x_im is not None and y_im is not None:
        ax1.imshow(im_r, cmap='binary')
        ax2.imshow(im_nb, cmap='binary')

        # Add circumference showing aperture 3arcsec diameter
        aper_r_px = 1.5 / 0.23
        circ1 = plt.Circle((box_side, box_side),
                        radius=aper_r_px, ec='g', fc='none')
        circ2 = plt.Circle((box_side, box_side),
                        radius=aper_r_px, ec='g', fc='none')
        ax1.add_patch(circ1)
        ax2.add_patch(circ2)
        ax2.text(ax2.get_xlim()[0] + 2, ax2.get_ylim()[0] - 1,
                 filter_labels[nb_sel], fontsize=11, c='k',
                 weight='bold',
                 path_effects=[patheffects.withStroke(linewidth=2,
                                                      foreground='w')])
        r_text = f'r = {selection["r"][n]:0.1f}'
        ax1.text(ax1.get_xlim()[0] + 2, ax1.get_ylim()[0] - 1,
                 r_text, fontsize=11, c='k',
                 weight='bold',
                 path_effects=[patheffects.withStroke(linewidth=2,
                                                      foreground='w')])

    # Add zero line
    ax0.axhline(0, ls='-', c='dimgray', lw=1, zorder=-9999)
    
    # Plot text
    L_lya_str = r'$L_{\mathrm{Ly}\alpha}$'
    EW_lya_str = r'EW$_0^{\mathrm{Ly}\alpha}$'
    z_NB_str = r'$z_\mathrm{NB}$'
    z_spec_str = r'$z_\mathrm{spec}$'
    z_spec_value_str = (f'{z_SDSS:0.2f}' if z_SDSS > 0 else '-')
    L_np = selection['L_lya'][n]
    L_err_np = selection['L_lya_err'][n]
    ##
    L_err_up = np.log10(10**L_np + L_err_np) - L_np
    L_err_down = L_np - np.log10(10**L_np - L_err_np)
    L_err_up_down_str = f'$^\u007b + {L_err_up:0.2f}\u007d_\u007b - {L_err_down:0.2f}\u007d$'
    erg_s = r'erg\,s$^{-1}$'
    src_info_str = (
        f'{L_lya_str} = {selection["L_lya"][n]:0.2f}'
        f'{L_err_up_down_str} {erg_s}'
        f'\n{EW_lya_str} = {selection["EW_lya"][n]:0.0f}'
        f'$\pm${selection["EW_lya_err"][n]:0.0f} \AA'
        f'\n{z_NB_str} = {z_src:0.2f}, {z_spec_str} = {z_spec_value_str}'
    )
    this_text_h = ax0.get_ylim()[1] - 0.375 * (ax0.get_ylim()[1] - ax0.get_ylim()[0])
    ax0.text(9900, this_text_h, src_info_str, fontsize=10)

    ax0.tick_params(labelsize=13, direction='in', which='both')
    ax0.set_xticks(np.arange(3000, 10000, 1000))
    ax0.yaxis.set_ticks_position('both')
    ax0.xaxis.set_ticks_position('both')
    ax0.set_xlabel('')
    ax0.set_ylabel('')
    ax0.set_xlim(3100, 9800)

    if iii == 4 or iii == 9:
        ax0.set_xlabel(r'$\lambda_\mathrm{obs}$ [\AA]',
                        fontsize=17)

    if iii == 2:
        ax0.set_ylabel(r'$f_\lambda\cdot10^{17}$ [erg cm$^{-2}$ s$^{-1}$ \AA$^{-1}$]',
                       fontsize=17)

fig.savefig('figures/LAEs_Examples.pdf', bbox_inches='tight',
            pad_inches=0.1, facecolor='w')
# plt.show()