In [None]:
import numpy as np
import os
import matplotlib.pyplot as plt

from astropy.io import fits
from matplotlib.colors import LogNorm, TwoSlopeNorm
from matplotlib.ticker import LogFormatter 
import numpy as np
import scipy.io
import hcipy

import pastis.util as util
from pastis.config import CONFIG_PASTIS
from pastis.simulators.scda_telescopes import HexRingAPLC

os.chdir('../ULTRA')
from config import CONFIG_ULTRA

In [None]:
data_path = CONFIG_ULTRA.get('local_path', 'local_data_path')
analysis_path = CONFIG_ULTRA.get('local_path', 'local_analysis_path')

open_loop = np.genfromtxt(os.path.join(data_path,'INPUT_MISALIGNMENTS.txt'),  delimiter=';')
close_loop = np.genfromtxt(os.path.join(data_path,'RESIDUAL_MISALIGNMENTS.txt'),  delimiter=';')

In [None]:
optics_dir = os.path.join(util.find_repo_location(), 'data', 'SCDA')
NUM_RINGS = 2
sampling = 4

coron_psfs = []
direct_psfs = []
contrast_floors = []

wvls = np.arange(475, 530, 25)

for wvl in wvls:
    tel2 = HexRingAPLC(optics_dir, NUM_RINGS, sampling)

    unaberrated_psf, ref, intermediates = tel2.calc_psf(ref=True, display_intermediate=True, 
                                      return_intermediate='intensity',
                                      norm_one_photon=True, wv=wvl*1e-9)

    norm = np.max(ref)
    normalized_unaberrated_psf = unaberrated_psf / norm
    
   
    unaberr_roi = normalized_unaberrated_psf * tel2.dh_mask
    contrast_floor = np.mean(unaberr_roi[np.where(tel2.dh_mask != 0)])
    
    coron_psfs.append(normalized_unaberrated_psf)
    direct_psfs.append(ref)
    contrast_floors.append(contrast_floor)
    
    print( "wavelength:", wvl, "contrast_floor:", contrast_floor, "central wavelength:", tel2.wvln)

In [None]:
print("Mean broadband contrast:", np.mean(np.array(contrast_floors)))

### Plot broadband/ Monochromatic PSF

In [None]:
broadband_psf = 0 * coron_psfs[0]
for i in range(0, len(coron_psfs)):
    broadband_psf = broadband_psf + coron_psfs[i]

x_pixels = [0, 20 ,40, 60, 80, 100]
x_ld = [r"-14.25$\lambda/D$",r"-9.25$\lambda/D$", r"-4.25$\lambda/D$", r"0.75$\lambda/D$",
        r"5.75$\lambda/D$", r"10.75$\lambda/D$"]

plt.figure()
plt.title(f"Broadband PSF")
hcipy.imshow_field((broadband_psf / (len(coron_psfs))),norm=LogNorm(vmin=1e-14, vmax=1e-1), cmap='inferno')
plt.colorbar()
# plt.xticks(x_pixels, x_ld, fontsize = 8)
# plt.yticks(x_pixels, x_ld, fontsize = 8)
# plt.xlabel(r"Angular Separation $(\lambda_{0}/D)$")
# plt.ylabel(r"Angular Separation $(\lambda_{0}/D)$")
plt.tick_params(bottom=False, left=False, labelleft=False, labelbottom=False)
plt.tight_layout()

#plt.savefig(os.path.join(analysis_path, f'broadband_psf.png'))

### Plot mean DH contrast_floor vs wavlength

In [None]:
plt.figure()
plt.title("Monochromatic Contrasts floors vs wavelength")
plt.plot(wvls, contrast_floors)
plt.ylabel("Mean DH Contrast")
plt.xlabel('Wavelength (in nm)')
plt.yscale('log')

#plt.savefig(os.path.join(analysis_path, 'contrast_floors_wavelength.png'))

### Create segmented zernike mirror

In [None]:
n_zernikes = 11
tel2.create_segmented_mirror(n_zernikes)

#### Set the type of data you want to analyze, i.e. whether open loop or close loop. 

