 # Detailed walk through for py4DSTEM ptycho reconstruction

 Chia-Hao Lee

cl2696@cornell.edu

Updated on 2024.10.26

In [None]:
%load_ext autoreload
%autoreload 2

import py4DSTEM
import numpy as np
import matplotlib.pyplot as plt
import os
from time import time

from py4DSTEM.process.phase.utils_CHL import make_output_folder, parse_sec_to_time_str, print_system_info, init_datacube, init_ptycho

In [None]:
work_dir = "H:\workspace\ptyrad"
os.chdir(work_dir)
print("Current working dir: ", os.getcwd())
print_system_info()

In [None]:
exp_params = {
    'kv'                : 80, # type: float, unit: kV. Acceleration voltage for relativistic electron wavelength calculation
    'conv_angle'        : 24.9, # type: float, unit: mrad. Semi-convergence angle for probe-forming aperture
    'Npix'              : 128, # type: integer, unit: px (k-space). Detector pixel number, EMPAD is 128. Only supports square detector for simplicity
    'dx_spec'           : 0.1494, # type: float, unit: Ang. Real space pixel size calibration at specimen plane (object, probe, and probe positions share the same pixel size)
    'defocus'           : 0, # type: float, unit: Ang. Defocus (-C1) aberration coefficient for the probe. Positive defocus here refers to actual underfocus or weaker lens strength following Kirkland/abtem/ptychoshelves convention
    'c3'                : 0, # type: float, unit: Ang. Spherical aberration coefficient (Cs) for the probe
    'slice_thickness'   : 2, # type: float, unit: Ang. Slice thickness for multislice ptychography. Typical values are between 1 to 20 Ang
    'Nlayer'            : 6, # type: int, unit: #. Number of slices for multislice object
    'N_scans'           : 4096, # type: int, unit: #. Number of probe positions (or equivalently diffraction patterns since 1 DP / position)
    'N_scan_slow'       : 64, # type: int, unit: #. Number of scan position along slow scan direction. Usually it's the vertical direction of acquisition GUI
    'N_scan_fast'       : 64, # type: int, unit: #. Number of scan position along fast scan direction. usually it's the horizontal direction of acquisition GUI
    'scan_step_size'    : 0.4290, # type: float, unit: Ang. Step size between probe positions in a rectangular raster scan pattern
    'scan_flipT'        : None, # type: None or list of 3 binary booleans (0 or 1) as [flipup, fliplr, transpose] just like PtychoShleves. Default value is None or equivalently [0,0,0]. This applies additional flip and transpose to initialized scan patterns. Note that modifing 'scan_flipT' would change the image orientation, so it's recommended to set this to None, and only use 'meas_flipT' to get the orientation correct
    'scan_affine'       : None, # type: None or list of 4 floats as [scale, asymmetry, rotation, shear] just like PtychoShleves. Default is None or equivalently [1,0,0,0], rotation and shear are in unit of degree. This applies additional affine transformation to initialized scan patterns to correct sample drift and imperfect scan coils
        # Note: in py4dstem by default global_affine_transformation = False, if set True, the transform is esimated using RANSAC instead of taking user input.
        # rotation angle alone could be either estimated by COM, or set by force_com_rotation = [].
    'pmode_max'         : 12, # type: int, unit: #. Maximum number of mixed probe modes. Set to pmode_max = 1 for single probe state, pmode_max > 1 for mixed-state probe during initialization. For simulated initial object, it'll be generated with the specified number of probe modes. For loaded probe, the pmode dimension would be capped at this number
    'meas_permute'           : None, # type: None or list of ints. This applies additional permutation (reorder axes) for the initialized diffraction patterns. The syntax is the same as np.transpose()
    'meas_reshape'           : [64,64,128,128], # type: None or list of 4 ints. This applies additional reshaping (rearrange elements) for the initialized diffraction patterns. The syntax is the same as np.reshape(). This is commonly needed to convert the 4D diffraction dataset (Ry,Rx,ky,kx) into 3D (N_scans,ky,kx)
    'meas_flipT'             : [0,0,1], # type: None or list of 3 binary booleans (0 or 1) as [flipup, fliplr, transpose] just like PtychoShleves. Default is None or [0,0,0] but you may need to find the correct flip and transpose to match your dataset configuration. This applies additional flip and transpose to initialized diffraction patterns. It's suggested to use 'meas_flipT' to correct the dataset orientation and this is the only orientaiton-related value attached to output reconstruction folder name
    'meas_crop'              : None, # type: None or (4,2) nested list of ints as [[scan_slow_start, scan_slow_end], [scan_fast_start, scan_fast_end], [ky_start, ky_end], [kx_start, kx_end]]. This applies additional cropping to the 4D dataset in both real and k-space. This is useful for reconstrucing a subset of real-space probe positions, or to crop the kMax of diffraction patterns. The syntax follows conventional numpy indexing so the upper bound is not included
    'meas_resample'          : None, # type: None or list of 2 floats as [ky_zoom, kx_zoom]. This applies additional resampling of initialized diffraction patterns along ky and kx directions. This is useful for changing the k-space sampling of diffraction patterns. See scipy.ndimage.zoom for more details
    'measurements_params'    : {'source': 'hdf5', 'path': 'data/paper/simu_tBL_WSe2/phonon_temporal_spatial_N4096_dp128.hdf5', 'key': 'dp'}
}

