# Figure - sparsity series

In [None]:
import os
work_dir = "H:\workspace\ptyrad"
os.chdir(work_dir)
print("Current working dir: ", os.getcwd())

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from ptyrad.data_io import load_pt
from ptyrad.utils import center_crop, mfft2
import h5py

In [None]:
sparsity_weights = [0, 0.01, 0.03, 0.1]

ptyrad_objects = []

for i, weight in enumerate(sparsity_weights):
    if weight == 0:
        path_ptyrad   =  "H:/workspace/ptyrad/output/paper/tBL_WSe2/20250131_ptyrad_batch_sizes/full_N16384_dp128_flipT100_random16_p6_1obj_6slice_dz2_Adam_plr1e-4_oalr5e-4_oplr5e-4_slr5e-4_orblur0.5_ozblur1_oathr0.98_opos_sng1.0_aff1_0_-3_0/model_iter0100.pt"
    else:
        path_ptyrad   = f"H:/workspace/ptyrad/output/paper/tBL_WSe2/20250131_ptyrad_batch_sizes/full_N16384_dp128_flipT100_random16_p6_1obj_6slice_dz2_Adam_plr1e-4_oalr5e-4_oplr5e-4_slr5e-4_orblur0.5_ozblur1_oathr0.98_opos_sng1.0_spr{weight}_aff1_0_-3_0/model_iter0100.pt"
    object_ptyrad = center_crop(load_pt(path_ptyrad)['optimizable_tensors']['objp'].squeeze().cpu().numpy().sum(0), 384, 384)
    ptyrad_objects.append(object_ptyrad)

In [None]:
import numpy as np
from numpy.fft import fft2
from numpy.fft import fftshift
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from mpl_toolkits.axes_grid1.inset_locator import mark_inset
from mpl_toolkits.axes_grid1.anchored_artists import AnchoredSizeBar
import matplotlib.font_manager as fm

# Titles for the columns

# Panel labels
panel_labels = np.array([['a', 'b', 'c', 'd'], ['e', 'f', 'g', 'h']])
shadow_offset = (3, 3)

# Zoom-in region
zoom_region = (270, 320, 65, 15)  

# Define figure and GridSpec layout
fig = plt.figure(figsize=(7, 4), dpi=600)
# plt.suptitle('PtyRAD', y=0.93, fontsize=12)
gs = gridspec.GridSpec(2, 4, wspace=0.05, hspace=-0.2, height_ratios=[1,1], width_ratios=[1, 1, 1, 1], figure=fig)

# Iterate over columns (methods) and rows (batch sizes)
for i, obj in enumerate(ptyrad_objects):
    ax = fig.add_subplot(gs[0, i])

    # Scale bar settings
    scale_bar_length = 134  # Length in pixels
    scale_bar_label = "2 nm"  
    scale_bar_color = "white"
    fontprops = fm.FontProperties(size=10)

    # Main image
    vmin, vmax = np.percentile(obj, [0.5, 99.8])
    ax.imshow(obj, cmap='gray', origin='upper', vmin=vmin, vmax=vmax)
    ax.axis('off')

    # Titles for the top row
    ax.set_title(f'Sparsity = {sparsity_weights[i]}', fontsize=10)

    # Add label with text shadow
    ax.text(9 + shadow_offset[0], 9 + shadow_offset[1], panel_labels[0,i], color='black', fontsize=16, fontweight='bold', va='top', ha='left', alpha=0.6)
    ax.text(9, 9, panel_labels[0,i], color='white', fontsize=16, fontweight='bold', va='top', ha='left')

    # Add scale bar
    scalebar = AnchoredSizeBar(ax.transData, scale_bar_length, scale_bar_label, loc='lower right',
                                pad=0.5, color=scale_bar_color, frameon=False, size_vertical=3, label_top=True,
                                fontproperties=fontprops)
    ax.add_artist(scalebar)

    # Zoom-in inset
    x1, x2, y1, y2 = zoom_region  
    axins = ax.inset_axes([0.6, 0.45, 0.35, 0.35], xlim=(x1, x2), ylim=(y1, y2))
    axins.imshow(obj, cmap='gray', origin='upper', vmin=vmin, vmax=vmax)

    # Remove ticks and labels from inset
    axins.tick_params(left=False, bottom=False, labelleft=False, labelbottom=False)
    for spine in axins.spines.values():
        spine.set_edgecolor('darkred')

    # Indicate zoom region with a red edge
    mark_inset(ax, axins, loc1=1, loc2=2, fc="none", ec="darkred")

