 # PtyRAD - PTYchographic Reconstruction with Automatic Differentiation

 Chia-Hao Lee

cl2696@cornell.edu

Created 2024.03.08

# 01. Imports

In [None]:
%reload_ext autoreload
%autoreload 2

import os
import matplotlib.pyplot as plt

import numpy as np
import torch
from tifffile import imread, imwrite

GPUID = 0
DEVICE = torch.device("cuda:" + str(GPUID))
print("Execution device: ", DEVICE)
print("PyTorch version: ", torch.__version__)
print("CUDA available: ", torch.cuda.is_available())
print("CUDA version: ", torch.version.cuda)
print("CUDA device:", torch.cuda.get_device_name(GPUID))

In [None]:
from ptyrad.initialization import Initializer
from ptyrad.models import PtychoAD
from ptyrad.optimization import CombinedLoss, CombinedConstraint, ptycho_recon, loss_logger
from ptyrad.visualization import plot_forward_pass
from ptyrad.utils import test_loss_fn, make_batches, time_sync, select_center_rectangle_indices, make_save_dict, make_recon_params_dict, shuffle_batches

# 02. Initialize optimization

In [None]:
# # tBL_WSe2

# ptycho_output_mat_path = 'data/tBL_WSe2/Fig_1h_24.9mrad_Themis/1/roi1_Ndp128_step128/MLs_L1_p10_g128_pc0_noModel_updW100_mm_dpFlip_ud_T/Niter9000_v7.mat'
# exp_CBED_path          = 'data/tBL_WSe2/Fig_1h_24.9mrad_Themis/1/data_roi1_Ndp128_step128_dp.hdf5'

# exp_params = {
#     "kv"                : 80,  # kV
#     "conv_angle"        : 24.9, # mrad, semi-convergence angle
#     "Npix"              : 128, # Detector pixel number, EMPAD is 128. Only supports square detector for simplicity
#     "rbf"               : None, # Pixels of radius of BF disk, used to calculate dk
#     "dx_spec"           : 0.1494,# Ang
#     "defocus"           : 0, # Ang, positive defocus here refers to actual underfocus or weaker lens strength following Kirkland/abtem/ptychoshelves convention
#     "c3"                : 0, # Ang, spherical aberration coefficients
#     "z_distance"        : 8, # Ang
#     "Nlayer"            : 1,
#     "N_scans"           : 16384,
#     "omode_max"         : 1,
#     "pmode_max"         : 10,
#     "pmode_init_pows"   : [0.02],
#     "probe_permute"     : None,
#     "cbeds_permute"     : None,
#     "cbeds_reshape"     : None,
#     "cbeds_flip"        : (1),
#     "probe_simu_params" : None
#     }

# # Source and params, note that these should be changed in accordance with each other
# source_params = {
#     'measurements_source': 'hdf5',
#     'measurements_params': [exp_CBED_path, 'dp'],
#     'obj_source'         : 'simu', 
#     'obj_params'         : (1,1,592,592),
#     'probe_source'       : 'PtyShv',
#     'probe_params'       : ptycho_output_mat_path,  # probe_simu_params
#     'pos_source'         : 'PtyShv',
#     'pos_params'         : ptycho_output_mat_path,
#     'omode_occu_source'  : 'uniform',
#     'omode_occu_params'  : None
# }

In [None]:
# # CNS
# ptycho_output_mat_path = "data/CNS_from_Hari/Niter10000.mat"
# exp_CBED_path =          "data/CNS_from_Hari/240327_fov_23p044A_x_24p402A_thickness_9p978A_step0p28_conv30_dfm100_det70_TDS_2configs_xdirection_Co_0p25_Nb_0_S_0.mat" 

# exp_params = {
#     "kv"                : 300,  # kV
#     "conv_angle"        : 30, # mrad, semi-convergence angle
#     "Npix"              : 164, # Detector pixel number, EMPAD is 128. Only supports square detector for simplicity
#     "rbf"               : None, # Pixels of radius of BF disk, used to calculate dk
#     "dx_spec"           : 0.1406,# Ang
#     "defocus"           : -100, # Ang, positive defocus here refers to actual underfocus or weaker lens strength following Kirkland/abtem/ptychoshelves convention
#     "c3"                : 0, # Ang, spherical aberration coefficients
#     "z_distance"        : 1.25, # Ang
#     "Nlayer"            : 8,
#     "N_scans"           : 7134,
#     "omode_max"         : 1,
#     "pmode_max"         : 2,
#     "pmode_init_pows"   : [0.02],
#     "probe_permute"     : None,
#     "cbeds_permute"     : (0,1,3,2),
#     "cbeds_reshape"     : (7134,164,164),
#     "cbeds_flip"        : None,
#     "probe_simu_params" : None
#     }

