# Figure - Iteration time comparison


In [None]:
import os
work_dir = "H:\workspace\ptyrad"
os.chdir(work_dir)
print("Current working dir: ", os.getcwd())

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from ptyrad.data_io import load_hdf5, load_pt
import h5py

In [None]:
# Normalize error
def norm_arr(arr):
    return (arr - arr[-1])/(arr.max() - arr[-1])

In [None]:
# Note that these data for "iter vs error" are all computed on 80GB A100

# #45368 # path_ptyrad = "H:/workspace/ptyrad/output/paper/tBL_WSe2/20241211_ptyrad_convergence/full_N16384_dp128_flipT100_random16_p12_1obj_6slice_dz2_Adam_plr1e-4_oalr5e-4_oplr5e-4_slr5e-4_orblur0.2_ozblur1_oathr0.98_opos_sng1.0_spr0.1_aff1_0_-3_0/model_iter0200.pt"
# #46331 # path_ptyrad = 'H:\workspace\ptyrad\output\paper/tBL_WSe2/20250122_ptyrad_convergence/full_N16384_dp128_flipT100_random16_p12_1obj_6slice_dz2_Adam_plr1e-4_oalr5e-4_oplr5e-4_slr5e-4_kzf1_sng1.0_aff1_0_-3_0_no_reg/model_iter0200.pt' 
path_ptyrad   = "H:\workspace\ptyrad\output\paper/tBL_WSe2/20250131_ptyrad_convergence/full_N16384_dp128_flipT100_random16_p12_1obj_6slice_dz2_Adam_plr1e-4_oalr5e-4_oplr5e-4_slr5e-4_orblur0.5_ozblur1_oathr0.98_opos_sng1.0_spr0.1_aff1_0_-3_0/model_iter0200.pt"
error_ptyrad   = np.array(load_pt(path_ptyrad)['loss_iters'])[:,1]
iter_time_ptyrad = load_pt(path_ptyrad)['avg_iter_t']

# #45367
path_ptyshv = "H:\workspace\ptyrad\data\paper/tBL_WSe2\Panel_g-h_Themis/10/roi10_Ndp128_step128\MLs_ptyrad_p12_g16_pc0_noModel_updW100_mm_Ns6_dz2_reg1_dpFlip_ud_T/Niter200.mat"
with h5py.File(path_ptyshv, "r") as hdf_file:
    error_ptyshv = hdf_file['outputs']['fourier_error_out'][()].squeeze()
    iter_time_ptyshv = hdf_file['outputs']['avgTimePerIter'][()].squeeze()[()]

# #45369
# path_py4dstem = "H:/workspace/ptyrad/output/paper/tBL_WSe2/20241211_py4DSTEM_convergence/N16384_dp128_flipT100_random16_p12_6slice_dz2_update0.1_kzf1_archive/model_iter0200.hdf5"
# #46425, 46426 contains update_step from 0.2 to 1.0 in 0.1 increment
path_py4dstem = "H:\workspace\ptyrad\output\paper/tBL_WSe2/20250124_py4DSTEM_convergence/N16384_dp128_flipT100_random16_p12_6slice_dz2_update0.5_kzf1/model_iter0200.hdf5"
error_py4dstem = load_hdf5(path_py4dstem, dataset_key='error_iterations')
iter_time_py4dstem = load_hdf5(path_py4dstem, dataset_key='iter_times').mean()

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import matplotlib as mpl
import matplotlib.ticker as tck

mpl.rc('xtick', direction='in')
mpl.rc('xtick.major', width=1, size=3.5)
mpl.rc('xtick.minor', width=1, size=2)
mpl.rc('ytick', direction='in')
mpl.rc('ytick.major', width=1, size=3.5)
mpl.rc('ytick.minor', width=1, size=2)

# Plot Data
# Note that these data are all computed on 20GB slice

iterations = np.arange(200) #200
norm_error_ptyrad = norm_arr(error_ptyrad)
norm_error_ptyshv = norm_arr(error_ptyshv)
norm_error_py4dstem = norm_arr(error_py4dstem)

batch_sizes = np.array([16, 32, 64, 128, 256, 512, 1024])
time_ptyrad_batch = np.array([19.0, 16.45, 15.18, 14.37, 13.91, 13.84, 13.71])
time_ptyshv_batch = np.array([84.08, 42.31, 23.47, 17.18, 13.02, 11.87, 11.15])
time_py4dstem_batch = np.array([145.57, 75.85, 43.40, 23.20, 15.47, 12.32, 11.51])

