In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Circle
import astropy.units as u
from astropy.io import fits
from pathlib import Path
from IPython.display import clear_output, display, HTML
display(HTML("<style>.container { width:90% !important; }</style>"))
from importlib import reload
import copy
import os

import poppy

import logging, sys
poppy_log = logging.getLogger('poppy')
poppy_log.setLevel('DEBUG')
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
poppy_log.disabled = True

import scoobpsf
from scoobpsf.math_module import xp, _scipy
from scoobpsf.imshows import *
from scoobpsf import scoobm

import lina

pupil_diam = 6.75*u.mm
wavelength_c = 632.8e-9*u.m

lyot_diam = 3.6*u.mm

# Flat wavefront as determined by Kyle doing phase diversity measurements.
dm_flat = fits.getdata(scoobm.module_path/'scoob_dm_flat.fits')
dm_flat0=copy.copy(dm_flat)
# Known bad actuator is 26,21 - useful to keep this value 
# as a variable since "dm_flat" evolves throughout the notebook
bad_act_value=dm_flat[26,21]

vortex = scoobpsf.agpm.IdealAGPM(name='VVC', wavelength=wavelength_c, charge=6, rotation=20)
lyot_stop = poppy.CircularAperture(name='Lyot Stop', radius=lyot_diam/2, gray_pixel=False)

In [None]:
imshow1(dm_flat)

In [None]:
!nvidia-smi

In [None]:
reload(scoobm)
bad_acts=0
bad_acts=[(26,21)] # 1 dead actuator
# bad_acts=[(26,21), (18,12)] # 2 dead
model = scoobm.SCOOBM(bad_acts, wavelength=wavelength_c-(0*10e-9*u.m), use_opds=True)

In [None]:
# Calculate the base metrics of the flat wavefront
med_val = np.median(dm_flat[model.dm_mask])
std_val = np.std(dm_flat[model.dm_mask])
ptv=np.max(dm_flat[model.dm_mask])-np.min(dm_flat[model.dm_mask])
print(f'Median is: {med_val*1e9:0.1f} nm')
print(f'stddev is: {std_val*1e9:0.1f} nm')
print(f'PtV is: {ptv*1e9:0.1f} nm')
print(f'Bad actuator is {(np.max(dm_flat[model.dm_mask])-med_val)*1e9:0.1f} nm above the median value')

In [None]:
# Pin the actuators to desired value(s)
pin=True
if pin:
    # pinned_value=med_val+(np.sqrt(49.)*(bad_act_value-med_val)/16)
    pinned_value=bad_act_value
    if bad_acts:
        print(f'Pinning {len(bad_acts)} bad actuators at {pinned_value*1e9:0.3f} nm')
        for act in bad_acts:
            # print(act)
            dm_flat[act]=pinned_value

In [None]:
# Re-calculate the base metrics of the flat wavefront with the
# new actuator settings
med_val = np.median(dm_flat[model.dm_mask])
std_val = np.std(dm_flat[model.dm_mask])
ptv=np.max(dm_flat[model.dm_mask])-np.min(dm_flat[model.dm_mask])
print(f'Median is: {med_val*1e9:0.1f} nm')
print(f'stddev is: {std_val*1e9:0.1f} nm')
print(f'PtV is: {ptv*1e9:0.1f} nm')

In [None]:
# Standard flat
model.npix = round(512/4)
model.oversample = round(16/4)

model.det_rotation = 0
model.set_dm(dm_flat)

model.imnorm = model.snap().max()

normalized_im = model.snap()

In [None]:
# Now add the coronograph
model.FPM = vortex
model.LYOT = lyot_stop

occ_im = model.snap()

imshow2(normalized_im, occ_im, lognorm1=True, lognorm2=True, pxscl=model.psf_pixelscale_lamD)

Nact = model.Nact
npsf = model.npsf # 400, which is the side length in pixels of the camera
psf_pixelscale_lamD = model.psf_pixelscale_lamD

