# NRC-23 - Image Quality Verification by Filter   

## Notebook: Create an empirical PSF

**Author**: Matteo Correnti, STScI Scientist II
<br>
**Created**: November, 2021
<br>
**Last Updated**: February, 2022

## Table of contents
1. [Introduction](#intro)<br>
2. [Setup](#setup)<br>
    2.1 [Python imports](#py_imports)<br>
    2.2 [Plotting functions imports](#matpl_imports)<br>
    2.3 [PSF FWHM dictionary](#psf_fwhm)<br>
3. [Import images to analyze](#data)<br>
    3.1 [Select Detector/Filter to analyze](#sel_data)<br>
    3.2 [Display image](#display_data)<br>
    3.3 [Convert image units and apply pixel area map](#convert_data)<br>
4. [Create the empirical PSF model](#epsf_intro)<br>
    4.1 [Calculate the background](#bkg)<br>
    4.2 [Find sources in the image](#find)<br>
    4.3 [Select sources](#select)<br>
    4.4 [Create catalog of selected sources](#create_cat)<br>
    4.5 [Build the empirical PSFs](#build_epsf)<br>
    4.6 [Display the empirical PSFs](#display_epsf)<br>
5. [Create a single or grid of empirical PSFs](#epsf_intro2)<br>
    5.1 [Count stats in N x N grid](#count_stars)<br>
    5.2 [Build effective PSF (single or grid)](#epsf_grid)<br>

1.<font color='white'>-</font>Introduction <a class="anchor" id="intro"></a>
------------------

This notebook shows how to create an empirical PSF or a grid of empirical PSFs. The choice between a single or a grid is dictated by the number of good PSF stars that will be observed. 

<div class="alert alert-block alert-warning">
    <h3><u><b>Warning</b></u></h3>

This notebook has the primary goal to showcase how it is possible to create an empirical PSF using PhotUtils (in particular the function EPSFBuilider). It is important to note that an accurate empirical PSF (or a grid of empirical PSF) can be derived from a single image **only** if we have a significant number of sources in the image.  
<div >

2.<font color='white'>-</font>Setup <a class="anchor" id="setup"></a>
------------------

In this section we import all the necessary Python packages and we define some plotting parameters.

### 2.1<font color='white'>-</font>Python imports<a class="anchor" id="py_imports"></a> ###

In [None]:
import os 

import sys
import time
import copy

import glob as glob

import numpy as np

import pickle

from astropy.io import fits
from astropy.visualization import simple_norm
from astropy.table import Table
from astropy.nddata import NDData
from astropy.stats import sigma_clipped_stats, SigmaClip

from photutils.background import MMMBackground, MADStdBackgroundRMS, Background2D
from photutils.detection import DAOStarFinder
from photutils.psf import extract_stars
from photutils import EPSFBuilder, GriddedPSFModel

from collections import OrderedDict

### 2.2<font color='white'>-</font>Plotting function imports<a class="anchor" id="matpl_imports"></a> ###

In [None]:
%matplotlib inline
from matplotlib import style, pyplot as plt
import matplotlib.patches as patches
import matplotlib.ticker as ticker

from mpl_toolkits.axes_grid1 import make_axes_locatable

plt.rcParams['image.cmap'] = 'viridis'
plt.rcParams['image.origin'] = 'lower'
plt.rcParams['axes.titlesize'] = plt.rcParams['axes.labelsize'] = 30
plt.rcParams['xtick.labelsize'] = plt.rcParams['ytick.labelsize'] = 20

font1 = {'family': 'helvetica', 'color': 'black', 'weight': 'normal', 'size': '12'}
font2 = {'family': 'helvetica', 'color': 'black', 'weight': 'normal', 'size': '20'}

### 2.3<font color='white'>-</font>PSF FWHM dictionary<a class="anchor" id="psf_fwhm"></a> ###

The dictionary contains the NIRCam point spread function (PSF) FWHM, from the [NIRCam Point Spread Function](https://jwst-docs.stsci.edu/near-infrared-camera/nircam-predicted-performance/nircam-point-spread-functions) JDox page. The FWHM are calculated from the analysis of the expected NIRCam PSFs simulated with [WebbPSF](https://www.stsci.edu/jwst/science-planning/proposal-planning-toolbox/psf-simulation-tool). 

FWHM is used in the finding script to provide a first order discrimination between sources and spurious detections

**Note**: this dictionary need to be updated once the values for the FWHM will be available for each detectors during commissioning.

In [None]:
filters = ['F070W', 'F090W', 'F115W', 'F140M', 'F150W2', 'F150W', 'F162M', 'F164N', 'F182M',
           'F187N', 'F200W', 'F210M', 'F212N', 'F250M', 'F277W', 'F300M', 'F322W2', 'F323N',
           'F335M', 'F356W', 'F360M', 'F405N', 'F410M', 'F430M', 'F444W', 'F460M', 'F466N', 'F470N', 'F480M']

psf_fwhm = [0.987, 1.103, 1.298, 1.553, 1.628, 1.770, 1.801, 1.494, 1.990, 2.060, 2.141, 2.304, 2.341, 1.340,
            1.444, 1.585, 1.547, 1.711, 1.760, 1.830, 1.901, 2.165, 2.179, 2.300, 2.302, 2.459, 2.507, 2.535, 2.574]

dict_utils = {filters[i]: {'psf fwhm': psf_fwhm[i]} for i in range(len(filters))}

3.<font color='white'>-</font>Import images to analyze<a class="anchor" id="data"></a>
------------------

We load all the images and we create a dictionary that contains all of them, divided by detectors and filters. This is useful to check which detectors and filters are available and to perform the analysis presented in this notebook on a detector/filter base. 

We retrieve the NIRCam detector and filter from the image header. Note that for the LW channels, we transform the detector name derived from the header (**NRCBLONG**) to **NRCB5** for consitency with other the notebooks created for NRC-23. 

In [None]:
dict_images = {'NRCA1': {}, 'NRCA2': {}, 'NRCA3': {}, 'NRCA4': {}, 'NRCA5': {},
               'NRCB1': {}, 'NRCB2': {}, 'NRCB3': {}, 'NRCB4': {}, 'NRCB5': {}}

dict_filter_short = {}
dict_filter_long = {}

ff_short = []
det_short = []
det_long = []
ff_long = []
detlist_short = []
detlist_long = []
filtlist_short = []
filtlist_long = []

images_dir = '../Simulation/Pipeline_Outputs/Level2_Outputs'
images = sorted(glob.glob(os.path.join(images_dir, "*cal.fits")))

for image in images:

    im = fits.open(image)
    f = im[0].header['FILTER']
    d = im[0].header['DETECTOR']
    p = im[0].header['PUPIL']

    if d == 'NRCBLONG':
        d = 'NRCB5'
    elif d == 'NRCALONG':
        d = 'NRCA5'
    else:
        d = d
    
    if p == 'CLEAR':
        f = f
    else:
        f = p
    
    wv = float(f[1:3])

    if wv > 24:         
        ff_long.append(f)
        det_long.append(d)

    else:
        ff_short.append(f)
        det_short.append(d)   

    detlist_short = sorted(list(dict.fromkeys(det_short)))
    detlist_long = sorted(list(dict.fromkeys(det_long)))

    unique_list_filters_short = []
    unique_list_filters_long = []

    for x in ff_short:

        if x not in unique_list_filters_short:

            dict_filter_short.setdefault(x, {})
                 
    for x in ff_long:
        if x not in unique_list_filters_long:
            dict_filter_long.setdefault(x, {})   
            
    for d_s in detlist_short:
        dict_images[d_s] = copy.deepcopy(dict_filter_short)

    for d_l in detlist_long:
        dict_images[d_l] = copy.deepcopy(dict_filter_long)

    filtlist_short = sorted(list(dict.fromkeys(dict_filter_short)))
    filtlist_long = sorted(list(dict.fromkeys(dict_filter_long)))

print("Available Detectors for SW channel:", detlist_short)
print("Available Detectors for LW channel:", detlist_long)
print("Available SW Filters:", filtlist_short)
print("Available LW Filters:", filtlist_long)

In [None]:
for image in images:
    
    im = fits.open(image)
    f = im[0].header['FILTER']
    d = im[0].header['DETECTOR']
    p = im[0].header['PUPIL']

    if d == 'NRCBLONG':
        d = 'NRCB5'
    elif d == 'NRCALONG':
        d = 'NRCA5'
    else:
        d = d
    
    if p == 'CLEAR':
        f = f
    else:
        f = p

    if len(dict_images[d][f]) == 0:
        dict_images[d][f] = {'images': [image]}
    else:
        dict_images[d][f]['images'].append(image)


### 3.1<font color='white'>-</font>Select detector/filter to analyze<a class="anchor" id="sel_data"></a> ###

In [None]:
det = 'NRCB1'
filt = 'F200W'

num_images = len(dict_images[det][filt]['images'])
images_original = dict_images[det][filt]['images']

print('Number of images for detector {}, filter {}:'.format(det, filt), num_images)

### 3.2<font color='white'>-</font>Display the image<a class="anchor" id="display_data"></a> ###

To check that our images do not present artifacts and can be used in the analysis, we display them. 

In [None]:
if len(images_original) > 2:

    nn = int(np.sqrt(len(images_original)))
    figsize = (12, 12)
    fig, ax = plt.subplots(nn, nn, figsize=figsize)

    for ix in range(nn):
        for iy in range(nn):
            
            i = ix * nn + iy
            
            im = fits.open(dict_images[det][filt]['images'][i])
            data_sb = im[1].data
            
            ax[nn - 1 - ix, iy].set_xlabel('X [px]', fontsize=15)
            ax[nn - 1 - ix, iy].set_ylabel('Y [px]', fontsize=15)
            
            norm = simple_norm(data_sb, 'sqrt', percent=99.)
            ax[nn - 1 - ix, iy].set_title(det + ' - ' + filt +  ' - image' + str(i+1), fontsize=20)
            ax[nn - 1 - ix, iy].imshow(data_sb, norm=norm, cmap='Greys')
            
            plt.tight_layout()
else:
    
    plt.figure(figsize = (14, 14))
    nn = 2 
    for i in range(nn):
        ax = plt.subplot(1, nn, i + 1)
        
        im = fits.open(dict_images[det][filt]['images'][i])
        data_sb = im[1].data
        
        ax.set_xlabel('X [px]')
        ax.set_ylabel('Y [px]')
        ax.set_title(det + ' - ' + filt +  ' - image' + str(i+1), fontsize=20)
        norm = simple_norm(data_sb, 'sqrt', percent=99.)
        ax.imshow(data_sb, norm=norm, cmap='Greys')
       
        plt.tight_layout()


### 3.3<font color='white'>-</font>Convert image units and apply pixel area map<a class="anchor" id="convert_data"></a> ###

The unit of the Level-2 and Level-3 Images from the pipeline is MJy/sr (hence a surface brightness). The actual unit of the image can be checked from the header keyword **BUNIT**. The scalar conversion constant is copied to the header keyword **PHOTMJSR**, which gives the conversion from DN/s to megaJy/steradian. For our analysis we revert back to DN/s. It is possible to revert back to DN/s setting `unit = True` in the function below.

For images that have not been transformed into a distortion-free frame (i.e. not drizzled), a correction must be applied to account for the different on-sky pixel size across the field of view. A pixel area map (PAM), which is an image where each pixel value describes that pixel's area on the sky relative to the native plate scale, is used for this correction. In the stage 2 of the JWST pipeline, the PAM is copied into an image extension called **AREA** in the science data product. To apply the PAM correction, set the parameter `distorted = True` in the function below.

In [None]:
def convert_pam_image(image, units=True, distorted=True):
    im = fits.open(image)
    data_sb = im[1].data
    imh = im[1].header
    dq = im[3].data
    
    if units:
        
        print('Converting units from {0} to DN/s').format(imh['BUNIT'])
        data = data_sb / imh['PHOTMJSR']
    
    else:
        
        data = data_sb
    
    zero_mask = np.where(data == 0,0,1)
    nan_mask  = np.where(np.isnan(data),0,1)
    zero_mask = nan_mask * zero_mask
    
    nan_mask = np.where(zero_mask == 0,True,False)
    data_mask = nan_mask
    
    if distorted:
        print('Analyzing Level-2 image - *cal.fits and correcting for PAM')
        area = im[4].data
        data_corrected = data * area
    else:
        print('Analyzing Level-3 image - *i2d.fits or not using PAM correction')
        data_corrected = data
    
    return data_corrected, data_mask

4.<font color='white'>-</font>Create the empirical PSF model<a class="anchor" id="epsf_intro"></a>
------------------

More information on the PhotUtils Effective PSF can be found [here](https://photutils.readthedocs.io/en/stable/epsf.html).

The process of creating an effective PSF can be summarized as follows:

* Find the stars in the image.
* Select the stars we want to use for building the effective PSF. 
* Build the effective PSF.

### 4.1<font color='white'>-</font>Calculate the background<a class="anchor" id="bkg"></a> ###

We can adopt as Background estimator the function [MMMBackground](https://photutils.readthedocs.io/en/stable/api/photutils.background.MMMBackground.html#photutils.background.MMMBackground), which calculates the background in an array using the DAOPHOT MMM algorithm, on the whole image (The background is calculated using a mode estimator of the form `(3 * median) - (2 * mean)`). 

However, when dealing with a variable background and/or the need to mask the regions where we have no data (for example, if we are analyzing an image with all the 4 NIRCam SW detectors, i.e. containing the chip gaps), we can set `var_bkg = True` and use a more complex algorithm that takes into account those issues.

In [None]:
def calc_bkg(data, mask, var_bkg=False):
    
    bkgrms = MADStdBackgroundRMS()
    mmm_bkg = MMMBackground()

    if var_bkg:
        print('Using 2D Background')
        sigma_clip = SigmaClip(sigma=3.)
        

        bkg = Background2D(data, (100, 100), filter_size=(3, 3), sigma_clip=sigma_clip, bkg_estimator=mmm_bkg,
                           coverage_mask=mask, fill_value=0.0)

        data_bkgsub = data.copy()
        data_bkgsub = data_bkgsub - bkg.background

        median = bkg.background_median
        std = bkg.background_rms_median        
        print('Background median and rms using Background 2D:', median, std)


    else:

        std = bkgrms(data)
        bkg = mmm_bkg(data)
        print('Background median and rms:', bkg, std)
        data_bkgsub = data.copy()
        data_bkgsub -= bkg

    return data_bkgsub, std

### 4.2<font color='white'>-</font>Find sources in the image<a class="anchor" id="find"></a> ###

To find sources in the image, we use the [DAOStarFinder](https://photutils.readthedocs.io/en/stable/api/photutils.detection.DAOStarFinder.html) function. 

[DAOStarFinder](https://photutils.readthedocs.io/en/stable/api/photutils.detection.DAOStarFinder.html) detects stars in an image using the DAOFIND ([Stetson 1987](https://ui.adsabs.harvard.edu/abs/1987PASP...99..191S/abstract)) algorithm. DAOFIND searches images for local density maxima that have a peak amplitude greater than `threshold` (approximately; threshold is applied to a convolved image) and have a size and shape similar to the defined 2D Gaussian kernel.

In [None]:
def find_stars(image, det='NRCA1', filt='F070W', threshold=3, var_bkg=False):
    
    '''
    Parameters
    ----------
    
    threshold : float 
        The absolute image value above which to select sources.
    
    fwhm : float
        The full-width half-maximum (FWHM) of the major axis of the Gaussian kernel in units of pixels.
        
    var_bkg : bool
        Use Background2D (see description above)
        
    '''
    
    print('Finding stars --- Detector: {d}, Filter: {f}'.format(f=filt, d=det))
    
    sigma_psf = dict_utils[filt]['psf fwhm']

    print('FWHM for the filter {f}:'.format(f=filt), sigma_psf, "px")
    
    print('Converting to MJy/sr to DN/s and applying Pixel Area Map (if needed)')
    
    data_converted, data_mask = convert_pam_image(image, distorted=True)
    
    data_bkgsub, std = calc_bkg(data_converted, data_mask, var_bkg=var_bkg)
    
    daofind = DAOStarFinder(threshold=threshold * std, fwhm=sigma_psf)
    found_stars = daofind(data_bkgsub, mask=data_mask)
    
    print('')
    print('Number of sources found in the image:', len(found_stars))
    print('-------------------------------------')
    print('')
    
    return found_stars, data_bkgsub

In [None]:
tic = time.perf_counter()

found_stars_tot = []
data_bkgsub_tot = []



for i, image in enumerate(images_original):
    
    print('Working on image: {}'.format(i + 1))
    print('')
    found_stars, data_bkgsub = find_stars(image, det=det, filt=filt, threshold=5, var_bkg=True)
    
    found_stars_tot.append(found_stars)
    data_bkgsub_tot.append(data_bkgsub)

toc = time.perf_counter()
print("Elapsed Time for finding stars:", toc - tic)

### 4.3<font color='white'>-</font>Select sources<a class="anchor" id="select"></a> ###

We can adopt different methods to select sources we want to use to build an effective PSF. Here, we select objects applying a brightness cut (we do not want to include objects that are too faint) and using the `roundness2` and `sharpness` parameters provided in the [DAOStarFinder](https://photutils.readthedocs.io/en/stable/api/photutils.detection.DAOStarFinder.html) output catalog.

`roundness2` measures the ratio of the difference in the height of the best fitting Gaussian function in x minus the best fitting Gaussian function in y, divided by the average of the best fitting Gaussian functions in x and y.

`sharpness` measures the ratio of the difference between the height of the central pixel and the mean of the surrounding non-bad pixels in the convolved image, to the height of the best fitting Gaussian function at that point.

We derive the cut from 1 image and we verify that they are appropriate for all the other images.

**Note**: when we derive the selection cuts for a particular detector/filter, they are stored in a file at the end of section 4.3. Hence, below we chek if the file exists so we can skip the selection from the images.  

In [None]:
dict_dir = 'CUT_FOR_ePSF/'

if not os.path.exists(dict_dir):
    os.makedirs(dict_dir)

dict_filename = 'epsf_selection_cuts_{}_{}.pkl'.format(det, filt)

if os.path.exists(os.path.join(dict_dir, dict_filename)):
    
    with open(os.path.join(dict_dir, dict_filename), 'rb') as handle:
        
    
        dict_cut_values = pickle.load(handle)
        print('Load dictionary with selection cuts for detector {} - filter {}'.format(det, filt))
        print('Skip to section 4.4')

else:
    print('The selection cuts have not been created yet for detector {} - filter {}'.format(det, filt))
    print('Determine the parameters from the plots below')

In [None]:
plt.figure(figsize=(12, 8))
plt.clf()

found_stars_test = found_stars_tot[0]

ax1 = plt.subplot(2, 1, 1)

plt.title('Sharpness and Roundness plots for image 1')

ax1.set_xlabel('mag', fontdict=font2)
ax1.set_ylabel('sharpness', fontdict=font2)

xlim0 = np.min(found_stars_test['mag']) - 0.25
xlim1 = np.max(found_stars_test['mag']) + 0.25
ylim0 = np.min(found_stars_test['sharpness']) - 0.15
ylim1 = np.max(found_stars_test['sharpness']) + 0.15

ax1.set_xlim(xlim0, xlim1)
ax1.set_ylim(ylim0, ylim1)

ax1.xaxis.set_major_locator(ticker.AutoLocator())
ax1.xaxis.set_minor_locator(ticker.AutoMinorLocator())
ax1.yaxis.set_major_locator(ticker.AutoLocator())
ax1.yaxis.set_minor_locator(ticker.AutoMinorLocator())

ax1.scatter(found_stars_test['mag'], found_stars_test['sharpness'], s=10, color='k')

ax2 = plt.subplot(2, 1, 2)

ax2.set_xlabel('mag', fontdict=font2)
ax2.set_ylabel('roundness', fontdict=font2)

ylim0 = np.min(found_stars['roundness2']) - 0.25
ylim1 = np.max(found_stars['roundness2']) - 0.25

ax2.set_xlim(xlim0, xlim1)
ax2.set_ylim(ylim0, ylim1)

ax2.xaxis.set_major_locator(ticker.AutoLocator())
ax2.xaxis.set_minor_locator(ticker.AutoMinorLocator())
ax2.yaxis.set_major_locator(ticker.AutoLocator())
ax2.yaxis.set_minor_locator(ticker.AutoMinorLocator())

ax2.scatter(found_stars_test['mag'], found_stars_test['roundness2'], s=10, color='k')

plt.tight_layout()

**Note**: we need to record the values for the selected limits (expectation is that they will not change for the different images of the same detector/filter combination). We create a dictionary to store the values for the different images.

In [None]:
dict_cut_values = {}
for i in np.arange(num_images):
    j = str(i+1)
    dict_cut_values['image '+j] = {'sh inf': 0, 
                   'sh sup': 0,
                   'mag lim': 0,
                   'round inf': 0,
                   'round sup': 0}
dict_cut_values

In [None]:
plt.figure(figsize=(12, 8))
plt.clf()

j = 3

found_stars_test = found_stars_tot[j]

num = str(j+1)

ax1 = plt.subplot(2, 1, 1)

plt.title('Sharpness and Roundness plots for image'+ num+' - selection cuts')

ax1.set_xlabel('mag', fontdict=font2)
ax1.set_ylabel('sharpness', fontdict=font2)

xlim0 = np.min(found_stars_test['mag']) - 0.25
xlim1 = np.max(found_stars_test['mag']) + 0.25
ylim0 = np.min(found_stars_test['sharpness']) - 0.15
ylim1 = np.max(found_stars_test['sharpness']) + 0.15

ax1.set_xlim(xlim0, xlim1)
ax1.set_ylim(ylim0, ylim1)

ax1.xaxis.set_major_locator(ticker.AutoLocator())
ax1.xaxis.set_minor_locator(ticker.AutoMinorLocator())
ax1.yaxis.set_major_locator(ticker.AutoLocator())
ax1.yaxis.set_minor_locator(ticker.AutoMinorLocator())

ax1.scatter(found_stars_test['mag'], found_stars_test['sharpness'], s=10, color='k')

sh_inf = 0.57
sh_sup = 0.70
mag_lim = -2.5


dict_cut_values['image '+num]['sh inf'] = sh_inf
dict_cut_values['image '+num]['sh sup'] = sh_sup
dict_cut_values['image '+num]['mag lim'] = mag_lim

ax1.plot([xlim0, xlim1], [sh_sup, sh_sup], color='r', lw=3, ls='--')
ax1.plot([xlim0, xlim1], [sh_inf, sh_inf], color='r', lw=3, ls='--')
ax1.plot([mag_lim, mag_lim], [ylim0, ylim1], color='r', lw=3, ls='--')

ax2 = plt.subplot(2, 1, 2)

ax2.set_xlabel('mag', fontdict=font2)
ax2.set_ylabel('roundness', fontdict=font2)

ylim0 = np.min(found_stars['roundness2']) - 0.25
ylim1 = np.max(found_stars['roundness2']) - 0.25

ax2.set_xlim(xlim0, xlim1)
ax2.set_ylim(ylim0, ylim1)

ax2.xaxis.set_major_locator(ticker.AutoLocator())
ax2.xaxis.set_minor_locator(ticker.AutoMinorLocator())
ax2.yaxis.set_major_locator(ticker.AutoLocator())
ax2.yaxis.set_minor_locator(ticker.AutoMinorLocator())

round_inf = -0.30
round_sup = 0.30

dict_cut_values['image '+num]['round inf'] = round_inf
dict_cut_values['image '+num]['round sup'] = round_sup


ax2.scatter(found_stars_test['mag'], found_stars_test['roundness2'], s=10, color='k')

ax2.plot([xlim0, xlim1], [round_sup, round_sup], color='r', lw=3, ls='--')
ax2.plot([xlim0, xlim1], [round_inf, round_inf], color='r', lw=3, ls='--')
ax2.plot([mag_lim, mag_lim], [ylim0, ylim1], color='r', lw=3, ls='--')

plt.tight_layout()

In [None]:
dict_cut_values

In [None]:
dict_filename = 'epsf_selection_cuts_{}_{}.pkl'.format(det, filt)

with open(os.path.join(dict_dir, dict_filename), 'wb') as handle:
    pickle.dump(dict_cut_values, handle, protocol=pickle.HIGHEST_PROTOCOL)

### 4.4<font color='white'>-</font>Create catalog of selected sources<a class="anchor" id="create_cat"></a> ###

We can also include a separation criteria if we want to retain in the final catalog only the stars that are well isolated. In particular, we can select only the stars that do not have a neighbour closer than X pixel, where X is a parameter that can be set manually.

In [None]:
found_stars_sel_tot = []
found_stars_sel_dist_tot = []

for i, found_stars in enumerate(found_stars_tot):
    
    j = str(i+1)
    
    mask = ((found_stars['mag'] < dict_cut_values['image '+j]['mag lim']) & 
            (found_stars['roundness2'] > dict_cut_values['image '+j]['round inf']) 
            & (found_stars['roundness2'] < dict_cut_values['image '+j]['round sup']) 
            & (found_stars['sharpness'] > dict_cut_values['image '+j]['sh inf']) 
            & (found_stars['sharpness'] < dict_cut_values['image '+j]['sh sup']))

    found_stars_sel = found_stars[mask]

    print('Number of stars selected to build ePSF for image {}:'.format(i+1), len(found_stars_sel))

    # if we include the separation criteria:

    d = []

    # we do not want any stars in a 10 px radius. 

    min_sep = 10

    x_tot = found_stars['xcentroid']
    y_tot = found_stars['ycentroid']

    for xx, yy in zip(found_stars_sel['xcentroid'], found_stars_sel['ycentroid']):

        sep = []
        dist = np.sqrt((x_tot - xx)**2 + (y_tot - yy)**2)
        sep = np.sort(dist)[1:2][0]
        d.append(sep)

    found_stars_sel['min distance'] = d
    mask_dist = (found_stars_sel['min distance'] > min_sep)

    found_stars_sel_dist = found_stars_sel[mask_dist]

    print('Number of stars selected to build ePSF including "mimimum distance closest neighbour" selection for image {}:'.format(i+1), len(found_stars_sel_dist))
    print('--------------------------')
    print('')
    found_stars_sel_tot.append(found_stars_sel)
    found_stars_sel_dist_tot.append(found_stars_sel_dist)



### 4.5<font color='white'>-</font>Build the empirical PSFs<a class="anchor" id="build_epsf"></a> ###

We Build the effective PSF using [EPSBuilder](https://photutils.readthedocs.io/en/stable/api/photutils.psf.EPSFBuilder.html#photutils.psf.EPSFBuilder) function.

First, we exclude the objects for which the bounding box exceed the detector edge. Then, we extract cutouts of the stars using the [extract_stars()](https://photutils.readthedocs.io/en/stable/api/photutils.psf.extract_stars.html#photutils.psf.extract_stars) function. The size of the cutout is determined by the parameter `size` in our function *build_epsf*. Once we have the object containing the cutouts of our selected stars, we can build our ePSF using [EPSFBuilder](https://photutils.readthedocs.io/en/stable/api/photutils.psf.EPSFBuilder.html#photutils.psf.EPSFBuilder) class. 

In [None]:
def build_epsf(data, table, det='NRCA1', filt='F070W', size=11, oversample=4, iters=10):
    
    hsize = (size - 1) / 2
    
    x = table['xcentroid']
    y = table['ycentroid']
    
    pos_mask = ((x > hsize) & (x < (data.shape[1] - 1 - hsize)) & (y > hsize) & (y < (data.shape[0] - 1 - hsize)))

    stars_tbl = Table()
    stars_tbl['x'] = x[pos_mask]
    stars_tbl['y'] = y[pos_mask]
    
    nddata = NDData(data=data)
    stars = extract_stars(nddata, stars_tbl, size=size)

    print('Creating ePSF --- Detector {d}, filter {f}'.format(f=filt, d=det))
    print('---------------')

    epsf_builder = EPSFBuilder(oversampling=oversample, maxiters=iters, progress_bar=True)

    epsf, fitted_stars = epsf_builder(stars)        
    
    return epsf

In [None]:
epsf_tot = []

size = 41
num_psf = 1
oversample = 4

distorted = True

if distorted:
    epsf_dir = 'ePSF_MODELS/Distorted/Single_PSF/Fov{}px_numPSFs{}_oversample{}'.format(size, num_psf, oversample)

    if not os.path.exists(epsf_dir):
        os.makedirs(epsf_dir)
else:
    epsf_dir = 'ePSF_MODELS/Undistorted/Single_PSF/Fov{}px_numPSFs{}_oversample{}'.format(size, num_psf, oversample)
    
    if not os.path.exists(epsf_dir):
        os.makedirs(epsf_dir)
        
save_epsf = True

for i, (data, table) in enumerate(zip(data_bkgsub_tot, found_stars_sel_tot)):
    
    print('----------------')
    print('Working on image {}'.format(str(i+1)))
    print('----------------')
    print('')
    
    epsf = build_epsf(data, table=table, det=det, filt=filt, size=size, oversample=oversample, iters=3)
    
    if save_epsf:
        hdu = fits.PrimaryHDU(epsf.data)
        hdul = fits.HDUList([hdu])
        epsf_name = 'ePSF_{}_{}_fov{}px_image_{}.fits'.format(det, filt, size, str(i+1))
        hdul.writeto(os.path.join(epsf_dir, epsf_name), overwrite = True)
    
    epsf_tot.append(epsf)
    

### 4.6<font color='white'>-</font>Display the emprirical PSFs<a class="anchor" id="display_epsf"></a> ###

In [None]:
if len(images_original) > 2:

    nn = int(np.sqrt(len(images_original)))
    figsize = (12, 12)
    fig, ax = plt.subplots(nn, nn, figsize=figsize)

    for ix in range(nn):
        for iy in range(nn):
            
            i = ix * nn + iy
            
            epsf = epsf_tot[i].data
            
            ax[nn - 1 - ix, iy].set_xlabel('X [px]', fontsize=15)
            ax[nn - 1 - ix, iy].set_ylabel('Y [px]', fontsize=15)
            
            norm = simple_norm(epsf, 'sqrt', percent=99.)
            ax[nn - 1 - ix, iy].set_title(det + ' - ' + filt +  ' - image' + str(i+1), fontsize=20)
            ax[nn - 1 - ix, iy].imshow(epsf, norm=norm)
            
            plt.tight_layout()
else:
    
    plt.figure(figsize = (14, 14))
    nn = 2 
    for i in range(nn):
        ax = plt.subplot(1, nn, i + 1)
        
        epsf = epsf_tot[i].data
        
        ax.set_xlabel('X [px]')
        ax.set_ylabel('Y [px]')
        ax.set_title(det + ' - ' + filt +  ' - ePSF' + str(i+1), fontsize=20)
        norm = simple_norm(epsf, 'sqrt', percent=99.)
        
        ax.imshow(epsf, norm=norm)
       
        plt.tight_layout()


5.<font color='white'>-</font>Create a single or grid of empirical PSFs <a class="anchor" id="eps_intro2"></a>
------------------

### 5.1<font color='white'>-</font>Count stars in N x N grid<a class="anchor" id="count_stars"></a> ###

The purpose of the function count_stars_grid is to count how many good PSF stars are in cell of a N x N grid. The function starts from a grid of size N x N (where N = sqrt(num_psfs)) and iterate until the minimum grid size 2 x 2. Depending on the number of PSF stars that the users want in each cell of the grid, they can choose the appropriate grid size or modify the threshold values and/or the selection parameters adopted during the stars detection, in Sections 4.3, 4.4.

The minimum number of PSF stars needed in each cell can also be set using the parameter min_numpsfs_stars. Useful when inspecting the plot, since in the cells with a number of PSF stars < min_numpsfs_stars, the value is reported in RED. Moreover, when verbose = True, it is easier to identify for each N x N combination, if and which cells have not enough PSF stars.

This function returns sqrt(num_psfs) - 1 figures showing the number of PSFs stars in each cell for all the N x N combination.

In [None]:
def find_centers(num):
    points = int(((data.shape[0] / num) / 2) - 1)
    x_center = np.arange(points, 2 * points * num, 2 * points)
    y_center = np.arange(points, 2 * points * num, 2 * points)

    centers = np.array(np.meshgrid(x_center, y_center)).T.reshape(-1, 2)

    return points, centers

def count_stars_grid(table, data, num_psfs=4, min_numpsf_stars=40, size=11, verbose=True, savefig=True):
    
    # calculate the number of stars from find_stars in each cell of the grid. The maximum number of cell
    # is defined by num_psfs and the function iterate from N x N (where N = sqrt(num_psfs)) until a 2 x 2 grid.

    if np.sqrt(num_psfs).is_integer():
        grid_points = int(np.sqrt(num_psfs))

    else:
        raise ValueError("You must choose a square number of cells to create (E.g. 9, 16, etc.)")

    num_grid = np.arange(2, grid_points + 1, 1)
    num_grid = num_grid[::-1]


    for num in num_grid:
        print("--------------------")
        print("")
        print("Calculating the number of PSF stars in a %d x %d grid:" % (num, num))
        print("")

        s = (data.shape[1], data.shape[0])
        temp_arr = np.zeros(s)
        num_psfs_stars = []

        points, centers = find_centers(num)

        for n, val in enumerate(centers):

            x = table['xcentroid']
            y = table['ycentroid']

            half_size = (size - 1) / 2

            lim1 = int(val[0] - points + half_size)
            lim2 = int(val[0] + points - half_size)
            lim3 = int(val[1] - points + half_size)
            lim4 = int(val[1] + points - half_size)

            number_psf_stars = (x > lim1) & (x < lim2) & (y > lim3) & (y < lim4)
            count_psfs_stars = np.count_nonzero(number_psf_stars)

            lim_x1 = int(lim1 - half_size)
            lim_x2 = int(lim2 + half_size)
            lim_y1 = int(lim3 - half_size)
            lim_y2 = int(lim4 + half_size)

            if verbose:

                if np.count_nonzero(number_psf_stars) < min_numpsf_stars:
                    print('Center Coordinates of grid cell {:d} are ({:d}, {:d}) --- Not enough stars in the cell '
                            '(number of stars < {:d})'.format(n + 1, val[0], val[1], min_numpsf_stars))

                else:
                    print(f'Center Coordinate of grid cell {n + 1:d} are ({val[0]:d}, {val[1]:d})'
                            '--- Number of stars:', np.count_nonzero(number_psf_stars))
                    print("")

            temp_arr[lim_y1:lim_y2, lim_x1:lim_x2] = count_psfs_stars
            num_psfs_stars.append(count_psfs_stars)

        if savefig:
            plot_count_grid(temp_arr, num, num_psfs_stars, centers)

def plot_count_grid(arr, num, nstars, centers):
    
    plt.clf()
    
    from mpl_toolkits.axes_grid1 import make_axes_locatable
    plt.figure(figsize=(10, 10))
    ax = plt.subplot(1, 1, 1)

    plt.xlabel('X [px]', font={'size': 20})
    plt.ylabel('Y [px]', font={'size': 20})
    plt.title('%dx%d grid - ' % (num, num) + det + ' - ' + filt, font={'size': 25})
    im = ax.imshow(arr, origin='lower', vmin=np.min(arr[arr > 0]), vmax=np.max(arr))
    for i in range(num ** 2):
        if nstars[i] < 40:
            ax.text(centers[i][0] - 100, centers[i][1] - 50, "%d" % nstars[i], c='r', font={'size': 30})
        else:
            ax.text(centers[i][0] - 100, centers[i][1] - 50, "%d" % nstars[i], c='w', font={'size': 30})
    ax.text(2300, 750, "# of PSF stars", rotation=270, font={'size': 25})
    divider = make_axes_locatable(ax)
    cax = divider.append_axes("right", size="5%", pad=0.05)
    plt.colorbar(im, cax=cax)

    plt.tight_layout()
    
    filename = 'number_PSFstars_{}x{}grid_fov{}_{}_{}_image_{}.pdf'.format(num, num, size, det, filt, nimage)

    plt.savefig(os.path.join(figures_dir, filename))

In [None]:
figures_dir = 'FIGURES/'

if not os.path.exists(figures_dir):
    os.makedirs(figures_dir)

for i, (data, table) in enumerate(zip(data_bkgsub_tot, found_stars_sel_tot)):
    
    nimage = str(i+1)
    size = 11
    
    count_stars_grid(table, data, num_psfs=25, min_numpsf_stars=40, size=size, verbose=True, savefig=True)

### 5.2<font color='white'>-</font>Build effective PSF (single or grid)<a class="anchor" id="epsf_grid"></a> ###

This function creates a grid of PSFs with EPSFBuilder (or a single PSF, when **num_psfs**=1). The function returns a GriddedEPSFModel object containing a 3D array of N  ×  n  ×  n. The 3D array represents the N number of 2D n  ×  n ePSFs created. It includes a grid_xypos key which will state the position of the PSF on the detector for each of the PSFs. The order of the tuples in grid_xypos refers to the number the PSF is in the 3D array.

In [None]:
def build_epsf_grid(table, data, num_psfs=4, size=11, oversample=4, save=True, savefig=True, overwrite=True):

    if np.sqrt(num_psfs).is_integer():
        num_grid = int(np.sqrt(num_psfs))

    else:
        raise ValueError("You must choose a square number of cells to create (E.g. 9, 16, etc.)")


    points, centers = find_centers(num_grid)

    epsf_size = size * oversample
    epsf_arr = np.empty((num_grid ** 2, epsf_size + 1, epsf_size + 1))

    for i, val in enumerate(centers):

        x = table['xcentroid']
        y = table['ycentroid']

        half_size = (size - 1) / 2

        lim1 = int(val[0] - points + half_size)
        lim2 = int(val[0] + points - half_size)
        lim3 = int(val[1] - points + half_size)
        lim4 = int(val[1] + points - half_size)

        mask = ((x > lim1) & (x < lim2) & (y > lim3) & (y < lim4))

        stars_tbl = Table()
        stars_tbl['x'] = x[mask]
        stars_tbl['y'] = y[mask]
        print('Number of sources in cell %d used to build the ePSF:' % (i + 1), len(stars_tbl['x']))

        nddata = NDData(data=data)
        stars = extract_stars(nddata, stars_tbl, size=size)

        print("Creating ePSF for cell %d - Coordinates (%d, %d)" % (i + 1, val[0], val[1]))
        print("")

        epsf_builder = EPSFBuilder(oversampling=oversample, maxiters=3, progress_bar=False)

        epsf, fitted_stars = epsf_builder(stars)

        epsf_arr[i, :, :] = epsf.data

        meta = OrderedDict()
        meta["DETECTOR"] = (det, "Detector name")
        meta["FILTER"] = (filt, "Filter name")
        meta["NUM_PSFS"] = (num_grid ** 2, "The total number of ePSFs")
        for h, loc in enumerate(centers):
            loc = np.asarray(loc, dtype=float)

            meta["DET_YX{}".format(h)] = (str((loc[1], loc[0])),
                                            "The #{} PSF's (y,x) detector pixel position".format(h))

        meta["OVERSAMP"] = (oversample, "Oversampling Factor in EPSFBuilder")

        model_epsf = create_model(epsf_arr, meta)

    if savefig:
        plot_epsf(model_epsf, num_psfs)

    if save:
        writeto(epsf_arr, meta, num_psfs)

        return model_epsf

def writeto(data, meta, num_psfs, overwrite=True):

    primaryhdu = fits.PrimaryHDU(data)
    
    # Convert meta dictionary to header
    tuples = [(a, b, c) for (a, (b, c)) in meta.items()]
    primaryhdu.header.extend(tuples)

    # Add extra descriptors for how the file was made
    primaryhdu.header["COMMENT"] = "For a given filter, and detector 1 file is produced in "
    primaryhdu.header["COMMENT"] = "the form [i, y, x] where i is the ePSF position on the detector grid "
    primaryhdu.header["COMMENT"] = "and (y,x) is the 2D PSF. The order of PSFs can be found under the "
    primaryhdu.header["COMMENT"] = "header DET_YX* keywords"

    hdu = fits.HDUList(primaryhdu)

    filename = "ePSF_{}_{}_fov{}_nepsf{}_image_{}.fits".format(det, filt, size, num_psfs, nimage)
    
    file = os.path.join(epsf_dir, filename)

    hdu.writeto(file, overwrite=overwrite)

def plot_epsf(model, num):

    if num == 1:
        plt.clf()
        plt.figure(figsize=(10, 10))
        ax = plt.subplot(1, 1, 1)

        norm_epsf = simple_norm(model.data[0], 'log', percent=99.)
        plt.suptitle(det + ' - ' + filt, font={'size': 20})
        plt.title(model.meta['grid_xypos'][0], font={'size': 20})
        ax.imshow(model.data[0], norm=norm_epsf)
        plt.tight_layout()

        filename = 'ePSF_single_{}_{}_fov{}_image_{}.pdf'.format(det, filt, size, nimage)
        
        plt.savefig(os.path.join(figures_dir, filename))
 
    else:
        plt.clf()

        nn = int(np.sqrt(num))
        figsize = (12, 12)
        fig, ax = plt.subplots(nn, nn, figsize=figsize)

        for ix in range(nn):
            for iy in range(nn):
                i = ix * nn + iy
                norm_epsf = simple_norm(model.data[i], 'log', percent=99.)
                ax[nn - 1 - iy, ix].imshow(model.data[i], norm=norm_epsf)
                ax[nn - 1 - iy, ix].set_title(model.meta['grid_xypos'][i], font={'size': 20})

        plt.suptitle(det + ' - ' + filt, font={'size': 40})
        plt.tight_layout()
    
        filename = 'ePSF_{}x{}grid_{}_{}_fov{}_image_{}.pdf'.format(nn, nn, det, filt, size, nimage)
        
        plt.savefig(os.path.join(figures_dir, filename))


def create_model(data, meta):

    ndd = NDData(data, meta=meta, copy=True)

    ndd.meta['grid_xypos'] = [((float(ndd.meta[key][0].split(',')[1].split(')')[0])),
                                   (float(ndd.meta[key][0].split(',')[0].split('(')[1]))) for key in ndd.meta.keys() if
                                  "DET_YX" in key]

    ndd.meta['oversampling'] = meta["OVERSAMP"][0]
    ndd.meta = {key.lower(): ndd.meta[key] for key in ndd.meta}
    model = GriddedPSFModel(ndd)

    return model

In [None]:
epsf_grid_tot = []

distorted = True
size = 41
num_psf = 4
oversample = 4

if distorted:
    epsf_dir = 'ePSF_MODELS/Distorted/Grid/Fov{}px_numPSFs{}_oversample{}'.format(size, num_psf, oversample)

    if not os.path.exists(epsf_dir):
        os.makedirs(epsf_dir)
else:
    epsf_dir = 'ePSF_MODELS/Undistorted/Grid/Fov{}px_numPSFs{}_oversample{}'.format(size, num_psf, oversample)
    
    if not os.path.exists(epsf_dir):
        os.makedirs(epsf_dir)

        
for i, (data, table) in enumerate(zip(data_bkgsub_tot, found_stars_sel_tot)):
    
    nimage = str(i+1)
    
    print('----------------')
    print('Working on image {}'.format(str(i+1)))
    print('----------------')

    epsf_grid = build_epsf_grid(table, data, num_psfs=num_psf, size=size, oversample=oversample, save=True, 
                                savefig=True, overwrite=True)
    epsf_grid_tot.append(epsf_grid)