In [None]:
time_series = open_loop

#### Get pre-saved sensitvity OPDs corresponding to 1 micron dx, dy, dz, rdx, rdy, rdz 

In [None]:
tel2_actuators = 0.5 * fits.getdata(os.path.join(analysis_path, 'optical_sensitvity.fits'))

#### Plot OPDs corresponding to PTT via SCDA simulator.
The following plot including colorbars, and RMS should closely match to OPDs pictures from Garrett for the case of PTT, If not matching, there might be some error in sorting the segment number or Zernike coefficients. Simulator uses Noll Convention. 

In [None]:
titles = np.array(["X nm RMS/um", "Y nm RMS/um", "Z nm RMS/um", 
                  "Rx nm RMS/urad", "Ry nm RMS/urad", "Rz nm RMS/urad" ])
        
plt.figure(figsize = (14, 9))
for dof in range(0, 6):
    if dof!=2:
        tel2.sm.flatten()
        tel2.sm.actuators = tel2_actuators[dof]
        rms_scda = np.sqrt(np.mean((tel2.sm.opd[np.where(tel2.aperture!= 0)])**2)) * 1e9
        plt.subplot(2, 3, dof+1)
        plt.title(titles[dof], fontweight = 'bold')
        hcipy.imshow_field(tel2.sm.opd*1e9, cmap='jet')
        plt.xlabel(f'RMS: {rms_scda:.2f} nm')
        cbar = plt.colorbar()
        cbar.set_label("in nm", loc='center')

plt.subplot(2, 3, 3)
tel2.sm.flatten()
tel2.sm.actuators = tel2_actuators[2]
rms_scda = np.sqrt(np.mean((tel2.sm.opd[np.where(tel2.aperture!= 0)])**2)) * 1e9
plt.title("Z nm RMS/um", fontweight = 'bold')
hcipy.imshow_field(tel2.sm.opd*1e9, cmap='jet', vmin = -2000,  vmax= -1960)
plt.xlabel(f"RMS: {rms_scda:.2f} nm")
cbar = plt.colorbar()
cbar.set_label("in nm", loc='center')

### Map HWO segment number to the SCDA segment number:

In [None]:
hwo_to_scda = {"1": 4, "2": 5, "3": 6, "4":7, "5":2, "6":3, "7":12, "8":13, "9":14, "10":15,
               "11": 16, "12":17, "13":18, "14":19, "15":8, "16":9, "17":10, "18":11, "19":1}

hwo_scda_segments = np.array([4, 5, 6, 7, 2, 3, 12, 13, 14, 15, 16, 17, 18, 19, 8, 9, 10, 11, 1])

#### Load coefficients from .mat file

In [None]:
hwo_sensitivities = scipy.io.loadmat(os.path.join(data_path, 'HWO_sens_old.mat'))
hwo_table = hwo_sensitivities['HWO_sens']
first_element = hwo_table[0,0]
data_list = first_element.tolist()  # tuple of length 4

hwo_hexike_coeffs = data_list[0]  # ndarray of shape (11, 19, 6)
mask = data_list[1]    # ndarray of shape (256, 256, 19, 6)
dopd = data_list[2]    # ndarray of shape (256, 256, 19, 6)
units = data_list[3]   # ndarray of shape (1,) where the only element is a string

### Propagate time-series through 475 nm

In [None]:
wvl = 525

tel2.sm.flatten()
dx_times = []
for time in range(0, 401):
    dx = time_series[:, time*6] * 1e12 # in units of pm. 
    dx = np.delete(dx, -1)
    dx_times.append(dx)

dy_times = []
for time in range(0, 401):
    dy = time_series[:, time*6 + 1] * 1e12 # in units of pm. 
    dy = np.delete(dy, -1)
    dy_times.append(dy)
    
dz_times = []
for time in range(0, 401):
    dz = time_series[:, time*6 + 2] * 1e12 # in units of pm. 
    dz = np.delete(dz, -1)
    dz_times.append(dz) 
    
