In [None]:
import psfr
from psfr.psfr import shift_psf
import astropy.io.fits as pyfits
import os
import matplotlib.pyplot as plt
import numpy as np
import copy
import scipy
from scipy.ndimage import interpolation
from lenstronomy.Util import util, kernel_util, image_util
from mpl_toolkits.axes_grid1 import AxesGrid, make_axes_locatable
%matplotlib inline

%load_ext autoreload
%autoreload 2


vmin, vmax = -5, -1  # log10 minimum and maximum scale being plotted in PSF

In [None]:
# We export the five stars from JWST filter F090W that are in the Data folder

package_path = psfr.__path__[0]
path_stars = os.path.join(package_path, 'Data/JWST/')
star_name = 'psf_f090w_star'

star_list_web = []
for i in range(5):
    path = os.path.join(path_stars, star_name+str(i)+'.fits')
    hdulist_star = pyfits.open(path)

    star = hdulist_star[0].data
    star_list_web.append(star)



In [None]:
# make plot of stars
f, axes = plt.subplots(1, len(star_list_web), figsize=(4*len(star_list_web), 4), sharex=False, sharey=False)
for i, star in enumerate(star_list_web):
    ax = axes[i]
    im = ax.imshow(np.log10(star), origin='lower')
    ax.autoscale(False)
    divider = make_axes_locatable(ax)
    cax = divider.append_axes("right", size="5%", pad=0.05)
    plt.colorbar(im, cax=cax)
plt.show()                    


## Using photutils to reconstruct PSF
This code is for comparison purpose with a well-established library.


In [None]:

from photutils.psf import EPSFStar
from photutils.psf import EPSFStars
from photutils.psf import EPSFBuilder

star_list_epsf = []
for star_ in star_list_web:
    x_grid, y_grid = util.make_grid(numPix=len(star_), deltapix=1, left_lower=True)
    x_grid, y_grid = util.array2image(x_grid), util.array2image(y_grid) 
    x_c, y_c = np.sum(star_ * x_grid)/np.sum(star_), np.sum(star_ * y_grid)/np.sum(star_)
    c_ = (len(star_) - 1) / 2
    x_s, y_s = x_c, y_c
    x_s, y_s = 2*c_ - y_c, 2*c_ - x_c
    star_list_epsf.append(EPSFStar(star_, cutout_center=[x_s, y_s]))

stars_epsf = EPSFStars(star_list_epsf)

oversampling = 4
epsf_builder_super = EPSFBuilder(oversampling=oversampling, maxiters=1, progress_bar=False)  
epsf_super, fitted_stars = epsf_builder_super(stars_epsf)

epsf_builder = EPSFBuilder(oversampling=1, maxiters=5, progress_bar=False)  
epsf, fitted_stars = epsf_builder(stars_epsf)


In [None]:

epsf_degraded = kernel_util.degrade_kernel(epsf_super.data, oversampling)
epsf_degraded = kernel_util.cut_psf(epsf_degraded, len(star_list_web[0]))
epsf_regular = kernel_util.cut_psf(epsf.data, len(star_list_web[0]))

f, axes = plt.subplots(1, 4, figsize=(4*4, 4), sharex=False, sharey=False)

ax = axes[0]
im = ax.imshow(np.log10(epsf_super.data), origin='lower', vmin=vmin, vmax=vmax)
ax.autoscale(False)
divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="5%", pad=0.05)
plt.colorbar(im, cax=cax)
ax.set_title('photutils oversampled')

ax = axes[1]
im = ax.imshow(np.log10(epsf_degraded), origin='lower', vmin=vmin, vmax=vmax)
ax.autoscale(False)
divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="5%", pad=0.05)
plt.colorbar(im, cax=cax)
ax.set_title('photutils oversampled degraded')

ax = axes[2]
im = ax.imshow(np.log10(epsf_regular), origin='lower', vmin=vmin, vmax=vmax)
ax.autoscale(False)
divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="5%", pad=0.05)
plt.colorbar(im, cax=cax)
ax.set_title('photutils no oversampling')

ax = axes[3]
im = ax.imshow(epsf_degraded - epsf_regular, origin='lower')
ax.autoscale(False)
divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="5%", pad=0.05)
plt.colorbar(im, cax=cax)
ax.set_title('photutils degraded - no oversampling')

plt.show()


In [None]:


