In [None]:
"""Script to compile training images (input, output and predicted) into a PDF for selected global steps."""

import os
import re
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

from matplotlib.backends.backend_pdf import PdfPages
import numpy as np


# Define the folder where images are stored
image_folder = '/hpc/dctrl/ks723/Huggingface_repos/ControlNet_repo/controlnet_repo/image_logs/train/20250130_215944/image_log/train'  # Update this

#  Regex patterns to extract `gs-XXXXXX` (global step) from filenames
pattern_control = re.compile(r"control_gs-(\d+)_e-\d+_b-\d+\.png")
pattern_target = re.compile(r"reconstruction_gs-(\d+)_e-\d+_b-\d+\.png")
pattern_predicted = re.compile(r"samples_cfg_scale_.*?_gs-(\d+)_e-\d+_b-\d+\.png")

#  Dictionaries to store images categorized by `global step`
image_dict = {"control": {}, "target": {}, "predicted": {}}

# Load and categorize images
for filename in sorted(os.listdir(image_folder)):
    control_match = pattern_control.match(filename)
    target_match = pattern_target.match(filename)
    predicted_match = pattern_predicted.match(filename)

    if control_match:
        gs_step = int(control_match.group(1))  # Extract global step
        img_path = os.path.join(image_folder, filename)
        image_dict["control"][gs_step] = Image.open(img_path)
    
    elif target_match:
        gs_step = int(target_match.group(1))
        img_path = os.path.join(image_folder, filename)
        image_dict["target"][gs_step] = Image.open(img_path)
    
    elif predicted_match:
        gs_step = int(predicted_match.group(1))
        img_path = os.path.join(image_folder, filename)
        image_dict["predicted"][gs_step] = Image.open(img_path)

# Get sorted list of matching global steps
sorted_gs_steps = sorted(set(image_dict["control"].keys()) & 
                         set(image_dict["target"].keys()) & 
                         set(image_dict["predicted"].keys()))

#  Group images into trios (Control, Target, Prediction)
trios = [(image_dict["control"][gs], image_dict["target"][gs], image_dict["predicted"][gs]) 
         for gs in sorted_gs_steps]

#  Debugging: Check if trios are found
print(f"Total matched trios: {len(trios)}")

#  Define batch size for plotting (10 trios per plot)
batch_size = 10
num_batches = (len(trios) + batch_size - 1) // batch_size  # Compute number of plots needed





# --- Get sorted global steps (as before) ---
sorted_gs_steps = sorted(set(image_dict["control"].keys()) & 
                         set(image_dict["target"].keys()) & 
                         set(image_dict["predicted"].keys()))

# --- Select only 10 evenly spaced global steps ---
if len(sorted_gs_steps) > 10:
    step_interval = len(sorted_gs_steps) // 10
    selected_steps = sorted_gs_steps[::step_interval]
    selected_steps = selected_steps[:10]  # ensure exactly 10 steps
else:
    selected_steps = sorted_gs_steps

# --- Function to split a composite image into 4 equal parts ---
def split_image(img, parts=4):
    width, height = img.size
    sub_width = width // parts
    return [img.crop((i * sub_width, 0, (i + 1) * sub_width, height)) for i in range(parts)]

# --- Create PDF and plot each selected global step ---
pdf_out = "training_images_simtoexp.pdf"
with PdfPages(pdf_out) as pdf:
    for gs in selected_steps:
        # Retrieve composite images for this global step
        control_img = image_dict["control"][gs]
        predicted_img = image_dict["predicted"][gs]
        target_img = image_dict["target"][gs]
        
        # Split each composite image into 4 horizontal parts
        control_parts = split_image(control_img, parts=4)
        predicted_parts = split_image(predicted_img, parts=4)
        target_parts = split_image(target_img, parts=4)
        
        # Create figure with 4 rows (one per split) and 3 columns
        fig, axes = plt.subplots(4, 3, figsize=(10, 12))
        
        # Set column titles (only for the top row)
        axes[0, 0].set_title("Input")
        axes[0, 1].set_title("Predicted Output")
        axes[0, 2].set_title("Ground Truth")
        
        # Plot each row with the corresponding slice from each composite image
        for i in range(4):
            axes[i, 0].imshow(np.array(control_parts[i]), cmap='gray')
            axes[i, 1].imshow(np.array(predicted_parts[i]), cmap='gray')
            axes[i, 2].imshow(np.array(target_parts[i]), cmap='gray')
            # Remove ticks and axis for each subplot
            for j in range(3):
                axes[i, j].axis("off")
        
        # Add a suptitle with the global step number
        fig.suptitle(f"Global Step: {gs}", fontsize=16)
        plt.tight_layout(rect=[0, 0, 1, 0.95])
        pdf.savefig(fig)
        plt.close(fig)
