# Figure - Convergence_same_iters

In [None]:
import os
work_dir = "H:\workspace\ptyrad"
os.chdir(work_dir)
print("Current working dir: ", os.getcwd())

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from numpy.fft import fft2
from numpy.fft import fftshift
import h5py

from ptyrad.data_io import load_hdf5, load_pt

In [None]:
def center_crop(image, crop_height, crop_width):
    """
    Center crops a 2D or 3D array (e.g., an image).

    Args:
        image (numpy.ndarray): The input array to crop. Can be 2D (H, W) or 3D (H, W, C).
        crop_height (int): The desired height of the crop.
        crop_width (int): The desired width of the crop.

    Returns:
        numpy.ndarray: The cropped image.
    """
    if len(image.shape) not in [2, 3]:
        raise ValueError("Input image must be a 2D or 3D array.")

    height, width = image.shape[-2:]

    if crop_height > height or crop_width > width:
        raise ValueError("Crop size must be smaller than the input image size.")

    start_y = (height - crop_height) // 2
    start_x = (width - crop_width) // 2

    return image[..., start_y:start_y + crop_height, start_x:start_x + crop_width]

In [None]:
def mfft2(im):
    # Periodic Artifact Reduction in Fourier Transforms of Full Field Atomic Resolution Images
    # https://doi.org/10.1017/S1431927614014639
    rows, cols = im.shape
    
    # Compute boundary conditions
    s = np.zeros_like(im)
    s[0, :] = im[0, :] - im[rows-1, :]
    s[rows-1, :] = -s[0, :]
    s[:, 0] += im[:, 0] - im[:, cols-1]
    s[:, cols-1] -= im[:, 0] - im[:, cols-1]

    # Create grid for computing Poisson solution
    cx, cy = np.meshgrid(2 * np.pi * np.arange(cols) / cols, 
                          2 * np.pi * np.arange(rows) / rows)

    # Generate smooth component from Poisson Eq with boundary condition
    D = 2 * (2 - np.cos(cx) - np.cos(cy))
    D[0, 0] = np.inf  # Enforce zero mean & handle division by zero
    S = np.fft.fft2(s) / D

    P = np.fft.fft2(im) - S  # FFT of periodic component
    return P, S

In [None]:
path_ptyrad   = "H:\workspace\ptyrad\output\paper/tBL_WSe2/20250131_ptyrad_convergence/full_N16384_dp128_flipT100_random16_p12_1obj_6slice_dz2_Adam_plr1e-4_oalr5e-4_oplr5e-4_slr5e-4_orblur0.5_ozblur1_oathr0.98_opos_sng1.0_spr0.1_aff1_0_-3_0/model_iter0200.pt"
obj_ptyrad = center_crop(load_pt(path_ptyrad)['optimizable_tensors']['objp'].squeeze().cpu().numpy().sum(0), 384, 384)
iter_time_ptyrad = load_pt(path_ptyrad)['avg_iter_t'] # Note that the iter_t on c0001 for ptyrad seems to fluctuate quite alot from 9-13sec, this one is 12.6sec

path_ptyshv = "H:\workspace\ptyrad\data\paper/tBL_WSe2\Panel_g-h_Themis/10/roi10_Ndp128_step128\MLs_ptyrad_p12_g16_pc0_noModel_updW100_mm_Ns6_dz2_reg1_dpFlip_ud_T/Niter200.mat"
with h5py.File(path_ptyshv, "r") as hdf_file:
    obj_ptyshv = center_crop(np.angle(np.array(hdf_file['object']).view('complex128')).sum(0).T, 384, 384)
    iter_time_ptyshv = hdf_file['outputs']['avgTimePerIter'][()].squeeze()[()]

path_py4dstem = "H:\workspace\ptyrad\output\paper/tBL_WSe2/20250124_py4DSTEM_convergence/N16384_dp128_flipT100_random16_p12_6slice_dz2_update0.5_kzf1/model_iter0200.hdf5"
with h5py.File(path_py4dstem, "r") as hdf_file:
    obj_py4dstem = center_crop(np.angle(np.array(hdf_file['object']).view('complex64')).sum(0), 384, 384)
    iter_time_py4dstem = load_hdf5(path_py4dstem, dataset_key='iter_times').mean()

# P+S decomposition FFT, note that this requires a float32 dtype
fft_ptyrad   = np.log(np.abs(fftshift(mfft2(obj_ptyrad)[0])))
fft_ptyshv   = np.log(np.abs(fftshift(mfft2(obj_ptyshv)[0])))
fft_py4dstem = np.log(np.abs(fftshift(mfft2(obj_py4dstem)[0])))