def plot_individual_fit_photutils(epsf, fitted_stars, oversampling):


    f, axes = plt.subplots(1, len(fitted_stars), figsize=(4*len(fitted_stars), 4), sharex=False, sharey=False)
    for i, star in enumerate(fitted_stars):
        
        # retriev shift
        n_c = (len(star.data)-1) /2
        shift = star.cutout_center - np.array([n_c, n_c])
        psf_degraded = shift_psf(epsf.data, star.data, oversampling, shift, degrade=True)
        ax = axes[i]
        im = ax.imshow(psf_degraded - star.data/np.sum(star.data), origin='lower', vmin=-0.001, vmax=0.001)
        ax.autoscale(False)
        divider = make_axes_locatable(ax)
        cax = divider.append_axes("right", size="5%", pad=0.05)
        plt.colorbar(im, cax=cax)
    plt.show()


plot_individual_fit_photutils(epsf_super, fitted_stars, oversampling)
    

## now the PSF-r code for stacking

In [None]:

from psfr.psfr import stack_psf

psf_psfr, center_list_psfr, mask_list = stack_psf(star_list_web, oversampling=1, 
                                                  saturation_limit=None, num_iteration=50, 
                                                  n_recenter=4)

psf_psfr_super, center_list_psfr_super, mask_list = stack_psf(star_list_web, oversampling=oversampling, 
                                                  saturation_limit=None, num_iteration=50, 
                                                  n_recenter=4)




In [None]:
# comparison between degraded and supersampled



psf_psfr_super_degraded = kernel_util.degrade_kernel(psf_psfr_super, oversampling)
psf_psfr_super_degraded = kernel_util.cut_psf(psf_psfr_super_degraded, len(psf_psfr))

f, axes = plt.subplots(1, 4, figsize=(4*4, 4), sharex=False, sharey=False)

ax = axes[0]
im = ax.imshow(np.log10(psf_psfr_super * oversampling**2), origin='lower', vmin=vmin, vmax=vmax)
ax.autoscale(False)
divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="5%", pad=0.05)
plt.colorbar(im, cax=cax)
ax.set_title('PSF-r oversampled')

ax = axes[1]
im = ax.imshow(np.log10(psf_psfr_super_degraded), origin='lower', vmin=vmin, vmax=vmax)
ax.autoscale(False)
divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="5%", pad=0.05)
plt.colorbar(im, cax=cax)
ax.set_title('PSF-r oversampled degraded')

ax = axes[2]
im = ax.imshow(np.log10(psf_psfr), origin='lower', vmin=vmin, vmax=vmax)
ax.autoscale(False)
divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="5%", pad=0.05)
plt.colorbar(im, cax=cax)
ax.set_title('PSF-r no oversampling')

ax = axes[3]
im = ax.imshow(psf_psfr_super_degraded - psf_psfr, origin='lower')
ax.autoscale(False)
divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="5%", pad=0.05)
plt.colorbar(im, cax=cax)
ax.set_title('PSF-r degraded - no oversampling')

plt.show()

In [None]:
# fitted stars with supersampled


def plot_individual_fit_psfr(psfr, fitted_stars, center_list, oversampling):


    f, axes = plt.subplots(1, len(fitted_stars), figsize=(4*len(fitted_stars), 4), sharex=False, sharey=False)
    for i, star in enumerate(fitted_stars):
        
        center = center_list[i]
        psf_degraded = shift_psf(psfr, star, oversampling, center, degrade=True)
        ax = axes[i]
        im = ax.imshow(psf_degraded - star/np.sum(star), origin='lower', vmin=-0.001, vmax=0.001)
        ax.autoscale(False)
        divider = make_axes_locatable(ax)
        cax = divider.append_axes("right", size="5%", pad=0.05)
        plt.colorbar(im, cax=cax)
    plt.show()


plot_individual_fit_psfr(psf_psfr_super, star_list_web, center_list_psfr, oversampling)



In [None]:
# regular method comparison between PSF-r and photutils

f, axes = plt.subplots(1, 3, figsize=(4*3, 4), sharex=False, sharey=False)

psf_photutil = kernel_util.cut_psf(epsf.data, psf_size=len(star_list_web[0]))


ax = axes[0]
im = ax.imshow(np.log10(psf_psfr), origin='lower', vmin=vmin, vmax=vmax)
ax.autoscale(False)
divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="5%", pad=0.05)
plt.colorbar(im, cax=cax)
ax.set_title('PSF-r method')

ax = axes[1]
im = ax.imshow(np.log10(psf_photutil), origin='lower', vmin=vmin, vmax=vmax)
ax.autoscale(False)
divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="5%", pad=0.05)
plt.colorbar(im, cax=cax)
ax.set_title('photutils method')

ax = axes[2]
im = ax.imshow(psf_psfr - psf_photutil, origin='lower')  # epfs_web
ax.autoscale(False)
divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="5%", pad=0.05)
plt.colorbar(im, cax=cax)
ax.set_title('PSF-r - photutils')

