In [None]:
import numpy as np
import matplotlib.pyplot as plt
from my_functions import *
import pickle as pkl
from three_filter import three_filter_method
from time import time

In [None]:
## Load the catalog
with open('pkl/cat_flambda_photoz_gaia.pkl', 'rb') as file:
    cat = pkl.load(file)

In [None]:
tcurves = load_tcurves(load_filter_tags())

In [None]:
# Define the array of the filters we are going to use // gSDSS band
nb_ind_arr = [8, 9, 10, 11, 12, 13, 14, 15]
bb_ind = -3

In [None]:
# Function to drop sources
def make_masks(cat, nb_ind):
    # Mask sources with high photoz odds
    mask_pz_odds = cat['odds'] < 0.98
    # Mask sources with proper motion according to Gaia (sigma > 3)
    parallax_sn = np.abs(cat['parallax'] / cat['parallax_err'])
    pmra_sn = np.abs(cat['pmra'] / cat['pmra_err'])
    pmdec_sn = np.abs(cat['pmdec'] / cat['pmdec_err'])
    mask_pmotion = (
        (np.sqrt(parallax_sn**2 + pmra_sn**2 + pmdec_sn**2) < 27.**0.5)
        | ( np.isnan(parallax_sn) | np.isnan(pmra_sn) | np.isnan(pmdec_sn) )
    )
    # Mask sources with SNR < 5 in the selected NB
    mask_snr = cat['flx_err'][:,nb_ind] < 0.2
    # Mask too bright gSDSS
    mask_b = cat['flx'][:, -3] > mag_to_flux(15, 4750)
    
    mask_total = mask_pz_odds & mask_pmotion & mask_snr
    return mask_total

In [None]:
# The model function
def model_f(x, m, b):
    return m*x + b

In [None]:
nb_ind = 11
mask = make_masks(cat, nb_ind)
pm_flx = cat['flx'][mask] * 1e-19
pm_err = cat['flx_err'][mask] * pm_flx
filters_tags = load_filter_tags()
fwhm_nb = nb_fwhm(tcurves, nb_ind, True)
nb_fwhm_Arr = [nb_fwhm(tcurves, idx, True) for idx in np.arange(len(filters_tags))]
w_central = np.array(central_wavelength(load_tcurves(load_filter_tags())))
N_nb = 6 # Number of nb on each side of the central one
ew0min = 50

cont_stack, cont_err_stack = stack_estimation(pm_flx.T, pm_err.T, nb_ind, N_nb, w_central, nb_fwhm_Arr, ew0min)
_, cf, cont_err_fit = nbex_cont_estimate(model_f, pm_flx, pm_err, nb_ind, w_central, N_nb, ew0min, fwhm_nb)

In [None]:
# 3-filter
N_sources = pm_flx.shape[0]
A = np.zeros(N_sources)
B = np.zeros(N_sources)
A_err = np.zeros(N_sources)
B_err = np.zeros(N_sources)

t0 = time()

for i in range(N_sources):
    print('{}/{}'.format(i+1, N_sources), end='\r')
    NB = pm_flx[i, nb_ind]
    BB_LC = pm_flx[i, -3]
    BB_LU = pm_flx[i, -2]
    NB_err = pm_err[i, nb_ind]
    BB_LC_err = pm_err[i, -3]
    BB_LU_err = pm_err[i, -2]
    t_NB = np.array(tcurves['t'][nb_ind])
    w_NB = np.array(tcurves['w'][nb_ind])
    t_LC = np.array(tcurves['t'][-3])
    w_LC = np.array(tcurves['w'][-3])
    t_LU = np.array(tcurves['t'][-2])
    w_LU = np.array(tcurves['w'][-2])
    w_EL = np.array(w_central[nb_ind])
    _, A[i], B[i], A_err[i], B_err[i] = three_filter_method(
        NB, BB_LC, BB_LU,
        NB_err, BB_LC_err, BB_LU_err,
        t_NB, w_NB,
        t_LC, t_LU, w_LC, w_LU,
        w_EL
    )
