 # PtyRAD - PTYchographic Reconstruction with Automatic Differentiation

 Chia-Hao Lee

cl2696@cornell.edu

Created 2024.03.08

## Worklog
- 20240308: Start working on this
- 20240310: Got it working as a close loop for single object with mixed probe
- 20240311: Implemented mixed object and unified dimension for 2D and 3D
- 20240313: Tried position correction, need to do it with real-valued object. Switch to 2 channels for probe and object.
- 20240314: Got probe position correction roughly implemented with STN on object. Implemented obj phase L1, obj phase positivity, obj phase TV.
- 20240315: Implemented forward model with batch, but turns out it's bottlenecking at the get_obj_ROI.
- 20240316: Decided to roll back to integer get_obj_ROI.Implementd get_probes with torchvision.transformation.affine as well. Now everything is batch.




## Note

## Done feature
- Probe mode optimization (pmode, Ny, Nx)
- Multi-object with 2/3D optimization (omode, Nz, Ny, Nx, 2)
- Probe position correction with STN (roll back to integer obj position)
- Obj phase L1, obj phase positivity, obj phase TV
- Probe position corrcetion with v2.affine with shifted probe



## TODO

### Notebook
- Loading steps
- Preprocessing steps


### Optimization
- Batch calculation of forward model without breaking the backprop
- Test regularization / constraint / multiscale consistence loss
- Multiscale pyramidal reconstruction
- Specify MLs / MLc

### Initialization
- Add loading initial guess
- Add object and probe initialization (probe modes and object modes)

### Input / Output
- Add saving intermediate results
- Add checkpoints



In [None]:
from preprocess import preprocess_CBED, preprocess_ptycho_output_dict 
from optimization import cbed_rmse
from forward_model import multislice_forward_model_batch_all
from utils import cplx_from_np, complex_object_interp3d, near_field_evolution
from data_io import load_fields_from_mat, load_hdf5
from scipy.ndimage import gaussian_filter

#from visualization import plot_recon_progress


import numpy as np
import matplotlib.pyplot as plt

import torch
from torchvision.transforms import v2


GPUID = 0
DEVICE = torch.device("cuda:" + str(GPUID))
print("Execution device: ", DEVICE)
print("PyTorch version: ", torch.__version__)
print("CUDA available: ", torch.cuda.is_available())
print("CUDA version: ", torch.version.cuda)
print("CUDA device:", torch.cuda.get_device_name(GPUID))

In [None]:
# Setup data path
#ptycho_output_mat_path = "data/data3D_300kV_df_20nm_alpha_21.4mrad_Cs_0.0um_dp_128_blur_0px_dose_1.0e+08ePerAng2/1/roi_1_Ndp_128/MLs_L1_p4_g32_Ns21_dz10_reg1_dec0.92_dpFlip_T/Niter500.mat"
#exp_CBED_path = "data/data3D_300kV_df_20nm_alpha_21.4mrad_Cs_0.0um_dp_128_blur_0px_dose_1.0e+08ePerAng2/1/data_roi_1_Ndp_128_dp.hdf5"

# Exp PSO
# ptycho_output_mat_path = "data/MLs_L1_p8_g192_Ndp128_pc50_noModel_vp1_Ns21_dz10_reg1/Niter200.mat"
# exp_CBED_path = "data/MLs_L1_p8_g192_Ndp128_pc50_noModel_vp1_Ns21_dz10_reg1/PSO_data_roi0_Ndp256_dp.hdf5"

# Exp PSO
# ptycho_output_mat_path = "data/MLs_L1_p4_g32_Ns21_dz10_dpFlip_T/Niter50_ML.mat"
# exp_CBED_path = "data/MLs_L1_p4_g32_Ns21_dz10_dpFlip_T/data_roi_1_Ndp_128_dp.hdf5"

ptycho_output_mat_path = "data/Fig_1h_24.9mrad_Themis/1/roi1_Ndp128_step128\MLs_L1_p10_g128_pc0_noModel_updW100_mm_dpFlip_ud_T/Niter9000_v7.mat"
exp_CBED_path =          "data/Fig_1h_24.9mrad_Themis/1/data_roi1_Ndp128_step128_dp.hdf5"

