# This notebook consists of generating both static and temporal segment tolerances using only one segment-level defocus aberrations

### Importing necessary python libraries, and PASTIS pre-built functions

In [None]:
import os
os.chdir("/Users/asahoo/repos/PASTIS")
import time
from shutil import copy
from astropy.io import fits
import astropy.units as u
import hcipy
import numpy as np
import pastis.util as util    
from pastis.config import CONFIG_PASTIS 
from pastis.e2e_simulators.luvoir_imaging import LuvoirA_APLC 
from pastis.e2e_simulators.generic_segmented_telescopes import SegmentedAPLC
import matplotlib.pyplot as plt
import pandas as pd
from scipy.interpolate import griddata
import exoscene.image
import exoscene.star
import exoscene.planet
from exoscene.planet import Planet
from astropy.io import fits as pf

### Set some initial parameter or call them from config file 

In [None]:
coronagraph_design = 'small'
nb_seg = CONFIG_PASTIS.getint('LUVOIR', 'nb_subapertures')
nm_aber = CONFIG_PASTIS.getfloat('LUVOIR', 'calibration_aberration') * 1e-9
sampling = CONFIG_PASTIS.getfloat('LUVOIR', 'sampling')

### Define and create directory 

In [None]:
data_dir = "/Users/asahoo/Desktop/data_repos/harris_data"
repo_dir = "/Users/asahoo/repos/PASTIS"
overall_dir = util.create_data_path(data_dir, telescope='luvoir_'+coronagraph_design)
resDir = os.path.join(overall_dir, 'matrix_numerical')

os.makedirs(resDir, exist_ok=True)

### Instantiate LUVOIR-A 

In [None]:
optics_input = os.path.join(util.find_repo_location(), CONFIG_PASTIS.get('LUVOIR', 'optics_path_in_repo'))
luvoir = LuvoirA_APLC(optics_input, coronagraph_design, sampling)

### Create segement level defocus mirror 

In [None]:
luvoir.create_segmented_mirror(4) 


In [None]:
luvoir.sm
n_MID = luvoir.sm.num_actuators

### Flatten the DM 

In [None]:
luvoir.sm.flatten()

### Calculate the unaberrated coronagraphic PSF

In [None]:
unaberrated_coro_psf, ref= luvoir.calc_psf(ref=True, display_intermediate=True, norm_one_photon=True)

### Calculate peak value of reference psf and static coronagraphic floor 

In [None]:
norm = np.max(ref)
dh_intensity = (unaberrated_coro_psf / norm) * luvoir.dh_mask
contrast_floor = np.mean(dh_intensity[np.where(luvoir.dh_mask != 0)])
print(f'norm: {norm}',f'constrast floor: {contrast_floor}')

### Poking each segment with a zernike defocus

In [None]:
nonaberrated_coro_psf, ref, efield = luvoir.calc_psf(ref=True, display_intermediate=False, return_intermediate='efield',norm_one_photon=True)
Efield_ref = nonaberrated_coro_psf.electric_field

In [None]:
print('Generating the E-fields for harris modes in science plane')
print(f'Calibration aberration used: {nm_aber} m')

start_time = time.time()
focus_fieldS = []
focus_fieldS_Re = []
focus_fieldS_Im = []

for i in range(0, n_MID):
    print(f'Working on "bulk" thermal mode, segment: {i}')
    
    # Apply calibration aberration to used mode
    sm_mode = np.zeros(n_MID)
    #sm_mode[6*i -3] = (nm_aber)/2 
    sm_mode[i] = (nm_aber)/2
    luvoir.sm.actuators  = sm_mode
    # Calculate coronagraphic E-field and add to lists
    aberrated_coro_psf, inter = luvoir.calc_psf(display_intermediate=False, return_intermediate='efield',norm_one_photon=True)
    focus_field1 = aberrated_coro_psf
    focus_fieldS.append(focus_field1)
    focus_fieldS_Re.append(focus_field1.real)
    focus_fieldS_Im.append(focus_field1.imag)
    
    