rdx_times = []
for time in range(0, 401):
    rdx = time_series[:, time*6 + 3] * 1e12 # in units of pm. 
    rdx = np.delete(rdx, -1)
    rdx_times.append(rdx)
    
    
rdy_times = []
for time in range(0, 401):
    rdy = time_series[:, time*6 + 4] * 1e12 # in units of pm. 
    rdy = np.delete(rdy, -1)
    rdy_times.append(rdy)
    
rdz_times = []
for time in range(0, 401):
    rdz = time_series[:, time*6 + 5] * 1e12 # in units of pm. 
    rdz = np.delete(rdz, -1)
    rdz_times.append(rdz)
    

contrasts_dof = []
for time in range(0, 401):
    tel2.sm.flatten()
    for hwo_seg in range(0, tel2.nseg):
        scda_seg = hwo_to_scda[str(hwo_seg + 1)] - 1
        for hexike in range(0, n_zernikes):
            dox = tel2_actuators[0][hexike + scda_seg * n_zernikes] * dx_times[time][hwo_seg] * 1e-6
            doy = tel2_actuators[1][hexike + scda_seg * n_zernikes] * dy_times[time][hwo_seg] * 1e-6
            doz = tel2_actuators[2][hexike + scda_seg * n_zernikes] * dz_times[time][hwo_seg] * 1e-6
            rdx = tel2_actuators[3][hexike + scda_seg * n_zernikes] * rdx_times[time][hwo_seg] * 1e-6
            rdy = tel2_actuators[4][hexike + scda_seg * n_zernikes] * rdy_times[time][hwo_seg] * 1e-6
            rdz = tel2_actuators[5][hexike + scda_seg * n_zernikes] * rdz_times[time][hwo_seg] * 1e-6
            tel2.sm.actuators[hexike + scda_seg * n_zernikes] =  dox + doy + doz + rdx + rdy + rdz
    

    aberrated_psf, ref, intermediates = tel2.calc_psf(ref=True, display_intermediate=False, return_intermediate='intensity',
                                                      norm_one_photon=True, wv=wvl*1e-9)
    norm = np.max(ref)
    normalized_aberrated_psf = aberrated_psf / norm
    
    aberr_roi = normalized_aberrated_psf * tel2.dh_mask
    aber_contrast_floor = np.mean(aberr_roi[np.where(tel2.dh_mask != 0)])
    contrasts_dof.append(aber_contrast_floor)
    
    print("contrast_floor:", aber_contrast_floor)

In [None]:
np.savetxt(os.path.join(analysis_path, f'contrast_dof_new_openloop_{wvl :.0f}.csv'), contrasts_dof, delimiter=',')

In [None]:
contrasts_dof_475 = np.genfromtxt(os.path.join(analysis_path, 'contrast_dof_new_openloop_475.csv'), delimiter=',')
contrasts_dof_500 = np.genfromtxt(os.path.join(analysis_path, 'contrast_dof_new_openloop_500.csv'), delimiter=',')
contrasts_dof_525 = np.genfromtxt(os.path.join(analysis_path, 'contrast_dof_new_openloop_525.csv'), delimiter=',')

broadband_contrast = (contrasts_dof_475 + contrasts_dof_500 + contrasts_dof_525) / 3

plt.figure(figsize=(20, 10))

times = np.arange(0, 401, 1)

plt.plot(times, contrasts_dof_475, marker='s', markersize=4, linewidth=2, label='500 nm')
plt.plot(times, contrasts_dof_500, marker='^', markersize=4, linewidth=2, label='475 nm')
plt.plot(times, contrasts_dof_525, marker='p', markersize=4, linewidth=2, label='525 nm')

plt.plot(times, broadband_contrast,  marker='s', markersize=8, linewidth=2, label='broadband' )


#plt.plot(times, contrasts_dof_500 -  broadband_contrast, marker='s', markersize=8, linewidth=2, label='difference')
# plt.plot(times, contrasts_dof_500_new - contrasts_dof_500, label="subtraction")