# # Source and params, note that these should be changed in accordance with each other
# source_params = {
#     'measurements_source': 'mat',
#     'measurements_params': [exp_CBED_path, 'cbed'],
#     'obj_source'         : 'simu',
#     'obj_params'         : (1,8,391,403),
#     'probe_source'       : 'simu',
#     'probe_params'       : None, 
#     'pos_source'         : 'PtyShv',
#     'pos_params'         : ptycho_output_mat_path,
#     'omode_occu_source'  : 'uniform',
#     'omode_occu_params'  : None
# }


In [None]:
# # PSO 256

# ptycho_output_mat_path = 'data/PSO/MLs_L1_p8_g32_pc50_noModel_vp1_Ns21_dz10_reg1/Niter200.mat'
# exp_CBED_path          = 'data/PSO/MLs_L1_p8_g32_pc50_noModel_vp1_Ns21_dz10_reg1/PSO_data_roi0_Ndp256_dp.hdf5'

# exp_params = {
#     "kv"                : 300,  # kV
#     "conv_angle"        : 21.4, # mrad, semi-convergence angle
#     "Npix"              : 256, # Detector pixel number, EMPAD is 128. Only supports square detector for simplicity
#     "rbf"               : None, # Pixels of radius of BF disk, used to calculate dk
#     "dx_spec"           : 0.0934,# Ang
#     "defocus"           : -200, # Ang, positive defocus here refers to actual underfocus or weaker lens strength following Kirkland/abtem/ptychoshelves convention
#     "c3"                : 0, # Ang, spherical aberration coefficients
#     "z_distance"        : 10, # Ang
#     "Nlayer"            : 21,
#     "N_scans"           : 4096,
#     "omode_max"         : 1,
#     "pmode_max"         : 8,
#     "pmode_init_pows"   : [0.02],
#     "probe_permute"     : None,
#     "cbeds_permute"     : (0,2,1),
#     "cbeds_reshape"     : None,
#     "cbeds_flip"        : None
#     "probe_simu_params" : None
#     }

# # Source and params, note that these should be changed in accordance with each other
# source_params = {
#     'measurements_source': 'hdf5',
#     'measurements_params': [exp_CBED_path, 'dp'],
#     'obj_source'         : 'PtyShv',
#     'obj_params'         : ptycho_output_mat_path,
#     'probe_source'       : 'simu',
#     'probe_params'       : ptycho_output_mat_path,  # probe_simu_params
#     'pos_source'         : 'PtyShv',
#     'pos_params'         : ptycho_output_mat_path,
#     'omode_occu_source'  : 'uniform',
#     'omode_occu_params'  : None
# }


In [None]:
# PSO 128

ptycho_output_mat_path = 'data/PSO/MLs_L1_p8_g192_Ndp128_pc50_noModel_vp1_Ns21_dz10_reg1/Niter200.mat'
exp_CBED_path          = 'data/PSO/MLs_L1_p8_g192_Ndp128_pc50_noModel_vp1_Ns21_dz10_reg1/PSO_data_roi0_Ndp128_dp.hdf5'

exp_params = {
    "kv"                : 300,  # kV
    "conv_angle"        : 21.4, # mrad, semi-convergence angle
    "Npix"              : 128, # Detector pixel number, EMPAD is 128. Only supports square detector for simplicity
    "rbf"               : None, # Pixels of radius of BF disk, used to calculate dk
    "dx_spec"           : 0.1868,# Ang
    "defocus"           : -200, # Ang, positive defocus here refers to actual underfocus or weaker lens strength following Kirkland/abtem/ptychoshelves convention
    "c3"                : 0, # Ang, spherical aberration coefficients
    "z_distance"        : 10, # Ang
    "Nlayer"            : 21,
    "N_scans"           : 4096,
    "omode_max"         : 1,
    "pmode_max"         : 8,
    "pmode_init_pows"   : [0.02],
    "probe_permute"     : None,
    "cbeds_permute"     : (0,2,1),
    "cbeds_reshape"     : None,
    "cbeds_flip"        : None,
    "probe_simu_params" : None
    }

# Source and params, note that these should be changed in accordance with each other
source_params = {
    'measurements_source': 'hdf5',
    'measurements_params': [exp_CBED_path, 'dp'],
    'obj_source'         : 'simu',#'PtyShv', #'simu',
    'obj_params'         : (1,21,320,320), #ptycho_output_mat_path,#(1,21,320,320),
    'probe_source'       : 'simu', #'PtyShv',
    'probe_params'       : None, #ptycho_output_mat_path,  # probe_simu_params
    'pos_source'         : 'PtyShv',
    'pos_params'         : ptycho_output_mat_path,
    'omode_occu_source'  : 'uniform',
    'omode_occu_params'  : None
}


In [None]:
init = Initializer(exp_params, source_params).init_all()

In [None]:
model_params = {
    'init_variables': init.init_variables,
    'lr_params':{
        'obja': 0,
        'objp': 5e-4,
        'probe': 5e-4, 
        'probe_pos_shifts': 5e-4}}

model = PtychoAD(model_params, device=DEVICE)