In [None]:
mat_bulk = np.zeros([n_MID, n_MID])
for i in range(0, n_MID):
    for j in range(0, n_MID):
        test = np.real((focus_fieldS[i].electric_field - Efield_ref) * np.conj(focus_fieldS[j].electric_field - Efield_ref))
        dh_test = (test / norm) * luvoir.dh_mask
        contrast = np.mean(dh_test[np.where(luvoir.dh_mask != 0)])
        mat_bulk[i, j] = contrast

In [None]:
plt.figure(figsize=(10,8))
plt.imshow(np.log10(np.abs(mat_bulk)))
plt.title(r"PASTIS matrix $M$ for defocus zernike", fontsize=20)
plt.xlabel("Mode Index",fontsize=20)
plt.ylabel("Mode Index",fontsize=20)
plt.tick_params(labelsize=15)
# cbar = plt.colorbar(ticks = np.linspace(-13,-5,8,endpoint=False))
# cbar.ax.set_yticklabels([r'$10^{-13}$', r'$10^{-12}$', r'$10^{-11}$', r'$10^{-10}$', 
#                          r'$10^{-9}$', r'$10^{-8}$', r'$10^{-7}$',r'$10^{-6}$'], fontsize=15)
#cbar.set_label(r"in units of $1/{nm^2}$",fontsize =15)
plt.colorbar()
plt.tight_layout()

In [None]:
filename_matrix1 = 'PASTISmatrix_n_harris_' + str(n_MID)
hcipy.write_fits(mat_bulk, os.path.join(resDir, filename_matrix1 + '.fits'))
print('Matrix saved to:', os.path.join(resDir, filename_matrix1 + '.fits','\n'))

filename_matrix2 = 'EFIELD_Re_matrix_n_harris_' + str(n_MID)
hcipy.write_fits(focus_fieldS_Re, os.path.join(resDir, filename_matrix2 + '.fits'))
print('Efield Real saved to:', os.path.join(resDir, filename_matrix2 + '.fits', '\n'))

filename_matrix3 = 'EFIELD_Im_matrix_n_harris_' + str(n_MID)
hcipy.write_fits(focus_fieldS_Im, os.path.join(resDir, filename_matrix3 + '.fits'))
print('Efield Imag saved to:', os.path.join(resDir, filename_matrix3 + '.fits','\n'))

In [None]:
evals, evecs = np.linalg.eig(mat_bulk)
sorted_evals = np.sort(evals)
sorted_indices = np.argsort(evals)
sorted_evecs = evecs[:, sorted_indices]

In [None]:
c_target_log = -10
c_target = 10**(c_target_log)
n_repeat = 20

In [None]:
c_target?

In [None]:
mu_map_defous = np.sqrt(((c_target) / (n_MID)) / (np.diag(mat_bulk)))

In [None]:
plt.plot(mu_map_defous)

In [None]:
z_pup_downsample = CONFIG_PASTIS.getfloat('numerical', 'z_pup_downsample') 
N_pup_z = int(luvoir.pupil_grid.shape[0] / z_pup_downsample) #N_pup_z = 100,used to define out-of-band efield
grid_zernike = hcipy.field.make_pupil_grid(N_pup_z, diameter=luvoir.diam)

npup = int(np.sqrt(luvoir.pupil_grid.x.shape[0]))
nimg = int(np.sqrt(luvoir.focal_det.x.shape[0]))

# Getting the flux together
sptype = 'A0V'
Vmag = 5.0
minlam = 500
maxlam = 600 
dark_current = 0     
CIC = 0            
star_flux = exoscene.star.bpgs_spectype_to_photonrate(spectype=sptype, Vmag=Vmag, minlam=minlam, maxlam=maxlam) #ph/s/m^2
Nph = star_flux.value*15**2*np.sum(luvoir.apodizer**2) / npup**2

In [None]:
luvoir.sm.flatten()
nonaberrated_coro_psf ,refshit ,inter_ref = luvoir.calc_psf(ref=True, display_intermediate=False, return_intermediate='efield',norm_one_photon=True)
Efield_ref = nonaberrated_coro_psf.electric_field

