 # PtyRAD - PTYchographic Reconstruction with Automatic Differentiation

 Chia-Hao Lee

cl2696@cornell.edu

Created 2024.03.08

# 01. Imports

In [None]:
%reload_ext autoreload
%autoreload 2

from random import shuffle

import numpy as np
import matplotlib.pyplot as plt
import torch

GPUID = 0
DEVICE = torch.device('cuda:' + str(GPUID))
print('Execution device: ', DEVICE)
print('PyTorch version: ', torch.__version__)
print('CUDA available: ', torch.cuda.is_available())
print('CUDA version: ', torch.version.cuda)
print('CUDA device:', torch.cuda.get_device_name(GPUID))

In [None]:
from ptyrad.initialization import Initializer
from ptyrad.models import PtychoAD
from ptyrad.optimization import CombinedLoss, CombinedConstraint, ptycho_recon, loss_logger
from ptyrad.visualization import plot_forward_pass, plot_scan_positions, plot_summary, plot_pos_grouping
from ptyrad.utils import test_loss_fn, select_scan_indices, make_batches, make_recon_params_dict, make_output_folder, save_results, get_blob_size, imshift_batch

# 02. Initialize optimization

In [None]:
from ptyrad.inputs.params_CNS import exp_params, source_params
# from ptyrad.inputs.params_PSO_128 import exp_params, source_params
# from ptyrad.inputs.params_tBL_WSe2 import exp_params, source_params

# from ptyrad.inputs.params_BaM_128 import exp_params, source_params
# from ptyrad.inputs.params_BaM_256 import exp_params, source_params
# from ptyrad.inputs.params_STO_128 import exp_params, source_params
# from ptyrad.inputs.params_NNO3 import exp_params, source_params
# from ptyrad.inputs.params_Si_128 import exp_params, source_params
# from ptyrad.inputs.params_PdPt import exp_params, source_params



In [None]:
init = Initializer(exp_params, source_params).init_all()

In [None]:
pos = init.init_variables['crop_pos'] + init.init_variables['probe_pos_shifts']
plot_scan_positions(pos)

In [None]:
model_params = {
    'detector_blur_std': 1,
    'lr_params':{
        'obja': 5e-4,
        'objp': 5e-4,
        'probe': 1e-3, 
        'probe_pos_shifts': 1e-3}}

model = PtychoAD(init.init_variables, model_params, device=DEVICE)

# Use model.set_optimizer(new_lr_params) to update the variable flag and optimizer_params
optimizer = torch.optim.Adam(model.optimizer_params)

## Check the forward pass

In [None]:
indices = np.random.randint(0,exp_params['N_scans'],2)
dp_power = 0.5

plot_forward_pass(model, indices, dp_power)

## Finetune the loss params

In [None]:
loss_params = {
    'loss_single': {'state':  True,  'weight': 1.0, 'dp_pow': 0.5},
    'loss_pacbed': {'state': False,  'weight': 1.0, 'dp_pow': 0.2},
    'loss_tv'    : {'state': False,  'weight': 2e-4},
    'loss_l1'    : {'state': False,  'weight': 0.1},
    'loss_l2'    : {'state': False,  'weight': 1.0},
    'loss_postiv': {'state': False,  'weight': 1.0}
}

indices = np.random.randint(0,exp_params['N_scans'], 16)
loss_fn = CombinedLoss(loss_params, device=DEVICE)
test_loss_fn(model, indices, loss_fn)

# Setup the iteration-wise constraint params

In [None]:
constraint_params = {
    'objp_blur'     : {'freq': None, 'std':None},
    'ortho_pmode'   : {'freq': 1},
    'ortho_omode'   : {'freq': None},
    'kz_filter'     : {'freq': 1, 'beta':1, 'alpha':1, 'z_pad':None, 'obj_type':'both'},
    'postiv'        : {'freq': 1},
    'fix_probe_int' : {'freq': 1}
}

constraint_fn = CombinedConstraint(constraint_params, device=DEVICE)

# 03. Main optimization loop

In [None]:
NITER        = 100
INDICES_MODE = 'full'   # 'full', 'center', 'sub'
BATCH_SIZE   = 256
GROUP_MODE   = 'sparse' # 'random', 'sparse', 'compact'
SAVE_ITERS   = 10        # scalar or None

output_dir   = 'output/CNS'
postfix      = '_testplot'

pos          = (model.crop_pos + model.opt_probe_pos_shifts).detach().cpu().numpy()
probe_int    = model.opt_probe[0].abs().pow(2).detach().cpu().numpy()
dx           = exp_params['dx_spec']
d90          = get_blob_size(dx, probe_int)
indices      = select_scan_indices(exp_params['N_scan_slow'], exp_params['N_scan_fast'], subscan_slow=None, subscan_fast=None, mode=INDICES_MODE)
batches      = make_batches(indices, pos, BATCH_SIZE, mode=GROUP_MODE)
recon_params = make_recon_params_dict(NITER, INDICES_MODE, BATCH_SIZE, GROUP_MODE, SAVE_ITERS)
output_path  = make_output_folder(output_dir, indices, exp_params, recon_params, model, constraint_params, postfix)

fig_grouping = plot_pos_grouping(pos, batches, circle_diameter=d90/dx, dot_scale=1, pass_fig=True)
fig_grouping.savefig(output_path + f"/summary_pos_grouping.png")

In [None]:
loss_iters = []
for niter in range(1, NITER+1):
    
    shuffle(batches)
    batch_losses, iter_t = ptycho_recon(batches, model, optimizer, loss_fn, constraint_fn, niter)
    loss_iters.append((niter, loss_logger(batch_losses, niter, iter_t)))
    
    ## Saving intermediate results
    if SAVE_ITERS is not None and niter % SAVE_ITERS == 0:
        save_results(output_path, model, exp_params, source_params, loss_params, constraint_params, recon_params, loss_iters, iter_t, niter, batch_losses)
        
        ## Saving summary
        plot_summary(output_path, loss_iters, niter, indices, init.init_variables, model, show_fig=False, save_fig=True)