print("Loading ptycho output and input CBED")
ptycho_output_dict = load_fields_from_mat(ptycho_output_mat_path, 'All', squeeze_me=True, simplify_cells=True)
#exp_CBED = load_empad_as_4D(exp_CBED_path, 128,130,128,128,'C')
input_CBED, CBED_source = load_hdf5(exp_CBED_path, dataset_key='dp') 
# Note that loading Matlab-generated F-order HDF5 (kx, ky, Nscan) into Python would automatically make it C-order (Nscan, ky, kx)

In [None]:
# Preprocessing and setting up the data dimension

print("Preprocessing ptycho output and experimental CBED\n")
probe, object, exp_params = preprocess_ptycho_output_dict(ptycho_output_dict)
cbeds                     = np.flip(input_CBED, axis=1) # preprocess_CBED(exp_CBED) # (Nscan, ky, kx) that matches the .tif view
cbeds[cbeds<0] = 0

# Prepare the experimental param for forward model and dataset generation
lambd           = exp_params['lambd']
dx_spec         = exp_params['dx_spec']
z_distance      = 8 #exp_params['z_distance_arr'][0] # Ang, for 2D input, put the final desired total thickness if you're planning to do multislice
probe_positions = exp_params['probe_positions']
Nlayer          = exp_params['Nlayer']
N_scans         = exp_params['N_scans']

# Preprocessing variables for tBL-WSe2 with Themis
object = object[None, :,:]#object.transpose((2,0,1)) # Converting object into (Nz, Ny, Nx)
probe = probe.transpose((2,0,1))        # Converting probe into (pmode, Ny, Nx)
      
print(f"\nobject dtype/shape          (Nz, Ny, Nx) = {object.dtype}, {object.shape},\
        \nprobe data dtype/shape   (pmode, Ny, Nx) = {probe.dtype}, {probe.shape},\
        \ncbeds data dtype/shape       (N, Ky, Kx) = {cbeds.dtype}, {cbeds.shape}")

In [None]:

# Reslice the z slices
final_z = 8 # z slices #21 for PSO
z_zoom = final_z / Nlayer
z_distance = z_distance / z_zoom # Scale the interlayer distance based on z_zoom
object = complex_object_interp3d(object, (z_zoom, 1, 1), z_axis = 0, use_np_or_cp = 'np') # Use cp for faster interpolation and convert it back to np with .get() as a default postprocessing

# Calculate the crop coordinates with floating points
probe_positions = probe_positions[:, [1,0]] # The first index after shifting is the row index (along vertical axis)
crop_coordinates = probe_positions + np.ceil((object.shape[-1]/2) - (probe.shape[-1]/2)) - 1 # For Matlab - Python index shift
sub_px_shift = crop_coordinates - np.round(crop_coordinates) # This shift (tH, tW) would be added to the probe to compensate the integer obj cropping
crop_indices = np.round(crop_coordinates) # This one is rounded and 

## Calculate propagator for multislice forward model
extent = dx_spec * np.array(probe.shape[-2:])
_, H, _, _ = near_field_evolution(probe.shape[-2:], z_distance, lambd, extent, use_ASM_only=True, use_np_or_cp='np')

# Specify forward model accuracy options
N_max = 16384
pmode_max = 10 # 4
omode_max = 1 # By default we only do 1 object mode

# Initialize probe / object if needed (expand them to desired dimension)
if object.ndim == 2:
      # (Ny, Nx) for 2D ptychography
      object_data = object[None, :, :]
      print(f"Expanding 2D object (Ny, Nx) from {object.shape} to 3D object (Nz, Ny, Nx) {object_data.shape}")
elif object.ndim == 3:
      # (Nz, Ny, Nx) for multislice (3D) ptycho
      object_data = object
      print(f"Object is already 3D (Nz, Ny, Nx) with {object_data.shape}")

obj_power_factor = np.exp(np.log(0.02)/2/final_z) # Multiply the object by this factor would result a 2% total scattering intensity in the CBED after all the loss from all multiplicative slices

