In [None]:
%load_ext autoreload
%autoreload 2

import py4DSTEM
import numpy as np
import matplotlib.pyplot as plt
import h5py
import os

In [None]:
work_dir = "H:\workspace\ptyrad"
os.chdir(work_dir)
print("Current working dir: ", os.getcwd())

In [None]:
# Wrapper to intialize ptycho class
def initialize_ptycho(datacube, exp_params):
    
    Npix = exp_params['Npix']
    N_scan_fast = exp_params['N_scan_fast']
    N_scan_slow = exp_params['N_scan_slow']
    dx_spec = exp_params['dx_spec']
    scan_step_size = exp_params['scan_step_size']
    
    pos_extent = np.array([N_scan_slow,N_scan_fast]) * scan_step_size / dx_spec
    object_extent = 1.2 * (pos_extent + Npix)
    object_padding_px = tuple((object_extent - pos_extent)//2)
    print(f"pos_extent = {pos_extent} px, object_extent = {object_extent}, object_padding_px = {object_padding_px}")
    
    if exp_params['Nlayer'] == 1:
        print("Initializing MixedstatePtychography")
        ptycho = py4DSTEM.process.phase.MixedstatePtychography(
            datacube=datacube,
            num_probes = exp_params['pmode_max'],
            verbose=True,
            energy = exp_params['kv']*1e3, # energy in eV
            defocus= exp_params['defocus'], # defocus guess in A
            semiangle_cutoff = exp_params['conv_angle'],
            object_padding_px = object_padding_px,
            device='gpu', 
            storage='cpu', 
        )
    else:
        print("Initializing MixedstateMultislicePtychography")
        ptycho = py4DSTEM.process.phase.MixedstateMultislicePtychography(
            datacube=datacube,
            num_probes = exp_params['pmode_max'],
            num_slices=exp_params['Nlayer'],
            slice_thicknesses=exp_params['z_distance'],
            verbose=True,
            energy = exp_params['kv']*1e3, # energy in eV
            defocus= exp_params['defocus'], # defocus guess in A
            semiangle_cutoff = exp_params['conv_angle'],
            object_padding_px = object_padding_px,
            device='gpu', 
            storage='cpu',
        )
    return ptycho

In [None]:
# Note: py4dstem always estimates an affine transformation instead of taking user input

exp_params = {
    'kv'                : 80, # type: float, unit: kV. Acceleration voltage for relativistic electron wavelength calculation
    'conv_angle'        : 24.9, # type: float, unit: mrad. Semi-convergence angle for probe-forming aperture
    'Npix'              : 128, # type: integer, unit: px (k-space). Detector pixel number, EMPAD is 128. Only supports square detector for simplicity
    'dx_spec'           : 0.1494, # type: float, unit: Ang. Real space pixel size calibration at specimen plane (object, probe, and probe positions share the same pixel size)
    'defocus'           : 0, # type: float, unit: Ang. Defocus (-C1) aberration coefficient for the probe. Positive defocus here refers to actual underfocus or weaker lens strength following Kirkland/abtem/ptychoshelves convention
    'c3'                : 0, # type: float, unit: Ang. Spherical aberration coefficient (Cs) for the probe
    'z_distance'        : 12, # type: float, unit: Ang. Slice thickness for multislice ptychography. Typical values are between 1 to 20 Ang
    'Nlayer'            : 1, # type: int, unit: #. Number of slices for multislice object
    'N_scans'           : 16384, # type: int, unit: #. Number of probe positions (or equivalently diffraction patterns since 1 DP / position)
    'N_scan_slow'       : 128, # type: int, unit: #. Number of scan position along slow scan direction. Usually it's the vertical direction of acquisition GUI
    'N_scan_fast'       : 128, # type: int, unit: #. Number of scan position along fast scan direction. usually it's the horizontal direction of acquisition GUI
    'scan_step_size'    : 0.4290, # type: float, unit: Ang. Step size between probe positions in a rectangular raster scan pattern
    'scan_flipT'        : None, # type: None or list of 3 binary booleans (0 or 1) as [flipup, fliplr, transpose] just like PtychoShleves. Default value is None or equivalently [0,0,0]. This applies additional flip and transpose to initialized scan patterns. Note that modifing 'scan_flipT' would change the image orientation, so it's recommended to set this to None, and only use 'meas_flipT' to get the orientation correct
    'scan_affine'       : None, # type: None or list of 4 floats as [scale, asymmetry, rotation, shear] just like PtychoShleves. Default is None or equivalently [1,0,0,0], rotation and shear are in unit of degree. This applies additional affine transformation to initialized scan patterns to correct sample drift and imperfect scan coils
        # Note: in py4dstem by default global_affine_transformation = False, if set True, the transform is esimated using RANSAC instead of taking user input.
        # rotation angle alone could be either estimated by COM, or set by force_com_rotation = [].
    'pmode_max'         : 6, # type: int, unit: #. Maximum number of mixed probe modes. Set to pmode_max = 1 for single probe state, pmode_max > 1 for mixed-state probe during initialization. For simulated initial object, it'll be generated with the specified number of probe modes. For loaded probe, the pmode dimension would be capped at this number
    'meas_flipT'        : [1,0,0], # type: null or list of 3 binary booleans (0 or 1) as [flipup, fliplr, transpose] just like PtychoShleves. Default is null or [0,0,0] but you may need to find the correct flip and transpose to match your dataset configuration. This applies additional flip and transpose to initialized diffraction patterns. It's suggested to use 'meas_flipT' to correct the dataset orientation and this is the only orientaiton-related value attached to output reconstruction folder name
}

recon_params = {
    'NITER': 50, # type: int. Total number of reconstruction iterations. 1 iteration means a full pass of all selected diffraction patterns. Usually 20-50 iterations can get 90% of the work done with a proper learning rate between 1e-3 to 1e-4. For faster trials in hypertune mode, set 'NITER' to a smaller number than your typical reconstruction to save time. Usually 10-20 iterations are enough for the hypertune parameters to show their relative performance. 
    'BATCH_SIZE': 32, # type: int. Number of diffraction patterns processed simultaneously to get the gradient update. "Batch size" is commonly used in machine learning community, while it's called "grouping" in PtychoShelves. Batch size has an effect on both convergence speed and final quality, usually smaller batch size leads to better final quality for iterative gradient descent, but smaller batch size would also lead to longer computation time per iteration because the GPU isn't as utilized as large batch sizes (due to less GPU parallelism). Generally batch size of 32 to 128 is used, although certain algorithms (like ePIE) would prefer a large batch size that is equal to the dataset size for robustness. For extremely large object (or with a lot of object modes), you'll need to reduce batch_size to save GPU memory as well.
    'SAVE_ITERS': 10,  # type: null or int. Number of completed iterations before saving the current reconstruction results (model, probe, object) and summary figures. If 'SAVE_ITERS' is 50, it'll create an output reconstruction folder and save the results and figures into it every 50 iterations. If null, the output reconstruction folder would not be created and no reconstruction results or summary figures would be saved. If 'SAVE_ITERS' > 'NITER', it'll create the output reconstruction folder but no results / figs would be saved. Typically we set 'SAVE_ITERS' to 50 for reconstruction mode with 'NITER' around 200 to 500. For hypertune mode, it's suggested to set 'SAVE_ITERS' to null and set 'collate_results' to true to save the disk space, while also provide an convenient way to check the hypertune performance by the collated results.
    'output_dir': 'output/paper/tBL_WSe2', # type str. Path and name of the main output directory. Ideally the 'output_dir' keeps a series of reconstruction of the same materials system or project. The PtyRAD results and figs will be saved into a reconstruction-specific folder under 'output_dir'. The 'output_dir' folder will be automatically created if it doesn't exist.
    'prefix_date': True, # type: boolean. Whether to prefix a date str to the reconstruction folder or not. Set to true to automatically prefix a date str like '20240903_' in front of the reconstruction folder name. Suggested value is true for both reconstruction and hypertune modes. In hypertune mode, the date string would be applied on the hypertune folder instead of the reconsstruction folder. 
    'prefix': '', # type: str. Prefix this string to the reconstruction folder name. Note that "_" will be automatically generated, and the attached str would be after the date str if 'prefix_date' is true. In hypertune mode, the prefix string would be applied on the hypertune folder instead of the reconsstruction folder.  
    'postfix': '', # type: str. Postfix this string to the reconstruction folder name. Note that "_" will be automatically generated. In hypertune mode, the postfix string would be applied on the hypertune folder instead of the reconsstruction folder.  
}

In [None]:
data_dir = 'H:/workspace\ptyrad/data/'
data_path = os.path.join(data_dir, 'tBL_WSe2/Fig_1h_24.9mrad_Themis/1/data_roi1_Ndp128_step128_dp.hdf5')

with h5py.File(data_path, 'r') as f:
    meas = np.array(f['dp'])
    
    # Flip the measurements
    flipT_axes = exp_params['meas_flipT']
    print(f"Flipping measurements with [flipup, fliplr, transpose] = {flipT_axes}")
    if flipT_axes[0] != 0:
        meas = np.flip(meas, 1)
    if flipT_axes[1] != 0:
        meas = np.flip(meas, 2)
    if flipT_axes[2] != 0:
        meas = np.transpose(meas, (0,2,1))
    
    # Reshape
    dataset = np.reshape(meas, [exp_params['N_scan_slow'],exp_params['N_scan_fast'],exp_params['Npix'],exp_params['Npix']])
    
    # Calibrate py4DSTEM datacube
    datacube = py4DSTEM.DataCube(dataset)
    datacube.calibration.set_R_pixel_size(exp_params['scan_step_size'])
    datacube.calibration.set_R_pixel_units('A')
    datacube.calibration.set_Q_pixel_size(1/(exp_params['dx_spec']*exp_params['Npix']))
    datacube.calibration.set_Q_pixel_units('A^-1')

### (Optional) Visualize the data

In [None]:
# py4DSTEM.show(
#     datacube.get_dp_mean(),
#     cmap = 'magma',
#     scaling = 'log',
#     figsize = (3,3),
# )

In [None]:
# probe_radius_pixels, probe_qx0, probe_qy0 = datacube.get_probe_size(
#     datacube.tree('dp_mean').data,
#     plot = True,
#     figsize = (3,3),
# )

# # Print the estimated center and probe radius
# print('Estimated probe center =', 'qx = %.2f, qy = %.2f' % (probe_qx0, probe_qy0), 'pixels')
# print('Estimated probe radius =', '%.2f' % probe_radius_pixels, 'pixels')

In [None]:
# # Make a virtual bright field and dark field image
# expand_BF = 2.0  # expand radius by 2 pixels to encompass the full center disk

# center = (probe_qx0, probe_qy0)
# radius_BF = probe_radius_pixels + expand_BF
# radii_DF = (probe_radius_pixels + expand_BF, 1e3)

# datacube.get_virtual_image(
#     mode = 'circle',
#     geometry = (center,radius_BF),
#     name = 'bright_field',
#     shift_center = False,
# )
# datacube.get_virtual_image(
#     mode = 'annulus',
#     geometry = (center,radii_DF),
#     name = 'dark_field',
#     shift_center = False,
# );

# # plot the virtual images
# py4DSTEM.show(
#     [
#         datacube.tree('bright_field'),
#         datacube.tree('dark_field'),               
#     ],
#     cmap='viridis',
#     ticks = False,
#     axsize=(3,3),
#     title=['Bright Field','Dark Field'],
# )

In [None]:
## Initialize py4dstem ptycho instance
ptycho = initialize_ptycho(datacube, exp_params)
ptycho.preprocess(
    plot_center_of_mass = False,
    plot_rotation=False,
    # force_com_rotation = 93,
);

In [None]:
def get_date(date_format = '%Y%m%d'):
    from datetime import date
    date_format = date_format
    date_str = date.today().strftime(date_format)
    return date_str

# Preprocess prefix and postfix
prefix = recon_params['prefix']
postfix = recon_params['postfix']
prefix = prefix + '_' if prefix  != '' else ''
postfix = '_'+ postfix if postfix != '' else ''

if recon_params['prefix_date']:
    prefix = get_date() + '_' + prefix 

# Append basic parameters to folder name
output_dir  = recon_params['output_dir']
meas_flipT  = exp_params['meas_flipT'] 
folder_str = prefix + f"N{(exp_params['N_scans'])}_dp{exp_params['Npix']}"

if meas_flipT is not None:
    folder_str = folder_str + '_flipT' + ''.join(str(x) for x in meas_flipT)

folder_str += f"_random{recon_params['BATCH_SIZE']}_p{exp_params['pmode_max']}_{exp_params['Nlayer']}slice"

if exp_params['Nlayer'] != 1:
    z_distance = exp_params['z_distance'].round(2)
    folder_str += f"_dz{z_distance:.3g}"

output_path = os.path.join(output_dir, folder_str)
output_path += postfix
os.makedirs(output_path, exist_ok=True)
print(f"output_path = '{output_path}' is generated!")

In [None]:
## Reconstruct py4dstem ptycho and visualize the result
ptycho.reconstruct(
    num_iter = 1, #recon_params['NITER'],
    reconstruction_method = 'gradient-descent',
    max_batch_size = recon_params['BATCH_SIZE'],
    step_size = 0.8, # Update step size, default is 0.5
    reset = True, # If True, previous reconstructions are ignored
    progress_bar = False, # If True, reconstruction progress is displayed
    store_iterations = False, # If True, reconstructed objects and probes are stored at each iteration.
    save_iters = recon_params['SAVE_ITERS'], # Added by CHL to save intermediate results
    output_path = output_path
).visualize(
    # iterations_grid = 'auto'
);

In [None]:
plt.figure()
plt.imshow(np.angle(ptycho.object), cmap='magma')
plt.scatter(x=ptycho._positions_px[0,1], y=ptycho._positions_px[0,0])
plt.show()