In [None]:
luvoir.sm.flatten()
defocus_ref2 = luvoir.calc_out_of_band_wfs(norm_one_photon=True) #returns wavefront on obwfs detector
defocus_ref2_sub_real = hcipy.field.subsample_field(defocus_ref2.real, z_pup_downsample, grid_zernike, statistic='mean')
defocus_ref2_sub_imag = hcipy.field.subsample_field(defocus_ref2.imag, z_pup_downsample, grid_zernike, statistic='mean')
Efield_ref_OBWFS = (defocus_ref2_sub_real + 1j*defocus_ref2_sub_imag) * z_pup_downsample

In [None]:
nyquist_sampling = 2.

# Actual grid for LUVOIR images
grid_test = hcipy.make_focal_grid(
            luvoir.sampling,
            luvoir.imlamD,
            pupil_diameter=luvoir.diam,
            focal_length=1,
            reference_wavelength=luvoir.wvln,
        )

# Actual grid for LUVOIR images that are nyquist sampled
grid_det_subsample = hcipy.make_focal_grid(
            nyquist_sampling,
            np.floor(luvoir.imlamD),
            pupil_diameter=luvoir.diam,
            focal_length=1,
            reference_wavelength=luvoir.wvln,
        )
n_nyquist = int(np.sqrt(grid_det_subsample.x.shape[0]))

In [None]:
design = 'small'

dh_outer_nyquist = hcipy.circular_aperture(2 * luvoir.apod_dict[design]['owa'] * luvoir.lam_over_d)(grid_det_subsample)
dh_inner_nyquist = hcipy.circular_aperture(2 * luvoir.apod_dict[design]['iwa'] * luvoir.lam_over_d)(grid_det_subsample)
dh_mask_nyquist = (dh_outer_nyquist - dh_inner_nyquist).astype('bool')

dh_size = len(np.where(luvoir.dh_mask != 0)[0])
dh_size_nyquist = len(np.where(dh_mask_nyquist != 0)[0])
dh_index = np.where(luvoir.dh_mask != 0)[0]
dh_index_nyquist = np.where(dh_mask_nyquist != 0)[0]

In [None]:
E0_OBWFS = np.zeros([N_pup_z*N_pup_z,1,2])
E0_OBWFS[:,0,0] = Efield_ref_OBWFS.real
E0_OBWFS[:,0,1] = Efield_ref_OBWFS.imag

E0_coron = np.zeros([nimg*nimg,1,2])
E0_coron[:,0,0] = Efield_ref.real #not clear, why??
E0_coron[:,0,1] = Efield_ref.imag

tmp0 = hcipy.interpolation.make_linear_interpolator_separated(Efield_ref, grid=grid_test)
Efield_ref_nyquist = (luvoir.sampling/nyquist_sampling)**2*tmp0(grid_det_subsample)

E0_coron_nyquist = np.zeros([n_nyquist*n_nyquist,1,2])
E0_coron_nyquist[:,0,0] = Efield_ref_nyquist.real
E0_coron_nyquist[:,0,1] = Efield_ref_nyquist.imag

E0_coron_DH = np.zeros([dh_size,1,2])
E0_coron_DH[:,0,0] = Efield_ref.real[dh_index]
E0_coron_DH[:,0,1] = Efield_ref.imag[dh_index]

E0_coron_DH_nyquist = np.zeros([dh_size_nyquist,1,2])
E0_coron_DH_nyquist[:,0,0] = Efield_ref_nyquist.real[dh_index_nyquist]
E0_coron_DH_nyquist[:,0,1] = Efield_ref_nyquist.real[dh_index_nyquist]

In [None]:
filename_matrix2 = 'EFIELD_Re_matrix_n_harris_' + str(n_MID) + '.fits'
G_harris_real = fits.getdata(os.path.join(overall_dir, 'matrix_numerical', filename_matrix2)) 
filename_matrix3 = 'EFIELD_Im_matrix_n_harris_' + str(n_MID) + '.fits'
G_harris_imag = fits.getdata(os.path.join(overall_dir, 'matrix_numerical', filename_matrix3)) 