object_data = np.stack([object_data if i < 1 else gaussian_filter(object_data, sigma = (0,i//2, i%2)) * obj_power_factor for i in range(omode_max)], axis=0).astype('complex64') # Adding gaussian blur to the obj modes
probe_data = probe[:pmode_max, :, :].astype('complex64')
crop_indices_data = crop_indices[:N_max].astype('int32')
shift_vec_data = sub_px_shift.astype('float32')
cbeds_data = cbeds[:N_max].astype('float32')
H = H.astype('complex64')
      
print(f"\nobject_data dtype/shape (omode, Nz, Ny, Nx) = {object_data.dtype}, {object_data.shape}, \
        \nprobe_data dtype/shape      (pmode, Ny, Nx) = {probe_data.dtype}, {probe_data.shape}, \
        \ncrop_indices_data                     (N,2) = {crop_indices_data.dtype}, {crop_indices_data.shape}, \
        \nshift_vec_data                        (N,2) = {shift_vec_data.dtype}, {shift_vec_data.shape}, \
        \ncbeds_data dtype/shape          (N, Ky, Kx) = {cbeds_data.dtype}, {cbeds_data.shape}, \
        \nH dtype/shape                      (Ky, Kx) = {H.dtype}, {H.shape}")

In [None]:
crop_coordinates

In [None]:
sub_px_shift

In [None]:
crop_indices_data



### Note: Everything is still np array (complex) from this cell, it'll be converted to tensor later

## Build the model object

In [None]:
# Currently putting the model in the notbeook, but eventually would move to another module

class PtychoAD(torch.nn.Module):
    def __init__(self, init_obj, init_probe, init_crop_pos, init_probe_pos_shifts, H, lr_params=None, device='cuda:0'):
        super(PtychoAD, self).__init__()
        with torch.no_grad():
            self.device = device
            self.opt_obj = cplx_from_np(init_obj, cplx_type="amp_phase", ndim=-1).to(self.device)
            self.opt_probe = cplx_from_np(init_probe, cplx_type="amp_phase", ndim=-1).to(self.device) 
            self.opt_probe_pos_shifts = torch.tensor(init_probe_pos_shifts, device=self.device)
            self.crop_pos = torch.tensor(init_crop_pos, dtype=torch.int32, device=self.device)
            self.H = torch.tensor(H, dtype=torch.complex64, device=self.device)
            self.roi_shape = init_probe.shape[-2:]
            self.shift_probes = (lr_params['probe_pos_shifts'] != 0) # Set shift_probes to False if lr_params['probe_pos_shifts'] = 0
            
            # Create a dictionary to store the optimizable tensors
            self.optimizable_tensors = {
                'obj': self.opt_obj,
                'probe': self.opt_probe,
                'probe_pos_shifts': self.opt_probe_pos_shifts
            }

            self.optimizer_params = []
            if lr_params:
                for param_name, lr in lr_params.items():
                    if param_name in self.optimizable_tensors:
                        self.optimizable_tensors[param_name].requires_grad = (lr != 0)  # Set requires_grad based on learning rate
                        if lr != 0:
                            self.optimizer_params.append({'params': [self.optimizable_tensors[param_name]], 'lr': lr})
                    else:
                        print(f"Warning: '{param_name}' is not a valid parameter name.")

            print('PtychoAD major variables:')
            for name, tensor in self.optimizable_tensors.items():
                print(f"{name}: {tensor.shape}, {tensor.dtype}, device:{tensor.device}, grad:{tensor.requires_grad}, lr:{lr_params[name]:.0e}")
            
    def get_obj_ROI(self, indices):
        """ Get object ROI with integer coordinates """
        # It's strongly recommended to do integer version of get_obj_ROI
        # opt_obj.shape = (B,D,H,W,C) = (omode,D,H,W,2)
        # object_patches = (N,B,D,H,W,2), N is the additional sample index within the input batch, B is now used for omode.
        
        height, width  = self.roi_shape[0], self.roi_shape[1]
        object_patches = torch.zeros((len(indices), *self.opt_obj.shape[:2], height, width, 2)).to(self.device)

        for i, idx in enumerate(indices):
            height_start, height_end = self.crop_pos[idx,0], self.crop_pos[idx,0] + height
            width_start,  width_end  = self.crop_pos[idx,1], self.crop_pos[idx,1] + width
            object_patch      =  self.opt_obj[:, :, height_start:height_end, width_start:width_end, :] # object_patch (omode, D,H,W,2), 2 for the amp/phase channel
            object_patches[i] = object_patch

        return object_patches

    def get_probes(self, indices):
        """ Get probes for each position """
        # If you're not trying to optimize probe positions, there's not much point using sub-px shifted stationary probes
        # This function will return a single probe when self.shift_probes = False,
        # and would only be returning multiple sub-px shifted probes if you're optimizing self.opt_probe_pos_shifts

        if self.shift_probes:
            temp_probe = self.opt_probe.permute(0,3,1,2) # (pmode, Ny, Nx, 2) -> (pmode, 2, Ny, Nx)
            probes = torch.zeros((len(indices), *temp_probe.shape)).to(self.device) # (N, pmode, 2, Ny, Nx)

            for i, idx in enumerate(indices):
                tH = self.opt_probe_pos_shifts[idx][0] # Note that translate (a,b) is in unit of px, although the doc says it's fractional
                tW = self.opt_probe_pos_shifts[idx][1] # positive is moving to right/down for tW and tH.
                probes[i] = v2.functional.affine(temp_probe, translate = (tW, tH), interpolation=v2.InterpolationMode.BILINEAR, angle=0, scale=1, shear=0) 
            probes = probes.permute(0,1,3,4,2) # (N, pmode, Ny, Nx, 2)
        else:
            probes = self.opt_probe[None,...] # Extend a singleton N dimension, essentially using same probe for all samples
        
        return probes
        
    def forward(self, indices):
        """ Doing the forward pass and get an output diffraction pattern for each input index """
        # The indices are passed as an array and representing the whole batch
        
        object_patches = self.get_obj_ROI(indices)
        probes = self.get_probes(indices)
        dp_fwd = multislice_forward_model_batch_all(object_patches, probes, self.H)
        
        return dp_fwd

## Create optimization object

In [None]:
#init_obj      = np.exp(1j * 0.05*np.random.rand(*object_data.shape)).astype('complex64') # init_obj = object_data
init_obj              = object_data
init_probe            = probe_data
init_crop_pos         = crop_indices_data
init_probe_pos_shifts = shift_vec_data
measurements  = torch.from_numpy(cbeds_data).cuda()

model = PtychoAD(init_obj, init_probe, init_crop_pos, init_probe_pos_shifts, H, 
                lr_params={'obj': 1e-3, 
                           'probe': 0, 
                           'probe_pos_shifts': 0},
                device=DEVICE)

opt = torch.optim.Adam(model.optimizer_params)

## Check the forward pass

In [None]:
indices = [10000,100,200,300]
dp_power = 0.5

dp_fwd = model(indices).detach().cpu()
obj_ROI = model.get_obj_ROI(indices).detach().cpu()

for i, idx in enumerate(indices):

    fig, axs = plt.subplots(2, 2, figsize=(8, 8))

    object_patch = np.angle(object_data[0,0,crop_indices_data[idx,0]:crop_indices_data[idx,0]+probe.shape[-1],crop_indices_data[idx,1]:crop_indices_data[idx,1]+probe.shape[-1]])
    
    axs[0, 0].imshow(obj_ROI[i,0,:,:,:,1].sum(0))
    axs[0, 0].set_title(f"Model obj {idx}")

    axs[0, 1].imshow(object_patch)
    axs[0, 1].set_title(f"Data obj {idx}")

    axs[1, 0].imshow(dp_fwd[i]**dp_power)
    axs[1, 0].set_title(f"Model CBED {idx}")

    axs[1, 1].imshow(cbeds_data[idx]**dp_power)
    axs[1, 1].set_title(f"Data CBED {idx}")

    plt.show()



In [None]:
# BATCH_SIZE = 128 # The actual batch size would only be "close" if it's not divisible by len(measurments)
# num_batch = len(measurements)/BATCH_SIZE
# shuffled_indices = np.random.choice(len(measurements), size = len(measurements), replace=False) # Creates a shuffled 1D array of indices
# batches = np.array_split(shuffled_indices, num_batch) # return a list of `num_batch` arrays, or [batch0, batch1, ...]

In [None]:
# torch.cuda.synchronize()
# with torch.no_grad():
#     for batch in batches:
#         dp_fwd = model.get_obj_ROI(batch)

## Main optimization loop

In [None]:
from time import time
from IPython.display import clear_output
from torchmetrics.image import TotalVariation
# https://lightning.ai/docs/torchmetrics/stable/image/total_variation.html 
# This TV only applies to the last 2 dim (N,C,H,W)

NITER = 100
BATCH_SIZE = 256 # The actual batch size would only be "close" if it's not divisible by len(measurments)
num_batch = len(measurements)/BATCH_SIZE
cbed_shape = measurements.shape[1:]
loss_iters = []

Softplus = torch.nn.Softplus(beta=100, threshold=2)
tv = TotalVariation().to(DEVICE)
for iter in range(NITER):
    loss_batches = []
    shuffled_indices = np.random.choice(len(measurements), size = len(measurements), replace=False) # Creates a shuffled 1D array of indices
    batches = np.array_split(shuffled_indices, num_batch) # return a list of `num_batch` arrays, or [batch0, batch1, ...]

    for batch_idx, batch in enumerate(batches):
        start_batch_t = time()
        
        model_CBEDs = model(batch)
        measured_CBEDs = measurements[batch]
        
        loss_single     = cbed_rmse(model_CBEDs.sqrt(), measured_CBEDs.sqrt())    
        loss_pacbed     = 0 #cbed_rmse(model_CBEDs.mean(0).pow(0.2), measured_CBEDs.mean(0).pow(0.2)) # Ensuring the Position-averaged CBED are consistent as well
        loss_tv         = 0 #1e-7 * tv(model.opt_obj[:,1])
        loss_l1         = 0 #torch.mean(model.opt_obj[:,1].abs())
        loss_batch      = loss_single + loss_pacbed + loss_tv + loss_l1
        
        loss_batch.backward()
        opt.step()
        opt.zero_grad()
        end_batch_t = time()

        # #Update the plots after each update
        # AD_image = np.angle(model.opt_obj.detach().cpu()).sum(1)
        # Input_image = np.angle(object_data).sum(axis=1)[0]

        # # Show the figure per batch
        # clear_output(wait=True)   
        # fig = plot_recon_progress(iter, batch_idx, AD_image, Input_image)
        # plt.show()
        
        loss_batches.append(loss_batch)
        if batch_idx % 10 == 0:
            print(f"Done batch {batch_idx} in iter {iter} in {(end_batch_t - start_batch_t):.1f} sec")
    print(f"Iter: {iter}, Loss_batch: {loss_batch:.3f}, Loss_single: {loss_single:.3f}, Loss_pacbed: {loss_pacbed:.3f}, Loss_tv: {loss_tv:.3f}, Loss_L1: {loss_l1:.3f}")

    # # Do a softplus constraint at the end of each iter without grad
    # with torch.no_grad():
    #     print(f"Applying softplus to obj phase for positivity after iter {iter}")
    #     model.opt_obj[:,1] = Softplus(model.opt_obj[:,1])
        
    loss_iters.append(sum(loss_batches)/len(loss_batches))
    print('Iter: {} Loss: {} '.format(iter, loss_iters[-1]))



In [None]:
plt.figure(figsize=(12,12))
plt.imshow(model.opt_obj.detach().cpu()[0,:,:,:,1].sum(0))
plt.show()

In [None]:
from tifffile import imwrite
imwrite("output/AD_image_WSe2_test8.tif", model.opt_obj.detach().cpu().numpy()[0,...,1].astype('float32'))