## This notebook contains e2e analysis used to allocate tolerances for each $\color{red}{\text{Global Zernike Aberration}}$ mode for a segmented telescope.

In [None]:
import os
os.chdir("/Users/asahoo/repos/PASTIS")
import time
from shutil import copy
from astropy.io import fits
import astropy.units as u
import hcipy
import numpy as np
import pastis.util as util    
from pastis.config import CONFIG_PASTIS 
from pastis.e2e_simulators.luvoir_imaging import LuvoirA_APLC 
from pastis.e2e_simulators.generic_segmented_telescopes import SegmentedAPLC
import matplotlib.pyplot as plt
import pandas as pd
from scipy.interpolate import griddata
import exoscene.image
import exoscene.star
import exoscene.planet
from exoscene.planet import Planet
from astropy.io import fits as pf
from matplotlib.colors import TwoSlopeNorm
import matplotlib.gridspec as gridspec
from pastis.analytical_pastis.temporal_analysis import req_closedloop_calc_batch

In [None]:
coronagraph_design = 'small'
nb_seg = CONFIG_PASTIS.getint('LUVOIR', 'nb_subapertures')
nm_aber = CONFIG_PASTIS.getfloat('LUVOIR', 'calibration_aberration') * 1e-9
sampling = CONFIG_PASTIS.getfloat('LUVOIR', 'sampling')

In [None]:
data_dir = "/Users/asahoo/Documents/ultra/global_zernike"
repo_dir = "/Users/asahoo/repos/PASTIS"
overall_dir = util.create_data_path(data_dir, telescope='luvoir_'+coronagraph_design)
resDir = os.path.join(overall_dir, 'matrix_numerical')

os.makedirs(resDir, exist_ok=True)

In [None]:
optics_input = os.path.join(util.find_repo_location(), CONFIG_PASTIS.get('LUVOIR', 'optics_path_in_repo'))
luvoir = LuvoirA_APLC(optics_input, coronagraph_design, sampling)

In [None]:
max_LO = 20
luvoir.create_global_zernike_mirror(max_LO)
n_LO = luvoir.zernike_mirror.num_actuators

In [None]:
LO_modes = np.zeros(n_LO)
luvoir.zernike_mirror.actuators = LO_modes

In [None]:
luvoir.zernike_mirror.flatten()

In [None]:
unaberrated_coro_psf, ref = luvoir.calc_psf(ref=True, display_intermediate=True, norm_one_photon=True)


In [None]:
norm = np.max(ref)
dh_intensity = (unaberrated_coro_psf / norm) * luvoir.dh_mask
contrast_floor = np.mean(dh_intensity[np.where(luvoir.dh_mask != 0)])
print(f'norm: {norm}',f'constrast floor: {contrast_floor}')

In [None]:
nonaberrated_coro_psf, ref, efield = luvoir.calc_psf(ref=True, display_intermediate=False, return_intermediate='efield',norm_one_photon=True)
Efield_ref = nonaberrated_coro_psf.electric_field

In [None]:
# LO_modes = np.zeros(n_LO)
# LO_modes[3] = 100*(nm_aber)/2
# luvoir.zernike_mirror.actuators  = LO_modes
# aberrated_coro_psf, ref2 = luvoir.calc_psf(ref=True, display_intermediate=True)

# dh_intensity_aberrated = (aberrated_coro_psf/ norm) * luvoir.dh_mask
# aberrated_contrast = np.mean(dh_intensity_aberrated[np.where(luvoir.dh_mask != 0)])
# print(f'contrast floor: {aberrated_contrast}')

In [None]:
print('Generating the E-fields for low order zernike modes in science plane')
print(f'Calibration aberration used: {nm_aber} m')

start_time = time.time()
focus_fieldS = []
focus_fieldS_Re = []
focus_fieldS_Im = []

for i in range(1, n_LO):
    print(f'Working on global zernike mode: {i}')
    
    # Apply calibration aberration to used mode
    LO_modes = np.zeros(n_LO)
    LO_modes[i] = (nm_aber)/2
    luvoir.zernike_mirror.actuators  = LO_modes
    # Calculate coronagraphic E-field and add to lists
    aberrated_coro_psf, inter = luvoir.calc_psf(display_intermediate=False, return_intermediate='efield',norm_one_photon=True)
    focus_field1 = aberrated_coro_psf
    focus_fieldS.append(focus_field1)
    focus_fieldS_Re.append(focus_field1.real)
    focus_fieldS_Im.append(focus_field1.imag)
    

In [None]:
focus_fieldS[0]

