In [None]:
import matplotlib
matplotlib.use('TKAgg')
%load_ext autoreload
%autoreload 2
import os 
import datetime
import glob 
import numpy as np
import matplotlib.pyplot as plt
from pianoq.simulations.mplc_sim.mplc_sim import MPLCSim
from pianoq.simulations.mplc_sim.mplc_sim_result import MPLCMasks
from pianoq.lab.mplc.singles_scan import signal_scan, idler_scan, get_signal_scanner, get_idler_scanner
from pianoq.lab.mplc.phase_finder_result import PhaseFinderResult
from pianoq.lab.mplc.mask_utils import remove_input_modes, add_phase_input_spots, get_imaging_masks
from pianoq_results.scan_result import ScanResult
from pianoq.simulations.mplc_sim.mplc_sim_result import MPLCSimResult
from pianoq.lab.mplc.mplc_device import MPLCDevice
from pianoq.lab.photon_scan import PhotonScanner
from pianoq.misc.misc import run_in_thread, run_in_thread_simple
from pianoq.misc.mplt import mimshow, mplot
from pianoq.simulations.mplc_sim.create_wfm_masks import create_WFM_diffuser_masks
from pianoq.lab.mplc.discrete_photon_scanner import DiscretePhotonScanner
import time 
from pianoq.lab.mplc.phase_finder_result import PhaseFinderResult
from pianoq.lab.mplc.find_discreet_phases import PhaseFinder

dir_path = r'G:\My Drive\Projects\MPLC\results\lab\2024_09_24_zernike'
if not os.path.exists(dir_path):
    os.mkdir(dir_path)

modes_to_keep = np.array([3, 8, 13, 18, 23, 28, 33, 38, 43, 48])
U_no = 1
phases_no = 2 

In [None]:
# Motors
backlash = 0
wait_after_move = 0.3
from pianoq.lab.mplc.consts import thorlabs_x_serial, thorlabs_y_serial
from pianoq.lab.thorlabs_motor import ThorlabsKcubeDC, ThorlabsKcubeStepper
from pianoq.lab.zaber_motor import ZaberMotors
zaber_ms = ZaberMotors(backlash=backlash, wait_after_move=wait_after_move)
mxs = zaber_ms.motors[1]
mys = zaber_ms.motors[0]
print('Got Zaber motors')

mxi = ThorlabsKcubeDC(thorlabs_x_serial, backlash=backlash, wait_after_move=wait_after_move)
myi = ThorlabsKcubeStepper(thorlabs_y_serial, backlash=backlash, wait_after_move=wait_after_move)
print('Got Thorlabs motors')

# MPLC
mplc = MPLCDevice()
mplc.restore_location()
print('Got MPLC')

# Timetagger
from pianoq.lab.time_tagger import QPTimeTagger
from pianoq.lab.mplc.consts import TIMETAGGER_DELAYS, TIMETAGGER_COIN_WINDOW
tt = QPTimeTagger(integration_time=1, remote=True,
                  single_channel_delays=TIMETAGGER_DELAYS, coin_window=TIMETAGGER_COIN_WINDOW)
print('Got Time tagger')

# Optimization
## prepare

In [None]:
# masks 
masks_path = glob.glob(rf'{dir_path}\U{U_no}U*.masks')[0]
msks = MPLCMasks()
msks.loadfrom(masks_path)
masks = msks.real_masks
masks = remove_input_modes(masks, modes_to_keep=modes_to_keep)
mplc.load_masks(masks, linear_tilts=True)

# locs 
locs_idl_path = glob.glob(fr'{dir_path}\*idl*filter_80nm*.locs')[0]
locs_sig_path = glob.glob(fr'{dir_path}\*sig*filter_80nm*.locs')[0]
locs_idl = np.load(locs_idl_path)['locs']
locs_sig = np.load(locs_sig_path)['locs']

# i,j of corr matrix that is supposed to be strong 
i = 2
j = 2
mxi.move_absolute(locs_idl[i, 0])
myi.move_absolute(locs_idl[i, 1])
mxs.move_absolute(locs_sig[j, 0])
mys.move_absolute(locs_sig[j, 1])        
mplc.restore_location()
time.sleep(1)

## optimize

In [None]:
from aotools.functions import phaseFromZernikes
Z_MASK_SIZE = 140
z_slice = np.index_exp[180-140//2: 180+140//2, :]

# find phases
mplc.load_masks(masks, linear_tilts=True) 
orig_masks = mplc.masks.copy()

iters = 3
N_zernike = 14
magnitude_range = np.linspace(-3, 3, 20)
timestamp = datetime.datetime.now().strftime('%Y_%m_%d_%H_%M_%S')
integration_time = 10
tt.set_integration_time(integration_time)

all_costs = []
v_zernike = np.zeros(N_zernike)
best_z_vals = np.zeros(N_zernike)

for iter in range(iters):
    print(f'{iter=}')
    for z_coef in range(N_zernike):
        print(f'{z_coef=}')
        now_costs = np.zeros(len(magnitude_range))
        for i, magnitude in enumerate(magnitude_range):
            print(f'{magnitude=}')
            vals = best_z_vals.copy()
            vals[z_coef] = magnitude
            z_mask = phaseFromZernikes(vals, Z_MASK_SIZE)
            assert z_mask.shape == (140, 140)
            new_masks = orig_masks.copy()
            new_masks[0][z_slice] = z_mask
            mplc.load_masks(new_masks, linear_tilts=True)
            s1, s2, c = tt.read_interesting()
            cost = c - 2*tt.coin_window*s1*s2
            now_costs[i] = cost
        all_costs.append(now_costs)
        best_ind = np.argmax(now_costs)
        best_z_vals[z_coef] = magnitude_range[best_ind]