In [None]:
from astropy.io import fits
from astropy.table import Table
import pandas as pd
import numpy as np
import threading
import time
import pickle

In [None]:
path_to_paus_data = '/home/alberto/almacen/PAUS_data'
with open(f'{path_to_paus_data}/paus_tcurves.pkl', 'rb') as f:
    tcurves = pickle.load(f)

In [None]:
filename = '/home/alberto/almacen/SDSS_spectra_fits/DR16/DR16Q_Superset_v3.fits'
with fits.open(filename) as fread:
    # Criteria in Queiroz et al. 2022
    # good_qso = (
    #     (fread[1].data['ZWARNING'] == 0)
    #     & (fread[1].data['SN_MEDIAN_ALL'] > 0)
    #     & (fread[1].data['Z_CONF'] == 3)
    #     & ((fread[1].data['CLASS_PERSON'] == 3) | (fread[1].data['CLASS_PERSON'] == 30))
    # )
    good_qso = (
        (fread[1].data['ZWARNING'] == 0)
        & (fread[1].data['SN_MEDIAN_ALL'] > 0)
        & (fread[1].data['IS_QSO_FINAL'] > 0)
    )

    plate = fread[1].data['PLATE'][good_qso]
    mjd = fread[1].data['MJD'][good_qso]
    fiber = fread[1].data['FIBERID'][good_qso]

    z_Arr = fread[1].data['Z_VI'][good_qso]

    print(f'Good QSOs: {sum(good_qso)}')

In [None]:
z = pd.DataFrame(z_Arr.reshape(-1, 1))
# z.to_csv('/home/alberto/Desktop/z.csv', header=['z'])

In [None]:
wt_Arr = [np.array(w) for w in tcurves['w']]
t_Arr = [np.array(t) for t in tcurves['t']]

# Cut w and t where the transmission is greater than some value for
# performance and bugs
# ALSO: cut where the wavelength is lower than the lower limit of SDSS
t_int_Arr = []
for fil in np.arange(46):
    cut_t_curve = (t_Arr[fil] > 0.05)
    wt_Arr[fil] = wt_Arr[fil][cut_t_curve]
    t_Arr[fil] = t_Arr[fil][cut_t_curve]

    w = wt_Arr[fil]
    t = t_Arr[fil]
    t_int_Arr.append(np.trapz(w * t, w))

phot_len = len(tcurves['tag'])

which_filters = np.arange(phot_len)
def synth_phot(SEDs, w_Arr):
    pm = np.zeros(phot_len)    

    for fil in which_filters:
        w = wt_Arr[fil]
        t = t_Arr[fil]

        # Set t to zero where w<min(w_Arr)
        t[w < np.min(w_Arr)] = 0

        wt = w * t

        sed_interp = np.interp(w, w_Arr, SEDs, left=1e99, right=np.inf)
        sed_int = np.trapz(wt * sed_interp, w)
        
        pm[fil] = sed_int / t_int_Arr[fil]
    return pm[which_filters]

In [None]:
fits_dir = '/home/alberto/almacen/SDSS_spectra_fits/DR16/QSO'
def do_qso_phot(plate, mjd, fiber, pm_SEDs, slc):
    # N_src_this = len(mjd)
    for src, (pl, mj, fi) in enumerate(zip(plate, mjd, fiber)):
        # if src % 100 == 0:
        #     print(f'{src} / {N_src_this}')
        spec_name = f'{fits_dir}/spec-{pl:04d}-{mj:05d}-{fi:04d}.fits'

        spec = Table.read(spec_name, hdu=1, format='fits')

        # The range of SDSS is 3561-10327 Angstroms. Beyond the range limits,
        # the flux will be 0
        pm_SEDs[:, slc][:, src] += synth_phot(spec['FLUX'] * 1e-17, 10 ** spec['LOGLAM'])

In [None]:
N_src = len(mjd)
N_thr = 16
N_src_thr, rem = divmod(N_src, N_thr)

pm_SEDs = np.zeros((46, N_src))

for thr_i in range(N_thr):
    if thr_i == N_thr - 1:
        slc = slice(thr_i * N_src_thr, (thr_i + 1) * N_src_thr + rem)
    else:
        slc = slice(thr_i * N_src_thr, (thr_i + 1) * N_src_thr)
    args = (plate[slc], mjd[slc], fiber[slc], pm_SEDs, slc)

do_qso_phot(*args)

In [None]:
N_src = len(mjd)
pm_SEDs = np.zeros((46, N_src))

# Divide in 16 processes
N_thr = 16
N_src_thr, rem = divmod(N_src, N_thr)

initial_count = threading.activeCount()
for thr_i in range(N_thr):
    if thr_i == N_thr - 1:
        slc = slice(thr_i * N_src_thr, (thr_i + 1) * N_src_thr + rem)
    else:
        slc = slice(thr_i * N_src_thr, (thr_i + 1) * N_src_thr)
    args = (plate[slc], mjd[slc], fiber[slc], pm_SEDs, slc)
    threading.Thread(target=do_qso_phot, args=args).start()

t0 = time.time()
while True:
    thr_count = threading.activeCount() - initial_count
    if thr_count == 0:
        break
    N_done = np.sum(pm_SEDs != 0) // 60
    time_i = time.time() - t0
    print(f'{N_done} / {N_src},\tspeed = {N_done / time_i:0.1f} s^-1,\telapsed = {time_i:0.1f} s')
    time.sleep(5)

In [None]:
# Save the cat
hdr = tcurves['tag'] + ['mjd', 'plate', 'fiber']

where_bad = np.where(~np.isfinite(pm_SEDs))
pm_SEDs[where_bad] = np.inf

savedir = '/home/alberto/almacen/PAUS_data'

pd.DataFrame(
    data=np.hstack(
        [
            pm_SEDs.T,
            mjd.reshape(-1, 1),
            plate.reshape(-1, 1),
            fiber.reshape(-1,1)
        ]
    )
).to_csv(f'{savedir}/PAUS-PHOTOSPECTRA_QSO_Superset_DR16_v2.csv', header=hdr)

print('\nCatalog saved\n')