### Notebook deals with making a prelimnary code to analysis temporal wavefront error data.

In [None]:
import numpy as np
import os
import matplotlib.pyplot as plt

import numpy as np
import scipy.io
import hcipy

import pastis.util as util
from pastis.config import CONFIG_PASTIS
from pastis.simulators.luvoir_imaging import LuvoirA_APLC
from pastis.simulators.scda_telescopes import HexRingAPLC

os.chdir('../ULTRA')
from config import CONFIG_ULTRA

In [None]:
data_path = CONFIG_ULTRA.get('local_path', 'local_data_path')
analysis_path = CONFIG_ULTRA.get('local_path', 'local_analysis_path')

input_misalignments = np.genfromtxt(os.path.join(data_path,'INPUT_MISALIGNMENTS.txt'),  delimiter=';')
output_misalignments = np.genfromtxt(os.path.join(data_path,'RESIDUAL_MISALIGNMENTS.txt'),  delimiter=';')

In [None]:
print("input_misalignment shape:", input_misalignments.shape)
print("output_misalignment shape:", output_misalignments.shape)

## For LUOVIR A 

### Unpacking data from .mat file 

In [None]:
#sensitivities = scipy.io.loadmat(os.path.join(data_path, 'dWFE_sensitivities_LEC.mat'))

# seg1_hexike_coeffs = (sensitivities['hfit_all'][:, :, 0])
# seg2_hexike_coeffs = (sensitivities['hfit_all'][:, :, 1])   # segn = (mat['hfit_all'][:, :, n])

#np.savetxt("/Users/asahoo/Downloads/seg1.csv", seg1, delimiter=",")

#### Intantiate the LUVOIR A telescope simulator

In [None]:
# optics_input = os.path.join(util.find_repo_location(), CONFIG_PASTIS.get('LUVOIR', 'optics_path_in_repo'))
# sampling = CONFIG_PASTIS.getfloat('LUVOIR', 'sampling')
# tel = LuvoirA_APLC(optics_input, 'small', sampling)

#### The following line takes ~5 minutes to execute in local machine.  

In [None]:
# n_zernikes = 11      # Set the total number of hexike polynomials you want over the hexagonal segment.
# tel.create_segmented_mirror(n_zernikes) 

In [None]:
# unaberrated_coro_psf, ref = tel.calc_psf(ref=True, display_intermediate=False, norm_one_photon=True)
# norm = np.max(ref)
# dh_intensity = (unaberrated_coro_psf / norm) * tel.dh_mask
# contrast_floor = np.mean(dh_intensity[np.where(tel.dh_mask != 0)])
# print(f'static contrast floor for luvoirA small APLC design: {contrast_floor}')

In [None]:
# num_actuators = len(tel.sm.actuators) # ensure this to be equal to (total_segments) * (total_hexikes_per_segment)

In [None]:
# (tel.sm.actuators).shape

In [None]:
# sensitivities_table = sensitivities['hfit_all']
# luvoir_hexike_coeffs = []

# #removing the outer_ring
# for seg in range(0, 91):
#     seg_hexike_coeffs = sensitivities['hfit_all'][:, :, seg]
#     luvoir_hexike_coeffs.append(seg_hexike_coeffs)

In [None]:
#sensitivities_table.shape
# seg1_hexike_coeffs = (sensitivities['hfit_all'][:, :, 0])

In [None]:
# seg1_hexike_coeffs.shape
# luvoir_hexike_coeffs[0][0][0]

In [None]:
# Note, Hexikes are set according to Noll Convention.

# tel.sm.flatten()
# for seg in range(0, 91):
#     for hexike in range(0, n_zernikes):
#         tel.sm.actuators[hexike + seg * n_zernikes] = (luvoir_hexike_coeffs[seg][2][hexike])*1e-12
        

# aberrated_coro_psf, efields_inter = tel.calc_psf(display_intermediate=False,return_intermediate='efield', norm_one_photon=True)