probe_modes = np.array([1, 3, 6, 12])
time_ptyrad_modes = np.array([12.87, 14.18, 19.00, 26.46])
time_ptyshv_modes = np.array([19.07, 45.46, 84.08, 161.49])
time_py4dstem_modes = np.array([29.70, 74.71, 145.57, 282.91])

num_slices = np.array([1, 3, 6])
time_ptyrad_slices = np.array([7.09, 11.34, 19.00])
time_ptyshv_slices = np.array([22.24, 44.38, 84.08])
time_py4dstem_slices = np.array([24.26, 82.01, 145.57])

# Speedup factor
speedup_factor_iter = iter_time_py4dstem/iter_time_ptyrad
speedup_factor_batch = time_py4dstem_batch[0] / time_ptyrad_batch[0]
speedup_factor_modes = time_py4dstem_modes[-1] / time_ptyrad_modes[-1]
speedup_factor_slices = time_py4dstem_slices[-1] / time_ptyrad_slices[-1]

# Global font/line control
linewidth = 0.8
markersize = 4
fontsize_title = 9
fontsize_subtitle = 7
fontsize_label = 9
fontsize_legend = 5

# Create subplots
fig, axes = plt.subplots(2, 2, figsize=(7, 5), dpi=300)

# Panel 1: Error vs. Iteration
step = 10  # Select a subset of the data for plotting
# axes[0, 0].plot(iterations[::step], norm_error_ptyrad[::step], label='PtyRAD', marker='o', linewidth=linewidth, markersize=markersize)
# axes[0, 0].plot(iterations[::step], norm_error_ptyshv[::step], label='PtyShv', marker='s', linewidth=linewidth, markersize=markersize)
# axes[0, 0].plot(iterations[::step], norm_error_py4dstem[::step], label='py4DSTEM', marker='^', linewidth=linewidth, markersize=markersize)
axes[0, 0].plot(iterations[::step] * iter_time_ptyrad, norm_error_ptyrad[::step], label='PtyRAD', marker='o', linewidth=linewidth, markersize=markersize)
axes[0, 0].plot(iterations[::step] * iter_time_ptyshv, norm_error_ptyshv[::step], label='PtyShv', marker='s', linewidth=linewidth, markersize=markersize)
axes[0, 0].plot(iterations[::step] * iter_time_py4dstem, norm_error_py4dstem[::step], label='py4DSTEM', marker='^', linewidth=linewidth, markersize=markersize)

axes[0, 0].set_title('Normalized Data Error vs. Reconstruction Time', fontsize=fontsize_title)
axes[0, 0].text(0.45, 0.95, '200 iterations, batch size 16 \n12 probes, 6 slices', 
                transform=axes[0, 0].transAxes, ha='center', va='top', fontsize=fontsize_subtitle, color='k')
axes[0, 0].set_xlabel('Reconstruction Time (sec)', fontsize=fontsize_label)
axes[0, 0].set_ylabel('Normalized Data Error', fontsize=fontsize_label)
axes[0, 0].text(-0.2, 1.08, 'a', transform=axes[0, 0].transAxes, fontsize=16, fontweight='bold')  # Label "a"
axes[0, 0].text(0.11, 0.125, f'{np.round(speedup_factor_iter,1)}x faster', transform=axes[0, 0].transAxes, fontsize=7, fontweight='bold', c='C0')  # Label "24x faster"
axes[0, 0].legend(fontsize=fontsize_legend)
axes[0, 0].set_ylim(-0.07,1.1)
# axes[0, 0].set_yscale('log')
axes[0, 0].yaxis.set_minor_locator(tck.AutoMinorLocator())

# Panel 2: Iter Time vs. Batch Sizes
axes[0, 1].plot(np.arange(len(batch_sizes)), time_ptyrad_batch, label='PtyRAD', marker='o', linewidth=linewidth, markersize=markersize)
axes[0, 1].plot(np.arange(len(batch_sizes)), time_ptyshv_batch, label='PtyShv', marker='s', linewidth=linewidth, markersize=markersize)
axes[0, 1].plot(np.arange(len(batch_sizes)), time_py4dstem_batch, label='py4DSTEM', marker='^', linewidth=linewidth, markersize=markersize)
axes[0, 1].set_title('Iteration Time vs. Batch Sizes', fontsize=fontsize_title)
axes[0, 1].text(0.5, 0.95, '6 probes, 6 slices', transform=axes[0, 1].transAxes, 
                ha='center', va='top', fontsize=fontsize_subtitle, color='k')