plt.show()


# supersampled comparison between PSF-r and photutils

f, axes = plt.subplots(1, 3, figsize=(4*3, 4), sharex=False, sharey=False)

psf_photutils_super = kernel_util.cut_psf(epsf_super.data, psf_size=len(psf_psfr_super))


vmin, vmax = -5, -1
ax = axes[0]
im = ax.imshow(np.log10(psf_psfr_super * oversampling**2), origin='lower', vmin=vmin, vmax=vmax)
ax.autoscale(False)
divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="5%", pad=0.05)
plt.colorbar(im, cax=cax)
ax.set_title('PSF-r method')

ax = axes[1]
im = ax.imshow(np.log10(psf_photutils_super * oversampling**2), origin='lower', vmin=vmin, vmax=vmax)
ax.autoscale(False)
divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="5%", pad=0.05)
plt.colorbar(im, cax=cax)
ax.set_title('photutils method')

ax = axes[2]
im = ax.imshow(psf_psfr_super * oversampling**2 - psf_photutils_super * oversampling**2, origin='lower')  # epfs_web
ax.autoscale(False)
divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="5%", pad=0.05)
plt.colorbar(im, cax=cax)
ax.set_title('PSF-r - photutils')

plt.show()


## Testing
This is incomplete and needs to be improved during packaging

In [None]:
def test_one_step_psf_estimation():
    
    from lenstronomy.LightModel.light_model import LightModel
    numpix = 21
    n_c = (numpix - 1) / 2
    x_grid, y_grid = util.make_grid(numPix=21, deltapix=1, left_lower=True)
    gauss = LightModel(['GAUSSIAN'])
    x_c, y_c = -0.6, 0.2
    sigma = 1
    kwargs_true = [{'amp': 1, 'sigma': sigma, 'center_x': n_c, 'center_y': n_c}]
    flux_true = gauss.surface_brightness(x_grid, y_grid, kwargs_true)
    psf_true = util.array2image(flux_true)
    psf_true /= np.sum(psf_true)
    
    kwargs_guess = [{'amp': 1, 'sigma': 1.5, 'center_x': n_c, 'center_y': n_c}]
    flux_guess = gauss.surface_brightness(x_grid, y_grid, kwargs_guess)
    psf_guess = util.array2image(flux_guess)
    psf_guess /= np.sum(psf_guess)
    
    center_list = []
    star_list = []
    displacement_scale = 1
    
    x_c, y_c = 0., 4.5
    center_list.append(np.array([x_c, y_c]))
    kwargs_model = [{'amp': 1, 'sigma': sigma, 'center_x': n_c + x_c, 'center_y': n_c + y_c}]
    flux_model = gauss.surface_brightness(x_grid, y_grid, kwargs_model)
    star = util.array2image(flux_model)
    star_list.append(star)
    
    
    
    oversampling = 2
    numpix_super = numpix * oversampling
    if oversampling % 2 == 0:
        numpix_super -= 1
    
    x_grid_super, y_grid_super = util.make_grid(numPix=numpix_super, deltapix=1. / oversampling, left_lower=True)
    flux_guess_super = gauss.surface_brightness(x_grid_super, y_grid_super, kwargs_guess)
    psf_guess_super = util.array2image(flux_guess_super)
    psf_guess_super /= np.sum(psf_guess_super)
    
    flux_true_super = gauss.surface_brightness(x_grid_super, y_grid_super, kwargs_true)
    psf_true_super = util.array2image(flux_true_super)
    psf_true_super /= np.sum(psf_true_super)

    psf_after_super = one_step_psf_estimate(star_list, psf_guess_super, center_list, mask_list=None, 
                                      error_map_list=None, step_factor=0.2, oversampling=oversampling, verbose=True)
    
    # psf_after = one_step_psf_estimate(star_list, psf_guess, center_list, mask_list=None, error_map_list=None, step_factor=0.2, verbose=True)
    # psf_after should be a better guess of psf_true than psf_guess
    diff_after = np.sum((psf_after_super - psf_true_super)**2)
    diff_before = np.sum((psf_guess_super - psf_true_super)**2)
    
    plt.imshow(psf_true_super - psf_after_super)
    plt.colorbar()
    plt.title('after')
    plt.show()
    plt.imshow(psf_true_super - psf_guess_super)
    plt.colorbar()
    plt.title('before')
    plt.show()
    
    plt.imshow(psf_after_super - psf_guess_super)
    plt.colorbar()
    plt.title('after - before')
    plt.show()
    
    assert diff_after < diff_before
    
    
    