# Use model.set_optimizer(new_lr_params) to update the variable flag and optimizer_params
optimizer = torch.optim.Adam(model.optimizer_params)

In [None]:
# # Use this to edit learning rate if needed some refinement

# model.set_optimizer(lr_params={'obja': 0,
#                                'objp': 5e-4,
#                                'probe': 5e-4, 
#                                'probe_pos_shifts': 0})
# optimizer=torch.optim.Adam(model.optimizer_params)

## Check the forward pass

In [None]:
indices = np.random.randint(0,exp_params['N_scans'],2)
dp_power = 0.5

plot_forward_pass(model, indices, dp_power, init.init_variables['obj'])

## Finetune the loss params

In [None]:
loss_params = {
    'loss_single': {'state':  True,  'weight': 1.0, 'dp_pow': 0.5},
    'loss_pacbed': {'state': False,  'weight': 1.0, 'dp_pow': 0.2},
    'loss_tv'    : {'state': False,  'weight': 1e-5},
    'loss_l1'    : {'state': True,   'weight': 0.1},
    'loss_l2'    : {'state': False,  'weight': 1.0},
    'loss_postiv': {'state':  False,  'weight': 1.0}
}

indices = np.random.randint(0,exp_params['N_scans'],128)
loss_fn = CombinedLoss(loss_params, device=DEVICE)
test_loss_fn(model, indices, loss_fn)

# Setup the iteration-wise constraint params

In [None]:
constraint_params = {
    'postiv'        : {'freq': 1},
    'ortho_pmode'   : {'freq': 1},
    'ortho_omode'   : {'freq': None},
    'kz_filter'     : {'freq': 1, 'beta':1, 'alpha':1, 'z_pad':None}
}

constraint_fn = CombinedConstraint(constraint_params, device=DEVICE)

# 03. Main optimization loop

In [None]:
NITER = 120
BATCH_SIZE = 32
GROUP = 'sparse'

#indices = np.array(select_center_rectangle_indices(87,82,32,32))
indices = np.arange(model.measurements.size(0))
pos = model.crop_pos.cpu().numpy()
batches = make_batches(indices, pos, BATCH_SIZE, group = GROUP)

output_path = f"output/PSO/{model.opt_objp.size(0)}obj_{model.opt_objp.size(1)}slice_N{len(indices)}_dp{model.measurements.size(1)}_lr{model_params['lr_params']['objp']:.0e}_b{BATCH_SIZE}_{GROUP}_fix_obja_l10.1/"
print(f"output_path = {output_path}")
os.makedirs(output_path, exist_ok=True)
recon_params = make_recon_params_dict(NITER, BATCH_SIZE, GROUP, batches, output_path)

loss_iters = []
for iter in range(1, NITER+1):
    
    batches = shuffle_batches(batches, BATCH_SIZE, GROUP)
    batch_losses, iter_t = ptycho_recon(batches, model, optimizer, loss_fn, constraint_fn, iter)
    loss_iters.append((iter, loss_logger(batch_losses, iter, iter_t)))
    
    ## Saving
    if iter % 10 == 0:
        save_dict = make_save_dict(model, exp_params, source_params, loss_params, constraint_params, recon_params, loss_iters, iter_t, iter, batch_losses)
        torch.save(save_dict, os.path.join(output_path, f"model_iter{str(iter).zfill(4)}.pt"))
        imwrite(os.path.join(output_path, f"objp_iter{str(iter).zfill(4)}.tif"), model.opt_objp[0].detach().cpu().numpy().astype('float32'))

In [None]:
# Quick plot of the loss curve
plt.figure()
plt.plot(np.array(loss_iters)[:,1])
plt.show()

In [None]:
# Comparing the initial probe intensity, CBED intensity, and optimized probe int
init_probe_int = np.sum(np.abs(init.init_variables['probe'])**2)
init_CBED_int = np.sum(np.mean(init.init_variables['measurements'], 0))
opt_probe_int = model.opt_probe.abs().pow(2).sum().detach().cpu().numpy()
print(f"{init_probe_int}, {init_CBED_int}, {opt_probe_int}")

In [None]:
# Visualize the probe modes
init_probe = init.init_variables['probe']
opt_probe = model.opt_probe.detach().cpu().numpy()
fig, axs = plt.subplots(2, len(opt_probe), figsize=(len(opt_probe)*2.5, 6))
for i in range(len(opt_probe)):
    ax_init = axs[0, i]
    ax_init.set_title(f"Init probe {i}")
    im_init = ax_init.imshow(np.abs(init_probe[i]))
    ax_init.axis('off')
    plt.colorbar(im_init, ax=ax_init, shrink=0.6)

    ax_opt = axs[1, i]
    ax_opt.set_title(f"Opt probe {i}")
    im_opt = ax_opt.imshow(np.abs(opt_probe[i]))
    ax_opt.axis('off')
    plt.colorbar(im_opt, ax=ax_opt, shrink=0.6)

plt.tight_layout()
plt.show()