In [None]:
import numpy as np
import os
import matplotlib.pyplot as plt

from astropy.io import fits
from matplotlib.colors import LogNorm, TwoSlopeNorm
import numpy as np
import scipy.io
import hcipy

import pastis.util as util
from pastis.config import CONFIG_PASTIS
from pastis.simulators.luvoir_imaging import LuvoirA_APLC
from pastis.simulators.scda_telescopes import HexRingAPLC

os.chdir('../ULTRA')
from config import CONFIG_ULTRA

In [None]:
data_path = CONFIG_ULTRA.get('local_path', 'local_data_path')
analysis_path = CONFIG_ULTRA.get('local_path', 'local_analysis_path')

input_misalignments = np.genfromtxt(os.path.join(data_path,'INPUT_MISALIGNMENTS.txt'),  delimiter=';')
output_misalignments = np.genfromtxt(os.path.join(data_path,'RESIDUAL_MISALIGNMENTS.txt'),  delimiter=';')

In [None]:
print("input_misalignment shape:", input_misalignments.shape)
print("output_misalignment shape:", output_misalignments.shape)

In [None]:
tel2_actuators = 0.5 * fits.getdata(os.path.join(analysis_path, 'optical_sensitvity.fits'))
# multiplying 0.5 to compensate a factor of ~2 for Hexike Coefficients, See email for the RMS.

### Instantiate the 2Hex Simulator

In [None]:
optics_dir = os.path.join(util.find_repo_location(), 'data', 'SCDA')
NUM_RINGS = 2
sampling = 4

tel2 = HexRingAPLC(optics_dir, NUM_RINGS, sampling)

# Set the total number of hexike polynomials you want over the hexagonal segment.
n_zernikes = 11
tel2.create_segmented_mirror(n_zernikes)

unaberrated_psf, ref, intermediates = tel2.calc_psf(ref=True, display_intermediate=True, 
                                  return_intermediate='intensity',
                                  norm_one_photon=True)

norm = np.max(ref)
normalized_unaberrated_psf = unaberrated_psf / norm

unaberr_roi = normalized_unaberrated_psf * tel2.dh_mask
contrast_floor = np.mean(unaberr_roi[np.where(tel2.dh_mask != 0)])


print("contrast_floor:", contrast_floor)
print("Wavelength:", tel2.wvln)

#### Load HWO sensitivities .mat file

#### Map HWO segment number to the SCDA segment number:

In [None]:
hwo_to_scda = {"1": 4, "2": 5, "3": 6, "4":7, "5":2, "6":3, "7":12, "8":13, "9":14, "10":15,
               "11": 16, "12":17, "13":18, "14":19, "15":8, "16":9, "17":10, "18":11, "19":1}

hwo_scda_segments = np.array([4, 5, 6, 7, 2, 3, 12, 13, 14, 15, 16, 17, 18, 19, 8, 9, 10, 11, 1])

#### Sort coefficient table for each segment.  

In [None]:
segs_tables = []
for seg in range(0, tel2.nseg):
    seg_table = hwo_hexike_coeffs[:, seg, :].T
    segs_tables.append(seg_table)

### Sort Zernike coefficients, and initialize the actuators

In [None]:
titles = np.array(["X nm RMS/um", "Y nm RMS/um", "Z nm RMS/um", 
                  "Rx nm RMS/urad", "Ry nm RMS/urad", "Rz nm RMS/urad" ])
        
plt.figure(figsize = (14, 9))
for dof in range(0, 6):
    if dof!=2:
        tel2.sm.flatten()
        tel2.sm.actuators = tel2_actuators[dof]
        rms_scda = np.sqrt(np.mean((tel2.sm.opd[np.where(tel2.aperture!= 0)])**2)) * 1e9
        plt.subplot(2, 3, dof+1)
        plt.title(titles[dof], fontweight = 'bold')
        hcipy.imshow_field(tel2.sm.opd*1e9, cmap='jet')
        plt.xlabel(f'RMS: {rms_scda:.2f} nm')
        cbar = plt.colorbar()
        cbar.set_label("in nm", loc='center')

plt.subplot(2, 3, 3)
tel2.sm.flatten()
tel2.sm.actuators = tel2_actuators[2]
rms_scda = np.sqrt(np.mean((tel2.sm.opd[np.where(tel2.aperture!= 0)])**2)) * 1e9
plt.title("Z nm RMS/um", fontweight = 'bold')
hcipy.imshow_field(tel2.sm.opd*1e9, cmap='jet', vmin = -2000,  vmax= -1960)
plt.xlabel(f"RMS: {rms_scda:.2f} nm")
cbar = plt.colorbar()
cbar.set_label("in nm", loc='center')
#plt.savefig(os.path.join(analysis_path, 'hwo_PTT_sorted.png'))