G_coron_harris_nyquist= np.zeros([n_nyquist*n_nyquist,2, n_MID])
for pp in range(0, n_MID):
    tmp0 = G_harris_real[pp] + 1j*G_harris_imag[pp]
    tmp1 = hcipy.interpolation.make_linear_interpolator_separated(tmp0, grid=grid_test)
    tmp2 = (luvoir.sampling/nyquist_sampling)**2*tmp1(grid_det_subsample)
    G_coron_harris_nyquist[:,0,pp] = tmp2.real - Efield_ref_nyquist.real
    G_coron_harris_nyquist[:,1,pp] = tmp2.real - Efield_ref_nyquist.imag #recheck?

In [None]:
G_coron_harris_DH= np.zeros([dh_size,2,n_MID])
for pp in range(0, n_MID):
    G_coron_harris_DH[:,0,pp] = G_harris_real[pp,dh_index] - Efield_ref.real[dh_index]
    G_coron_harris_DH[:,1,pp] = G_harris_imag[pp,dh_index] - Efield_ref.imag[dh_index]

In [None]:
G_coron_harris_DH_nyquist= np.zeros([dh_size_nyquist,2,n_MID])
for pp in range(0, n_MID):
    tmp0 = G_harris_real[pp] + 1j*G_harris_imag[pp]
    tmp1 = hcipy.interpolation.make_linear_interpolator_separated(tmp0, grid=grid_test)
    tmp2 = (luvoir.sampling/nyquist_sampling)**2*tmp1(grid_det_subsample)
    G_coron_harris_DH_nyquist[:,0,pp-1] = tmp2.real[dh_index_nyquist] - Efield_ref_nyquist.real[dh_index_nyquist]
    G_coron_harris_DH_nyquist[:,1,pp-1] = tmp2.imag[dh_index_nyquist] - Efield_ref_nyquist.imag[dh_index_nyquist]

In [None]:
G_coron_harris= np.zeros([nimg*nimg,2,n_MID])
for pp in range(0, n_MID):
    G_coron_harris[:,0,pp] = G_harris_real[pp] - Efield_ref.real
    G_coron_harris[:,1,pp] = G_harris_imag[pp] - Efield_ref.imag

In [None]:
start_time = time.time()
focus_fieldS = []
focus_fieldS_Re = []
focus_fieldS_Im = []

In [None]:
for i in range(1, n_MID):
    print(f'Working on "defocus" zernike mode, segment: {i}')
    
    # Apply calibration aberration to used mode
    sm_mode = np.zeros(n_MID)
    #sm_mode[6*i-3] = (nm_aber)/2 
    sm_mode[i] = (nm_aber)/2
    luvoir.sm.actuators  = sm_mode
    harris_meas = luvoir.calc_out_of_band_wfs(norm_one_photon=True)
    harris_meas_sub_real = hcipy.field.subsample_field(harris_meas.real, z_pup_downsample, grid_zernike, statistic='mean')
    harris_meas_sub_imag = hcipy.field.subsample_field(harris_meas.imag, z_pup_downsample, grid_zernike, statistic='mean')
    focus_field1 = harris_meas_sub_real + 1j * harris_meas_sub_imag
    focus_fieldS.append(focus_field1)
    focus_fieldS_Re.append(focus_field1.real)
    focus_fieldS_Im.append(focus_field1.imag)

In [None]:
filename_matrix = 'EFIELD_OBWFS_Re_matrix_num_harris_' + str(n_MID)
hcipy.write_fits(focus_fieldS_Re, os.path.join(resDir, filename_matrix + '.fits'))
print('Efield Real saved to:', os.path.join(resDir, filename_matrix + '.fits'))

filename_matrix = 'EFIELD_OBWFS_Im_matrix_num_harris_' + str(n_MID)
hcipy.write_fits(focus_fieldS_Im, os.path.join(resDir, filename_matrix + '.fits'))
print('Efield Imag saved to:', os.path.join(resDir, filename_matrix + '.fits'))

In [None]:
filename_matrix = 'EFIELD_OBWFS_Re_matrix_num_harris_' + str(n_MID)+'.fits'
G_OBWFS_real = fits.getdata(os.path.join(overall_dir, 'matrix_numerical', filename_matrix))
filename_matrix = 'EFIELD_OBWFS_Im_matrix_num_harris_' + str(n_MID)+'.fits'
G_OBWFS_imag =  fits.getdata(os.path.join(overall_dir, 'matrix_numerical', filename_matrix))

