In [None]:
import sys
sys.path.insert(0, '..')

import glob

import matplotlib
import matplotlib.pyplot as plt
matplotlib.rcParams.update({'font.size': 16})

import numpy as np

from astropy.io import fits
from astropy.visualization import ZScaleInterval
from astropy.wcs import WCS
from astropy.coordinates import SkyCoord
from astropy.nddata import Cutout2D

from scipy import stats
from scipy.ndimage import gaussian_filter

import lmfit

NB_wav_Arr = np.arange(455, 855, 10).astype(int)

In [None]:
vi_cat = fits.open('/home/alberto/almacen/PAUS_data/catalogs/PAUS_LAE_selection_visual_insp_AT.fits')[1].data

# vi_cat = vi_cat[vi_cat['is_LAE_VI']  & (vi_cat['lya_NB'] > 3) & (vi_cat['L_lya'] < 49)]
vi_cat = vi_cat[(vi_cat['lya_NB'] > 0) & (vi_cat['L_lya'] < 49)]
print(len(vi_cat))

In [None]:
# Collect list of PSF FWHMs
lya_fwhm_list = []
cont_fwhm_list = []

for src in range(len(vi_cat)):
    lya_NB = int(vi_cat['lya_NB'][src])
    ref_id = int(vi_cat['ref_id'][src])
    NB_int_wav = NB_wav_Arr[lya_NB]


    # Lya band
    this_dir = f'/home/alberto/almacen/PAUS_data/cutouts/PSF_FWHM/ID{ref_id}'
    this_fwhm_list = np.genfromtxt(f'{this_dir}/NB{NB_int_wav}/psf_fwhm.txt')

    lya_fwhm_list.append(this_fwhm_list)

    # Cont.
    NB_Arr_cont = np.concatenate(
        [np.arange(max(0, lya_NB - 6), max(0, lya_NB - 1)),
        np.arange(min(lya_NB + 2, 39), min(lya_NB + 6 + 1, 39))]       
    )
    NB_wav_Arr_cont = NB_wav_Arr[NB_Arr_cont]
    cont_dir = f'NB' + '_'.join(NB_wav_Arr_cont.astype(str))

    psf_dir = f'{this_dir}/{cont_dir}/psf_fwhm.txt'
    this_fwhm_list = np.genfromtxt(psf_dir)

    cont_fwhm_list.append(this_fwhm_list)

In [None]:
cut_size = 50

cutouts_dir = '/home/alberto/almacen/PAUS_data/cutouts'
lya_img_list = []
cont_img_list = []

