 # PtyRAD - PTYchographic Reconstruction with Automatic Differentiation

 Chia-Hao Lee

cl2696@cornell.edu

Created 2024.03.08

# 01. Imports

In [None]:
%reload_ext autoreload
%autoreload 2

from random import shuffle
import numpy as np
import matplotlib.pyplot as plt
import torch

GPUID = 0
DEVICE = torch.device('cuda:' + str(GPUID))
print('Execution device: ', DEVICE)
print('PyTorch version: ', torch.__version__)
print('CUDA available: ', torch.cuda.is_available())
print('CUDA version: ', torch.version.cuda)
print('CUDA device:', torch.cuda.get_device_name(GPUID))

In [None]:
from ptyrad.initialization import Initializer
from ptyrad.models import PtychoAD
from ptyrad.optimization import CombinedLoss, CombinedConstraint, ptycho_recon, loss_logger
from ptyrad.visualization import plot_forward_pass, plot_scan_positions, plot_summary, plot_pos_grouping, plot_sigmoid_mask, plot_probe_modes, plot_affine_transformation, plot_obj_tilts
from ptyrad.utils import test_loss_fn, select_scan_indices, make_batches, make_recon_params_dict, make_output_folder, save_results, get_blob_size, test_constraint_fn, get_date, get_local_obj_tilts

# 02. Initialize optimization

In [None]:
# from ptyrad.inputs.params_multi_obj import exp_params, source_params
# from ptyrad.inputs.params_CNS import exp_params, source_params
# from ptyrad.inputs.params_PSO_128 import exp_params, source_params
from ptyrad.inputs.params_tBL_WSe2 import exp_params, source_params

# from ptyrad.inputs.params_BaM_128 import exp_params, source_params
# from ptyrad.inputs.params_BaM_256 import exp_params, source_params
# from ptyrad.inputs.params_eCNS10 import exp_params, source_params
# from ptyrad.inputs.params_HEO_124 import exp_params, source_params
# from ptyrad.inputs.params_tSTO import exp_params, source_params
# from ptyrad.inputs.params_NNO3 import exp_params, source_params
# from ptyrad.inputs.params_Si_128 import exp_params, source_params
# from ptyrad.inputs.params_PdPt import exp_params, source_params
# from ptyrad.inputs.params_NMC_64 import exp_params, source_params
# from ptyrad.inputs.params_NMC_128 import exp_params, source_params
# from ptyrad.inputs.params_tNNO import exp_params, source_params
# from ptyrad.inputs.params_bSTO import exp_params, source_params

In [None]:
init = Initializer(exp_params, source_params).init_all()

In [None]:
# # Estimate the local obj tilts
# pos = init.init_variables['crop_pos'] + exp_params['Npix']//2
# objp = np.angle(init.init_variables['obj'].mean(0)) # Take the average omode
# dx = exp_params['dx_spec']
# z_distance = exp_params['z_distance']
# slice_indices = [2,9]
# blob_params = {'min_sigma':1, 'max_sigma':5, 'overlap':0.1, 'threshold':0.35, 'exclude_border':(150,150)}
# obj_tilts = get_local_obj_tilts(pos, objp, dx, z_distance, slice_indices, blob_params, window_size=11)
# init.init_variables['obj_tilts'] = obj_tilts

In [None]:
pos = init.init_variables['crop_pos'] + init.init_variables['probe_pos_shifts']
plot_scan_positions(pos)

In [None]:
model_params = {
    'obj_preblur_std'     : None, # scalar(px), None
    'detector_blur_std'   : None, # scalar(px), None
    'lr_params':{
        'obja'            : 5e-4,
        'objp'            : 5e-4,
        'obj_tilts'       : 0, 
        'probe'           : 1e-4, 
        'probe_pos_shifts': 5e-4}}

model = PtychoAD(init.init_variables, model_params, device=DEVICE)

# Use model.set_optimizer(new_lr_params) to update the variable flag and optimizer_params
optimizer = torch.optim.Adam(model.optimizer_params)

In [None]:
# model.set_optimizer(lr_params={'obja'            : 5e-4,
#                                'objp'            : 5e-4,
#                                'obj_tilts'       : 1e-4,
#                                'probe'           : 1e-4, 
#                                'probe_pos_shifts': 1e-4})
# optimizer=torch.optim.Adam(model.optimizer_params)

## Check the forward pass

In [None]:
indices = np.random.randint(0,init.init_variables['N_scans'],2)
dp_power = 0.5
plot_forward_pass(model, indices, dp_power)

## Finetune the loss params

In [None]:
loss_params = {
    'loss_single': {'state': True, 'weight': 1.0, 'dp_pow': 0.5},
    'loss_poissn': {'state': False, 'weight': 1.0, 'dp_pow':1.0},
    'loss_pacbed': {'state': False, 'weight': 0.5, 'dp_pow': 0.2},
    'loss_sparse': {'state': True, 'weight': 0.1, 'ln_order': 1},
    'loss_simlar': {'state': False, 'weight': 1.0, 'obj_type':'both', 'scale_factor':[1,1,1], 'blur_std':1}
}