In [None]:
# plt.figure()
# hcipy.imshow_field((efields_inter['seg_mirror']).phase, mask = tel.aperture, cmap='jet', vmin=-1e-4, vmax=0.0)
# plt.colorbar()
#plt.title('Segmented mirror phase)
#plt.savefig(os.path.join(analysis_path, 'luvoir_drz.png'))

## For SCDA 2hex Simulator

#### Instantiate the 2hex simulator.

In [None]:
optics_dir = os.path.join(util.find_repo_location(), 'data', 'SCDA')
NUM_RINGS = 2
sampling = 4

tel2 = HexRingAPLC(optics_dir, NUM_RINGS, sampling)

unaberrated_psf, ref, intermediates = tel2.calc_psf(ref=True, display_intermediate=False, 
                                  return_intermediate='intensity',
                                  norm_one_photon=True)

norm = np.max(ref)
normalized_unaberrated_psf = unaberrated_psf / norm

unaberr_roi = normalized_unaberrated_psf * tel2.dh_mask
contrast_floor = np.mean(unaberr_roi[np.where(tel2.dh_mask != 0)])
print("contrast_floor:", contrast_floor)

#### Load HWO sensitivities .mat file

In [None]:
hwo_sensitivities = scipy.io.loadmat(os.path.join(data_path, 'HWO_sens.mat'))
hwo_table = hwo_sensitivities['HWO_sens']
first_element = hwo_table[0,0]
data_list = first_element.tolist()  # tuple of length 4

In [None]:
hwo_hexike_coeffs = data_list[0]  # ndarray of shape (11, 19, 6)
mask = data_list[1]    # ndarray of shape (256, 256, 19, 6)
dopd = data_list[2]    # ndarray of shape (256, 256, 19, 6)
units = data_list[3]   # ndarray of shape (1,) where the only element is a string

#### Locate segment number using mask data

In [None]:
seg = 15 # starts with 0 till 18
dof = 5 # starts with 0 till 5

plt.figure(figsize=(8, 3))
plt.subplot(1, 2, 1)
plt.title("mask")
plt.imshow(mask[:, :, seg, dof])
plt.colorbar()

plt.subplot(1, 2, 2)
plt.title("dopd")
plt.imshow(dopd[:, :, seg, dof], cmap='jet')
plt.colorbar()

#### Plot surface maps using the opd data only

In [None]:
dopds = []

for dof in range(0, dopd.shape[3]):
    full_dopd = np.zeros((dopd.shape[1], dopd.shape[1]))
    for seg in range(0, dopd.shape[2]):
        opd_per_segment = dopd[:, :, seg, dof]
        full_dopd = opd_per_segment + full_dopd
        
    dopds.append(full_dopd)
    
titles = np.array(["X nm RMS/um", "Y nm RMS/um", "Z nm RMS/um", 
                  "Rx nm RMS/urad", "Ry nm RMS/urad", "Rz nm RMS/urad" ])
        
plt.figure(figsize = (10, 5))
for dof in range(0, len(dopds)):
    if dof!=2:
        plt.subplot(2, 3, dof+1)
        plt.title(titles[dof], fontweight = 'bold', fontsize= 7)
        plt.imshow(dopds[dof], cmap='jet')
        plt.tick_params(top=False, bottom=True, left=True, right=False, labelleft=True, labelbottom=True, labelsize=7)
        cbar = plt.colorbar()
        cbar.ax.tick_params(labelsize=10)

plt.subplot(2, 3, 3)
plt.title("Z nm RMS/um", fontweight = 'bold', fontsize=10)
plt.imshow(dopds[2], cmap='jet', vmin=-2000, vmax= -1966)
plt.tick_params(top=False, bottom=True, left=True, right=False, labelleft=True, labelbottom=True, labelsize=7)
cbar = plt.colorbar()
cbar.ax.tick_params(labelsize=7)

plt.tight_layout()
plt.savefig(os.path.join(analysis_path, 'hwo_dopds_PTT.png'))

In [None]:
#do not understand why they sent me multiple masks, 
#I plotted all the masks per dof, subtracted two masks to check if there is some mismatch between dofs, 
#they are identical to 1e-30 scientific precision.

