# Figure - Batch size series

In [None]:
import os
work_dir = "H:\workspace\ptyrad"
os.chdir(work_dir)
print("Current working dir: ", os.getcwd())

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from ptyrad.data_io import load_hdf5, load_pt
import h5py

In [None]:
import numpy as np
import torch
import matplotlib.pyplot as plt
from ptyrad.data_io import load_hdf5, load_pt
import h5py

In [None]:
def center_crop(image, crop_height, crop_width):
    """
    Center crops a 2D or 3D array (e.g., an image).

    Args:
        image (numpy.ndarray): The input array to crop. Can be 2D (H, W) or 3D (H, W, C).
        crop_height (int): The desired height of the crop.
        crop_width (int): The desired width of the crop.

    Returns:
        numpy.ndarray: The cropped image.
    """
    if len(image.shape) not in [2, 3]:
        raise ValueError("Input image must be a 2D or 3D array.")

    height, width = image.shape[-2:]

    if crop_height > height or crop_width > width:
        raise ValueError("Crop size must be smaller than the input image size.")

    start_y = (height - crop_height) // 2
    start_x = (width - crop_width) // 2

    return image[..., start_y:start_y + crop_height, start_x:start_x + crop_width]

In [None]:
batch_sizes = [16,64,256,1024]

ptyrad_objects = []
ptyshv_objects = []
py4dstem_objects = []

for i, batch_size in enumerate(batch_sizes):
    # Note that py4DSTEM has the 3 deg rotation, while these ptyrad and ptyshv results do not have the scan rotation as I did for the convergence test
    path_ptyrad   = f"H:/workspace/ptyrad/output/paper/tBL_WSe2/20250131_ptyrad_batch_sizes/full_N16384_dp128_flipT100_random{batch_size}_p6_1obj_6slice_dz2_Adam_plr1e-4_oalr5e-4_oplr5e-4_slr5e-4_orblur0.5_ozblur1_oathr0.98_opos_sng1.0_spr0.1_aff1_0_-3_0/model_iter0100.pt"
    # path_ptyrad   = f"H:/workspace/ptyrad/output/paper/tBL_WSe2/20241202_ptyrad_benchmark_6slice/full_N16384_dp128_flipT100_random{batch_size}_p6_1obj_6slice_dz2_Adam_plr1e-4_oalr5e-4_oplr5e-4_slr5e-4_orblur0.2_ozblur1_oathr0.98_opos_sng1.0_spr0.1/model_iter0100.pt"
    path_ptyshv   = f"H:/workspace/ptyrad/data/paper/tBL_WSe2/Panel_g-h_Themis/11/roi11_Ndp128_step128/MLs_L1_p6_g{batch_size}_pc0_noModel_mm_Ns6_dz2_reg1_dpFlip_ud_T/Niter100.mat"
    path_py4dstem = f"H:/workspace/ptyrad/output/paper/tBL_WSe2/20250129_py4DSTEM_batch_sizes/N16384_dp128_flipT100_random{batch_size}_p6_6slice_dz2_update0.5_kzf1/model_iter0100.hdf5"

    # ptyrad
    object_ptyrad = center_crop(load_pt(path_ptyrad)['optimizable_tensors']['objp'].squeeze().cpu().numpy().sum(0), 384, 384)
    
    # ptyshv
    with h5py.File(path_ptyshv, "r") as hdf_file:
        object_ptyshv = center_crop(np.angle(np.array(hdf_file['object']).view('complex128')).sum(0).T, 384, 384)

    object_py4dstem = center_crop(np.angle(load_hdf5(path_py4dstem, dataset_key='object')).sum(0), 384, 384)

    ptyrad_objects.append(object_ptyrad)
    ptyshv_objects.append(object_ptyshv)
    py4dstem_objects.append(object_py4dstem)

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from mpl_toolkits.axes_grid1.inset_locator import mark_inset
from mpl_toolkits.axes_grid1.anchored_artists import AnchoredSizeBar
import matplotlib.font_manager as fm

# Titles for the columns
titles = ['PtyRAD', 'PtychoShelves', 'py4DSTEM']

# Panel labels
panel_labels = np.array([['a', 'b', 'c'], ['d', 'e', 'f'], ['g', 'h', 'i'], ['j', 'k', 'l']])
shadow_offset = (3, 3)

# Zoom-in region
zoom_region = (270, 320, 65, 15)  

# Define figure and GridSpec layout
fig = plt.figure(figsize=(6, 7.2), dpi=300)
gs = gridspec.GridSpec(4, 3, wspace=-0.25, hspace=0.05, height_ratios=[1, 1, 1, 1], width_ratios=[1, 1, 1], figure=fig)

# Iterate over columns (methods) and rows (batch sizes)
for i, (obj, title) in enumerate(zip([ptyrad_objects, ptyshv_objects, py4dstem_objects], titles)):
    for j in range(4):
        ax = fig.add_subplot(gs[j, i])

        # Scale bar settings
        scale_bar_length = 134  # Length in pixels
        scale_bar_label = "2 nm"  
        scale_bar_color = "white"
        fontprops = fm.FontProperties(size=10)

        # Main image
        vmin, vmax = np.percentile(obj[j], [0.5, 99.9])
        ax.imshow(obj[j], cmap='gray', origin='upper', vmin=vmin, vmax=vmax)
        ax.axis('off')

        # Titles for the top row
        if j == 0:
            ax.set_title(title, fontsize=14)

        # Add label with text shadow
        ax.text(9 + shadow_offset[0], 9 + shadow_offset[1], panel_labels[j, i], color='black', fontsize=16, fontweight='bold', va='top', ha='left', alpha=0.6)
        ax.text(9, 9, panel_labels[j, i], color='white', fontsize=16, fontweight='bold', va='top', ha='left')

        # Add scale bar
        scalebar = AnchoredSizeBar(ax.transData, scale_bar_length, scale_bar_label, loc='lower right',
                                   pad=0.5, color=scale_bar_color, frameon=False, size_vertical=3, label_top=True,
                                   fontproperties=fontprops)
        ax.add_artist(scalebar)

        # Zoom-in inset
        x1, x2, y1, y2 = zoom_region  
        axins = ax.inset_axes([0.6, 0.45, 0.35, 0.35], xlim=(x1, x2), ylim=(y1, y2))
        axins.imshow(obj[j], cmap='gray', origin='upper', vmin=vmin, vmax=vmax)

        # Remove ticks and labels from inset
        axins.tick_params(left=False, bottom=False, labelleft=False, labelbottom=False)
        for spine in axins.spines.values():
            spine.set_edgecolor('darkred')

        # Indicate zoom region with a red edge
        mark_inset(ax, axins, loc1=1, loc2=2, fc="none", ec="darkred")

        # Add row labels on the leftmost column
        if i == 0:
            ax.annotate(f'Batch size\n{batch_sizes[j]}', xy=(-0.3, 0.5), xycoords='axes fraction', fontsize=10, ha='center', va='center', rotation=0)

# Show the plot
plt.show()
