 # Detailed walk through for PtyRAD

 Chia-Hao Lee

cl2696@cornell.edu

Updated on 2025.02.13

Note: This notebook is designed for showcasing only the "reconstruction" mode, most of the wrapper class / functions are exposed so that you can see how different components work together.

# 01. Imports

In [None]:
import os
from random import shuffle

import numpy as np
import torch

work_dir = "H:\workspace\ptyrad"
os.chdir(work_dir)
print("Current working dir: ", os.getcwd())

In [None]:

from ptyrad.data_io import load_params
from ptyrad.initialization import Initializer
from ptyrad.models import PtychoAD
from ptyrad.optimization import CombinedConstraint, CombinedLoss, create_optimizer
from ptyrad.reconstruction import recon_step
from ptyrad.utils import (
    copy_params_to_dir,
    CustomLogger,
    get_blob_size,
    make_batches,
    make_output_folder,
    parse_sec_to_time_str,
    print_system_info,
    save_results,
    select_scan_indices,
    set_gpu_device,
    test_loss_fn,
    time_sync,
    vprint,
)
from ptyrad.visualization import (
    plot_forward_pass,
    plot_pos_grouping,
    plot_scan_positions,
    plot_summary,
)


In [None]:
logger = CustomLogger(log_file='ptyrad_log.txt', log_dir='auto', prefix_date=True, append_to_file=False, show_timestamp=True)

print_system_info()
device = set_gpu_device(gpuid=0)

# 02. Initialize optimization

In [None]:
params_path = "params/demo/tBL_WSe2_reconstruct.yml"

params              = load_params(params_path)
exp_params          = params['exp_params']
source_params       = params['source_params']
hypertune_params    = params['hypertune_params'] # It's parsed but not needed in this demo notebook
model_params        = params['model_params']
loss_params         = params['loss_params']
constraint_params   = params['constraint_params']
recon_params        = params['recon_params']

In [None]:
init = Initializer(exp_params, source_params).init_all()

In [None]:
pos = init.init_variables["crop_pos"] + init.init_variables["probe_pos_shifts"]
plot_scan_positions(pos, figsize=(8, 8))

In [None]:
model = PtychoAD(init.init_variables, model_params, device=device)
optimizer = create_optimizer(model.optimizer_params, model.optimizable_params)

## Check the forward pass

In [None]:
indices = np.random.randint(0, init.init_variables["N_scans"], 2)
dp_power = 0.5
plot_forward_pass(model, indices, dp_power)

## Finetune the loss params

In [None]:
indices = np.random.randint(0, init.init_variables["N_scans"], 16)
loss_fn = CombinedLoss(loss_params, device=device)
test_loss_fn(model, indices, loss_fn)

# Setup the iteration-wise constraint params

In [None]:
constraint_fn = CombinedConstraint(constraint_params, device=device)

# 03. Main optimization loop

In [None]:
NITER             = recon_params['NITER']
INDICES_MODE      = recon_params['INDICES_MODE']
batch_size        = recon_params['BATCH_SIZE'].get("size")
grad_accumulation = recon_params['BATCH_SIZE'].get("grad_accumulation")
GROUP_MODE        = recon_params['GROUP_MODE']
SAVE_ITERS        = recon_params['SAVE_ITERS']
output_dir        = recon_params['output_dir']
recon_dir_affixes = recon_params['recon_dir_affixes']
prefix_date       = recon_params['prefix_date']
prefix            = recon_params['prefix']
postfix           = recon_params['postfix']
save_result       = recon_params['save_result']
result_modes      = recon_params['result_modes'] 
selected_figs     = recon_params['selected_figs']
copy_params       = recon_params['copy_params']

pos = (model.crop_pos + model.opt_probe_pos_shifts).detach().cpu().numpy() # The .to(torch.float32) upcast is a preventive solution because .numpy() doesn't support bf16
probe_int = model.get_complex_probe_view().abs().pow(2).sum(0).detach().cpu().numpy()
dx = init.init_variables["dx"]
d_out = get_blob_size(dx, probe_int, output="d90")  # d_out unit is in Ang

indices = select_scan_indices(
    init.init_variables['N_scan_slow'],
    init.init_variables['N_scan_fast'],
    subscan_slow=INDICES_MODE['subscan_slow'],
    subscan_fast=INDICES_MODE['subscan_fast'],
    mode=INDICES_MODE['mode'],
)

batches = make_batches(indices, pos, batch_size, mode=GROUP_MODE)
vprint(f"The effective batch size (i.e., how many probe positions are simultaneously used for 1 update of ptychographic parameters) is batch_size * grad_accumulation = {batch_size} * {grad_accumulation} = {batch_size*grad_accumulation}")

fig_grouping = plot_pos_grouping(
    pos,
    batches,
    circle_diameter=d_out / dx,
    diameter_type="90%",
    dot_scale=1,
    show_fig=True,
    pass_fig=True,
)

if SAVE_ITERS is not None:
    output_path = make_output_folder(
        output_dir,
        indices,
        exp_params,
        recon_params,
        model,
        constraint_params,
        loss_params,
        recon_dir_affixes
    )
    
    fig_grouping.savefig(output_path + "/summary_pos_grouping.png")

    if copy_params:
        copy_params_to_dir(params_path, output_path)

# Flush to file after the output_path is created
if logger is not None and logger.flush_file:
    logger.flush_to_file(log_dir = output_path)

In [None]:
start_t = time_sync()
vprint("### Starting the PtyRADSolver in reconstruction mode ###")
vprint(" ")

for niter in range(1, NITER + 1):

    shuffle(batches)
    batch_losses = recon_step(
        batches, grad_accumulation, model, optimizer, loss_fn, constraint_fn, niter
    )

    ## Saving intermediate results
    if SAVE_ITERS is not None and niter % SAVE_ITERS == 0:
        with torch.no_grad():
        # Note that `exp_params` stores the initial exp_params, while `model` contains the actual params that could be updated if either meas_crop or meas_resample is not None
            save_results(
                output_path,
                model,
                params,
                optimizer,
                niter,
                indices,
                batch_losses,
            )

            ## Saving summary
            plot_summary(
                output_path,
                model,
                niter,
                indices,
                init.init_variables,
                selected_figs=selected_figs,
                show_fig=False,
                save_fig=True,
            )
vprint(f"### Finish {NITER} iterations, averaged iter_t = {np.mean(model.iter_times):.5g} sec ###")
vprint(" ")
end_t = time_sync()
solver_t = end_t - start_t
time_str = f", or {parse_sec_to_time_str(solver_t)}" if solver_t > 60 else ""
vprint(f"### The PtyRADSolver is finished in {solver_t:.3f} sec {time_str} ###")

if logger is not None and logger.flush_file:
    logger.close()