plt.axhline(y=4.171337358217274e-11, color='r', linestyle='--', label='contrast_floor')
#plt.ylim(4.1e-11, 5.25e-11)
plt.ylabel("Mean DH Contrast", fontsize=25)
plt.xlabel("Time (in s)", fontsize=25)
plt.legend(fontsize=20)
plt.xticks(fontsize=20)
plt.yticks(fontsize=20)
#plt.yscale('log')
plt.tight_layout()
plt.savefig(os.path.join(analysis_path, 'openloop_broadband_timeseries.png'))

In [None]:


unaberrated_psf, ref, intermediates = tel2.calc_psf(ref=True, display_intermediate=False, 
                                  return_intermediate='intensity',
                                  norm_one_photon=True, wv=wvl*1e-9)

norm = np.max(ref)
normalized_unaberrated_psf = unaberrated_psf / norm
coron_psf.append(normalized_unaberrated_psf)
unaberr_roi = normalized_unaberrated_psf * tel2.dh_mask
contrast_floor = np.mean(unaberr_roi[np.where(tel2.dh_mask != 0)])
contrast_floors.append(contrast_floor)
print("contrast_floor:", contrast_floor, "central wavelength:", tel2.wvln)

fits.writeto(os.path.join(analysis_path, f'ref_{wvl}.fits'), ref.shaped)

fpm_mask = np.zeros(len(intermediates['after_fpm']))

for i in range(0, len(intermediates['after_fpm'])):
    if intermediates['after_fpm'][i] == 0.0:
        fpm_mask[i] = 0
    else:
        fpm_mask[i] = 1

plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)

plt.title("After FPM")
plt.xlabel(r"in $\lambda_{0}/D$")
hcipy.imshow_field(10**(intermediates['after_fpm']), 
                   norm=LogNorm(vmin=1e-14, vmax=1e-1),mask = fpm_mask, cmap= 'inferno')
plt.tick_params(bottom=True, left=True, labelleft=True, labelbottom=True)
plt.colorbar()

plt.subplot(1, 2, 2)
plt.title(f"PSF at {wvl} nm ")
plt.xlabel(r"in $\lambda_{0}/D$")
hcipy.imshow_field((normalized_unaberrated_psf),norm=LogNorm(vmin=1e-14, vmax=1e-1), cmap='inferno')
plt.colorbar()
plt.tight_layout()

#plt.savefig(os.path.join(analysis_path, f'fpm_n_psf_{wvl}.png'))

### Test to see whether wavelength at all planes remains same. 

In [None]:
psf, ref, inter = tel2.calc_psf(ref=True, display_intermediate=False, return_intermediate='efield',
                                      norm_one_photon=True, wv=600*1e-9)

print(inter['active_pupil'].wavelength)
print(inter['seg_mirror'].wavelength)
print(inter['zernike_mirror'].wavelength)
print(inter['dm'].wavelength)
print(inter['before_fpm'].wavelength)
print(inter['before_lyot'].wavelength)
print(inter['after_lyot'].wavelength)

In [None]:
tel2.sampling

In [None]:
tel2.wvln/tel2.diam

In [None]:
scaled_coron_psf = []
for i in range(0, 3):
    copy_psf = coron_psf[i].copy()
    copy_psf.grid = copy_psf.grid.scaled(1/(500 * 1e-9 / tel2.diam))
    scaled_coron_psf.append(copy_psf)


In [None]:
copy_psf = normalized_unaberrated_psf.copy()
copy_psf.grid = copy_psf.grid.scaled(1/(500 * 1e-9 / tel2.diam))

hcipy.imshow_field(np.log10(copy_psf), cmap='inferno')

hcipy.write_fits(copy_psf, os.path.join(analysis_path, f'int_coron_{tel2.wvln *1e9 :.0f}.fits'))

In [None]:
# copy_psf = normalized_unaberrated_psf.copy()
# copy_psf.grid = copy_psf.grid.scaled(1/(700 * 1e-9 / tel2.diam))

# plt.figure(figsize = (12, 5))
# r1, profile1, _, _ = hcipy.radial_profile(copy_psf, 0.15, statistic = 'mean')
# plt.plot(r1, np.log10(profile1))
# plt.show()