 # PtyRAD - PTYchographic Reconstruction with Automatic Differentiation

 Chia-Hao Lee

cl2696@cornell.edu

Created 2024.03.08

## Major workflow
- Rename "`preprocess.py`" into "`initialization.py`"
- Make sure to keep things interact freely without too much coupling
- Initialize class (takes user defined dict with flags and params), probably would do a numpy class and separate from PtyAD for clarity. Also I might be using the same initialize class for something else, like another reconstruction engine parallel to `Ptyrad`.

```Exp:

torch.save({'init_params':init_params,
            'lr_params':lr_params,
            'loss_params':loss_params,
            'model_state_dict':model.state_dict()
            }, PATH)

```

# 01. Imports

In [None]:
import data_io 
import initialization
import models
import forward
import optimization
import visualization
import utils


import os
from importlib import reload
import matplotlib.pyplot as plt
import numpy as np
import torch


GPUID = 0
DEVICE = torch.device("cuda:" + str(GPUID))
print("Execution device: ", DEVICE)
print("PyTorch version: ", torch.__version__)
print("CUDA available: ", torch.cuda.is_available())
print("CUDA version: ", torch.version.cuda)
print("CUDA device:", torch.cuda.get_device_name(GPUID))

In [None]:
# If any of the modules are modified, use this to reload
reload(data_io)
reload(initialization)
reload(models)
reload(forward)
reload(optimization)
reload(visualization)
reload(utils)

from initialization import Initializer, make_stem_probe, make_mixed_probe
from models import PtychoAD
from optimization import CombinedLoss, ptycho_recon, loss_logger
from visualization import plot_forward_pass
from utils import test_loss_fn, make_batches

### Note: Everything is still np array (complex) from this cell, it'll be converted to tensor later

# 02. Initialize optimization

In [None]:
# ptycho_output_mat_path = "output/abtem/20240328_multi_object/cbed_WSe2_and_Cu_fov_42p86_39p97_dp180/1/roi_1_Ndp_180/MLs_L1_p2_g256_dpFlip_T/Niter100.mat"
# exp_CBED_path =          "output/abtem/20240328_multi_object/cbed_WSe2_and_Cu_fov_42p86_39p97_dp180.tif" 

# exp_params = {
#     "kv": 200,
#     "alpha":20,
#     "dx_spec":0.1199,
#     "z_distance":1,
#     "Nlayer":1,
#     "N_scans":9494,
#     "omode":1,
#     "pmode":1,
#     "probe_permute": (2,0,1),
#     "cbeds_permute":None,
#     "cbeds_reshape": (9494,180,180)
# }

ptycho_output_mat_path = "data/CNS_from_Hari/Niter10000.mat"
exp_CBED_path =          "data/CNS_from_Hari/240327_fov_23p044A_x_24p402A_thickness_9p978A_step0p28_conv30_dfm100_det70_TDS_2configs_xdirection_Co_0p25_Nb_0_S_0.mat" 

# Basic parameters
voltage = 300 # kV
conv_angle = 30 # mrad, semi-convergence angle
Npix = 164 # Detector pixel number, EMPAD is 128. Only supports square detector for simplicity
rbf = None  # Pixels of radius of BF disk, used to calculate dk
dx_spec = 0.1406 # Ang
df = -100 # Ang, positive defocus here refers to actual underfocus or weaker lens strength following Kirkland/abtem/ptychoshelves convention
omode_max = 1
pmode_max = 10
pmode_init_pows = [0.02]

exp_params = {
    "kv": voltage,
    "conv_angle":conv_angle,
    "Npix":Npix,
    "rbf": rbf,
    "dx_spec":dx_spec,
    "defocus": df,
    "z_distance":1,
    "Nlayer":1,
    "N_scans":7134,
    "omode_max":omode_max,
    "pmode_max":pmode_max,
    "pmode_init_pows": pmode_init_pows,
    "probe_permute": None,#(2,0,1),
    "cbeds_permute":(0,1,3,2),
    "cbeds_reshape": (7134,164,164)
}

