 # Detailed walk through for PtyRAD

 Chia-Hao Lee

cl2696@cornell.edu

Updated on 2024.08.18

# 01. Imports

In [None]:
import os
from random import shuffle

import numpy as np
import torch

GPUID = 0
DEVICE = torch.device("cuda:" + str(GPUID))
print("Execution device: ", DEVICE)
print("PyTorch version: ", torch.__version__)
print("CUDA available: ", torch.cuda.is_available())
print("CUDA version: ", torch.version.cuda)
print("CUDA device:", torch.cuda.get_device_name(GPUID))

In [None]:
work_dir = "H:\workspace\ptyrad"
os.chdir(work_dir)
print("Current working dir: ", os.getcwd())

In [None]:
from ptyrad.data_io import load_params
from ptyrad.initialization import Initializer
from ptyrad.models import PtychoAD
from ptyrad.optimization import CombinedConstraint, CombinedLoss
from ptyrad.reconstruction import loss_logger, recon_step
from ptyrad.utils import (
    copy_params_to_dir,
    get_blob_size,
    make_batches,
    make_output_folder,
    save_results,
    select_scan_indices,
    test_loss_fn,
)
from ptyrad.visualization import (
    plot_forward_pass,
    plot_pos_grouping,
    plot_scan_positions,
    plot_summary,
)

# 02. Initialize optimization

In [None]:
params_path = "params/demo/tBL_WSe2_1probe_1obj_1slice.yml"

params              = load_params(params_path)
exp_params          = params.get("exp_params")
source_params       = params.get("source_params")
hypertune_params    = params.get("hypertune_params")
model_params        = params.get("model_params")
loss_params         = params.get("loss_params")
constraint_params   = params.get("constraint_params")
recon_params        = params.get("recon_params")

In [None]:
init = Initializer(exp_params, source_params).init_all()

In [None]:
# # Estimate the local obj tilts
# from ptyrad.utils import get_local_obj_tilts
# pos = init.init_variables['crop_pos'] + exp_params['Npix']//2
# objp = np.angle(init.init_variables['obj'].mean(0)) # Take the average omode
# dx = exp_params['dx_spec']
# z_distance = exp_params['z_distance']
# slice_indices = [2,9]
# blob_params = {'min_sigma':1, 'max_sigma':5, 'overlap':0.1, 'threshold':0.35, 'exclude_border':(150,150)}
# obj_tilts = get_local_obj_tilts(pos, objp, dx, z_distance, slice_indices, blob_params, window_size=11)
# init.init_variables['obj_tilts'] = obj_tilts

In [None]:
pos = init.init_variables["crop_pos"] + init.init_variables["probe_pos_shifts"]
plot_scan_positions(pos, figsize=(8, 8))

In [None]:
# model_params = {
#     'obj_preblur_std'     : None, # scalar(px), None
#     'detector_blur_std'   : None, # scalar(px), None
#     'lr_params':{
#         'obja'            : 5e-4,
#         'objp'            : 5e-4,
#         'obj_tilts'       : 0,
#         'probe'           : 1e-4,
#         'probe_pos_shifts': 1e-4}}

model = PtychoAD(init.init_variables, model_params, device=DEVICE)
optimizer = torch.optim.Adam(model.optimizer_params)

## Check the forward pass

In [None]:
indices = np.random.randint(0, init.init_variables["N_scans"], 2)
dp_power = 0.5
plot_forward_pass(model, indices, dp_power)

## Finetune the loss params

In [None]:
# loss_params = {
#     'loss_single': {'state': True, 'weight': 1.0, 'dp_pow': 0.5},
#     'loss_poissn': {'state': False, 'weight': 1.0, 'dp_pow':1.0},
#     'loss_pacbed': {'state': False, 'weight': 0.5, 'dp_pow': 0.2},
#     'loss_sparse': {'state': True, 'weight': 0.1, 'ln_order': 1},
#     'loss_simlar': {'state': False, 'weight': 1.0, 'obj_type':'both', 'scale_factor':[1,1,1], 'blur_std':1}
# }

indices = np.random.randint(0, init.init_variables["N_scans"], 16)
loss_fn = CombinedLoss(loss_params, device=DEVICE)
test_loss_fn(model, indices, loss_fn)

# Setup the iteration-wise constraint params

In [None]:
# constraint_params = {
#     'ortho_pmode'   : {'freq': 1},
#     'probe_mask_k'  : {'freq': None, 'radius':0.16, 'width':0.02}, # k-radius should be larger than 2*rbf/Npix to avoid cutting out the BF disk
#     'fix_probe_int' : {'freq': 1},
#     'obj_rblur'     : {'freq': None, 'obj_type':'both', 'kernel_size': 5, 'std':0.4}, # Ideally kernel size is odd and larger than 6std+1 so it decays to 0
#     'obj_zblur'     : {'freq': None, 'obj_type':'both', 'kernel_size': 5, 'std':1},
#     'kr_filter'     : {'freq': None,    'obj_type':'both', 'radius':0.15, 'width':0.05},
#     'kz_filter'     : {'freq': 1,    'obj_type':'both', 'beta':0.3, 'alpha':1},
#     'obja_thresh'   : {'freq': None, 'relax':0, 'thresh':[0.8**(1/15), 1.2**(1/15)]},
#     'objp_postiv'   : {'freq': 1,    'relax':0},
#     # 'obj_gauss_fit' : {'freq': 1, 'obj_type':'both', 'num_fp_configs':1, 'num_groups':None},
#     'tilt_smooth'   : {'freq': None, 'std':2}
# }