In [None]:
mat_LO = np.zeros([n_LO-1, n_LO-1])
for i in range(0, n_LO-1):
    for j in range(0, n_LO-1):
        test = np.real((focus_fieldS[i].electric_field -Efield_ref) * np.conj(focus_fieldS[j].electric_field-Efield_ref))
        dh_test = (test / norm) * luvoir.dh_mask
        contrast = np.mean(dh_test[np.where(luvoir.dh_mask != 0)])
        mat_LO[i, j] = contrast

In [None]:
mat_LO.shape

In [None]:
from matplotlib.colors import LinearSegmentedColormap

plt.figure(figsize=(10,8))                                                      
plt.imshow((mat_LO))
plt.title(r"PASTIS matrix $M$ for global zernike", fontsize=20)
plt.xlabel("Mode Index",fontsize=20)
plt.ylabel("Mode Index",fontsize=20)
plt.tick_params(labelsize=15)
cbar = plt.colorbar(fraction=0.046, pad=0.04)
cbar.set_label(r"in units of $1/{nm^2}$",fontsize =15)
plt.tight_layout()

In [None]:
filename_matrix1 = 'PASTISmatrix_n_LO_' + str(n_LO)
hcipy.write_fits(mat_LO, os.path.join(resDir, filename_matrix1 + '.fits'))
print('Matrix saved to:', os.path.join(resDir, filename_matrix1 + '.fits','\n'))

filename_matrix2 = 'EFIELD_Re_matrix_n_LO_' + str(n_LO)
hcipy.write_fits(focus_fieldS_Re, os.path.join(resDir, filename_matrix2 + '.fits'))
print('Efield Real saved to:', os.path.join(resDir, filename_matrix2 + '.fits', '\n'))

filename_matrix3 = 'EFIELD_Im_matrix_n_LO_' + str(n_LO)
hcipy.write_fits(focus_fieldS_Im, os.path.join(resDir, filename_matrix3 + '.fits'))
print('Efield Imag saved to:', os.path.join(resDir, filename_matrix3 + '.fits','\n'))

In [None]:
evals, evecs = np.linalg.eig(mat_LO)
sorted_evals = np.sort(evals)
sorted_indices = np.argsort(evals)
sorted_evecs = evecs[:, sorted_indices]

In [None]:
c_target_log = -11
c_target = 10**(c_target_log)
n_repeat = 20

In [None]:
mu_map_LO = np.sqrt(((c_target) / (n_LO-1)) / (np.diag(mat_LO)))
#np.savetxt('/Users/asahoo/Documents/ultra/segment_zernike/mu_map_zernike_1e-11.csv', mu_map_zernike, delimiter=',')

In [None]:
z_pup_downsample = CONFIG_PASTIS.getfloat('numerical', 'z_pup_downsample') 
N_pup_z = int(luvoir.pupil_grid.shape[0] / z_pup_downsample) #N_pup_z = 100,used to define out-of-band efield
grid_zernike = hcipy.field.make_pupil_grid(N_pup_z, diameter=luvoir.diam)

npup = int(np.sqrt(luvoir.pupil_grid.x.shape[0]))
nimg = int(np.sqrt(luvoir.focal_det.x.shape[0]))

# Getting the flux together
sptype = 'A0V'
Vmag = 5.0
minlam = 500
maxlam = 600 
dark_current = 0     
CIC = 0            
star_flux = exoscene.star.bpgs_spectype_to_photonrate(spectype=sptype, Vmag=Vmag, minlam=minlam, maxlam=maxlam) #ph/s/m^2
Nph = star_flux.value*15**2*np.sum(luvoir.apodizer**2) / npup**2

In [None]:
luvoir.zernike_mirror.flatten()
nonaberrated_coro_psf ,refshit ,inter_ref = luvoir.calc_psf(ref=True, display_intermediate=False, return_intermediate='efield',norm_one_photon=True)
Efield_ref = nonaberrated_coro_psf.electric_field

In [None]:
luvoir.zernike_mirror.flatten()
defocus_ref2 = luvoir.calc_out_of_band_wfs(norm_one_photon=True) #returns wavefront on obwfs detector
defocus_ref2_sub_real = hcipy.field.subsample_field(defocus_ref2.real, z_pup_downsample, grid_zernike, statistic='mean')
defocus_ref2_sub_imag = hcipy.field.subsample_field(defocus_ref2.imag, z_pup_downsample, grid_zernike, statistic='mean')
Efield_ref_OBWFS = (defocus_ref2_sub_real + 1j*defocus_ref2_sub_imag) * z_pup_downsample

