In [None]:
# Enable auto-reloading of modules
%load_ext autoreload
%autoreload 2

# Enable inline plotting for Jupyter notebooks
%matplotlib inline


import math
import sys

import matplotlib.pyplot as plt
import torch
import torchvision.transforms
from icecream import ic
from PIL import Image
from torchvision.transforms import Compose, Normalize, ToTensor
from tqdm import tqdm

sys.path.append("../../..")

from src.annotated.croco.croco import AnnotatedCroCo
from src.annotated.losses.masked_mse import AnnotatedMaskedMSE

ic.disable()


def visualize_croco_output(model, out, mask, image1, image2, imagenet_std_tensor, imagenet_mean_tensor):
    """
    Create visualization for CroCo model output.

    Args:
        model: The CroCo model instance
        out: Model output tensor
        mask: Mask tensor from model
        image1: Input image tensor
        image2: Reference image tensor
        imagenet_std_tensor: ImageNet standardization tensor
        imagenet_mean_tensor: ImageNet mean tensor

    Returns:
        PIL Image containing the visualization
    """
    # Debug: Process model output to create visualization

    # the output is normalized, thus use the mean/std of the actual image to go back to RGB space
    patchified = model.patchify(image1)
    mean = patchified.mean(dim=-1, keepdim=True)
    var = patchified.var(dim=-1, keepdim=True)
    decoded_image = model.unpatchify(out * (var + 1.0e-6) ** 0.5 + mean)
    # undo imagenet normalization, prepare masked image
    decoded_image = decoded_image * imagenet_std_tensor + imagenet_mean_tensor
    input_image = image1 * imagenet_std_tensor + imagenet_mean_tensor
    ref_image = image2 * imagenet_std_tensor + imagenet_mean_tensor
    image_masks = model.unpatchify(model.patchify(torch.ones_like(ref_image)) * mask[:, :, None])
    masked_input_image = (1 - image_masks) * input_image

    # make visualization
    visualization = torch.cat(
        (ref_image, masked_input_image, decoded_image, input_image), dim=3
    )  # 4*(B, 3, H, W) -> B, 3, H, W*4
    B, C, H, W = visualization.shape
    visualization = visualization.permute(1, 0, 2, 3).reshape(C, B * H, W)
    visualization = torchvision.transforms.functional.to_pil_image(torch.clamp(visualization, 0, 1))

    return visualization

    # # Convert to PIL image for display/saving
    # return Image.fromarray((visualization.cpu().numpy() * 255).astype("uint8"))


def display(img):
    """
    Display an image in the notebook.

    Args:
        img: PIL Image to display
    """
    # Debug: Custom display function for notebook visualization
    plt.figure(figsize=(12, 6))
    plt.imshow(img)
    plt.axis("off")
    plt.tight_layout()
    plt.show()


# Debug: Check for MPS (Metal Performance Shaders) availability for Apple Silicon (M-series)
# device = torch.device(
#     "mps"
#     if torch.backends.mps.is_available()
#     else ("cuda:0" if torch.cuda.is_available() and torch.cuda.device_count() > 0 else "cpu")
# )
device = "cpu"
print(f"Using device: {device}")  # Debug output to confirm device selection

model = AnnotatedCroCo(img_size=224, patch_size=16, pos_embed="RoPE100").to(device)
loss_fn = AnnotatedMaskedMSE(norm_pix_loss=True, masked=True)
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001)

# Create sample inputs
imagenet_mean = [0.485, 0.456, 0.406]
imagenet_mean_tensor = torch.tensor(imagenet_mean).view(1, 3, 1, 1).to(device, non_blocking=True)
imagenet_std = [0.229, 0.224, 0.225]
imagenet_std_tensor = torch.tensor(imagenet_std).view(1, 3, 1, 1).to(device, non_blocking=True)
trfs = Compose([ToTensor(), Normalize(mean=imagenet_mean, std=imagenet_std)])
image1 = trfs(Image.open("data/Chateau1.png").convert("RGB")).to(device, non_blocking=True).unsqueeze(0)
image2 = trfs(Image.open("data/Chateau2.png").convert("RGB")).to(device, non_blocking=True).unsqueeze(0)


# Forward pass
loss_value = math.inf

# Initialize progress bar and step counter
max_steps = 5000
plot_every = 10
pbar = tqdm(desc="Training", leave=True, total=max_steps)
step = -1

# Debug: Training loop with visualization every 5 steps
while loss_value > 0.01 and step < max_steps:
    # Forward pass
    out, mask, target = model(image1, image2)
    loss = loss_fn(out, mask, target)

    # Backward pass
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

    # Update loss value and progress bar
    loss_value = loss.item()
    pbar.set_postfix({"loss": f"{loss_value:.4f}"})
    pbar.update(1)

    # Increment step counter
    step += 1

    # Generate and display visualization every 100 steps
    if step % plot_every == 0:
        with torch.inference_mode():
            out, mask, target = model(image1, image2)

        # Generate visualization using the function
        step_visualization = visualize_croco_output(
            model, out, mask, image1, image2, imagenet_std_tensor, imagenet_mean_tensor
        )

        # Display the visualization
        display(step_visualization)
        print(f"Step {step}, Loss: {loss_value:.4f}")

pbar.close()