In [None]:
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm, TwoSlopeNorm
import numpy as np
import os

from scipy.interpolate import RectBivariateSpline

import hcipy
import astropy.io.fits as fits

from pastis.simulators.scda_telescopes import HexRingAPLC
import pastis.util as util
from pastis.config import CONFIG_PASTIS

In [None]:
def resize_img(img, resize_factor, kx=3, ky=3):
        """Resizes an image to fit the target dimensions (typically the shape of another image) 
        using rectangular bivariate spline interpolation.
        
        Adapted from 
        https://github.com/spacetelescope/hicat-package2/blob/develop
        /hicat2/experiments/calibration/measure_ncpa.py
        

        :param img: (2d array) input image to be resized
        :param resize_factor: (float) scaling factor by which to resize image
        :param kx: (int) degree of the bivariate spline in x
        :param ky: (int) degree of the bivariate spline in y
        :output: (2d array) resized image
        """
        spline_interp = RectBivariateSpline(np.arange(img.shape[1]) * resize_factor, np.arange(img.shape[0]) * resize_factor, img, kx=kx, ky=ky)
        resized_img = spline_interp(np.arange(int(img.shape[1] * resize_factor)), np.arange(int(img.shape[0] * resize_factor)))

        return resized_img

In [None]:
ball_pupil = fits.getdata("/Users/asahoo/Downloads/Sample_Pupil.fits")
print("Shape of ball_pupil:", ball_pupil.shape)

plt.imshow(ball_pupil)
plt.colorbar()

In [None]:
optics_dir = os.path.join(util.find_repo_location(), 'data', 'SCDA')
NUM_RINGS = 2
sampling = 4

tel = HexRingAPLC(optics_dir, NUM_RINGS, sampling)

unaberrated_psf, ref, intermediates = tel.calc_psf(ref=True, display_intermediate=False, 
                                  return_intermediate='intensity',
                                  norm_one_photon=True)

norm = np.max(ref)
normalized_unaberrated_psf = unaberrated_psf / norm

unaberr_roi = normalized_unaberrated_psf * tel.dh_mask
contrast_floor = np.mean(unaberr_roi[np.where(tel.dh_mask != 0)])
print("contrast_floor:", contrast_floor)

In [None]:
plt.figure(figsize=(10, 4))
plt.subplot(1, 2, 1)
hcipy.imshow_field(np.log10((unaberrated_psf)))
plt.colorbar()

plt.subplot(1, 2, 2)
hcipy.imshow_field(np.log10((ref)))
plt.colorbar()

In [None]:
hcipy.imshow_field(tel.aperture)
plt.colorbar()

In [None]:
size = int(np.sqrt(len(tel.aperture)))
hex2_aperture = np.reshape(tel.aperture, (size, size))

In [None]:
# Zero padding all nan 
ball_pupil = np.nan_to_num(ball_pupil)
plt.imshow(ball_pupil)
plt.colorbar()

In [None]:
#Enlarging the pupil provided by Ball to 1024*1024

scale = size/(ball_pupil.shape[0])
ball_pupil_resized = resize_img(ball_pupil, scale, kx=3, ky=3)

print("New size of the pupil:",ball_pupil_resized.shape)

In [None]:
# plot to check how good they are overlapping with each other

plt.imshow(1e6*hex2_aperture -  ball_pupil_resized) 
plt.colorbar()

In [None]:
pupil_field = tel.wf_aper.electric_field

In [None]:
plt.figure()
plt.subplot(1,2,1)
plt.title("Amplitude")
hcipy.imshow_field(pupil_field.real)

plt.subplot(1,2,2)
plt.title("Phase")
hcipy.imshow_field(pupil_field.imag)

In [None]:
#Initialize the pupil phase with the OPD from Ball/

ball_pupil_1d = np.reshape(ball_pupil_resized, size**2)
ball_pupil_field = hcipy.field.Field(ball_pupil_1d, pupil_field.imag.grid)

#pupil_field.imag = hcipy.field.Field(ball_pupil_resized, pupil_field.imag.grid)

In [None]:
plt.figure(figsize=(10,5))

plt.subplot(1,2,1)
hcipy.imshow_field(ball_pupil_field)
plt.colorbar()