recon_params = {
    'NITER': 20, # type: int. Total number of reconstruction iterations. 1 iteration means a full pass of all selected diffraction patterns. Usually 20-50 iterations can get 90% of the work done with a proper learning rate between 1e-3 to 1e-4. For faster trials in hypertune mode, set 'NITER' to a smaller number than your typical reconstruction to save time. Usually 10-20 iterations are enough for the hypertune parameters to show their relative performance. 
    'BATCH_SIZE': 16, # type: int. Number of diffraction patterns processed simultaneously to get the gradient update. "Batch size" is commonly used in machine learning community, while it's called "grouping" in PtychoShelves. Batch size has an effect on both convergence speed and final quality, usually smaller batch size leads to better final quality for iterative gradient descent, but smaller batch size would also lead to longer computation time per iteration because the GPU isn't as utilized as large batch sizes (due to less GPU parallelism). Generally batch size of 32 to 128 is used, although certain algorithms (like ePIE) would prefer a large batch size that is equal to the dataset size for robustness. For extremely large object (or with a lot of object modes), you'll need to reduce batch_size to save GPU memory as well.
    'SAVE_ITERS': 10,  # type: None or int. Number of completed iterations before saving the current reconstruction results (model, probe, object) and summary figures. If 'SAVE_ITERS' is 50, it'll create an output reconstruction folder and save the results and figures into it every 50 iterations. If None, the output reconstruction folder would not be created and no reconstruction results or summary figures would be saved. If 'SAVE_ITERS' > 'NITER', it'll create the output reconstruction folder but no results / figs would be saved. Typically we set 'SAVE_ITERS' to 50 for reconstruction mode with 'NITER' around 200 to 500. For hypertune mode, it's suggested to set 'SAVE_ITERS' to None and set 'collate_results' to true to save the disk space, while also provide an convenient way to check the hypertune performance by the collated results.
    'update_step_size': 0.5, # type: float. Update step size, default is 0.5 but 0.1 is numerically more stable for multislice
    'output_dir': 'output/paper/simu_tBL_WSe2/phonon_partial/', # type str. Path and name of the main output directory. Ideally the 'output_dir' keeps a series of reconstruction of the same materials system or project. The PtyRAD results and figs will be saved into a reconstruction-specific folder under 'output_dir'. The 'output_dir' folder will be automatically created if it doesn't exist.
    'prefix_date': True, # type: boolean. Whether to prefix a date str to the reconstruction folder or not. Set to true to automatically prefix a date str like '20240903_' in front of the reconstruction folder name. Suggested value is true for both reconstruction and hypertune modes. In hypertune mode, the date string would be applied on the hypertune folder instead of the reconsstruction folder. 
    'prefix': '', # type: str. Prefix this string to the reconstruction folder name. Note that "_" will be automatically generated, and the attached str would be after the date str if 'prefix_date' is true. In hypertune mode, the prefix string would be applied on the hypertune folder instead of the reconsstruction folder.  
    'postfix': '', # type: str. Postfix this string to the reconstruction folder name. Note that "_" will be automatically generated. In hypertune mode, the postfix string would be applied on the hypertune folder instead of the reconsstruction folder.  
    'save_result': ['model', 'objp', 'probe'], # type: list of strings. This list specifies the available results to save. Available options are 'model', 'obja', 'objp', and 'probe'. 'model' is the pytorch dict stored as hdf5 with '.pt' file extension. 'model' contains optimizable tensors and metadata so that you can always refine from it and load whatever optimizable tensors (object, probe, positions, tilts) if you want to continue the reconstruction. It's similar to the NiterXXX.mat from PtychoShelves. 'object' and 'probe' output the reconstructed object and probe as '.tif'. If you don't want to save anything, set 'SAVE_ITERS' to None. Suggested setting is to save everything (i.e., ['model', 'obja', 'objp', 'probe']).
    'result_modes': {'obj_dim': [2, 3], 'FOV': ['crop'], 'bit': ['8']}, # type: dict. This dict specifies which object output is saved by their final dimension ('obj_dim'), whether to save the full or cropped FOV ('FOV') of object, and whether to save the raw or normalized bit depth version of object and probe. A comprehensive (but probably redundant) saving option looks like {'obj_dim': [2,3,4], 'FOV': ['full', 'crop'], 'bit': ['raw', '32', '16', '8']}. 'obj_dim' takes a list of int, the int ranges between 2 to 4, corresponding to 2D to 4D object output. Set 'obj_dim': [2] if you only want the zsum from multislice ptychography. Suggested value is [2,3,4] to save all possible output. 'FOV' takes a list of strings, the available strings are either 'full' or 'crop'. Suggested value is 'crop' so the lateral padded region of object is not saved. 'bit' takes a list of strings, the available strings are 'raw', '32', '16', and '8'. 'raw' is the original value range, while '32' normalizes the value from 0 to 1. '16' and '8' will normalize the value from 0 to 65535 and 255 correspondingly. Defualt is '8' to save only the normalized 8bit result for quick visualization. You can set it to ['raw', '8'] if you want to keep the original float32 bit results with normalized 8bit results. These postprocessing would postfix corresponding labels to the result files.
    'kz_regularization_gamma': 1, # type: None or float. Regularization constant for kz regularzation in multislice ptycho. This is ignored if using single slice, or when the value is set to None (None).
    'recon_kwargs': None # type: None or dict. This allows users to pass more detailed args to ptycho.reconstruct()
}