In [None]:
nyquist_sampling = 2.

# Actual grid for LUVOIR images
grid_test = hcipy.make_focal_grid(
            luvoir.sampling,
            luvoir.imlamD,
            pupil_diameter=luvoir.diam,
            focal_length=1,
            reference_wavelength=luvoir.wvln,
        )

# Actual grid for LUVOIR images that are nyquist sampled
grid_det_subsample = hcipy.make_focal_grid(
            nyquist_sampling,
            np.floor(luvoir.imlamD),
            pupil_diameter=luvoir.diam,
            focal_length=1,
            reference_wavelength=luvoir.wvln,
        )
n_nyquist = int(np.sqrt(grid_det_subsample.x.shape[0]))

In [None]:
design = 'small'

dh_outer_nyquist = hcipy.circular_aperture(2 * luvoir.apod_dict[design]['owa'] * luvoir.lam_over_d)(grid_det_subsample)
dh_inner_nyquist = hcipy.circular_aperture(2 * luvoir.apod_dict[design]['iwa'] * luvoir.lam_over_d)(grid_det_subsample)
dh_mask_nyquist = (dh_outer_nyquist - dh_inner_nyquist).astype('bool')

dh_size = len(np.where(luvoir.dh_mask != 0)[0])
dh_size_nyquist = len(np.where(dh_mask_nyquist != 0)[0])
dh_index = np.where(luvoir.dh_mask != 0)[0]
dh_index_nyquist = np.where(dh_mask_nyquist != 0)[0]

In [None]:
E0_OBWFS = np.zeros([N_pup_z*N_pup_z,1,2])
E0_OBWFS[:,0,0] = Efield_ref_OBWFS.real
E0_OBWFS[:,0,1] = Efield_ref_OBWFS.imag

In [None]:
E0_coron = np.zeros([nimg*nimg,1,2])
E0_coron[:,0,0] = Efield_ref.real 
E0_coron[:,0,1] = Efield_ref.imag

In [None]:
filename_matrix2 = 'EFIELD_Re_matrix_n_LO_' + str(n_LO) + '.fits'
G_zernike_real = fits.getdata(os.path.join(overall_dir, 'matrix_numerical', filename_matrix2)) 
filename_matrix3 = 'EFIELD_Im_matrix_n_LO_' + str(n_LO) + '.fits'
G_zernike_imag = fits.getdata(os.path.join(overall_dir, 'matrix_numerical', filename_matrix3)) 

In [None]:
G_coron_zernike= np.zeros([nimg*nimg,2,n_LO-1])
for pp in range(0, n_LO-1):
    G_coron_zernike[:,0,pp] = G_zernike_real[pp] - Efield_ref.real
    G_coron_zernike[:,1,pp] = G_zernike_imag[pp] - Efield_ref.imag

In [None]:
start_time = time.time()
focus_fieldS = []
focus_fieldS_Re = []
focus_fieldS_Im = []

for i in range(1, n_LO):
    #print(f'Working on "defocus" zernike mode, segment: {i}')
    
    # Apply calibration aberration to used mode
    LO_modes = np.zeros(n_LO)
    #sm_mode[6*i-3] = (nm_aber)/2 
    LO_modes[i] = (nm_aber)/2
    luvoir.zernike_mirror.actuators  = LO_modes
    zernike_meas = luvoir.calc_out_of_band_wfs(norm_one_photon=True)
    zernike_meas_sub_real = hcipy.field.subsample_field(zernike_meas.real, z_pup_downsample, grid_zernike, statistic='mean')
    zernike_meas_sub_imag = hcipy.field.subsample_field(zernike_meas.imag, z_pup_downsample, grid_zernike, statistic='mean')
    focus_field1 = zernike_meas_sub_real + 1j * zernike_meas_sub_imag
    focus_fieldS.append(focus_field1)
    focus_fieldS_Re.append(focus_field1.real)
    focus_fieldS_Im.append(focus_field1.imag)

In [None]:
filename_matrix = 'EFIELD_OBWFS_Re_matrix_num_LO_' + str(n_LO)
hcipy.write_fits(focus_fieldS_Re, os.path.join(resDir, filename_matrix + '.fits'))
print('Efield Real saved to:', os.path.join(resDir, filename_matrix + '.fits'))

filename_matrix = 'EFIELD_OBWFS_Im_matrix_num_LO_' + str(n_LO)
hcipy.write_fits(focus_fieldS_Im, os.path.join(resDir, filename_matrix + '.fits'))
print('Efield Imag saved to:', os.path.join(resDir, filename_matrix + '.fits'))