indices = np.random.randint(0,init.init_variables['N_scans'], 16)
loss_fn = CombinedLoss(loss_params, device=DEVICE)
test_loss_fn(model, indices, loss_fn)

# Setup the iteration-wise constraint params

In [None]:
constraint_params = {
    'ortho_pmode'   : {'freq': 1},
    'probe_mask_k'  : {'freq': None, 'radius':0.22, 'width':0.05}, # k-radius should be larger than 2*rbf/Npix to avoid cutting out the BF disk
    'fix_probe_int' : {'freq': 1},
    'obj_rblur'     : {'freq': None, 'obj_type':'both', 'kernel_size': 5, 'std':0.4}, # Ideally kernel size is odd and larger than 6std+1 so it decays to 0
    'obj_zblur'     : {'freq': None, 'obj_type':'both', 'kernel_size': 5, 'std':1},
    'kr_filter'     : {'freq': None,    'obj_type':'both', 'radius':0.15, 'width':0.05},
    'kz_filter'     : {'freq': None,    'obj_type':'both', 'beta':1, 'alpha':1},
    'obja_thresh'   : {'freq': None, 'relax':0, 'thresh':[0.8**(1/15), 1.2**(1/15)]},
    'objp_postiv'   : {'freq': 1,    'relax':0},
    # 'obj_gauss_fit' : {'freq': 1, 'obj_type':'both', 'num_fp_configs':1, 'num_groups':None},
    'tilt_smooth'   : {'freq': None, 'std':2}
}


# probe_k = np.fft.fftshift(np.abs(np.fft.fft2(np.fft.fftshift(init.init_variables['probe'], axes=(-2,-1))))[0])
# #plot_sigmoid_mask(Npix=exp_params['Npix'], relative_radius=0.61, relative_width=0.05, img = probe_k, show_circles=True)
# plot_sigmoid_mask(Npix=exp_params['Npix'], relative_radius=0.61, relative_width=0.05, img = init.init_variables['measurements'].mean(0), show_circles=True)
constraint_fn = CombinedConstraint(constraint_params, device=DEVICE)
# test_constraint_fn(PtychoAD(init.init_variables, model_params, device='cpu'), constraint_fn, plot_forward_pass)

# 03. Main optimization loop

In [None]:
NITER        = 5
INDICES_MODE = 'full'   # 'full', 'center', 'sub'
BATCH_SIZE   = 32
GROUP_MODE   = 'random' # 'random', 'sparse', 'compact' # Note that 'sparse' for 256x256 scan could take more than 10 mins on CPU. PtychoShelves automatically switch to 'random' for Nscans>1e3
SAVE_ITERS   = 1        # scalar or None

# Output folder and pre/postfix, note that the needed / and _ are automatically generated
output_dir   = 'output/tBL-WSe2'
prefix       = get_date(date_format='%Y%m%d')
postfix      = ''
fig_list     = ['loss', 'forward', 'probe_r_amp', 'probe_k_amp', 'probe_k_phase', 'pos'] # 'loss', 'forward', 'probe_r_amp', 'probe_k_amp', 'probe_k_phase', 'pos', 'tilt', or 'all' for all the figures

pos          = (model.crop_pos + model.opt_probe_pos_shifts).detach().cpu().numpy()
probe_int    = model.opt_probe[0].abs().pow(2).detach().cpu().numpy()
dx           = init.init_variables['dx']
d_out        = get_blob_size(dx, probe_int, output='d90') # d_out unit is in Ang
indices      = select_scan_indices(init.init_variables['N_scan_slow'], init.init_variables['N_scan_fast'], subscan_slow=None, subscan_fast=None, mode=INDICES_MODE)
batches      = make_batches(indices, pos, BATCH_SIZE, mode=GROUP_MODE)
recon_params = make_recon_params_dict(NITER, INDICES_MODE, BATCH_SIZE, GROUP_MODE, SAVE_ITERS)
output_path  = make_output_folder(output_dir, indices, exp_params, recon_params, model, constraint_params, loss_params, prefix, postfix)

fig_grouping = plot_pos_grouping(pos, batches, circle_diameter=d_out/dx, diameter_type='90%', dot_scale=1, show_fig=True, pass_fig=True)
# fig_grouping.savefig(output_path + f"/summary_pos_grouping.png")

In [None]:
loss_iters = []
for niter in range(1, NITER+1):
    
    shuffle(batches)
    batch_losses, iter_t = ptycho_recon(batches, model, optimizer, loss_fn, constraint_fn, niter)
    loss_iters.append((niter, loss_logger(batch_losses, niter, iter_t)))
    
    ## Saving intermediate results
    if SAVE_ITERS is not None and niter % SAVE_ITERS == 0:
        # Note that `exp_params` stores the initial exp_params, while `model` contains the actual params that could be updated if either meas_crop or meas_resample is not None
        save_results(output_path, model, exp_params, source_params, loss_params, constraint_params, recon_params, loss_iters, iter_t, niter, batch_losses)
        
        ## Saving summary
        plot_summary(output_path, loss_iters, niter, indices, init.init_variables, model, fig_list, show_fig=False, save_fig=True)