# Figure - misleading errors

In [None]:
import os
work_dir = "H:\workspace\ptyrad"
os.chdir(work_dir)
print("Current working dir: ", os.getcwd())

In [None]:
import numpy as np
import torch
import matplotlib.pyplot as plt
from ptyrad.data_io import load_hdf5, load_pt
import h5py

In [None]:
def center_crop(image, crop_height, crop_width):
    """
    Center crops a 2D or 3D array (e.g., an image).

    Args:
        image (numpy.ndarray): The input array to crop. Can be 2D (H, W) or 3D (H, W, C).
        crop_height (int): The desired height of the crop.
        crop_width (int): The desired width of the crop.

    Returns:
        numpy.ndarray: The cropped image.
    """
    if len(image.shape) not in [2, 3]:
        raise ValueError("Input image must be a 2D or 3D array.")

    height, width = image.shape[-2:]

    if crop_height > height or crop_width > width:
        raise ValueError("Crop size must be smaller than the input image size.")

    start_y = (height - crop_height) // 2
    start_x = (width - crop_width) // 2

    return image[..., start_y:start_y + crop_height, start_x:start_x + crop_width]

In [None]:
def get_error_from_str(string):
    import re
    # Regular expression to find the number after "_error_"
    match = re.search(r'_error_([\d.eE+-]+)', string)

    if match:
        return f"{float(match.group(1)):.5f}"  # Ensures 5 decimal places
    else:
        print("No match found")


In [None]:
def get_total_variation(img):
    """Calculates the total variation of a 2D image."""

    # Compute the horizontal and vertical differences
    diff_x = np.diff(img, axis=1)
    diff_y = np.diff(img, axis=0)

    # Calculate the total variation
    tv = np.sum(np.abs(diff_x)) + np.sum(np.abs(diff_y))

    return tv / img.size
    
def get_std(img):
    return np.std(img)

def get_contrast(img):
    return (np.std(img) / np.mean(img))**1

In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import re
import matplotlib.patches as patches
from mpl_toolkits.axes_grid1.anchored_artists import AnchoredSizeBar
import matplotlib.font_manager as fm
from tifffile import imread

# Set output directory and grid size
output_dir = 'H:/workspace/ptyrad/output/paper/tBL_WSe2/20250125_hypertune_GridSampler_convergence_optimization'
rows, cols = 5, 5  # Specify number of rows and columns

all_files = os.listdir(output_dir)
file_names = []
for file in all_files:
    if file.startswith('objp_zsum_crop_08bit_error_0.2'):
        file_names.append(file)
file_names.sort()

# Create figure with GridSpec to adjust spacing
fig = plt.figure(figsize=(7, 7), dpi=300)
gs = gridspec.GridSpec(rows, cols, figure=fig, wspace=0.03, hspace=0.03)  # Adjust wspace & hspace

# Flatten axs if there's more than one row
axs = [fig.add_subplot(gs[i, j]) for i in range(rows) for j in range(cols)]

k = 3  # Select every k images
shadow_offset = [2, 2]
text_offset = [12,9]

# Panel index that requires an colored edge
box_idx0 = 10
box_idx1 = 21

# Scale bar settings
scale_bar_length = 20.08  # Length of the scale bar in pixels (1 px = 0.1494 Ang)
scale_bar_label = "3 $\mathrm{\AA}$"  # Label for the scale bar
scale_bar_color = "white"
fontprops = fm.FontProperties(size=10)

imgs = []
errors = []

# Plot images
for i, file_name in enumerate(file_names[::k]):
    if i >= len(axs):  # Stop if we exceed available panels
        break
    img_path = os.path.join(output_dir, file_name)
    img = center_crop(imread(img_path), 96, 96)
    error = get_error_from_str(file_name)
    
    imgs.append(img) # save it for later evaluation
    errors.append(float(error)) # save it for later evaluation

    axs[i].imshow(img, cmap='gray')
    axs[i].axis('off')

    # Add label with text shadow
    error_value = error
    axs[i].text(
        text_offset[0] + shadow_offset[0], text_offset[1] + shadow_offset[1], error_value,
        color='black', fontsize=12, fontweight='bold',
        va='top', ha='left', alpha=0.6
    )
    axs[i].text(
        text_offset[0], text_offset[1], error_value, color='white', fontsize=12, fontweight='bold',
        va='top', ha='left'
    )

    # Add blue edge to box_idx0
    axs[box_idx0].axis('on')
    axs[box_idx1].axis('on')
    axs[box_idx0].tick_params(left=False, bottom=False, labelleft=False, labelbottom=False)
    axs[box_idx1].tick_params(left=False, bottom=False, labelleft=False, labelbottom=False)
    
    for spine in axs[box_idx0].spines.values():
        spine.set_edgecolor('dodgerblue')
        spine.set_linewidth(2)

    # Add orange edge to box_idx1
    for spine in axs[box_idx1].spines.values():
        spine.set_edgecolor('C1')
        spine.set_linewidth(2)
        
    # Add scale bar
    if i == 0:
        scalebar = AnchoredSizeBar(axs[i].transData, scale_bar_length, scale_bar_label,
                                loc='lower right', pad=0.5, color=scale_bar_color, frameon=False, size_vertical=3, label_top=True,
                                fontproperties=fontprops)
    else:
        scalebar = AnchoredSizeBar(axs[i].transData, scale_bar_length, '',
                                loc='lower right', pad=0.5, color=scale_bar_color, frameon=False, size_vertical=3, label_top=True,
                                fontproperties=fontprops)
    axs[i].add_artist(scalebar)