In [None]:
datacube, exp_params = init_datacube(exp_params)

### (Optional) Visualize the data

In [None]:
# py4DSTEM.show(
#     datacube.get_dp_mean(),
#     cmap = 'magma',
#     scaling = 'log',
#     figsize = (3,3),
# )

In [None]:
# probe_radius_pixels, probe_qx0, probe_qy0 = datacube.get_probe_size(
#     datacube.tree('dp_mean').data,
#     plot = True,
#     figsize = (3,3),
# )

# # Print the estimated center and probe radius
# print('Estimated probe center =', 'qx = %.2f, qy = %.2f' % (probe_qx0, probe_qy0), 'pixels')
# print('Estimated probe radius =', '%.2f' % probe_radius_pixels, 'pixels')

In [None]:
# # Make a virtual bright field and dark field image
# expand_BF = 2.0  # expand radius by 2 pixels to encompass the full center disk

# center = (probe_qx0, probe_qy0)
# radius_BF = probe_radius_pixels + expand_BF
# radii_DF = (probe_radius_pixels + expand_BF, 1e3)

# datacube.get_virtual_image(
#     mode = 'circle',
#     geometry = (center,radius_BF),
#     name = 'bright_field',
#     shift_center = False,
# )
# datacube.get_virtual_image(
#     mode = 'annulus',
#     geometry = (center,radii_DF),
#     name = 'dark_field',
#     shift_center = False,
# );

# # plot the virtual images
# py4DSTEM.show(
#     [
#         datacube.tree('bright_field'),
#         datacube.tree('dark_field'),               
#     ],
#     cmap='viridis',
#     ticks = False,
#     axsize=(3,3),
#     title=['Bright Field','Dark Field'],
# )

In [None]:
## Initialize py4dstem ptycho instance
ptycho = init_ptycho(datacube, exp_params)
ptycho.preprocess(
    plot_center_of_mass = False,
    plot_rotation=False,
    # force_com_rotation = 93,
); # This is a required step

In [None]:
## Reconstruct py4dstem ptycho and visualize the result

output_path = make_output_folder(exp_params, recon_params)

solver_start_t = time()

ptycho.reconstruct(
    num_iter = recon_params['NITER'],
    reconstruction_method = 'gradient-descent',
    max_batch_size = recon_params['BATCH_SIZE'],
    step_size = recon_params['update_step_size'], # Update step size, default is 0.5 but 0.1 is numerically more stable for multislice
    reset = True, # If True, previous reconstructions are ignored
    progress_bar = False, # If True, reconstruction progress is displayed
    store_iterations = False, # If True, reconstructed objects and probes are stored at each iteration.
    save_iters = recon_params['SAVE_ITERS'], # Added by CHL to save intermediate results
    save_result = recon_params['save_result'], # Added by CHL to save intermediate results
    result_modes = recon_params['result_modes'], # Added by CHL to save intermediate results
    output_path = output_path,
    kz_regularization_gamma = recon_params['kz_regularization_gamma'],
).visualize(
    # iterations_grid = 'auto'
);

solver_end_t = time()
print(f"py4DSTEM ptycho solver is finished in {parse_sec_to_time_str(solver_end_t - solver_start_t)}")

In [None]:
plt.figure()
plt.imshow(np.angle(ptycho.object).sum(0), cmap='magma')
plt.scatter(x=ptycho._positions_px[0,1], y=ptycho._positions_px[0,0])
plt.show()