In [None]:
len(G_OBWFS_imag)

In [None]:
G_OBWFS= np.zeros([N_pup_z*N_pup_z,2,n_MID])
for pp in range(0, n_MID-1):
    G_OBWFS[:,0,pp] = G_OBWFS_real[pp]*z_pup_downsample - Efield_ref_OBWFS.real
    G_OBWFS[:,1,pp] = G_OBWFS_imag[pp]*z_pup_downsample - Efield_ref_OBWFS.imag

In [None]:
def req_closedloop_calc_batch(Gcoro, Gsensor, E0coro, E0sensor, Dcoro, Dsensor, t_exp, flux, Q, Niter, dh_mask, norm):
    P = np.zeros(Q.shape)  # WFE modes covariance estimate
    r = Gsensor.shape[2]
    N = Gsensor.shape[0]
    N_img = Gcoro.shape[0]
    c = 1
    # Iterations of ALGORITHM 1
    contrast_hist = np.zeros(Niter)
    intensity_WFS_hist = np.zeros(Niter)
    cal_I_hist = np.zeros(Niter)
    eps_hist = np.zeros([Niter, r])
    averaged_hist = np.zeros(Niter)
    contrasts = []
    for pp in range(Niter):
        eps = np.random.multivariate_normal(np.zeros(r), P + Q * t_exp).reshape((1, 1, r))  # random modes
        G_eps = np.sum(Gsensor * eps, axis=2).reshape((N, 1, 2 * c)) + E0sensor  # electric field
        G_eps_squared = np.sum(G_eps * G_eps, axis=2, keepdims=True)
        G_eps_G = np.matmul(G_eps, Gsensor)
        G_eps_G_scaled = G_eps_G / np.sqrt(G_eps_squared + Dsensor / flux / t_exp)  # trick to save RAM
        cal_I = 4 * flux * t_exp * np.einsum("ijk,ijl->kl", G_eps_G_scaled, G_eps_G_scaled)  # information matrix
        #         P = np.linalg.inv(np.linalg.inv(P+Q*t_exp/2) + cal_I)
        P = np.linalg.pinv(cal_I)

        # Coronagraph
        G_eps_coron = np.sum(Gcoro * eps, axis=2).reshape((N_img, 1, 2 * c)) + E0coro
        G_eps_coron_squared = np.sum(G_eps_coron * G_eps_coron, axis=2, keepdims=True)
        intensity = G_eps_coron_squared * flux * t_exp + Dcoro

        # Wavefront sensor
        intensity_WFS = G_eps_squared * flux * t_exp + Dsensor

        # Archive
        test_DH0 = intensity[:, 0, 0] * luvoir.dh_mask
        test_DH = np.mean(test_DH0[np.where(test_DH0 != 0)])
        contrasts.append(test_DH / flux / t_exp / norm)
        intensity_WFS_hist[pp] = np.sum(intensity_WFS) / flux
        cal_I_hist[pp] = np.mean(cal_I) / flux
        eps_hist[pp] = eps
        averaged_hist[pp] = np.mean(contrasts)
    #         print("est. contrast", np.mean(contrasts))
    #         print("est. contrast", np.mean(contrasts))

    outputs = {'intensity_WFS_hist': intensity_WFS_hist,
               'cal_I_hist': cal_I_hist,
               'eps_hist': eps_hist,
               'averaged_hist': averaged_hist,
               'contrasts': contrasts}

    return outputs


In [None]:
flux = Nph
Qharris = np.diag(np.asarray(mu_map_defous**2))

In [None]:
Ntimes = 20
TimeMinus = -2
TimePlus = 5.5 #3.5
Nwavescale = 8
WaveScaleMinus = -2
WaveScalePlus = 1
Nflux = 3
fluxPlus = 10
fluxMinus = 0