masks = []

for dof in range(0, mask.shape[3]):
    full_mask = np.zeros((mask.shape[1], mask.shape[1]))
    for seg in range(0, mask.shape[2]):
        mask_per_segment = mask[:, :, seg, dof]
        full_mask = mask_per_segment + full_mask
        
    masks.append(full_mask)
    
titles = np.array(["mask (X nm RMS/um)", "mask (Y nm RMS/um)", "mask (Z nm RMS/um)", 
                  "mask (Rx nm RMS/urad)", "mask (Ry nm RMS/urad)", "mask (Rz nm RMS/urad)"])
    
plt.figure(figsize = (14, 7))
for dof in range(0, len(dopds)):
    plt.subplot(2, 3, dof+1)
    plt.title(titles[dof], fontweight = 'bold')
    plt.imshow(masks[dof], cmap='jet')
    plt.colorbar()

In [None]:
dopds = np.array(dopds)
masks = np.array(masks)

titles = np.array(["X nm RMS/um", "Y nm RMS/um", "Z nm RMS/um", 
                  "Rx nm RMS/urad", "Ry nm RMS/urad", "Rz nm RMS/urad" ])

for dof in range(0, len(dopds)):
    rms = np.sqrt((np.mean(dopds[dof][np.where(masks[dof]!= 0)]**2)))
    print(titles[dof],"----", rms)

### Map HWO segment number to the SCDA segment number:

In [None]:
hwo_to_scda_L = {"1": 4, "2": 3, "3": 2, "4":7, "5":6, "6":5, "7":12, "8":11, "9":10, "10":9,
               "11": 8, "12":19, "13":18, "14":17, "15":16, "16":15, "17":14, "18":13, "19":1}

In [None]:
hwo_to_scda = {"1": 4, "2": 5, "3": 6, "4":7, "5":2, "6":3, "7":12, "8":13, "9":14, "10":15,
               "11": 16, "12":17, "13":18, "14":19, "15":8, "16":9, "17":10, "18":11, "19":1}

hwo_scda_segments = np.array([4, 5, 6, 7, 2, 3, 12, 13, 14, 15, 16, 17, 18, 19, 8, 9, 10, 11, 1])

In [None]:
seg1_table = hwo_hexike_coeffs[:, 0, :].T

plt.imshow((seg1_table))
plt.ylabel('DOF freedom')
plt.xlabel('Zernike coefficients')
plt.colorbar()

In [None]:
segs_tables = []
for seg in range(0, 19):
    seg_table = hwo_hexike_coeffs[:, seg, :].T
    segs_tables.append(seg_table)

n_zernikes = 11      # Set the total number of hexike polynomials you want over the hexagonal segment.
tel2.create_segmented_mirror(n_zernikes)

In [None]:
plt.figure(figsize = (10, 12))

for z in range(0, 11):
    tel2.sm.flatten()
    tel2.sm.actuators[z] = 1e-9

    plt.subplot(4, 3, int(z+1))
    plt.title(f"Zernike {z}", fontweight = 'bold')
    hcipy.imshow_field(tel2.sm.surface, cmap='jet')
    cbar = plt.colorbar()
    cbar.set_label("in m", loc='center')

In [None]:
tel2.sm.flatten()

# multiply by 1/2 ?
tel2_surfaces = []
for dof in range(0, 6):
    tel2.sm.flatten()
    for hwo_seg in range(0, tel2.nseg):
        for hexike in range(0, n_zernikes):
            scda_seg = hwo_to_scda[str(hwo_seg + 1)] - 1
            tel2.sm.actuators[hexike + scda_seg * n_zernikes] = 1 * segs_tables[hwo_seg][dof][hexike] * 1e-9
    
    tel2_surfaces.append(tel2.sm.surface)
                      

titles = np.array(["X nm RMS/um", "Y nm RMS/um", "Z nm RMS/um", 
                  "Rx nm RMS/urad", "Ry nm RMS/urad", "Rz nm RMS/urad" ])
        
