In [1]:
import matplotlib.pyplot as plt
import numpy as np
# from functions import complex_correlation, colorize, show_colormap_image
import os
import json

# for Jupyter lab
%matplotlib widget
# for Jupyter notebook
# %matplotlib notebook

data_folder = '../Data/'

In [2]:
import torch
from torch.autograd import Variable

from functions import getZernikeCoefs, showZernikeCoefs
from PyTorchAberrations.aberration_functions import pt_to_cpx, cpx_to_pt, tm_to_pt
from PyTorchAberrations.aberration_models import AberrationModes
from PyTorchAberrations.aberration_functions import conjugate, complex_matmul, normalize

# 1. Load data

## 1.1 Load ideal mode bases 

In [3]:
path = os.path.sep.join([data_folder,'conversion_matrices.npz'])
data_dict = np.load(path)
modes_in = data_dict.f.modes_in
nmodes = modes_in.shape[0]
# number of input pixels
m = np.int(np.sqrt(modes_in.shape[1]))
modes_out = data_dict.f.modes_out
# number of output pixels
n = np.int(np.sqrt(modes_out.shape[1]))

## 1.2 Load mask corresponding to degenerate groups of mode

In [4]:
path = os.path.sep.join([data_folder,'mask_near_degenerate.npy'])
mask_near_degenerate = np.load(path)

## 1.3 Load a measured pixel basis TM

In [5]:
# in two parts because of Github 100 Mo size limit
file_path = os.path.sep.join([data_folder, 'TM5_0.npy'])
part1 = np.load(file_path)
file_path = os.path.sep.join([data_folder, 'TM5_1.npy'])
part2 = np.load(file_path)
TM_ref_pix = np.concatenate([part1, part2], axis = 0)
del part1
del part2

# 2. Parameters

In [6]:
# padding coefficient for FFT, float > 0.
# the bigger, the more precise the FFT during the calculations
padding_coeff = 0.05

# list of Zernike polynomials to consider in the Fourier plane (both input and output)
list_zernike_ft = list(range(9))

# list of Zernike polynomials to consider in the direct plane (both input and output)
list_zernike_direct = list(range(14))#14

# learning rate for the optimizer
learning_rate = 10e-2

# number of epoch to train the model
num_epoch = 500

deformation = 'scaling' # for test, do not modify

# 3. Prepare data

## 3.1 Select a quadrant of the pixel TM measured
(it corresponds to one input and one output polarization)

In [7]:
# TM_pix = TM_ref_pix[:n**2,:m**2]
TM_pix = TM_ref_pix[n**2:,m**2:]
# TM_pix = TM_ref_pix[n**2:,:m**2]
# TM_pix = TM_ref_pix[:n**2,m**2:]

## 3.2 Select device (GPU if available)

In [8]:
device_number = 0 # if you have multiple GPUs
if torch.cuda.is_available():
    print(f'Using GPU: {torch.cuda.get_device_name(device_number)}')
    device = torch.device(f'cuda:{device_number}')
else:
    print(f'No GPU available, running on CPU. Will be slow.')
    device = torch.device('cpu')

No GPU available, running on CPU. Will be slow.


## 3.3 Prepare data for PyTorch

In [9]:
pt_modes_in_var = Variable(
    cpx_to_pt(modes_in.reshape((-1,m,m)), device, torch.float32),
    requires_grad = False
).to(device)

pt_modes_out_var = Variable(
    cpx_to_pt(modes_out.reshape((-1,n,n)), device, torch.float32),
    requires_grad = False
).to(device)

# 4. Prepare optimization

## 4.1 Initialize the model

In [10]:
model = AberrationModes(
    inpoints = m, 
    onpoints = n, 
    padding_coeff = padding_coeff, # if the gpu memory gets full, reduce it to just above 0
    list_zernike_ft = list_zernike_ft,
    list_zernike_direct = list_zernike_direct,
    deformation = deformation
).to(device)

## 4.2 Define the cost function
We want to maximize the energy of the TM projected onto the mode basis.

The cost function (to minimize) is the inverse of the ratio of between 
the norm squared of the TM in the mode basis by the norm squared of the TM in the pixel basis.

In [11]:
def norm_mode_to_norm_pix(T_pix, pt_modes_out_var, pt_modes_in_var, onpoints, inpoints):
    '''
    Cost function to minimize.
    It is defined as the inverse of the ratio between the norm squared of the TM
    in the mode basis by the norm squared of the TM in the pixel basis.
    '''
    # reshape the change of basis matrices
    pt_modes_out_var = pt_modes_out_var.reshape((-1,onpoints**2,2))
    pt_modes_in_var = pt_modes_in_var.reshape((-1,inpoints**2,2))
    # project the TM in the mode basis with the current aberration parameters
    T_mode = complex_matmul(complex_matmul(pt_modes_out_var,T_pix),conjugate(pt_modes_in_var.permute((1,0,2))))
    # use the ratio of energy between the projected matrix and the pixel basis
    energy_ratio = (torch.norm(T_mode)/torch.norm(T_pix))**2 
    # the cost function to minimize is the inverse of this quantity
    return 1./energy_ratio

loss_fn = lambda A: norm_mode_to_norm_pix(
                                    tm_to_pt(TM_pix,device),
                                    A[0],
                                    A[1], 
                                    n,
                                    m)

## 4.3 Initialize the optimizer

We use the Adam optimizer

In [12]:
optimizer = torch.optim.Adam(
    model.parameters(), 
    lr=learning_rate, 
    betas=(0.9, 0.999), 
    eps=1e-08, 
    weight_decay=0.0, 
    amsgrad=False
)