# Plot FFTs (bottom row)
for i, obj in enumerate(ptyrad_objects):
    ax = fig.add_subplot(gs[1, i])
    
    # Scale bar settings
    scale_bar_length = 115  # Length of the scale bar in pixels (1 k-space px = 1/(385*0.1494) Ang-1)
    scale_bar_label = "2 $\mathrm{\AA}^{-1}$"  # Label for the scale bar
    scale_bar_color = "white"
    fontprops = fm.FontProperties(size=10)
    
    # Main FFT image
    fft = np.log(np.abs(fftshift(mfft2(obj)[0])))
    vmin, vmax = np.percentile(fft, [5, 99.5])
    ax.imshow(fft, cmap='magma', vmin=vmin, vmax=vmax)
    ax.axis('off')

    # Add label with text shadow
    ax.text(9 + shadow_offset[0], 9 + shadow_offset[1], panel_labels[1,i], color='black', fontsize=16, fontweight='bold', va='top', ha='left', alpha=0.6)
    ax.text(9, 9, panel_labels[1,i], color='white', fontsize=16, fontweight='bold', va='top', ha='left')
    
    # Add scale bar
    scalebar = AnchoredSizeBar(ax.transData, scale_bar_length, scale_bar_label,
                               loc='lower right', pad=0.5, color=scale_bar_color, frameon=False, size_vertical=3, label_top=True,
                               fontproperties=fontprops)
    ax.add_artist(scalebar)

plt.savefig("Fig_S07_sparsity_series.pdf", bbox_inches="tight")

# Show the plot
plt.show()


## Comparison of normal and P+S decomposition FFT

In [None]:
import numpy as np
from numpy.fft import fft2
from numpy.fft import fftshift
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from mpl_toolkits.axes_grid1.inset_locator import mark_inset
from mpl_toolkits.axes_grid1.anchored_artists import AnchoredSizeBar
import matplotlib.font_manager as fm

# Titles for the columns

# Panel labels
panel_labels = np.array([['a', 'b', 'c', 'd'], ['e', 'f', 'g', 'h'], ['i', 'j', 'k', 'l']])
shadow_offset = (3, 3)

# Zoom-in region
zoom_region = (270, 320, 65, 15)  

# Define figure and GridSpec layout
fig = plt.figure(figsize=(7, 6), dpi=300)
# plt.suptitle('PtyRAD', y=0.93, fontsize=12)
gs = gridspec.GridSpec(3, 4, wspace=0.05, hspace=-0.2, height_ratios=[1,1,1], width_ratios=[1, 1, 1, 1], figure=fig)

# Iterate over columns (methods) and rows (batch sizes)
for i, obj in enumerate(ptyrad_objects):
    ax = fig.add_subplot(gs[0, i])

    # Scale bar settings
    scale_bar_length = 134  # Length in pixels
    scale_bar_label = "2 nm"  
    scale_bar_color = "white"
    fontprops = fm.FontProperties(size=10)

    # Main image
    vmin, vmax = np.percentile(obj, [0.5, 99.8])
    ax.imshow(obj, cmap='gray', origin='upper', vmin=vmin, vmax=vmax)
    ax.axis('off')

    # Titles for the top row
    ax.set_title(f'Sparsity = {sparsity_weights[i]}', fontsize=10)

    # Add label with text shadow
    ax.text(9 + shadow_offset[0], 9 + shadow_offset[1], panel_labels[0,i], color='black', fontsize=16, fontweight='bold', va='top', ha='left', alpha=0.6)
    ax.text(9, 9, panel_labels[0,i], color='white', fontsize=16, fontweight='bold', va='top', ha='left')

    # Add scale bar
    scalebar = AnchoredSizeBar(ax.transData, scale_bar_length, scale_bar_label, loc='lower right',
                                pad=0.5, color=scale_bar_color, frameon=False, size_vertical=3, label_top=True,
                                fontproperties=fontprops)
    ax.add_artist(scalebar)

    # Zoom-in inset
    x1, x2, y1, y2 = zoom_region  
    axins = ax.inset_axes([0.6, 0.45, 0.35, 0.35], xlim=(x1, x2), ylim=(y1, y2))
    axins.imshow(obj, cmap='gray', origin='upper', vmin=vmin, vmax=vmax)

    # Remove ticks and labels from inset
    axins.tick_params(left=False, bottom=False, labelleft=False, labelbottom=False)
    for spine in axins.spines.values():
        spine.set_edgecolor('darkred')

    # Indicate zoom region with a red edge
    mark_inset(ax, axins, loc1=1, loc2=2, fc="none", ec="darkred")

    if i == 0:
        ax.annotate('Image', xy=(-0.25, 0.5), xycoords='axes fraction', fontsize=7, ha='center', va='center', rotation=0)  