plt.figure(figsize = (14, 7))
for dof in range(0, 6):
    if dof!=2:
        plt.subplot(2, 3, dof+1)
        plt.title(titles[dof], fontweight = 'bold')
        hcipy.imshow_field(tel2_surfaces[dof]*1e9, cmap='jet')
        cbar = plt.colorbar()
        cbar.set_label("in nm", loc='center')

plt.subplot(2, 3, 3)
plt.title("Z nm RMS/um", fontweight = 'bold')
hcipy.imshow_field(tel2_surfaces[2]*1e9, cmap='jet', vmin = -2000*1,  vmax= -1960*1)
cbar = plt.colorbar()
cbar.set_label("in nm", loc='center')  

#plt.savefig(os.path.join(analysis_path, 'hwo_PTT.png'))

# aberrated_coro_psf2, efields_inter2 = tel2.calc_psf(display_intermediate=True, 
#                                                     return_intermediate='efield', 
#                                                     norm_one_photon=True)


#### Sorting Zernike coefficients

In [None]:
segs_tables = []

for seg in range(0, 19):
    seg_table = hwo_hexike_coeffs[:, seg, :].T
    segs_tables.append(seg_table)   

In [None]:
segs_tables[0][5][:] * 1e9

In [None]:
tel2.sm.flatten()
for hwo_seg in range(0, tel2.nseg): 
    for hexike in range(0, n_zernikes):
        scda_seg = hwo_to_scda[str(hwo_seg + 1)] - 1
        if hexike == 1:
            hexn = 1
            sign = 1
        elif hexike == 2:
            hexn = 2
            sign = -1
        elif hexike == 3:
            hexn = 3
            sign = 1
        elif hexike == 4:
            hexn = 5
            sign = 1
        elif hexike == 5:
            hexn = 4
            sign = 1
        elif hexike == 6:
            hexn = 7
            sign = 1
        elif hexike == 7:
            hexn = 6
            sign = 1
        elif hexike == 8:
            hexn = 10
            sign = 1
        elif hexike == 9:
            hexn = 9
            sign = 1
        elif hexike == 10:
            hexn = 8
            sign = 1
        else:
            hexn = hexike
            sign = 1
        tel2.sm.actuators[hexike + scda_seg * n_zernikes] = sign * segs_tables[hwo_seg][2][hexn] * 1e-9

    
plt.figure()
plt.title("Rz m RMS/urad", fontweight = 'bold')
#vmin=-2000*1e-9, vmax = -1965*1e-9 
hcipy.imshow_field(tel2.sm.surface, mask = tel2.aperture, cmap='jet', vmin=-2000*1e-9, vmax = -1950*1e-9)
cbar = plt.colorbar()
cbar.set_label("in m", loc='center')
#plt.savefig(os.path.join(analysis_path, 'hwo_Rz_PTT.png'))

In [None]:
tel2.sm.flatten()
for hwo_seg in range(0, tel2.nseg): 
    for hexike in range(0, n_zernikes):
        scda_seg = hwo_to_scda[str(hwo_seg + 1)] - 1
        if hexike == 1:
            hexn = 1
            sign = 1
        elif hexike == 2:
            hexn = 2
            sign = -1
        elif hexike == 4:
            hexn = 5
            sign = -1
        elif hexike == 5:
            hexn = 4
            sign = 1
        elif hexike == 6:
            hexn = 7
            sign = 1
        elif hexike == 7:
            hexn = 6
            sign = 1
        elif hexike == 8:
            hexn = 10
            sign = 1
        elif hexike == 9:
            hexn = 9
            sign = 1
        elif hexike == 10:
            hexn = 8
            sign = 1
        else:
            hexn = hexike
        tel2.sm.actuators[hexike + scda_seg * n_zernikes] = sign * segs_tables[hwo_seg][0][hexn] * 1e-9

    