tf_err = (B_err**2 + A**2 * A_err**2)**0.5
t1 = time()
print('Elapsed: {} s'.format(t1-t0))

In [None]:
z = 1215.67 / w_central[nb_ind] - 1
ew = ew0min*(1 + z)
line = pm_flx[:, nb_ind] - cont_stack - (ew * cont_stack) / fwhm_nb\
            > 3 * (cont_err_stack**2 + pm_err[:, nb_ind]**2)**0.5

len(np.where(line)[0])

In [None]:
bb_fwhm = [
    nb_fwhm(tcurves, -4, True),
    nb_fwhm(tcurves, -3, True),
    nb_fwhm(tcurves, -2, True),
    nb_fwhm(tcurves, -1, True)
]

In [None]:
j = 0
for i in range(len(line)):
    if ~line[i]: continue
    j += 1
    if j > 20: break
        
    pm = pm_flx[i]
    errors = pm_err[i]
    cont_fit = cf[i]
    fit_err = cont_err_fit[i]
    
    fig, ax = plt.subplots(figsize=(12,9))
    ax.errorbar(w_central[:-3], pm[:-3], yerr=errors[:-3], fmt='.', c='gray')
    ax.scatter(w_central[nb_ind], pm[nb_ind], c='black')

    ax.scatter(w_central[-4], pm[-4], c='purple' , marker='s')
    ax.scatter(w_central[-3], pm[-3], c='green'  , marker='s')
    ax.scatter(w_central[-2], pm[-2], c='red'    , marker='s')
    ax.scatter(w_central[-1], pm[-1], c='dimgray', marker='s')

    ax.errorbar(w_central[-4], pm[-4], xerr=bb_fwhm[-4]/2, yerr = errors[-4],
                fmt='none', color='purple', elinewidth=4)
    ax.errorbar(w_central[-3], pm[-3], xerr=bb_fwhm[-3]/2, yerr = errors[-3],
                fmt='none', color='green', elinewidth=4)
    ax.errorbar(w_central[-2], pm[-2], xerr=bb_fwhm[-2]/2, yerr = errors[-2],
                fmt='none', color='red', elinewidth=4)
    ax.errorbar(w_central[-1], pm[-1], xerr=bb_fwhm[-1]/2, yerr = errors[-1],
                fmt='none', color='dimgray', elinewidth=4)

    ax.set_xlabel('$\lambda\ (\AA)$', size=15)
    ax.set_ylabel('$f_\lambda$ (erg cm$^{-2}$ s$^{-1}$ $\AA^{-1}$)', size=15)

    # Fit line
    ax.errorbar(w_central[nb_ind]+10, cont_stack[i], yerr=cont_err_stack[i],
                c='violet', marker='^', markersize=9,
                capsize=4, label='Stack NBs', elinewidth=2, capthick=2)
    
    cont_fit_value = cf[i, 1] + cf[i, 0]*w_central[nb_ind]
    ax.plot(np.linspace(4000,6000,1000)+20,
            cf[i, 1] + cf[i, 0]*np.linspace(4000,6000,1000),
            c='saddlebrown', linestyle='dashed')
    ax.errorbar(w_central[nb_ind]+20, cont_fit_value, yerr=cont_err_fit[i],
                c='saddlebrown', marker='*', markersize=9,
                capsize=4, label='Linear fit', elinewidth=2, capthick=2)
    
    cont_fit_value_3 = B[i] + A[i]*w_central[nb_ind]
    ax.plot(np.linspace(4000,6000,1000)+40,
            B[i] + A[i]*np.linspace(4000,6000,1000),
            c='slateblue', linestyle='dashed')
    ax.errorbar(w_central[nb_ind] + 40, cont_fit_value_3, yerr=tf_err[i],
                c='slateblue', marker='*', markersize=9,
                capsize=4, label='3-filter', elinewidth=2, capthick=2)
    
    plt.legend()
    
    print(cat['number'][mask][i])
    
    plt.show()