# Hide unused panels
for j in range(i + 1, len(axs)):
    axs[j].axis('off')

# Add arrows and text labels using fig.text and FancyArrowPatch
fig.text(0.5, 0.91, "Increasing Errors", ha='center', va='bottom', fontsize=14, fontweight='bold')
fig.text(0.08, 0.5, "Increasing Errors", ha='right', va='center', rotation=90, fontsize=14, fontweight='bold')

# Top arrow (from left to right)
arrow_top = patches.FancyArrowPatch((0.2, 0.9), (0.84, 0.9), transform=fig.transFigure, 
                                   arrowstyle="->", lw=3, color="black", mutation_scale=20)
fig.patches.append(arrow_top)

# Left arrow (from top to bottom)
arrow_left = patches.FancyArrowPatch((0.1, 0.8), (0.1, 0.17), transform=fig.transFigure,
                                    arrowstyle="->", lw=3, color="black", mutation_scale=20)
fig.patches.append(arrow_left)

plt.show()


# Try other error metrics

In [None]:
error_std = []
error_total_variation = []
error_contrast = []
for img in imgs:
    error_std.append(get_std(img))
    error_contrast.append(get_contrast(img))
    error_total_variation.append(get_total_variation(img))
    
idx_std = np.argsort(error_std)[::-1]
idx_contrast = np.argsort(error_contrast)[::-1]
idx_total_variation = np.argsort(error_total_variation)[::-1]

In [None]:
for sort_idx, error, title in zip([idx_std, idx_contrast, idx_total_variation], [error_std, error_contrast, error_total_variation], ['std', 'contrast', 'total variation']):

    # Create figure with GridSpec to adjust spacing
    fig = plt.figure(figsize=(7, 7), dpi=300)
    fig.suptitle(f'{title}', fontsize=20, y=1)
    gs = gridspec.GridSpec(rows, cols, figure=fig, wspace=0.03, hspace=0.03)  # Adjust wspace & hspace

    # Flatten axs if there's more than one row
    axs = [fig.add_subplot(gs[i, j]) for i in range(rows) for j in range(cols)]

    shadow_offset = [2, 2]
    text_offset = [12,9]

    # Panel index that requires an colored edge
    box_idx0 = 10
    box_idx1 = 21

    # Scale bar settings
    scale_bar_length = 20.08  # Length of the scale bar in pixels (1 px = 0.1494 Ang)
    scale_bar_label = "3 $\mathrm{\AA}$"  # Label for the scale bar
    scale_bar_color = "white"
    fontprops = fm.FontProperties(size=10)

    errors = np.array(error)[sort_idx]
    
    # Plot images
    for i, img in enumerate(np.array(imgs)[sort_idx]):
        if i >= len(axs):  # Stop if we exceed available panels
            break

        axs[i].imshow(img, cmap='gray')
        axs[i].axis('off')

        error_value = f"{errors[i]:.3f}"
        
        # Add label with text shadow
        axs[i].text(
            text_offset[0] + shadow_offset[0], text_offset[1] + shadow_offset[1], error_value,
            color='black', fontsize=12, fontweight='bold',
            va='top', ha='left', alpha=0.6
        )
        axs[i].text(
            text_offset[0], text_offset[1], error_value, color='white', fontsize=12, fontweight='bold',
            va='top', ha='left'
        )

        # # Add blue edge to box_idx0
        # axs[box_idx0].axis('on')
        # axs[box_idx1].axis('on')
        # axs[box_idx0].tick_params(left=False, bottom=False, labelleft=False, labelbottom=False)
        # axs[box_idx1].tick_params(left=False, bottom=False, labelleft=False, labelbottom=False)
        
        # for spine in axs[box_idx0].spines.values():
        #     spine.set_edgecolor('dodgerblue')
        #     spine.set_linewidth(2)

        # # Add orange edge to box_idx1
        # for spine in axs[box_idx1].spines.values():
        #     spine.set_edgecolor('C1')
        #     spine.set_linewidth(2)
            
        # Add scale bar
        if i == 0:
            scalebar = AnchoredSizeBar(axs[i].transData, scale_bar_length, scale_bar_label,
                                    loc='lower right', pad=0.5, color=scale_bar_color, frameon=False, size_vertical=3, label_top=True,
                                    fontproperties=fontprops)
        else:
            scalebar = AnchoredSizeBar(axs[i].transData, scale_bar_length, '',
                                    loc='lower right', pad=0.5, color=scale_bar_color, frameon=False, size_vertical=3, label_top=True,
                                    fontproperties=fontprops)
        axs[i].add_artist(scalebar)

    # Hide unused panels
    for j in range(i + 1, len(axs)):
        axs[j].axis('off')

    # Add arrows and text labels using fig.text and FancyArrowPatch
    fig.text(0.5, 0.91, "Increasing Errors", ha='center', va='bottom', fontsize=14, fontweight='bold')
    fig.text(0.08, 0.5, "Increasing Errors", ha='right', va='center', rotation=90, fontsize=14, fontweight='bold')

    # Top arrow (from left to right)
    arrow_top = patches.FancyArrowPatch((0.2, 0.9), (0.84, 0.9), transform=fig.transFigure, 
                                    arrowstyle="->", lw=3, color="black", mutation_scale=20)
    fig.patches.append(arrow_top)

    # Left arrow (from top to bottom)
    arrow_left = patches.FancyArrowPatch((0.1, 0.8), (0.1, 0.17), transform=fig.transFigure,
                                        arrowstyle="->", lw=3, color="black", mutation_scale=20)
    fig.patches.append(arrow_left)

    plt.show()