axes[0, 1].set_xlabel('Batch Sizes', fontsize=fontsize_label)
axes[0, 1].set_ylabel('Iteration Time (sec)', fontsize=fontsize_label)
axes[0, 1].set_xticks(np.arange(len(batch_sizes)))
axes[0, 1].set_xticklabels([str(int(b)) for b in batch_sizes], fontsize=fontsize_label)
axes[0, 1].text(-0.1, 1.08, 'b', transform=axes[0, 1].transAxes, fontsize=16, fontweight='bold')  # Label "b"
axes[0, 1].text(0.01, 0.08, f'{np.round(speedup_factor_batch,1)}x faster', transform=axes[0, 1].transAxes, fontsize=7, fontweight='bold', c='C0')  # Label "7.7x faster"
axes[0, 1].legend(fontsize=fontsize_legend)
axes[0, 1].set_ylim(-10,165)
axes[0, 1].yaxis.set_minor_locator(tck.AutoMinorLocator())

# Panel 3: Iter Time vs. Probe Modes
axes[1, 0].plot(probe_modes, time_ptyrad_modes, label='PtyRAD', marker='o', linewidth=linewidth, markersize=markersize)
axes[1, 0].plot(probe_modes, time_ptyshv_modes, label='PtyShv', marker='s', linewidth=linewidth, markersize=markersize)
axes[1, 0].plot(probe_modes, time_py4dstem_modes, label='py4DSTEM', marker='^', linewidth=linewidth, markersize=markersize)
axes[1, 0].set_title('Iteration Time vs. Probe Modes', fontsize=fontsize_title)
axes[1, 0].text(0.5, 0.95, 'Batch size 16, 6 slices', transform=axes[1, 0].transAxes, 
                ha='center', va='top', fontsize=fontsize_subtitle, color='k')
axes[1, 0].set_xlabel('Number of Probe Modes', fontsize=fontsize_label)
axes[1, 0].set_ylabel('Iteration Time (sec)', fontsize=fontsize_label)
axes[1, 0].set_xticks(probe_modes)
axes[1, 0].set_xticklabels([str(int(p)) for p in probe_modes], fontsize=fontsize_label)
axes[1, 0].text(-0.2, 1.08, 'c', transform=axes[1, 0].transAxes, fontsize=16, fontweight='bold')  # Label "c"
axes[1, 0].text(0.75, 0.15, f'{np.round(speedup_factor_modes,1)}x faster', transform=axes[1, 0].transAxes, fontsize=7, fontweight='bold', c='C0')  # Label "10.7x faster"
axes[1, 0].legend(fontsize=fontsize_legend)
axes[1, 0].set_ylim(-15,325)
axes[1, 0].yaxis.set_minor_locator(tck.AutoMinorLocator())

# Panel 4: Iter Time vs. Slices
axes[1, 1].plot(num_slices, time_ptyrad_slices, label='PtyRAD', marker='o', linewidth=linewidth, markersize=markersize)
axes[1, 1].plot(num_slices, time_ptyshv_slices, label='PtyShv', marker='s', linewidth=linewidth, markersize=markersize)
axes[1, 1].plot(num_slices, time_py4dstem_slices, label='py4DSTEM', marker='^', linewidth=linewidth, markersize=markersize)
axes[1, 1].set_title('Iteration Time vs. Slices', fontsize=fontsize_title)
axes[1, 1].text(0.5, 0.95, 'Batch size 16, 6 probes', transform=axes[1, 1].transAxes, 
                ha='center', va='top', fontsize=fontsize_subtitle, color='k')
axes[1, 1].set_xlabel('Number of Slices', fontsize=fontsize_label)
axes[1, 1].set_ylabel('Iteration Time (sec)', fontsize=fontsize_label)
axes[1, 1].set_xticks(num_slices)
axes[1, 1].set_xticklabels(num_slices, fontsize=fontsize_label)
axes[1, 1].text(-0.1, 1.08, 'd', transform=axes[1, 1].transAxes, fontsize=16, fontweight='bold')  # Label "d"
axes[1, 1].legend(fontsize=fontsize_legend)
axes[1, 1].set_ylim(-10,165)
axes[1, 1].yaxis.set_minor_locator(tck.AutoMinorLocator())

# Adjust layout
plt.tight_layout()
plt.show()