plt.subplot(1,2,2)
plt.imshow(ball_pupil_resized, origin='lower')
plt.colorbar()

In [None]:
#check whether the elements are not flipped. 

ball_pupil_field_2d = np.reshape(ball_pupil_field, (size, size))
print(ball_pupil_field_2d[13, 13], ball_pupil_resized[13, 13])

In [None]:
for i in range(0, len(ball_pupil_field)):
    if np.abs(ball_pupil_field[i]) <= 1e-1:
        ball_pupil_field[i] = 0.0
        

plt.imshow(ball_pupil_resized)
plt.colorbar()
fits.writeto('/Users/asahoo/Downloads/Sample_Pupil_resized2.fits', ball_pupil_resized)

In [None]:
print()

In [None]:
len(ball_pupil_field)

In [None]:
magnitude = 1e-3
tel.wf_aper.electric_field.imag = magnitude * ball_pupil_field

In [None]:
hcipy.imshow_field(tel.wf_aper.electric_field.imag)
plt.colorbar()

In [None]:
psf_aber, intermediates_aber = tel.calc_psf(display_intermediate=False, 
                                  return_intermediate='intensity',
                                  norm_one_photon=True)

In [None]:
fpm_mask = np.zeros(len(intermediates_aber['after_fpm']))

for i in range(0, len(intermediates_aber['after_fpm'])):
    if intermediates['after_fpm'][i] == 0.0:
        fpm_mask[i] = 0
    else:
        fpm_mask[i] = 1
        
plt.figure(figsize = (10, 5))

# Entrance Pupil
plt.subplot(2, 3, 1)
plt.title("Phase")
hcipy.imshow_field(tel.wf_aper.electric_field.imag, mask = tel.aperture, cmap='inferno')
plt.tick_params(bottom=False, left=False, labelleft=False, labelbottom=False)
plt.colorbar()


# before FPM
plt.subplot(2, 3, 2)
plt.title("before fpm")
hcipy.imshow_field(intermediates_aber['before_fpm'], 
                   norm=LogNorm(vmin=1e-8, vmax=1e-1), cmap='inferno')
plt.tick_params(bottom=False, left=False, labelleft=False, labelbottom=False)
plt.colorbar()

# after FPM, note: calc_psf returns int_after_fpm in log scale.
plt.subplot(2, 3, 3)
plt.title("after fpm")
hcipy.imshow_field(10**(intermediates_aber['after_fpm']), 
                   norm=LogNorm(vmin=1e-8, vmax=1e-1),mask = fpm_mask,cmap= 'inferno')
plt.tick_params(bottom=False, left=False, labelleft=False, labelbottom=False)
plt.colorbar()


# before Lyot Stop
plt.subplot(2, 3, 4)
plt.title("before lyot")
hcipy.imshow_field(intermediates_aber['before_lyot'], norm=LogNorm(vmin=1e-3, vmax=1), 
                   cmap='inferno')
plt.tick_params(bottom=False, left=False, labelleft=False, labelbottom=False)
plt.colorbar()

# after Lyot Stop
plt.subplot(2, 3, 5)
plt.title("after lyot")
hcipy.imshow_field(intermediates_aber['after_lyot'], mask=tel.lyotstop, norm=LogNorm(vmin=1e-3, vmax=1),
                   cmap='inferno')
plt.tick_params(bottom=False, left=False, labelleft=False, labelbottom=False)
plt.colorbar()

# final PSF
plt.subplot(2, 3, 6)
plt.title("aberrated PSF")
hcipy.imshow_field(psf_aber, norm=LogNorm(vmin=1e-14, vmax=1e-3), cmap='inferno')
plt.tick_params(bottom=False, left=False, labelleft=False, labelbottom=False)
plt.colorbar()

plt.tight_layout()
#plt.savefig(os.path.join(resDir, f'optical_train_{NUM_RINGS}_rings.png'))

In [None]:
normalized_aberrated_psf = psf_aber / norm

roi_aber = normalized_aberrated_psf * tel.dh_mask
contrast_aber = np.mean(roi_aber[np.where(tel.dh_mask != 0)])
print(contrast_aber)