In [None]:
reload(lina.utils)
npsf = model.npsf
nact = model.Nact

xfp = (xp.linspace(-npsf/2, npsf/2-1, npsf) + 1/2)*model.psf_pixelscale_lamD
fpx,fpy = xp.meshgrid(xfp,xfp)

edge = 2
iwa = 3
owa = 10
rot = 0

# Create the mask that is used to select which region to make dark.
dark_params = {
    'inner_radius' : iwa,
    'outer_radius' : owa,
    'edge' : edge,
    'rotation':rot,
}
dark_mask = lina.utils.create_annular_focal_plane_mask(fpx, fpy, dark_params)
imshow2(dark_mask, dark_mask*occ_im, lognorm2=True)

In [None]:
# Calculate the base metrics for the DH
print(f'Total counts in DH: {np.sum(dark_mask*occ_im):0.3e}')
print(f'Mean value in DH: {np.mean(dark_mask*occ_im):0.3e}')
print(f'Contrast in DH: {np.std(dark_mask*occ_im):0.3e}')

In [None]:
reload(lina.efc)

model.set_dm(dm_flat)
epsilon = 1e-9 # poke amplitudes (DM surface), presumably in meters

In [None]:
if bool(bad_acts):
    filename=f'jac-efc-npix{model.npix}-oversample{model.oversample}-badacts{len(bad_acts)}.fits'
else:
    filename=f'jac-efc-npix{model.npix}-oversample{model.oversample}.fits'
print(f'{filename=}')

In [None]:
if os.path.isfile(filename) is False:
    jac = lina.efc.build_jacobian(model, epsilon, dark_mask, plot=False)
    reload(scoobpsf.utils)
    scoobpsf.utils.save_fits(filename, scoobpsf.math_module.ensure_np_array(jac))
else:
    print('Jacobian exists, skipping')

In [None]:
jac = xp.array(fits.getdata(filename))

In [None]:
reload(lina.utils)
response = lina.utils.map_acts_to_dm(xp.sqrt(((jac)**2).sum(axis=0)).get(), model.dm_mask)
imshow1(response, lognorm=True, vmin=1e3)

In [None]:
reload(lina.efc)
reload(lina.utils)
reload(lina.math_module)
model.set_dm(dm_flat)

In [None]:
%time
# declare penalty matrix value, and number of iterations for each value (10)
# -4 --> -1 represent the beta value for regularization

# Sidick starts at -4, then goes down to -1
reg_conds = [(-4, 10), (-3,10), (-2,10), (-1, 10)]

for i in range(len(reg_conds)):
    print(f'{i=}')
    # Derive the control matrix, which is the gain matrix from Sidick 2012
    # details are described in utils.beta_reg.
    # matrix is then flattened, therefore not in 2d and matched to dm coords.
    control_matrix = lina.utils.beta_reg(jac, reg_conds[i][0])
    
    # Assume a system with perfect knowledge of the E-field
    ims, commands, sms_fig = lina.efc.run_efc_perfect(model, 
                                            jac, 
                                            control_matrix,
                                            dark_mask, 
                                            Imax_unocc=1,
                                            efc_loop_gain=0.5, 
                                            iterations=reg_conds[i][1], 
                                            plot_all=True, 
                                            plot_sms=False,
                                            plot_radial_contrast=True)

In [None]:
# show final DM shape, with the flat removed
# this shows how the DM moved as a result of the EFC runs
tmp=(model.get_dm()-dm_flat)*model.dm_mask
imshow1(tmp, 'DM shape - flat')

# Re-calculate metrics for DH
med_val = np.median(tmp[model.dm_mask])
std_val = np.std(tmp[model.dm_mask])
ptv=np.max(tmp[model.dm_mask])-np.min(tmp[model.dm_mask])
print(f'Median is: {med_val*1e9:0.1f} nm')
print(f'stddev is: {std_val*1e9:0.1f} nm')
print(f'PtV is: {ptv*1e9:0.1f} nm')