# Plot FFTs (bottom row)
for i, obj in enumerate(ptyrad_objects):
    ax = fig.add_subplot(gs[1, i])
    
    # Scale bar settings
    scale_bar_length = 115  # Length of the scale bar in pixels (1 k-space px = 1/(385*0.1494) Ang-1)
    scale_bar_label = "2 $\mathrm{\AA}^{-1}$"  # Label for the scale bar
    scale_bar_color = "white"
    fontprops = fm.FontProperties(size=10)
    
    # Main FFT image
    fft = np.log(np.abs(fftshift(fft2(obj))))
    vmin, vmax = np.percentile(fft, [5, 99.5])
    ax.imshow(fft, cmap='magma', vmin=vmin, vmax=vmax)
    ax.axis('off')

    # Add label with text shadow
    ax.text(9 + shadow_offset[0], 9 + shadow_offset[1], panel_labels[1,i], color='black', fontsize=16, fontweight='bold', va='top', ha='left', alpha=0.6)
    ax.text(9, 9, panel_labels[1,i], color='white', fontsize=16, fontweight='bold', va='top', ha='left')
    
    # Add scale bar
    scalebar = AnchoredSizeBar(ax.transData, scale_bar_length, scale_bar_label,
                               loc='lower right', pad=0.5, color=scale_bar_color, frameon=False, size_vertical=3, label_top=True,
                               fontproperties=fontprops)
    ax.add_artist(scalebar)
    
    if i == 0:
        ax.annotate('Normal\nFFT', xy=(-0.25, 0.5), xycoords='axes fraction', fontsize=7, ha='center', va='center', rotation=0)  
    
    
# Plot FFTs (bottom row)
for i, obj in enumerate(ptyrad_objects):
    ax = fig.add_subplot(gs[2, i])
    
    # Scale bar settings
    scale_bar_length = 115  # Length of the scale bar in pixels (1 k-space px = 1/(385*0.1494) Ang-1)
    scale_bar_label = "2 $\mathrm{\AA}^{-1}$"  # Label for the scale bar
    scale_bar_color = "white"
    fontprops = fm.FontProperties(size=10)
    
    # Main FFT image
    fft = np.log(np.abs(fftshift(mfft2(obj)[0])))
    vmin, vmax = np.percentile(fft, [5, 99.5])
    ax.imshow(fft, cmap='magma', vmin=vmin, vmax=vmax)
    ax.axis('off')

    # Add label with text shadow
    ax.text(9 + shadow_offset[0], 9 + shadow_offset[1], panel_labels[2,i], color='black', fontsize=16, fontweight='bold', va='top', ha='left', alpha=0.6)
    ax.text(9, 9, panel_labels[2,i], color='white', fontsize=16, fontweight='bold', va='top', ha='left')
    
    # Add scale bar
    scalebar = AnchoredSizeBar(ax.transData, scale_bar_length, scale_bar_label,
                               loc='lower right', pad=0.5, color=scale_bar_color, frameon=False, size_vertical=3, label_top=True,
                               fontproperties=fontprops)
    ax.add_artist(scalebar)
    
    if i == 0:
        ax.annotate('P+S\ndecomp.\nFFT', xy=(-0.25, 0.5), xycoords='axes fraction', fontsize=7, ha='center', va='center', rotation=0)  

# Show the plot
plt.show()