constraint_fn = CombinedConstraint(constraint_params, device=DEVICE)


# from ptyrad.utils import test_constraint_fn
# from ptyrad.visualization import plot_sigmoid_mask
# probe_k = np.fft.fftshift(np.abs(np.fft.fft2(np.fft.fftshift(init.init_variables['probe'], axes=(-2,-1))))[0])
# plot_sigmoid_mask(Npix=384, relative_radius=0.16, relative_width=0.02, img = probe_k, show_circles=True)
# constraint_fn = CombinedConstraint(constraint_params, device=DEVICE)
# test_constraint_fn(PtychoAD(init.init_variables, model_params, device='cpu'), constraint_fn, plot_forward_pass)

# 03. Main optimization loop

In [None]:

# recon_params = {
#     'NITER'         : 10,
#     'INDICES_MODE'  : 'center',  # 'full', 'center', 'sub'
#     'BATCH_SIZE'    : 32,
#     'GROUP_MODE'    : 'random',  # 'random', 'sparse', 'compact' # Note that 'sparse' for 256x256 scan could take more than 10 mins on CPU. PtychoShelves automatically switch to 'random' for Nscans>1e3
#     'SAVE_ITERS'    : 10,  # scalar or None
#     'output_dir'    : 'output/tBL-WSe2',  # Output folder and pre/postfix, note that the needed / and _ are automatically generated
#     'prefix_date'   : True,
#     'prefix'        : '',
#     'postfix'       : '',
#     'fig_list'      : ['loss', 'forward', 'probe_r_amp', 'probe_k_amp', 'probe_k_phase', 'pos']  # 'loss', 'forward', 'probe_r_amp', 'probe_k_amp', 'probe_k_phase', 'pos', 'tilt', or 'all' for all the figures
# }

NITER           = recon_params.get("NITER")
INDICES_MODE    = recon_params.get("INDICES_MODE")
BATCH_SIZE      = recon_params.get("BATCH_SIZE")
GROUP_MODE      = recon_params.get("GROUP_MODE")
SAVE_ITERS      = recon_params.get("SAVE_ITERS")
output_dir      = recon_params.get("output_dir")
dir_affixes     = recon_params.get("dir_affixes")
prefix_date     = recon_params.get("prefix_date")
prefix          = recon_params.get("prefix")
postfix         = recon_params.get("postfix")

fig_list        = recon_params.get("fig_list")
copy_params     = recon_params.get("copy_params")

pos = (model.crop_pos + model.opt_probe_pos_shifts).detach().cpu().numpy()
probe_int = model.opt_probe[0].abs().pow(2).detach().cpu().numpy()
dx = init.init_variables["dx"]
d_out = get_blob_size(dx, probe_int, output="d90")  # d_out unit is in Ang

indices = select_scan_indices(
    init.init_variables["N_scan_slow"],
    init.init_variables["N_scan_fast"],
    subscan_slow=INDICES_MODE.get('subscan_slow'),
    subscan_fast=INDICES_MODE.get('subscan_fast'),
    mode=INDICES_MODE.get('mode'),
)

batches = make_batches(indices, pos, BATCH_SIZE, mode=GROUP_MODE)

fig_grouping = plot_pos_grouping(
    pos,
    batches,
    circle_diameter=d_out / dx,
    diameter_type="90%",
    dot_scale=1,
    show_fig=True,
    pass_fig=True,
)

if SAVE_ITERS is not None:
    output_path = make_output_folder(
        output_dir,
        indices,
        exp_params,
        recon_params,
        model,
        constraint_params,
        loss_params,
        dir_affixes
    )
    
    fig_grouping.savefig(output_path + "/summary_pos_grouping.png")

    if copy_params:
        copy_params_to_dir(params_path, output_path)

In [14]:
loss_iters = []
for niter in range(1, NITER + 1):

    shuffle(batches)
    batch_losses, iter_t = recon_step(
        batches, model, optimizer, loss_fn, constraint_fn, niter
    )
    loss_iters.append((niter, loss_logger(batch_losses, niter, iter_t)))

    ## Saving intermediate results
    if SAVE_ITERS is not None and niter % SAVE_ITERS == 0:
        # Note that `exp_params` stores the initial exp_params, while `model` contains the actual params that could be updated if either meas_crop or meas_resample is not None
        save_results(
            output_path,
            model,
            params,
            loss_iters,
            iter_t,
            niter,
            batch_losses,
        )

        ## Saving summary
        plot_summary(
            output_path,
            model,
            loss_iters,
            niter,
            indices,
            init.init_variables,
            fig_list=fig_list,
            show_fig=False,
            save_fig=True,
        )

KeyboardInterrupt: 