In [None]:
filename_matrix = 'EFIELD_OBWFS_Re_matrix_num_LO_' + str(n_LO)+'.fits'
G_OBWFS_real = fits.getdata(os.path.join(overall_dir, 'matrix_numerical', filename_matrix))
filename_matrix = 'EFIELD_OBWFS_Im_matrix_num_LO_' + str(n_LO)+'.fits'
G_OBWFS_imag =  fits.getdata(os.path.join(overall_dir, 'matrix_numerical', filename_matrix))

In [None]:
G_OBWFS= np.zeros([N_pup_z*N_pup_z,2,n_LO-1])
for pp in range(0, n_LO-1):
    G_OBWFS[:,0,pp] = G_OBWFS_real[pp]*z_pup_downsample - Efield_ref_OBWFS.real
    G_OBWFS[:,1,pp] = G_OBWFS_imag[pp]*z_pup_downsample - Efield_ref_OBWFS.imag

In [None]:
flux = Nph
Q_LO = np.diag(np.asarray(mu_map_LO**2))

Ntimes = 20
TimeMinus = -2
TimePlus = 5.5 #3.5
Nwavescale = 8
Nflux = 3

res = np.zeros([Ntimes, Nwavescale, Nflux, 1])
result_wf_test =[]

#i=-1
for wavescale in range (1,15,2):
    #i=i+1
    print('Harris modes with batch OBWFS and noise %f'% wavescale, "i",i)  
    niter = 10
    timer1 = time.time()
    StarMag = 0.0
    #j=-1
    for tscale in np.logspace(TimeMinus, TimePlus, Ntimes):
        j=j+1
        Starfactor = 10**(-StarMag/2.5)
        print(tscale)
        tmp0 = req_closedloop_calc_batch(G_coron_zernike, G_OBWFS, E0_coron, E0_OBWFS, dark_current+CIC/tscale,
                                                 dark_current+CIC/tscale, tscale, flux*Starfactor,0.0001*wavescale**2*Q_LO,
                                                 niter, luvoir.dh_mask, norm)    
        tmp1 = tmp0['averaged_hist']
        n_tmp1 = len(tmp1)
        result_wf_test.append(tmp1[n_tmp1-1])

In [None]:
delta_wf = []
for wavescale in range (1,15,2):
    wf = 1e3*np.sqrt(0.0001*wavescale**2)
    delta_wf.append(wf)

texp = np.logspace(TimeMinus, TimePlus, Ntimes)

font = {'family': 'serif','color' : 'black','weight': 'normal','size'  :  20}
plt.figure(figsize =(15,10))

plt.title('Target contrast = %s, Vmag= %s'%(c_target, Vmag),fontdict=font)
plt.plot(texp,result_wf_test[0:20]-contrast_floor, label=r'$\Delta_{wf}= %d\ pm$'%(delta_wf[0]))
plt.plot(texp,result_wf_test[20:40]-contrast_floor, label=r'$\Delta_{wf}=%d\ pm$'%(delta_wf[0]))
plt.plot(texp,result_wf_test[40:60]-contrast_floor, label=r'$\Delta_{wf}=%d\ pm$'%(delta_wf[2]))
plt.plot(texp,result_wf_test[60:80]-contrast_floor, label=r'$\Delta_{wf}=%d\ pm$'%(delta_wf[3]))
plt.plot(texp,result_wf_test[80:100]-contrast_floor, label=r'$\Delta_{wf}=%d\ pm$'%(delta_wf[4]))
plt.plot(texp,result_wf_test[100:120]-contrast_floor, label=r'$\Delta_{wf}=%d\ pm$'%(delta_wf[5]))
plt.plot(texp,result_wf_test[120:140]-contrast_floor, label=r'$\Delta_{wf}=%d\ pm$'%(delta_wf[6]))
plt.xlabel("$t_{WFS}$ in secs",fontsize=20)
plt.ylabel("$\Delta$ contrast",fontsize=20)
plt.yscale('log')
plt.xscale('log')
plt.legend(loc = 'upper center',fontsize=20)
plt.tick_params(top=False, bottom=True, left=True, 
                right=True,labelleft=True, labelbottom=True,
                labelsize=20)
plt.tick_params(axis='both',which='major',length=10, width=2)
plt.tick_params(axis='both',which='minor',length=6, width=2)
plt.grid()
#plt.savefig('/Users/asahoo/Documents/ultra/segment_zernike/zernike_cont_wf.png')
plt.show()

In [None]:
delta_wf[1]