### Notebook deals with making a prelimnary code to analysis temporal wavefront error data.

In [None]:
import numpy as np
import os
import matplotlib.pyplot as plt

import numpy as np
import scipy.io
import hcipy

import pastis.util as util
from pastis.config import CONFIG_PASTIS
from pastis.simulators.luvoir_imaging import LuvoirA_APLC
from pastis.simulators.scda_telescopes import HexRingAPLC

os.chdir('../ULTRA')
from config import CONFIG_ULTRA

In [None]:
data_path = CONFIG_ULTRA.get('local_path', 'local_data_path')
analysis_path = CONFIG_ULTRA.get('local_path', 'local_analysis_path')

input_misalignments = np.genfromtxt(os.path.join(data_path,'INPUT_MISALIGNMENTS.txt'),  delimiter=';')
output_misalignments = np.genfromtxt(os.path.join(data_path,'RESIDUAL_MISALIGNMENTS.txt'),  delimiter=';')

sensitivities = scipy.io.loadmat(os.path.join(data_path, 'dWFE_sensitivities_LEC.mat'))
hwo_sensitivities = scipy.io.loadmat(os.path.join(data_path, 'HWO_sens.mat'))

In [None]:
seg1_hexike_coeffs = (sensitivities['hfit_all'][:, :, 0])
seg2_hexike_coeffs = (sensitivities['hfit_all'][:, :, 1])   # segn = (mat['hfit_all'][:, :, n])

#np.savetxt("/Users/asahoo/Downloads/seg1.csv", seg1, delimiter=",")

In [None]:
print("input_misalignment shape:", input_misalignments.shape)
print("output_misalignment shape:", output_misalignments.shape)

#### Intantiate the LUVOIR A telescope simulator

In [None]:
optics_input = os.path.join(util.find_repo_location(), CONFIG_PASTIS.get('LUVOIR', 'optics_path_in_repo'))
sampling = CONFIG_PASTIS.getfloat('LUVOIR', 'sampling')
tel = LuvoirA_APLC(optics_input, 'small', sampling)

#### The following line takes ~5 minutes to execute in local machine.  

In [None]:
n_zernikes = 11      # Set the total number of hexike polynomials you want over the hexagonal segment.
tel.create_segmented_mirror(n_zernikes) 

In [None]:
unaberrated_coro_psf, ref = tel.calc_psf(ref=True, display_intermediate=False, norm_one_photon=True)
norm = np.max(ref)
dh_intensity = (unaberrated_coro_psf / norm) * tel.dh_mask
contrast_floor = np.mean(dh_intensity[np.where(tel.dh_mask != 0)])
print(f'static contrast floor for luvoirA small APLC design: {contrast_floor}')

In [None]:
num_actuators = len(tel.sm.actuators) # ensure this to be equal to (total_segments) * (total_hexikes_per_segment)

In [None]:
(tel.sm.actuators).shape

In [None]:
sensitivities_table = sensitivities['hfit_all']
luvoir_hexike_coeffs = []

#removing the outer_ring
for seg in range(0, 91):
    seg_hexike_coeffs = sensitivities['hfit_all'][:, :, seg]
    luvoir_hexike_coeffs.append(seg_hexike_coeffs)

In [None]:
#sensitivities_table.shape
seg1_hexike_coeffs = (sensitivities['hfit_all'][:, :, 0])

In [None]:
seg1_hexike_coeffs.shape
luvoir_hexike_coeffs[0][0][0]

In [None]:
# Note, Hexikes are set according to Noll Convention.

tel.sm.flatten()
for seg in range(0, 91):
    for hexike in range(0, n_zernikes):
        tel.sm.actuators[hexike + seg * n_zernikes] = (luvoir_hexike_coeffs[seg][2][hexike])*1e-12
        

aberrated_coro_psf, efields_inter = tel.calc_psf(display_intermediate=False,return_intermediate='efield', norm_one_photon=True)