test_one_step_psf_estimation()

In [None]:
import numpy.testing as npt

def test_linear_amplitude():
    amp = 2
    data = np.ones((5, 5)) * amp
    model = np.ones((5, 5))

    amp_return = _linear_amplitude(data, model)
    npt.assert_almost_equal(amp_return, amp)
    
    mask = np.ones_like(data)
    
    amp_return = _linear_amplitude(data, model, mask=mask)
    npt.assert_almost_equal(amp_return, amp)
    
test_linear_amplitude()


def test_fit_centroid():


    from lenstronomy.LightModel.light_model import LightModel
    numpix = 21
    n_c = (numpix - 1) / 2
    x_grid, y_grid = util.make_grid(numPix=21, deltapix=1, left_lower=True)
    gauss = LightModel(['GAUSSIAN'])
    x_c, y_c = -0.6, 0.2
    kwargs_true = [{'amp': 2, 'sigma': 1, 'center_x': n_c + x_c, 'center_y': n_c + y_c}]
    kwargs_model = [{'amp': 1, 'sigma': 1, 'center_x': n_c, 'center_y': n_c}]
    flux_true = gauss.surface_brightness(x_grid, y_grid, kwargs_true)
    flux_true = util.array2image(flux_true)

    flux_model = gauss.surface_brightness(x_grid, y_grid, kwargs_model)
    flux_model = util.array2image(flux_model)

    mask = np.ones_like(flux_true)

    center = _fit_centroid(flux_true, flux_model, mask=mask, variance=None)
    npt.assert_almost_equal(center[0], x_c, decimal=1)
    npt.assert_almost_equal(center[1], y_c, decimal=1)
    
test_fit_centroid()


def test_one_step_psf_estimation():
    
    from lenstronomy.LightModel.light_model import LightModel
    numpix = 21
    n_c = (numpix - 1) / 2
    x_grid, y_grid = util.make_grid(numPix=21, deltapix=1, left_lower=True)
    gauss = LightModel(['GAUSSIAN'])
    x_c, y_c = -0.6, 0.2
    sigma = 1
    kwargs_true = [{'amp': 1, 'sigma': sigma, 'center_x': n_c, 'center_y': n_c}]
    flux_true = gauss.surface_brightness(x_grid, y_grid, kwargs_true)
    psf_true = util.array2image(flux_true)
    psf_true /= np.sum(psf_true)
    
    kwargs_guess = [{'amp': 1, 'sigma': 1.2, 'center_x': n_c, 'center_y': n_c}]
    flux_guess = gauss.surface_brightness(x_grid, y_grid, kwargs_guess)
    psf_guess = util.array2image(flux_guess)
    psf_guess /= np.sum(psf_guess)
    
    center_list = []
    star_list = []
    displacement_scale = 1
    for i in range(4):
        x_c, y_c = np.random.uniform(-0.5, 0.5) * displacement_scale, np.random.uniform(-0.5, 0.5) * displacement_scale
        center_list.append(np.array([x_c, y_c]))
        kwargs_model = [{'amp': 1, 'sigma': sigma, 'center_x': n_c + x_c, 'center_y': n_c + y_c}]
        flux_model = gauss.surface_brightness(x_grid, y_grid, kwargs_model)
        star = util.array2image(flux_model)
        star_list.append(star)
    
    
    psf_after = one_step_psf_estimate(star_list, psf_guess, center_list, mask_list=None, error_map_list=None, step_factor=0.2)
    # psf_after should be a better guess of psf_true than psf_guess
    diff_after = np.sum((psf_after - psf_true)**2)
    diff_before = np.sum((psf_guess - psf_true)**2)
    
    plt.imshow(psf_true - psf_after)
    plt.colorbar()
    plt.title('after')
    plt.show()
    plt.imshow(psf_true - psf_guess)
    plt.colorbar()
    plt.title('before')
    plt.show()
    
    plt.imshow(psf_after - psf_guess)
    plt.colorbar()
    plt.title('after - before')
    plt.show()
    
    print(np.sum(psf_after), np.sum(psf_true))
    
    assert diff_after < diff_before
    

    
    oversampling = 2
    numpix_super = numpix * oversampling
    if oversampling % 2 == 0:
        numpix_super -= 1
    
    x_grid_super, y_grid_super = util.make_grid(numPix=numpix_super, deltapix=1. / oversampling, left_lower=True)
    flux_guess_super = gauss.surface_brightness(x_grid_super, y_grid_super, kwargs_guess)
    psf_guess_super = util.array2image(flux_guess_super)
    psf_guess_super /= np.sum(psf_guess_super)
    
    flux_true_super = gauss.surface_brightness(x_grid_super, y_grid_super, kwargs_true)
    psf_true_super = util.array2image(flux_true_super)
    psf_true_super /= np.sum(psf_true_super)

    psf_after_super = one_step_psf_estimate(star_list, psf_guess_super, center_list, mask_list=None, 
                                      error_map_list=None, step_factor=0.2, oversampling=oversampling)
    diff_after = np.sum((psf_after_super - psf_true_super)**2)
    diff_before = np.sum((psf_guess_super - psf_true_super)**2)
    assert diff_after < diff_before
    
    
