 # PtyRAD - PTYchographic Reconstruction with Automatic Differentiation

 Chia-Hao Lee

cl2696@cornell.edu

Created 2024.03.08

# 01. Imports

In [None]:
%reload_ext autoreload
%autoreload 2


import matplotlib.pyplot as plt
import os
import numpy as np
import torch
from tifffile import imwrite

GPUID = 0
DEVICE = torch.device("cuda:" + str(GPUID))
print("Execution device: ", DEVICE)
print("PyTorch version: ", torch.__version__)
print("CUDA available: ", torch.cuda.is_available())
print("CUDA version: ", torch.version.cuda)
print("CUDA device:", torch.cuda.get_device_name(GPUID))

In [None]:
from ptyrad.initialization import Initializer
from ptyrad.models import PtychoAD
from ptyrad.optimization import CombinedLoss, ptycho_recon, loss_logger
from ptyrad.visualization import plot_forward_pass
from ptyrad.utils import test_loss_fn, make_batches, select_center_rectangle_indices

# 02. Initialize optimization

In [None]:
# # tBL_WSe2

# ptycho_output_mat_path = 'data/tBL_WSe2/Fig_1h_24.9mrad_Themis/1/roi1_Ndp128_step128/MLs_L1_p10_g128_pc0_noModel_updW100_mm_dpFlip_ud_T/Niter9000_v7.mat'
# exp_CBED_path          = 'data/tBL_WSe2/Fig_1h_24.9mrad_Themis/1/data_roi1_Ndp128_step128_dp.hdf5'

# exp_params = {
#     "kv"                : 80,  # kV
#     "conv_angle"        : 24.9, # mrad, semi-convergence angle
#     "Npix"              : 128, # Detector pixel number, EMPAD is 128. Only supports square detector for simplicity
#     "rbf"               : None, # Pixels of radius of BF disk, used to calculate dk
#     "dx_spec"           : 0.1494,# Ang
#     "defocus"           : 0, # Ang, positive defocus here refers to actual underfocus or weaker lens strength following Kirkland/abtem/ptychoshelves convention
#     "c3"                : 0, # Ang, spherical aberration coefficients
#     "z_distance"        : 8, # Ang
#     "Nlayer"            : 1,
#     "N_scans"           : 16384,
#     "omode_max"         : 1,
#     "pmode_max"         : 10,
#     "pmode_init_pows"   : [0.02],
#     "probe_permute"     : None,
#     "cbeds_permute"     : None,
#     "cbeds_reshape"     : None,
#     "cbeds_flip"        : (1),
#     "probe_simu_params" : None
#     }

# # Source and params, note that these should be changed in accordance with each other
# source_params = {
#     'measurements_source': 'hdf5',
#     'measurements_params': [exp_CBED_path, 'dp'],
#     'obj_source'         : 'simu', 
#     'obj_params'         : (1,1,592,592),
#     'probe_source'       : 'PtyShv',
#     'probe_params'       : ptycho_output_mat_path,  # probe_simu_params
#     'pos_source'         : 'PtyShv',
#     'pos_params'         : ptycho_output_mat_path,
#     'omode_occu_source'  : 'uniform',
#     'omode_occu_params'  : None
# }

In [None]:
# CNS
ptycho_output_mat_path = "data/CNS_from_Hari/Niter10000.mat"
exp_CBED_path =          "data/CNS_from_Hari/240327_fov_23p044A_x_24p402A_thickness_9p978A_step0p28_conv30_dfm100_det70_TDS_2configs_xdirection_Co_0p25_Nb_0_S_0.mat" 

exp_params = {
    "kv"                : 300,  # kV
    "conv_angle"        : 30, # mrad, semi-convergence angle
    "Npix"              : 164, # Detector pixel number, EMPAD is 128. Only supports square detector for simplicity
    "rbf"               : None, # Pixels of radius of BF disk, used to calculate dk
    "dx_spec"           : 0.1406,# Ang
    "defocus"           : -100, # Ang, positive defocus here refers to actual underfocus or weaker lens strength following Kirkland/abtem/ptychoshelves convention
    "c3"                : 0, # Ang, spherical aberration coefficients
    "z_distance"        : 1.25, # Ang
    "Nlayer"            : 8,
    "N_scans"           : 7134,
    "omode_max"         : 1,
    "pmode_max"         : 1,
    "pmode_init_pows"   : [0.02],
    "probe_permute"     : None,
    "cbeds_permute"     : (0,1,3,2),
    "cbeds_reshape"     : (7134,164,164),
    "cbeds_flip"        : None,
    "probe_simu_params" : None
    }

# Source and params, note that these should be changed in accordance with each other
source_params = {
    'measurements_source': 'mat',
    'measurements_params': [exp_CBED_path, 'cbed'],
    'obj_source'         : 'simu',
    'obj_params'         : (1,1,391,403),
    'probe_source'       : 'simu',
    'probe_params'       : None, 
    'pos_source'         : 'PtyShv',
    'pos_params'         : ptycho_output_mat_path,
    'omode_occu_source'  : 'uniform',
    'omode_occu_params'  : None
}


In [None]:
# # PSO 256