In [None]:
plt.figure()
hcipy.imshow_field((efields_inter['seg_mirror']).phase, mask = tel.aperture, cmap='jet', vmin=-1e-4, vmax=0.0)
plt.colorbar()
#plt.title('Segmented mirror phase)
#plt.savefig(os.path.join(analysis_path, 'luvoir_drz.png'))

#### Instantiate the 2hex simulator.

In [None]:
optics_dir = os.path.join(util.find_repo_location(), 'data', 'SCDA')
NUM_RINGS = 2
sampling = 4

tel2 = HexRingAPLC(optics_dir, NUM_RINGS, sampling)

unaberrated_psf, ref, intermediates = tel2.calc_psf(ref=True, display_intermediate=False, 
                                  return_intermediate='intensity',
                                  norm_one_photon=True)

norm = np.max(ref)
normalized_unaberrated_psf = unaberrated_psf / norm

unaberr_roi = normalized_unaberrated_psf * tel2.dh_mask
contrast_floor = np.mean(unaberr_roi[np.where(tel2.dh_mask != 0)])
print("contrast_floor:", contrast_floor)

In [None]:
hwo_table = hwo_sensitivities['HWO_sens']
first_element = hwo_table[0,0]
data_list = first_element.tolist()  # tuple of length 4

In [None]:
hwo_hexike_coeffs = data_list[0]  # ndarray of shape (11, 19, 6)
mask = data_list[1]    # ndarray of shape (256, 256, 19, 6)
dopd = data_list[2]    # ndarray of shape (256, 256, 19, 6)
units = data_list[3]   # ndarray of shape (1,) where the only element is a string

In [None]:
seg1_table = hwo_hexike_coeffs[:, 0, :].T

In [None]:
plt.imshow((seg1_table))
plt.ylabel('DOF freedom')
plt.xlabel('Zernike coefficients')
plt.colorbar()

In [None]:
seg1_table[0,0] * 1e-9 # nm

In [None]:
segs_tables = []
for seg in range(0, 19):
    seg_table = hwo_hexike_coeffs[:, seg, :].T
    segs_tables.append(seg_table)

In [None]:
plt.imshow(seg1_table)
plt.colorbar()

In [None]:
n_zernikes = 11      # Set the total number of hexike polynomials you want over the hexagonal segment.
tel2.create_segmented_mirror(n_zernikes)

In [None]:
tel2.sm.flatten()


for seg in range(0, tel2.nseg):
    for hexike in range(0, n_zernikes):
        #print("Seg num:", seg, "hexike_num:",hexike, "Coeffs in nm: ", segs_tables[seg][5][hexike])
        tel2.sm.actuators[hexike + seg * n_zernikes] = 1/2 * segs_tables[seg][5][hexike] * 1e-9


aberrated_coro_psf2, efields_inter2 = tel2.calc_psf(display_intermediate=False,
                                                    return_intermediate='efield', 
                                                    norm_one_photon=True)
plt.title("surface")
hcipy.imshow_field((tel2.sm.surface), cmap='jet')
cbar = plt.colorbar()
cbar.set_label("in m", loc='center')
plt.savefig(os.path.join(analysis_path, 'hwo_Rz_dof.png'))
plt.show()

In [None]:
-1970*1e-9

In [None]:
# Note, Hexikes are set according to Noll Convention., maximum 1000 nm aberrations it can handle. 

tel2.sm.flatten()


# for seg in range(0, 1):
#     for hexike in range(0, n_zernikes):
#         print("Seg num:", seg, "hexike_num:",hexike, "Coeffs in nm: ", segs_tables[seg][5][hexike])
#         tel2.sm.actuators[hexike + seg * n_zernikes] = segs_tables[seg][5][hexike] * 1e-9


# rms_aber = np.sqrt(np.mean((tel2.sm.actuators)**2))*1e9 # in nm 