probe_simu_params = { ## Basic params
                    "kv"             : exp_params['kv'],
                    "conv_angle"     : exp_params['conv_angle'],
                    "Npix"           : exp_params['Npix'],
                    "rbf"            : exp_params['rbf'], # dk = conv_angle/1e3/rbf/wavelength
                    "dx"             : exp_params['dx_spec'], # dx = 1/(dk*Npix) #angstrom
                    "print_info"     : True,
                    "pmodes"         : exp_params['pmode_max'],
                    "pmode_init_pows": exp_params['pmode_init_pows'],
                    ## Aberration coefficients
                    "df": exp_params['defocus'], #first-order aberration (defocus) in angstrom, positive defocus here refers to actual underfocus or weaker lens strength following Kirkland's notation
                    "c3":0, #third-order spherical aberration in angstrom
                    "c5":0, #fifth-order spherical aberration in angstrom
                    "c7":0, #seventh-order spherical aberration in angstrom
                    "f_a2":0, #twofold astigmatism in angstrom
                    "f_a3":0, #threefold astigmatism in angstrom
                    "f_c3":0, #coma in angstrom
                    "theta_a2":0, #azimuthal orientation in radian
                    "theta_a3":0, #azimuthal orientation in radian
                    "theta_c3":0, #azimuthal orientation in radian
                    "shifts":[0,0], #shift probe center in angstrom
                    }

init_params = {
    "exp_params"        :exp_params,
    "measurements"      :{"source":"mat",       "params":[exp_CBED_path, 'cbed']},
    "obj"               :{"source":"simu",      "params":(1,1,391,403)},
#    "obj"               :{"source":"mat",      "params" :ptycho_output_mat_path},
    "probe"             :{"source":"simu",      "params":probe_simu_params},
#    "probe"             :{"source":"mat",      "params":ptycho_output_mat_path},
    "pos"               :{"source":"mat",       "params":ptycho_output_mat_path},
    "omode_occu"        :{"source":"uniform",   "params":None},
 }

init = Initializer(init_params).init_all()

In [None]:
model = PtychoAD(init.init_variables, 
                lr_params={'obja': 0,
                           'objp': 1e-3,
                           'probe': 1e-3, 
                           'probe_pos_shifts': 0},
                device=DEVICE)

# Use model.set_optimizer(new_lr_params) to update the variable flag and optimizer_params
optimizer = torch.optim.Adam(model.optimizer_params)

In [None]:
# # Use this to edit learning rate if needed some refinement

# model.set_optimizer(lr_params={'obja': 0,
#                                'objp': 1e-3,
#                                'probe': 1e-3, 
#                                'probe_pos_shifts': 0})
# optimizer=torch.optim.Adam(model.optimizer_params)

## Check the forward pass

In [None]:
indices = np.random.randint(0,exp_params['N_scans'],1)
dp_power = 0.5

plot_forward_pass(model, indices, dp_power, init.init_variables['obj'])

## Finetune the loss params

In [None]:
loss_params = {
    'loss_single': {'state':  True,  'weight': 1.0, 'dp_pow': 0.5},
    'loss_pacbed': {'state': False,  'weight': 1.0, 'dp_pow': 0.2},
    'loss_tv'    : {'state': False,  'weight': 1e-4},
    'loss_l1'    : {'state': False,   'weight': 1e-2},
    'loss_l2'    : {'state': False,  'weight': 1.0},
    'loss_postiv': {'state':  True,  'weight': 1.0}
}

indices = np.random.randint(0,exp_params['N_scans'],256)
loss_fn = CombinedLoss(loss_params, device=DEVICE)
test_loss_fn(model, indices, loss_fn)

# 03. Main optimization loop

In [None]:
NITER = 10
BATCH_SIZE = 256
N_scans = exp_params['N_scans']
num_batch = N_scans / BATCH_SIZE # The actual batch size would only be "close" if it's not divisible by len(measurments)
cbed_shape = model.measurements.shape[1:]
loss_iters = []

# output_path = 'output/multi_obj_WSe2_and_Cu_4obj_8slice_1e-3_dppow0.5_b64/'
# os.makedirs(output_path, exist_ok=True)

for iter in range(NITER+1):
    batches = make_batches(N_scans, num_batch)
    batch_losses, iter_t = ptycho_recon(batches, model, optimizer, loss_fn)
    loss_iters.append(loss_logger(batch_losses, iter, iter_t))
    
    # # ## Saving
    # if iter % 10 == 0:
    #     torch.save(model.state_dict(), os.path.join(output_path, f"model_iter{str(iter).zfill(4)}.pt"))
    #     imwrite(os.path.join(output_path, f"objp_iter{str(iter).zfill(4)}_4D.tif"), model.opt_objp.detach().cpu().numpy().astype('float32'))
    #     imwrite(os.path.join(output_path, f"objp_iter{str(iter).zfill(4)}_zsum.tif"), model.opt_objp.sum(1).detach().cpu().numpy().astype('float32'))