timeVec = np.logspace(TimeMinus,TimePlus,Ntimes)
WaveVec = np.logspace(WaveScaleMinus,WaveScalePlus,Nwavescale)
fluxVec = np.linspace(fluxMinus,fluxPlus,Nflux)
wavescaleVec = np.logspace(WaveScaleMinus,WaveScalePlus,Nwavescale)

In [None]:
res = np.zeros([Ntimes, Nwavescale, Nflux, 1])
result_wf_test =[]

#i=-1
for wavescale in range (1,15,2):
    #i=i+1
    print('Harris modes with batch OBWFS and noise %f'% wavescale, "i",i)  
    niter = 10
    timer1 = time.time()
    StarMag = 0.0
    #j=-1
    for tscale in np.logspace(TimeMinus, TimePlus, Ntimes):
        j=j+1
        Starfactor = 10**(-StarMag/2.5)
        print(tscale)
        tmp0 = req_closedloop_calc_batch(G_coron_harris, G_OBWFS, E0_coron, E0_OBWFS, dark_current+CIC/tscale,
                                                 dark_current+CIC/tscale, tscale, flux*Starfactor, 0.0001*wavescale**2*Qharris,
                                                 niter, luvoir.dh_mask, norm)    
        tmp1 = tmp0['averaged_hist']
        n_tmp1 = len(tmp1)
        result_wf_test.append(tmp1[n_tmp1-1])

In [None]:
for wavescale in range (1,15,2):
    print(1e3*np.sqrt(0.0001*wavescale**2)) #pm 

In [None]:
texp = np.logspace(TimeMinus, TimePlus, Ntimes)
plt.figure(figsize =(15,10))
plt.title("c_target=1e-10",fontsize=20)
plt.plot(texp,result_wf_test[0:20]-contrast_floor, label=r'$\Delta_{wf}=10\ pm$')
plt.plot(texp,result_wf_test[20:40]-contrast_floor, label=r'$\Delta_{wf}=30\ pm$')
plt.plot(texp,result_wf_test[40:60]-contrast_floor, label=r'$\Delta_{wf}=50\ pm$')
plt.plot(texp,result_wf_test[60:80]-contrast_floor, label=r'$\Delta_{wf}=70\ pm$')
plt.plot(texp,result_wf_test[80:100]-contrast_floor, label=r'$\Delta_{wf}=90\ pm$')
plt.plot(texp,result_wf_test[100:120]-contrast_floor, label=r'$\Delta_{wf}=110\ pm$')
plt.plot(texp,result_wf_test[120:140]-contrast_floor, label=r'$\Delta_{wf}=130\ pm$')
plt.plot
plt.xlabel("$t_{WFS}$ in secs",fontsize=20)
plt.ylabel("$\Delta$ contrast",fontsize=20)
plt.yscale('log')
plt.xscale('log')
plt.legend(loc = 'upper center',fontsize=20)
plt.tick_params(top=True, bottom=True, left=True, 
                right=True,labelleft=True, labelbottom=True,
                labelsize=20)
plt.tick_params(axis='both',which='major',length=10, width=2)
plt.tick_params(axis='both',which='minor',length=6, width=2)
plt.grid()
plt.savefig('/Users/asahoo/Documents/ultra/temp_plots/zernike_cont_wf_19.png')
plt.show()

In [None]:
# for i in range(1, 121):
#     print(f'Working on "bulk" thermal mode, segment: {i+1}')
    
#     # Apply calibration aberration to used mode
#     sm_mode = np.zeros(n_MID)
#     sm_mode[6*i -3] = (nm_aber)/2 
#     luvoir.sm.actuators  = sm_mode
    
#     plt.figure()
#     plt.text(9.6, 7.6, 'm', fontsize = 10)
#     hcipy.imshow_field(luvoir.sm.surface, mask=luvoir.aperture, cmap='RdBu')
#     cbar = plt.colorbar()
#     plt.tight_layout()
#     #plt.savefig('/Users/asahoo/Documents/ultra/temp_plots/bulk_2_2_22/poke_%d.png'% i, dpi=165)
#     plt.savefig('/Users/asahoo/Documents/ultra/temp_plots/defocus_plots/poke_{0:03}.png'.format(i), dpi=165)