# tel2.sm.actuators[1] = 1/2 * 1000 *1e-9
# tel2.sm.actuators[12] = 1/2*1000 *1e-9
# aberrated_coro_psf2, efields_inter2 = tel2.calc_psf(display_intermediate=False,return_intermediate='efield', norm_one_photon=True)

# plt.figure()
# plt.subplot(1,3,1)
# hcipy.imshow_field(np.angle(np.exp(1j * (tel2.sm.phase_for(500e-9)))), cmap='jet')
# plt.colorbar()
# plt.subplot(1,3,2)
# hcipy.imshow_field((tel2.sm.phase_for(500e-9)), cmap='jet')
# plt.colorbar()

# plt.subplot(1,3,3)
# plt.title("surface")
# hcipy.imshow_field((tel2.sm.surface), cmap='jet')
# plt.colorbar()
# plt.show()

# plt.figure()
# hcipy.imshow_field((efields_inter2['seg_mirror']).phase, mask = tel2.aperture, cmap='jet', origin='lower')
# plt.title(f'Segmented mirror phase')
# plt.colorbar()

# plt.savefig(os.path.join(analysis_path, 'usort_Rz_dof.png'))

In [None]:
sm_wavefront_phase = efields_inter2['seg_mirror'].phase

In [None]:
 efields_inter2['seg_mirror'].wavelength

In [None]:
 segs_tables[1][0][0] 

#### Plot the hexmaps from ball.
 

In [None]:
def cart2pol(x, y):
    rho = np.sqrt(x**2 + y**2)
    phi = np.arctan2(y, x)
    return(rho, phi)

X = np.linspace(-10, 10, 1000)
Y = np.linspace(-10, 10, 1000)

th, R = cart2pol(X, -Y)

In [None]:
z1 = np.ones(X.shape) 
z2 = np.sqrt(24/5)*X
z3 = np.sqrt(24/5)*Y
z4 = np.sqrt(720/43)*(R**2 - 5/12)
z5 = np.sqrt(60/7)*(R**2)*np.cos(2*th)
z6 = np.sqrt(60/7)*(R**2)*np.sin(2*th)
z7 = np.sqrt(84000/737)*(R**2 - 14/25)*R*np.cos(th)
z8 = np.sqrt(84000/737)*(R**2 - 14/25)*R*np.sin(th)
z9 = np.sqrt(1517040/4987)*(R**4 - 257/301*(R**2) + 737/6020)
z10 = np.sqrt(1120/103)*(R**3)*(np.cos(3*th))
z11 = np.sqrt(160/9)*(R**3)*(np.cos(3*th))

plt.figure()
plt.subplot(3, 4, 1)
plt.plot(X, z1)

plt.subplot(3, 4, 2)
plt.plot(X, z2)

plt.subplot(3, 4, 3)
plt.plot(X, z3)

plt.subplot(3, 4, 4)
plt.plot(X, z4)

plt.subplot(3, 4, 5)
plt.plot(X, z5)

plt.subplot(3, 4, 6)
plt.plot(X, z6)

plt.subplot(3, 4, 7)
plt.plot(X, z7)

plt.subplot(3, 4, 8)
plt.plot(X, z8)

plt.subplot(3, 4, 9)
plt.plot(X, z9)

plt.subplot(3, 4, 10)
plt.plot(X, z10)

plt.subplot(3, 4, 11)
plt.plot(X, z11)

In [None]:
grid = hcipy.make_pupil_grid(1000)
R, th = grid.as_('polar').coords
X, Y = grid.coords

z4 = np.sqrt(720/43)*(R**2 - 5/12)
z5 = np.sqrt(60/7)*(R**2)*np.cos(2*th)
z9 = np.sqrt(1517040/4987)*(R**4 - 257/301*(R**2) + 737/6020)
z8 = np.sqrt(84000/737)*(R**2 - 14/25)*R*np.sin(th)
hcipy.imshow_field(z8, grid)