for src in range(len(vi_cat)):
    lya_NB = int(vi_cat['lya_NB'][src])
    ref_id = int(vi_cat['ref_id'][src])
    NB_int_wav = NB_wav_Arr[lya_NB]

    single_epoch_cutout_path = f'{cutouts_dir}/single_epoch/single_epoch_cutouts_{ref_id}/{NB_int_wav}'
    # List all the images
    # img_list = glob.glob(f'{single_epoch_cutout_path}/*.fits')
    # img_list = [fname for fname in img_list if not '.weight' in fname]
    img_list = np.genfromtxt(f'{single_epoch_cutout_path}/img_list.txt', dtype=str)
    img_list = [fname.split('/')[1] for fname in img_list]

    this_img_list = []

    for fname in img_list:
        cutout = fits.open(f'{single_epoch_cutout_path}/{fname}')
        RA = vi_cat['RA'][src]
        DEC = vi_cat['DEC'][src]
        coords = SkyCoord(RA, DEC, unit='deg')
        wcs = WCS(cutout[0])
        img = cutout[0].data
        cutout_img = Cutout2D(img, coords, size=(cut_size, cut_size),
                            wcs=wcs, mode='partial', fill_value=0.).data

        this_norm = np.sum(cutout_img[cutout_img.shape[0] // 2 - 5 : cutout_img.shape[0] // 2 + 5 + 1,
                                      cutout_img.shape[1] // 2 - 5 : cutout_img.shape[1] // 2 + 5 + 1])
        this_img_list.append(cutout_img / this_norm)
    lya_img_list.append(this_img_list)

    
    # CONT
    NB_Arr_cont = np.concatenate(
        [np.arange(max(0, lya_NB - 6), max(0, lya_NB - 1)),
        np.arange(min(lya_NB + 2, 39), min(lya_NB + 6 + 1, 39))]       
    )
    NB_wav_Arr_cont = NB_wav_Arr[NB_Arr_cont]
    single_epoch_cutout_path = f'{cutouts_dir}/single_epoch/single_epoch_cutouts_{ref_id}/{NB_wav_Arr_cont}'#.replace('[', '[[]')
    # List all the images
    # img_list = glob.glob(f'{single_epoch_cutout_path}/*.fits')
    # img_list = [fname for fname in img_list if not '.weight' in fname]
    img_list = np.genfromtxt(f'{single_epoch_cutout_path}/img_list.txt', dtype=str)
    img_list = [fname.split('/')[1] for fname in img_list]

    this_img_list = []

    for fname in img_list:
        cutout = fits.open(f'{single_epoch_cutout_path}/{fname}')
        RA = vi_cat['RA'][src]
        DEC = vi_cat['DEC'][src]
        coords = SkyCoord(RA, DEC, unit='deg')
        wcs = WCS(cutout[0])
        img = cutout[0].data
        cutout_img = Cutout2D(img, coords, size=(cut_size, cut_size),
                            wcs=wcs, mode='partial', fill_value=0.).data

        this_norm = np.sum(cutout_img[cutout_img.shape[0] // 2 - 5 : cutout_img.shape[0] // 2 + 5 + 1,
                                      cutout_img.shape[1] // 2 - 5 : cutout_img.shape[1] // 2 + 5 + 1])
        this_img_list.append(cutout_img / this_norm)
    cont_img_list.append(this_img_list)

In [None]:
# Get the worse PSF (Should check if there's horribly bad PSFs and remove them TODO)

# worst_psf = np.nanmax(np.concatenate(lya_fwhm_list))
# Set max psf to 2 and discard all the images with higher values
worst_psf = 2 / 2.35482 / 0.42


# Lya
for src in range(len(vi_cat)):
    for img_i in range(len(lya_img_list[src])):
        this_psf = lya_fwhm_list[src][img_i] / 2.35482 / 0.42
        gfilter_sigma = (worst_psf**2 - this_psf**2) ** 0.5
        filtered_img = gaussian_filter(lya_img_list[src][img_i], gfilter_sigma)

        if this_psf <= worst_psf:
            lya_img_list[src][img_i] = filtered_img
        else:
            lya_img_list[src][img_i] = filtered_img * np.nan

# Cont
for src in range(len(vi_cat)):
    for img_i in range(len(cont_img_list[src])):
        this_psf = cont_fwhm_list[src][img_i] / 2.35482 / 0.42
        gfilter_sigma = (worst_psf**2 - this_psf**2) ** 0.5
        filtered_img = gaussian_filter(cont_img_list[src][img_i], gfilter_sigma)

        if this_psf <= worst_psf:
            cont_img_list[src][img_i] = filtered_img
        else:
            cont_img_list[src][img_i] = filtered_img * np.nan

In [None]:
stacked_lya_img = np.nanmean(np.concatenate(lya_img_list), axis=0)
stacked_cont_img = np.nanmean(np.concatenate(cont_img_list), axis=0)

fig, axes = plt.subplots(1, 2, figsize=(9, 5))

[vmin, vmax] = ZScaleInterval(contrast=0.1).get_limits(stacked_lya_img.flatten())
axes[0].imshow(stacked_lya_img, vmin=vmin, vmax=vmax)
[vmin, vmax] = ZScaleInterval(contrast=0.1).get_limits(stacked_cont_img.flatten())
axes[1].imshow(stacked_cont_img, vmin=vmin, vmax=vmax)

axes[0].set_xlabel('Ly-alpha NB')
axes[1].set_xlabel('Continuum')

plt.show()

In [None]:
def gaussian_2d(x, y, centerx, centery, sigmax, sigmay, amplitude, c):
    return amplitude * np.exp(-((x - centerx)**2 / (2 * sigmax**2) + (y - centery)**2 / (2 * sigmay**2))) + c

In [None]:
# Fit Gaussian
pixel_size = 0.4293497

# model = lmfit.models.Gaussian2dModel()
model = lmfit.Model(gaussian_2d, independent_vars=('x', 'y'))

mesh_xxyy = np.meshgrid(np.arange(stacked_lya_img.shape[0]),
                        np.arange(stacked_lya_img.shape[1]))
xx = mesh_xxyy[0].flatten() * pixel_size
yy = mesh_xxyy[1].flatten() * pixel_size

params = lmfit.Parameters()

params.add('centerx', value=8., min=8., max=15)
params.add('centery', value=8., min=8., max=15)
params.add('sigmax', value=2, min=0.01, max=5)
params.add('sigmay', value=2, min=0.01, max=5)
params.add('amplitude', value=0.001, min=0.0001, max=2)
params.add('c', value=5, min=-5, max=10)

result = model.fit(stacked_lya_img.flatten(), x=xx, y=yy, params=params)

lmfit.report_fit(result)


In [None]:
# Plot to check the fits
fig, ax = plt.subplots()

plot_xx = np.arange(stacked_lya_img.shape[0]) * pixel_size

ax.errorbar(plot_xx, stacked_lya_img[stacked_lya_img.shape[1]//2, :])
# ax.errorbar(plot_xx, stacked_cont_img[stacked_lya_img.shape[1]//2, :])

fit_xx = np.linspace(0, stacked_cont_img.shape[0] * pixel_size, 200)
pdf1 = (stats.norm.pdf(fit_xx, result.params['centerx'], result.params['sigmax'])
        * (result.params['amplitude'] * (((2 * np.pi) ** 0.5 * result.params['sigmax']))) + result.params['c'])
ax.plot(fit_xx, pdf1)

plt.show()


In [None]:
# Fit Gaussian
pixel_size = 0.4293497

# model = lmfit.models.Gaussian2dModel()
model = lmfit.Model(gaussian_2d, independent_vars=('x', 'y'))

mesh_xxyy = np.meshgrid(np.arange(stacked_cont_img.shape[0]),
                        np.arange(stacked_cont_img.shape[1]))
xx = mesh_xxyy[0].flatten() * pixel_size
yy = mesh_xxyy[1].flatten() * pixel_size

params = lmfit.Parameters()

params.add('centerx', value=8., min=5., max=15)
params.add('centery', value=8., min=5., max=15)
params.add('sigmax', value=2, min=0.001, max=5)
params.add('sigmay', value=2, min=0.001, max=5)
params.add('amplitude', value=0.001, min=0, max=2)
params.add('c', value=5, min=-5, max=10)

result = model.fit(stacked_cont_img.flatten(), x=xx, y=yy, params=params)

lmfit.report_fit(result)


In [None]:
# Plot to check the fits
fig, ax = plt.subplots()

plot_xx = np.arange(stacked_cont_img.shape[0]) * pixel_size

ax.errorbar(plot_xx, stacked_cont_img[stacked_lya_img.shape[1]//2, :])

fit_xx = np.linspace(0, stacked_cont_img.shape[0] * pixel_size, 200)
pdf1 = (stats.norm.pdf(fit_xx, result.params['centerx'], result.params['sigmax'])
        * (result.params['amplitude'] * (((2 * np.pi) ** 0.5 * result.params['sigmax']))) + result.params['c'])
ax.plot(fit_xx, pdf1)


plt.show()