plt.figure()
#plt.title("Rz m RMS/urad", fontweight = 'bold')
#vmin=-1995*1e-9, vmax = -1965*1e-9 
# vmin=-0.1*1e-9, vmax = 0.3*1e-9
hcipy.imshow_field(tel2.sm.surface, cmap='jet')
cbar = plt.colorbar()
cbar.set_label("in m", loc='center')
#plt.savefig(os.path.join(analysis_path, 'hwo_Rz.png'))
tel2.sm.flatten()

In [None]:
#A = np.array([[1, 8, 9, 4], [2, 6, 7, 0], [1, 3, 8, 9], [0, 0, 4,1]])
# A = np.array([[1.6, 8, 9, 4], [2, 6, 7, 0], [1, 3, 8, 9]])
# print(A)

# A[:,[0,3]] = A[:,[3,0]]
# print(A)

# seg1_table[:, [5, 4]] = seg1_table[:, [4, 5]] # swapping astigmatism
# seg1_table[:, [7, 6]] = seg1_table[:, [6, 7]] # swapping coma

In [None]:
segs_tables2 = []

for seg in range(0, tel2.nseg):
    
    seg_table = np.array(segs_tables[seg])
    seg_table[:, [5, 4]] = seg_table[:, [4, 5]] # swapping astigmatism
    seg_table[:, [7, 6]] = seg_table[:, [6, 7]] # swapping coma
    seg_table[:, [10, 8]] = seg_table[:, [8, 10]]
    
    segs_tables2.append(seg_table)

In [None]:
sorted_seg_table = segs_tables2[0]

print(ref_table[:, 0], "\n", "\n")
print(sorted_seg_table[:, 0])

print(type(ref_table))

plt.imshow(ref_table - sorted_seg_table)
plt.colorbar()

In [None]:
tel2.sm.flatten()

# multiply by 1/2 ?
tel2_surfaces = []
for dof in range(0, 6):
    tel2.sm.flatten()
    for hwo_seg in range(0, tel2.nseg):
        for hexike in range(0, n_zernikes):
            scda_seg = hwo_to_scda[str(hwo_seg + 1)] - 1
            tel2.sm.actuators[hexike + scda_seg * n_zernikes] = segs_tables2[hwo_seg][dof][hexike] * 1e-9
    
    tel2_surfaces.append(tel2.sm.surface)
                      

titles = np.array(["X nm RMS/um", "Y nm RMS/um", "Z nm RMS/um", 
                  "Rx nm RMS/urad", "Ry nm RMS/urad", "Rz nm RMS/urad" ])
        
plt.figure(figsize = (14, 7))
for dof in range(0, len(dopds)):
    if dof!=2:
        plt.subplot(2, 3, dof+1)
        plt.title(titles[dof], fontweight = 'bold')
        hcipy.imshow_field(tel2_surfaces[dof], cmap='jet')
        cbar = plt.colorbar()
        cbar.set_label("in m", loc='center')

plt.subplot(2, 3, 3)
plt.title("Z nm RMS/um", fontweight = 'bold')
hcipy.imshow_field(tel2_surfaces[2], cmap='jet', vmin = -2000*1e-9,  vmax= -1960*1e-9 )
cbar = plt.colorbar()
cbar.set_label("in m", loc='center')


In [None]:
# Note, Hexikes are set according to Noll Convention., maximum 1000 nm aberrations it can handle. 

tel2.sm.flatten()


# for seg in range(0, 1):
#     for hexike in range(0, n_zernikes):
#         print("Seg num:", seg, "hexike_num:",hexike, "Coeffs in nm: ", segs_tables[seg][5][hexike])
#         tel2.sm.actuators[hexike + seg * n_zernikes] = segs_tables[seg][5][hexike] * 1e-9


# rms_aber = np.sqrt(np.mean((tel2.sm.actuators)**2))*1e9 # in nm 

# tel2.sm.actuators[1] = 1/2 * 1000 *1e-9
# tel2.sm.actuators[12] = 1/2*1000 *1e-9
# aberrated_coro_psf2, efields_inter2 = tel2.calc_psf(display_intermediate=False,return_intermediate='efield', norm_one_photon=True)