# ptycho_output_mat_path = 'data/PSO/MLs_L1_p8_g32_pc50_noModel_vp1_Ns21_dz10_reg1/Niter200.mat'
# exp_CBED_path          = 'data/PSO/MLs_L1_p8_g32_pc50_noModel_vp1_Ns21_dz10_reg1/PSO_data_roi0_Ndp256_dp.hdf5'

# exp_params = {
#     "kv"                : 300,  # kV
#     "conv_angle"        : 21.4, # mrad, semi-convergence angle
#     "Npix"              : 256, # Detector pixel number, EMPAD is 128. Only supports square detector for simplicity
#     "rbf"               : None, # Pixels of radius of BF disk, used to calculate dk
#     "dx_spec"           : 0.0934,# Ang
#     "defocus"           : -200, # Ang, positive defocus here refers to actual underfocus or weaker lens strength following Kirkland/abtem/ptychoshelves convention
#     "c3"                : 0, # Ang, spherical aberration coefficients
#     "z_distance"        : 10, # Ang
#     "Nlayer"            : 21,
#     "N_scans"           : 4096,
#     "omode_max"         : 1,
#     "pmode_max"         : 8,
#     "pmode_init_pows"   : [0.02],
#     "probe_permute"     : None,
#     "cbeds_permute"     : (0,2,1),
#     "cbeds_reshape"     : None,
#     "cbeds_flip"        : None
#     "probe_simu_params" : None
#     }

# # Source and params, note that these should be changed in accordance with each other
# source_params = {
#     'measurements_source': 'hdf5',
#     'measurements_params': [exp_CBED_path, 'dp'],
#     'obj_source'         : 'PtyShv',
#     'obj_params'         : ptycho_output_mat_path,
#     'probe_source'       : 'simu',
#     'probe_params'       : ptycho_output_mat_path,  # probe_simu_params
#     'pos_source'         : 'PtyShv',
#     'pos_params'         : ptycho_output_mat_path,
#     'omode_occu_source'  : 'uniform',
#     'omode_occu_params'  : None
# }


In [None]:
init = Initializer(exp_params, source_params).init_all()

In [None]:
model_params = {
    'init_variables': init.init_variables,
    'lr_params':{
        'obja': 0,
        'objp': 1e-3,
        'probe': 1e-3, 
        'probe_pos_shifts': 0}
}

model = PtychoAD(model_params, device=DEVICE)

# Use model.set_optimizer(new_lr_params) to update the variable flag and optimizer_params
optimizer = torch.optim.Adam(model.optimizer_params)

In [None]:
# # Use this to edit learning rate if needed some refinement

# model.set_optimizer(lr_params={'obja': 0,
#                                'objp': 1e-3,
#                                'probe': 1e-3, 
#                                'probe_pos_shifts': 0})
# optimizer=torch.optim.Adam(model.optimizer_params)

## Check the forward pass

In [None]:
indices = np.random.randint(0,exp_params['N_scans'],2)
dp_power = 0.5

plot_forward_pass(model, indices, dp_power, init.init_variables['obj'])

## Finetune the loss params

In [None]:
loss_params = {
    'loss_single': {'state':  True,  'weight': 1.0, 'dp_pow': 0.5},
    'loss_pacbed': {'state': False,  'weight': 1.0, 'dp_pow': 0.2},
    'loss_tv'    : {'state': False,  'weight': 1e-4},
    'loss_l1'    : {'state': False,   'weight': 1e-2},
    'loss_l2'    : {'state': False,  'weight': 1.0},
    'loss_postiv': {'state':  True,  'weight': 1.0}
}

indices = np.random.randint(0,exp_params['N_scans'],32)
loss_fn = CombinedLoss(loss_params, device=DEVICE)
test_loss_fn(model, indices, loss_fn)

# 03. Main optimization loop

In [None]:
NITER = 50
BATCH_SIZE = 32
GROUP = 'random'

indices = np.array(select_center_rectangle_indices(87,82,32,32))
#indices = np.arange(model.measurements.size(0))
pos = model.crop_pos.cpu().numpy()
batches = make_batches(indices, pos, BATCH_SIZE, group = GROUP)

loss_iters = []
output_path = f"output/CNS/{model.opt_objp.size(0)}obj_{model.opt_objp.size(1)}slice_N{len(indices)}_lr{model_params['lr_params']['objp']:.0e}_b{BATCH_SIZE}_{GROUP}/"
print(f"output_path = {output_path}")
os.makedirs(output_path, exist_ok=True)

for iter in range(1, NITER+1):
    
    batch_losses, iter_t = ptycho_recon(batches, model, optimizer, loss_fn)
    loss_iters.append(loss_logger(batch_losses, iter, iter_t))
    
    # ## Saving
    if iter % 5 == 0:
        torch.save(model.state_dict(), os.path.join(output_path, f"model_iter{str(iter).zfill(4)}.pt"))
        imwrite(os.path.join(output_path, f"objp_iter{str(iter).zfill(4)}.tif"), model.opt_objp[0].detach().cpu().numpy().astype('float32'))

In [None]:
plt.figure()
plt.plot(loss_iters)
plt.show()