# Starting to set up LUVOIR B simulator

In [None]:
import os

from astropy.io import fits
import hcipy
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
from scipy.ndimage import label

from pastis.config import CONFIG_PASTIS
import pastis.util

## Loading data 

In [None]:
datadir = os.path.join(pastis.util.find_repo_location(), CONFIG_PASTIS.get('LUVOIR-B', 'optics_path_in_repo'))

aperture_data = fits.getdata(os.path.join(datadir, 'Pupil1.fits'))
apod_stop_data = fits.getdata(os.path.join(datadir, 'APOD.fits'))
dm2_stop_data = fits.getdata(os.path.join(datadir, 'DM2stop.fits'))
lyot_stop_data = fits.getdata(os.path.join(datadir, 'LS.fits'))
dm1_data = fits.getdata(os.path.join(datadir, 'surfDM1.fits'))
dm2_data = fits.getdata(os.path.join(datadir, 'surfDM2.fits'))

print(f'aperture_data.shape: {aperture_data.shape}')
print(f'apod_stop_data.shape: {apod_stop_data.shape}')
print(f'dm2_stop_data.shape: {dm2_stop_data.shape}')
print(f'lyot_stop_data.shape: {lyot_stop_data.shape}')
print(f'dm1_data.shape: {dm1_data.shape}')
print(f'dm2_data.shape: {dm2_data.shape}')

In [None]:
plt.figure(figsize=(18, 12))
plt.subplot(2, 3, 1)
plt.imshow(aperture_data)
plt.title('Primary aperture')
plt.subplot(2, 3, 2)
plt.imshow(apod_stop_data)
plt.title('Apodizer stop')
plt.subplot(2, 3, 3)
plt.imshow(dm2_stop_data)
plt.title('DM2 stop')
plt.subplot(2, 3, 4)
plt.imshow(lyot_stop_data)
plt.title('Lyot stop')
plt.subplot(2, 3, 5)
plt.imshow(dm1_data)
plt.title('DM1')
plt.subplot(2, 3, 6)
plt.imshow(dm2_data)
plt.title('DM2')

## Parameters

In [None]:
nPup = CONFIG_PASTIS.getfloat('LUVOIR-B', 'pupil_pixels')
D_pup = CONFIG_PASTIS.getfloat('LUVOIR-B', 'D_pup')
samp_foc = CONFIG_PASTIS.getfloat('LUVOIR-B', 'sampling')
rad_foc = CONFIG_PASTIS.getfloat('LUVOIR-B', 'imlamD')
wavelength = CONFIG_PASTIS.getfloat('LUVOIR-B', 'lambda') * 1e-9  # m

print(f'nPup: {nPup}')
print(f'D_pup: {D_pup}')
print(f'samp_foc: {samp_foc}')
print(f'rad_foc: {rad_foc}')
print(f'rad_foc: {rad_foc}')
print(f'wavelength: {wavelength}')

nPup_arrays = apod_stop_data.shape[0]
nPup_dms = dm1_data.shape[0]
nPup_dm_stop = dm2_stop_data.shape[0]
zDM = (D_pup/2)**2 / (wavelength * 549.1429)

print(f'nPup_arrays: {nPup_arrays}')
print(f'nPup_dms: {nPup_dms}')
print(f'zDM: {zDM}')

## Dealing with stupidity of arrays being padded

The two DMs have the largest arrays so we pad all other pupil plane optics to match that.

In [None]:
# Pad apodizer to DM array size
apod_stop_data_pad = np.pad(apod_stop_data, int((nPup_dms - nPup_arrays) / 2), mode='constant')
# Pad DM2 stop to DM array size
DM2Stop_data_pad = np.pad(dm2_stop_data, int((nPup_dms - nPup_dm_stop)/2), mode='constant')
# Pad LS to DM array size
lyot_stop_data_pad = np.pad(lyot_stop_data, int((nPup_dms - nPup_arrays)/2), mode='constant')
# Pad primary aperture to DM array size
aperture_data_pad = np.pad(aperture_data, int((nPup_dms - nPup_arrays)/2), mode='constant')

# Create pupil grids
#pupil_grid_real = hcipy.make_pupil_grid(nPup, D_pup)
# Create pupil grid with everything scaled to apodizer and LS array
pupil_grid_arrays = hcipy.make_pupil_grid(nPup * (nPup_arrays/nPup), D_pup*(nPup_arrays/nPup))   # 1024 px, 0.049152 m
# Create pupil grid with everything scaled to DM array
pupil_grid_dms = hcipy.make_pupil_grid(nPup * (nPup_dms/nPup), D_pup*(nPup_dms/nPup))    # 1168 px, 0.05606 m