# 5. Optimization 

## 5.1 Calculate energy criteria before optimization

In [13]:
evol_energy_ratio = []


TM_modes_temp = modes_out @ TM_pix @ modes_in.T.conj()

# ratio of energy after projection to the energy of the full pixel matrix
energy_ratio = (np.linalg.norm(TM_modes_temp)/np.linalg.norm(TM_pix))**2
# ratio of the energy on the block diagonal (degenerate groups)
# of the mode basis TM compared to the total energy of the mode basis TM

energy_on_diagonal =np.linalg.norm(TM_modes_temp*mask_near_degenerate)**2
diag_ratio = energy_on_diagonal/np.linalg.norm(TM_modes_temp)**2

evol_energy_ratio.append(energy_ratio)

print(f"Initial ratios: conversion = {100*energy_ratio:.3f}%, \t diagonal = {100*diag_ratio:.3f}%")

Initial ratios: conversion = 60.227%, 	 diagonal = 21.543%


## 5.2 Optimization loop

In [15]:
# get the first guess using the initial parameters
y_pred = model(pt_modes_in_var,pt_modes_out_var)

best = 1e8 

for epoch in range(num_epoch):

    # infer the input and output projection matrices from the model
    y_pred = model(pt_modes_in_var,pt_modes_out_var)
    # normalize the input and output projection matrices
    y_pred = [normalize(y_pred[0].reshape((-1,n**2,2)),device = device).reshape((-1,n,n,2)), 
              normalize(y_pred[1].reshape((-1,m**2,2)),device = device).reshape((-1,m,m,2))]


    loss = loss_fn(y_pred)

    optimizer.zero_grad() 
    # backward propagation done by autograd
    loss.backward()

    # update the parameters
    optimizer.step()

    evol_energy_ratio.append(1./loss.item())

    if loss.item() < best :
        best  = loss.item()
        best_state = (model.abberation_output.state_dict(),
                      model.abberation_input.state_dict())

    # show some stats
    if epoch % 10 == 0:
   
        modes_out_corr = pt_to_cpx(y_pred[0].detach().cpu()).reshape(nmodes,-1)
        modes_in_corr = pt_to_cpx(y_pred[1].detach().cpu()).reshape(nmodes,-1)
        # project in the mode basis
        TM_modes_temp = modes_out_corr @ TM_pix @ modes_in_corr.T.conj()
 

        # ratio of energy after projection to the energy of the full pixel matrix
        energy_ratio = (np.linalg.norm(TM_modes_temp)/np.linalg.norm(TM_pix))**2
        # ratio of the energy on the block diagonal (degenerate groups)
        # of the mode basis TM compared to the total energy of the mode basis TM
        
        energy_on_diagonal =np.linalg.norm(TM_modes_temp*mask_near_degenerate)**2
        diag_ratio = energy_on_diagonal/np.linalg.norm(TM_modes_temp)**2
        print(f"--| epoch: {epoch+1}/{num_epoch} | {100.*(epoch+1)/num_epoch:3.2f}% |--")
        print(f" Current ratios: conversion = {100*energy_ratio:.3f}%, \t diagonal = {100*diag_ratio:.3f}%")
modes_out_corr = pt_to_cpx(y_pred[0].detach().cpu()).reshape(nmodes,-1)
modes_in_corr = pt_to_cpx(y_pred[1].detach().cpu()).reshape(nmodes,-1)
TM_modes_corr = modes_out_corr @ TM_pix @ modes_in_corr.T.conj()


--| epoch: 1/500 | 0.20% |--
 Current ratios: conversion = 60.046%, 	 diagonal = 21.661%
--| epoch: 11/500 | 2.20% |--
 Current ratios: conversion = 64.712%, 	 diagonal = 24.332%
--| epoch: 21/500 | 4.20% |--
 Current ratios: conversion = 68.714%, 	 diagonal = 27.015%
--| epoch: 31/500 | 6.20% |--
 Current ratios: conversion = 71.851%, 	 diagonal = 28.774%
--| epoch: 41/500 | 8.20% |--
 Current ratios: conversion = 74.868%, 	 diagonal = 33.574%
--| epoch: 51/500 | 10.20% |--
 Current ratios: conversion = 77.689%, 	 diagonal = 40.387%
--| epoch: 61/500 | 12.20% |--
 Current ratios: conversion = 80.432%, 	 diagonal = 47.236%
--| epoch: 71/500 | 14.20% |--
 Current ratios: conversion = 83.270%, 	 diagonal = 55.384%
--| epoch: 81/500 | 16.20% |--
 Current ratios: conversion = 86.021%, 	 diagonal = 64.303%
--| epoch: 91/500 | 18.20% |--
 Current ratios: conversion = 88.362%, 	 diagonal = 73.039%
--| epoch: 101/500 | 20.20% |--
 Current ratios: conversion = 90.408%, 	 diagonal = 80.784%
--| 

# 6. Results

## 6.1 Mode TM before and after

In [18]:
plt.figure()
plt.plot(evol_energy_ratio)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

[<matplotlib.lines.Line2D at 0x7fc0fa4c50b8>]

## 6.2 Show Zernike coefficients

In [19]:
best_Zernike_coeff = [getZernikeCoefs(model.abberation_output.state_dict()),
                   getZernikeCoefs(model.abberation_input.state_dict())]

In [20]:
showZernikeCoefs(best_Zernike_coeff[1])

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

(<Figure size 1200x700 with 2 Axes>,
 <matplotlib.axes._subplots.AxesSubplot at 0x7fc0fa47ac18>,
 [4, 12])