# fft_ptyrad = np.log(np.abs(fftshift(fft2(obj_ptyrad))))
# fft_ptyshv = np.log(np.abs(fftshift(fft2(obj_ptyshv))))
# fft_py4dstem = np.log(np.abs(fftshift(fft2(obj_py4dstem))))

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
from mpl_toolkits.axes_grid1.inset_locator import mark_inset
from mpl_toolkits.axes_grid1.anchored_artists import AnchoredSizeBar
import matplotlib.font_manager as fm

# Create figure with GridSpec
fig = plt.figure(figsize=(7, 5), dpi=300)
gs = GridSpec(2, 3, figure=fig, wspace=0.025, hspace=-0.095)  # Control spacing

# Titles for the columns
titles = ['PtyRAD', 'PtychoShelves', 'py4DSTEM']
iter_value = 200
iterations_labels = [f'{iter_value} iters\n'+r'$\approx$'+f'{np.int32(iter_time_ptyrad*iter_value)} sec', 
                     f'{iter_value} iters\n'+r'$\approx$'+f'{np.int32(iter_time_ptyshv*iter_value)} sec', 
                     f'{iter_value} iters\n'+r'$\approx$'+f'{np.int32(iter_time_py4dstem*iter_value)} sec']
panel_labels = ['a', 'b', 'c', 'd', 'e', 'f']
shadow_offset = (3, 3)
zoom_region = (270, 320, 65, 15)  # Define zoom regions for each panel

# Top row (images)
for i, (obj, title, label, iters) in enumerate(zip([obj_ptyrad, obj_ptyshv, obj_py4dstem], titles, panel_labels[:3], iterations_labels)):
    ax = fig.add_subplot(gs[0, i])

    # Scale bar settings
    scale_bar_length = 134
    scale_bar_label = "2 nm"
    fontprops = fm.FontProperties(size=10)
    
    # Main image
    vmin, vmax = np.percentile(obj, [1, 99.95])
    ax.imshow(obj, cmap='gray', origin='upper', vmin=vmin, vmax=vmax)
    ax.set_title(title, fontsize=12)
    ax.axis('off')

    # Panel label with shadow
    ax.text(5 + shadow_offset[0], 5 + shadow_offset[1], label, color='black', fontsize=16, fontweight='bold', va='top', ha='left', alpha=0.6)
    ax.text(5, 5, label, color='white', fontsize=16, fontweight='bold', va='top', ha='left')

    ax.text(102 + shadow_offset[0], 310 + shadow_offset[1], iters, color='black', fontsize=10, fontweight='bold', va='top', ha='center', alpha=0.6)
    ax.text(102, 310, iters, color='white', fontsize=10, fontweight='bold', ha='center', va='top')

    # Scale bar
    scalebar = AnchoredSizeBar(ax.transData, scale_bar_length, scale_bar_label,
                               loc='lower right', pad=0.5, color="white", frameon=False, size_vertical=3, label_top=True,
                               fontproperties=fontprops)
    ax.add_artist(scalebar)

    # Zoom-in inset
    x1, x2, y1, y2 = zoom_region
    axins = ax.inset_axes([0.6, 0.45, 0.35, 0.35], xlim=(x1, x2), ylim=(y1, y2))
    axins.imshow(obj, cmap='gray', origin='upper', vmin=vmin, vmax=vmax)
    axins.tick_params(left=False, bottom=False, labelleft=False, labelbottom=False)

    for spine in axins.spines.values():
        spine.set_edgecolor('darkred')

    mark_inset(ax, axins, loc1=1, loc2=2, fc="none", ec="darkred")

# Bottom row (FFTs)
for i, (fft, label) in enumerate(zip([fft_ptyrad, fft_ptyshv, fft_py4dstem], panel_labels[3:])):
    ax = fig.add_subplot(gs[1, i])

    # Scale bar settings
    scale_bar_length = 115
    scale_bar_label = "2 $\mathrm{\AA}^{-1}$"
    
    # Main FFT image
    vmin, vmax = np.percentile(fft, [5, 99.5])
    ax.imshow(fft, cmap='magma', vmin=vmin, vmax=vmax)
    ax.axis('off')

    # Panel label with shadow
    ax.text(5 + shadow_offset[0], 5 + shadow_offset[1], label, color='black', fontsize=16, fontweight='bold', va='top', ha='left', alpha=0.6)
    ax.text(5, 5, label, color='white', fontsize=16, fontweight='bold', va='top', ha='left')

    # Scale bar
    scalebar = AnchoredSizeBar(ax.transData, scale_bar_length, scale_bar_label,
                               loc='lower right', pad=0.5, color="white", frameon=False, size_vertical=3, label_top=True,
                               fontproperties=fontprops)
    ax.add_artist(scalebar)

# Show the plot
plt.show()