# Create all optical components on DM pupil grids
apod_stop = hcipy.Field(np.reshape(apod_stop_data_pad, nPup_dms**2), pupil_grid_dms)
DM2_circle = hcipy.Field(np.reshape(DM2Stop_data_pad, nPup_dms**2), pupil_grid_dms)
lyot_mask = hcipy.Field(np.reshape(lyot_stop_data_pad, nPup_dms**2), pupil_grid_dms)
aperture = hcipy.Field(np.reshape(aperture_data_pad, nPup_dms**2), pupil_grid_dms)
DM1 = hcipy.Field(np.reshape(dm1_data, nPup_dms**2), pupil_grid_dms)
DM2 = hcipy.Field(np.reshape(dm2_data, nPup_dms**2), pupil_grid_dms)

In [None]:
D_pup*(nPup_dms/nPup)

In [None]:
plt.figure(figsize=(18, 12))
plt.subplot(2, 3, 1)
hcipy.imshow_field(apod_stop)
plt.title('Apodizer stop')
plt.subplot(2, 3, 2)
hcipy.imshow_field(DM2_circle)
plt.title('DM2_circle')
plt.subplot(2, 3, 3)
hcipy.imshow_field(aperture)
plt.title('aperture')
plt.subplot(2, 3, 4)
hcipy.imshow_field(lyot_mask)
plt.title('Lyot stop')
plt.subplot(2, 3, 5)
hcipy.imshow_field(DM1)
plt.title('DM1')
plt.subplot(2, 3, 6)
hcipy.imshow_field(DM2)
plt.title('DM2')

In [None]:
focal_grid = hcipy.make_focal_grid(
    samp_foc,
    rad_foc,
    pupil_diameter=D_pup,
    focal_length=1.,
    reference_wavelength=wavelength)

In [None]:
plt.figure(figsize=(10, 10))
hcipy.imshow_field(aperture)
plt.colorbar()

## Propagators

In [None]:
prop = hcipy.FraunhoferPropagator(pupil_grid_dms, focal_grid)
fres_dms = hcipy.propagation.FresnelPropagator(pupil_grid_dms, zDM, num_oversampling=1)
fres_dms_minus = hcipy.propagation.FresnelPropagator(pupil_grid_dms, -zDM, num_oversampling=1)

charge = 6
coro = hcipy.VortexCoronagraph(pupil_grid_dms, charge, scaling_factor=4)

lyot_stop = hcipy.Apodizer(lyot_mask)

## Actual code for propagation

In [None]:
wf = hcipy.Wavefront(aperture * np.exp(4*1j*np.pi/wavelength*DM1), wavelength)
wf2 = fres_dms(wf)
wf3 = hcipy.Wavefront(wf2.electric_field*np.exp(4*1j*np.pi/wavelength*DM2)*DM2_circle, wavelength)
wf4 = fres_dms_minus(wf3)
wf5 = hcipy.Wavefront(wf4.electric_field * apod_stop, wavelength)

lyot_plane = coro.forward(wf5)
post_lyot_mask = lyot_stop(lyot_plane)
#post_lyot_mask.wavelength = wavelength

img = prop(post_lyot_mask).intensity
img_ref = prop(wf4).intensity

In [None]:
post_lyot_mask.wavelength

## Displaying some planes

In [None]:
plt.figure(figsize=(18, 12))

plt.subplot(2, 3, 1)
hcipy.imshow_field(wf4.amplitude)
plt.title('wf4.amplitude')
plt.colorbar()

plt.subplot(2, 3, 2)
hcipy.imshow_field(wf4.phase)
plt.title('wf4.phase')
plt.colorbar()

plt.subplot(2, 3, 3)
hcipy.imshow_field(lyot_plane.intensity)
plt.title('lyot_plane.intensity')
plt.colorbar()

plt.subplot(2, 3, 4)
hcipy.imshow_field(post_lyot_mask.intensity)
plt.title('post_lyot_mask.intensity')
plt.colorbar()

plt.subplot(2, 3, 5)
hcipy.imshow_field(np.log10(img_ref / img_ref.max()), vmin=-5, cmap='inferno')
plt.title('direct image')
plt.colorbar()