# plt.figure()
# plt.subplot(1,3,1)
# hcipy.imshow_field(np.angle(np.exp(1j * (tel2.sm.phase_for(500e-9)))), cmap='jet')
# plt.colorbar()
# plt.subplot(1,3,2)
# hcipy.imshow_field((tel2.sm.phase_for(500e-9)), cmap='jet')
# plt.colorbar()

# plt.subplot(1,3,3)
# plt.title("surface")
# hcipy.imshow_field((tel2.sm.surface), cmap='jet')
# plt.colorbar()
# plt.show()

# plt.figure()
# hcipy.imshow_field((efields_inter2['seg_mirror']).phase, mask = tel2.aperture, cmap='jet', origin='lower')
# plt.title(f'Segmented mirror phase')
# plt.colorbar()

# plt.savefig(os.path.join(analysis_path, 'usort_Rz_dof.png'))

In [None]:
sm_wavefront_phase = efields_inter2['seg_mirror'].phase

In [None]:
 efields_inter2['seg_mirror'].wavelength

In [None]:
 segs_tables[1][0][0] 

#### Plot the hexmaps from ball.
 

In [None]:
def cart2pol(x, y):
    rho = np.sqrt(x**2 + y**2)
    phi = np.arctan2(y, x)
    return(rho, phi)

X = np.linspace(-10, 10, 1000)
Y = np.linspace(-10, 10, 1000)

th, R = cart2pol(X, -Y)

In [None]:
z1 = np.ones(X.shape) 
z2 = np.sqrt(24/5)*X
z3 = np.sqrt(24/5)*Y
z4 = np.sqrt(720/43)*(R**2 - 5/12)
z5 = np.sqrt(60/7)*(R**2)*np.cos(2*th)
z6 = np.sqrt(60/7)*(R**2)*np.sin(2*th)
z7 = np.sqrt(84000/737)*(R**2 - 14/25)*R*np.cos(th)
z8 = np.sqrt(84000/737)*(R**2 - 14/25)*R*np.sin(th)
z9 = np.sqrt(1517040/4987)*(R**4 - 257/301*(R**2) + 737/6020)
z10 = np.sqrt(1120/103)*(R**3)*(np.cos(3*th))
z11 = np.sqrt(160/9)*(R**3)*(np.cos(3*th))

plt.figure()
plt.subplot(3, 4, 1)
plt.plot(X, z1)

plt.subplot(3, 4, 2)
plt.plot(X, z2)

plt.subplot(3, 4, 3)
plt.plot(X, z3)

plt.subplot(3, 4, 4)
plt.plot(X, z4)

plt.subplot(3, 4, 5)
plt.plot(X, z5)

plt.subplot(3, 4, 6)
plt.plot(X, z6)

plt.subplot(3, 4, 7)
plt.plot(X, z7)

plt.subplot(3, 4, 8)
plt.plot(X, z8)

plt.subplot(3, 4, 9)
plt.plot(X, z9)

plt.subplot(3, 4, 10)
plt.plot(X, z10)

plt.subplot(3, 4, 11)
plt.plot(X, z11)

In [None]:
grid = hcipy.make_pupil_grid(1000)
R, th = grid.as_('polar').coords
X, Y = grid.coords

z1 = np.ones(X.shape) 
z2 = np.sqrt(24/5)*X
z3 = np.sqrt(24/5)*Y
z4 = np.sqrt(720/43)*(R**2 - 5/12)
z5 = np.sqrt(60/7)*(R**2)*np.cos(2*th)
z6 = np.sqrt(60/7)*(R**2)*np.sin(2*th)
z7 = np.sqrt(84000/737)*(R**2 - 14/25)*R*np.cos(th)
z8 = np.sqrt(84000/737)*(R**2 - 14/25)*R*np.sin(th)
z9 = np.sqrt(1517040/4987)*(R**4 - 257/301*(R**2) + 737/6020)
z10 = np.sqrt(1120/103)*(R**3)*(np.cos(3*th))
z11 = np.sqrt(160/9)*(R**3)*(np.cos(3*th))

hcipy.imshow_field(z1, grid)