### Note, the plot above should match with OPDs from Garrett for the case with PTT.  If not matching, there might be some error in sorting the segment number or Zernike coefficients. 

In [None]:
tel2.sm.flatten()
dx_times = []
for time in range(0, 401):
    dx = output_misalignments[:, time*6] * 1e12 # in units of pm. 
    dx = np.delete(dx, -1)
    dx_times.append(dx)

dy_times = []
for time in range(0, 401):
    dy = output_misalignments[:, time*6 + 1] * 1e12 # in units of pm. 
    dy = np.delete(dy, -1)
    dy_times.append(dy)
    
dz_times = []
for time in range(0, 401):
    dz = output_misalignments[:, time*6 + 2] * 1e12 # in units of pm. 
    dz = np.delete(dz, -1)
    dz_times.append(dz) 
    
rdx_times = []
for time in range(0, 401):
    rdx = output_misalignments[:, time*6 + 3] * 1e12 # in units of pm. 
    rdx = np.delete(rdx, -1)
    rdx_times.append(rdx)
    
    
rdy_times = []
for time in range(0, 401):
    rdy = output_misalignments[:, time*6 + 4] * 1e12 # in units of pm. 
    rdy = np.delete(rdy, -1)
    rdy_times.append(rdy)
    
rdz_times = []
for time in range(0, 401):
    rdz = output_misalignments[:, time*6 + 5] * 1e12 # in units of pm. 
    rdz = np.delete(rdz, -1)
    rdz_times.append(rdz)
    

contrasts_dof = []
for time in range(0, 401):
    tel2.sm.flatten()
    for hwo_seg in range(0, tel2.nseg):
        scda_seg = hwo_to_scda[str(hwo_seg + 1)] - 1
        for hexike in range(0, n_zernikes):
            dox = tel2_actuators[0][hexike + scda_seg * n_zernikes] * dx_times[time][hwo_seg] * 1e-6
            doy = tel2_actuators[1][hexike + scda_seg * n_zernikes] * dy_times[time][hwo_seg] * 1e-6
            doz = tel2_actuators[2][hexike + scda_seg * n_zernikes] * dz_times[time][hwo_seg] * 1e-6
            rdx = tel2_actuators[3][hexike + scda_seg * n_zernikes] * rdx_times[time][hwo_seg] * 1e-6
            rdy = tel2_actuators[4][hexike + scda_seg * n_zernikes] * rdy_times[time][hwo_seg] * 1e-6
            rdz = tel2_actuators[5][hexike + scda_seg * n_zernikes] * rdz_times[time][hwo_seg] * 1e-6
            tel2.sm.actuators[hexike + scda_seg * n_zernikes] =  dox + doy + doz + rdx + rdy + rdz
    

    aberrated_psf, ref, intermediates = tel2.calc_psf(ref=True, display_intermediate=False, return_intermediate='intensity',norm_one_photon=True)
    normalized_aberrated_psf = aberrated_psf / norm
    aberr_roi = normalized_aberrated_psf * tel2.dh_mask
    aber_contrast_floor = np.mean(aberr_roi[np.where(tel2.dh_mask != 0)])
    contrasts_dof.append(aber_contrast_floor)
    print("contrast_floor:", aber_contrast_floor)

In [None]:
#np.savetxt(os.path.join(analysis_path, 'contrast_dof_openloop1.csv'), contrasts_dof, delimiter=',')

contrasts_dof2 = np.genfromtxt(os.path.join(analysis_path, 'contrast_dof_closeloop.csv'), delimiter=',')

plt.figure(figsize=(20, 10))
times = np.arange(0, 401, 1)
plt.plot(times, contrasts_dof - contrasts_dof2, marker='s', markersize=8, linewidth=2, label='total')
plt.axhline(y=4.171337358217274e-11, color='r', linestyle='--', label='contrast_floor')
#plt.ylim(4.1e-11, 5.25e-11)
plt.ylabel("Mean DH Contrast", fontsize=25)
plt.xlabel("Time (in s)", fontsize=25)
plt.legend(fontsize=20)
plt.xticks(fontsize=20)
plt.yticks(fontsize=20)
plt.tight_layout()
#plt.savefig(os.path.join(analysis_path, 'open_close_loop.png'))