plt.subplot(2, 3, 6)
hcipy.imshow_field(np.log10(img / img_ref.max()), cmap='inferno', vmin=-10, vmax=-5)
plt.title('coron image')
plt.colorbar()
plt.show()

In [None]:
plt.figure(figsize=(10,10))
hcipy.imshow_field(np.log10(img / img_ref.max()), cmap='inferno', vmin=-12, vmax=-5)
plt.title('coron image')
plt.colorbar()

## Getting the segment locations from primary aperture

In [None]:
segments, num_segments = label(aperture_data > 0.154)   # Fudged number based on what "looks right"
centroids = []
for i in range(num_segments):
    seg = (segments == (i + 1)).ravel()
    c_x = np.sum(seg * pupil_grid_arrays.x) / np.sum(seg)
    c_y = np.sum(seg * pupil_grid_arrays.y) / np.sum(seg)
    centroids.append(np.array([c_x, c_y]))
segment_circum_diameter = D_pup * (nPup_arrays/962) / 8 * 1.024   # Fudged number based on what "looks right"
seg_pos_pre = np.transpose(np.array(centroids))
seg_pos = hcipy.CartesianGrid(hcipy.UnstructuredCoords(seg_pos_pre))

## Making segmented primary

In [None]:
def make_segment_zernike_primary(Nzernike):
    """Generate a zernike basis, up to Nzernike, for each segment.

            Parameters:
            ----------
            Nzernike : int
                Maximum order of each zernike on each segment

            --------
            self.sm: DeformableMirror
                Segmented mirror (primary) as a DM object
            """

    segment = hcipy.hexagonal_aperture(segment_circum_diameter, np.pi / 2)
    segment_sampled = hcipy.evaluate_supersampled(segment,pupil_grid_arrays, 1)
    aper2, segs2 = hcipy.make_segmented_aperture(segment,seg_pos, segment_transmissions=1, return_segments=True)
    luvoir_segmented_pattern = hcipy.evaluate_supersampled(aper2,pupil_grid_arrays, 1)
    seg_evaluated = []
    for seg_tmp in segs2:
        tmp_evaluated = hcipy.evaluate_supersampled(seg_tmp,pupil_grid_arrays, 1)
        seg_evaluated.append(tmp_evaluated)


    seg_num = 0
    mode_basis_local_zernike = hcipy.mode_basis.make_zernike_basis(Nzernike, segment_circum_diameter,pupil_grid_arrays.shifted(-seg_pos[seg_num]),
                                                                starting_mode=1,
                                                                ansi=False, radial_cutoff=True, use_cache=True)
    for qq in range(0, Nzernike):
        mode_basis_local_zernike._transformation_matrix[:, qq] = seg_evaluated[seg_num]*mode_basis_local_zernike._transformation_matrix[:, qq]
    for seg_num in range(1, 55):
        # print(seg_num)
        mode_basis_local_zernike_tmp = hcipy.mode_basis.make_zernike_basis(Nzernike,segment_circum_diameter,pupil_grid_arrays.shifted(-seg_pos[seg_num]),
                                                                        starting_mode=1,
                                                                        ansi=False, radial_cutoff=True,
                                                                        use_cache=True)
        for qq in range(0, Nzernike):
            mode_basis_local_zernike_tmp._transformation_matrix[:, qq] = seg_evaluated[seg_num] * mode_basis_local_zernike_tmp._transformation_matrix[:, qq]
        mode_basis_local_zernike.extend(mode_basis_local_zernike_tmp)

    sm = hcipy.optics.DeformableMirror(mode_basis_local_zernike)
    return sm

## Quick sanity check of segmented primary

In [None]:
sm_test = make_segment_zernike_primary(1)
aperture_small = hcipy.Field(np.reshape(aperture_data, nPup_arrays**2), pupil_grid_arrays)
input_wf = hcipy.Wavefront(aperture_small, wavelength)
sm_test.actuators = np.ones(55) * wavelength
tmp_pupil = sm_test(input_wf)
plt.figure(figsize=(10, 10))
hcipy.imshow_field(sm_test.opd)
plt.colorbar()
plt.figure(figsize=(10, 10))
plt.imshow(aperture_data)
plt.colorbar()

In [None]:
plt.figure(figsize=(12,12))
hcipy.imshow_field(wf4.amplitude)
plt.title('wf4.amplitude')