test_one_step_psf_estimation()


In [None]:
def test_stack_psf():
    
    from lenstronomy.LightModel.light_model import LightModel
    numpix = 21
    n_c = (numpix - 1) / 2
    oversampling = 4
    x_grid, y_grid = util.make_grid(numPix=21, deltapix=1, left_lower=True)
    x_grid_super, y_grid_super = util.make_grid(numPix=21, deltapix=1, left_lower=True)
    gauss = LightModel(['GAUSSIAN'])
    sigma = 1
    kwargs_true = [{'amp': 1, 'sigma': sigma, 'center_x': n_c, 'center_y': n_c}]
    flux_true = gauss.surface_brightness(x_grid, y_grid, kwargs_true)
    psf_true = util.array2image(flux_true)
    psf_true /= np.sum(psf_true)
    
    kwargs_guess = [{'amp': 1, 'sigma': 1.5, 'center_x': n_c, 'center_y': n_c}]
    flux_guess = gauss.surface_brightness(x_grid, y_grid, kwargs_guess)
    psf_guess = util.array2image(flux_guess)
    psf_guess /= np.sum(psf_guess)
    
    center_list = []
    star_list = []
    scatter_scale = 1
    for i in range(10):
        x_c, y_c = np.random.uniform(-0.5, 0.5)*scatter_scale, np.random.uniform(-0.5, 0.5)*scatter_scale
        center_list.append(np.array([x_c, y_c]))
        amp = np.random.uniform([0.1, 10])
        kwargs_model = [{'amp': 1, 'sigma': sigma, 'center_x': n_c + x_c, 'center_y': n_c + y_c}]
        flux_model = gauss.surface_brightness(x_grid, y_grid, kwargs_model)
        star = util.array2image(flux_model)
        star_list.append(star)
    

    psf_after, center_list_after, mask_list = stack_psf(star_list, oversampling=1, saturation_limit=None, num_iteration=100, n_recenter=4)
    
    plt.imshow(psf_true - psf_after)
    plt.title('psf_true - psf_after')
    plt.colorbar()
    plt.show()
    plt.imshow(psf_true - psf_guess)
    plt.title('psf_true - psf_guess')
    plt.colorbar()
    plt.show()
    assert np.sum((psf_after - psf_true)**2) < np.sum((psf_guess - psf_true)**2)
    assert len(psf_after) == len(star_list[0])
    
    
    oversampling = 4
    numpix_super = numpix * oversampling
    if oversampling % 2 == 0:
        numpix_super -= 1
    
    x_grid_super, y_grid_super = util.make_grid(numPix=numpix_super, deltapix=1. / oversampling, left_lower=True)
    flux_guess_super = gauss.surface_brightness(x_grid_super, y_grid_super, kwargs_guess)
    psf_guess_super = util.array2image(flux_guess_super)
    psf_guess_super /= np.sum(psf_guess_super)
    
    flux_true_super = gauss.surface_brightness(x_grid_super, y_grid_super, kwargs_true)
    psf_true_super = util.array2image(flux_true_super)
    psf_true_super /= np.sum(psf_true_super)

    psf_after_super, center_list_after, mask_list = stack_psf(star_list, oversampling=oversampling, 
                                                              saturation_limit=None, num_iteration=100, 
                                                              n_recenter=200, verbose=False, kwargs_one_step={'verbose': False})

    diff_after = np.sum((psf_after_super - psf_true_super)**2)
    diff_before = np.sum((psf_guess_super - psf_true_super)**2)
    
    
    plt.imshow(psf_true_super - psf_after_super)
    plt.title('psf_true - psf_after')
    plt.colorbar()
    plt.show()
    plt.imshow(psf_true_super - psf_guess_super)
    plt.title('psf_true - psf_before')
    plt.colorbar()
    plt.show()
    plt.imshow(psf_after_super)
    plt.colorbar()
    plt.title('psf_after')
    plt.show()
    assert diff_after < diff_before
    
test_stack_psf()