In [None]:
%load_ext autoreload
%autoreload 2

import os
from pathlib import Path

# Change to project root
os.chdir("..")


In [None]:
import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from dotenv import load_dotenv

from src.train_state import TrainState

load_dotenv()

# ============ CONFIGURATION ============
# Change this to your run directory
RUN_DIRECTORY = "output/runs/simple_cnn@SimpleCNN_CIFAR10_20251201_214440"
# ======================================

print(f"Loading model from: {RUN_DIRECTORY}")

In [None]:
# Load the model from run directory (loads best_model.pt by default)
train_state = TrainState.from_checkpoint_dir(
    run_directory=Path(RUN_DIRECTORY),
    rank=0,
    world_size=1,
    local_rank=0,
    epoch=None,  # None = load best model
)

# Get model and set to eval mode
model = train_state.model
model.eval()

# Get validation loader
val_loader = train_state.val_loader

print(f"Model loaded successfully!")
print(f"Dataset: {train_state.config.dataset.name}")
print(f"Architecture: {train_state.config.arch.model}")

In [None]:
# Get predictions on validation data
print("Getting predictions on validation set...")

all_images = []
all_labels = []
all_predictions = []
all_confidences = []

device = train_state.device

with torch.no_grad():
    for images, labels in val_loader:
        images, labels = images.to(device), labels.to(device)

        # Get model predictions
        outputs = model(images)
        probs = F.softmax(outputs, dim=1)

        # Get predicted class and confidence
        confidences, predictions = torch.max(probs, dim=1)

        # Store results
        all_images.append(images.cpu())
        all_labels.append(labels.cpu())
        all_predictions.append(predictions.cpu())
        all_confidences.append(confidences.cpu())

# Concatenate all batches
all_images = torch.cat(all_images, dim=0)
all_labels = torch.cat(all_labels, dim=0)
all_predictions = torch.cat(all_predictions, dim=0)
all_confidences = torch.cat(all_confidences, dim=0)

# Calculate accuracy
correct = (all_predictions == all_labels).sum().item()
total = len(all_labels)
accuracy = 100.0 * correct / total

print(f"Validation Accuracy: {accuracy:.2f}%")
print(f"Total samples: {total}")

In [None]:
# Find most confident correct predictions
correct_mask = all_predictions == all_labels
correct_indices = torch.where(correct_mask)[0]
correct_confidences = all_confidences[correct_indices]

# Get top 5 most confident correct predictions
top5_correct_indices = correct_indices[
    torch.argsort(correct_confidences, descending=True)[:5]
]

# Find most confident incorrect predictions
incorrect_mask = ~correct_mask
incorrect_indices = torch.where(incorrect_mask)[0]
incorrect_confidences = all_confidences[incorrect_indices]

# Get top 5 most confident incorrect predictions
top5_incorrect_indices = incorrect_indices[
    torch.argsort(incorrect_confidences, descending=True)[:5]
]

print(f"Found {len(correct_indices)} correct predictions")
print(f"Found {len(incorrect_indices)} incorrect predictions")

In [None]:
# Get class names directly from the dataset
class_names = train_state.val_dataset.classes
print(f"Dataset has {len(class_names)} classes")


def denormalize_image(img, mean=[0.4914, 0.4822, 0.4465], std=[0.2470, 0.2435, 0.2616]):
    """Denormalize image for visualization."""
    img = img.clone()
    for t, m, s in zip(img, mean, std):
        t.mul_(s).add_(m)
    return torch.clamp(img, 0, 1)


def plot_predictions(indices, title):
    """Plot images with their predictions."""
    fig, axes = plt.subplots(1, 5, figsize=(15, 3))
    fig.suptitle(title, fontsize=16, y=1.05)

    for idx, ax in enumerate(axes):
        img_idx = indices[idx].item()

        # Get image and denormalize
        img = all_images[img_idx]
        img = denormalize_image(img)
        img = img.permute(1, 2, 0).numpy()

        # Get labels and confidence
        true_label = all_labels[img_idx].item()
        pred_label = all_predictions[img_idx].item()
        confidence = all_confidences[img_idx].item()

        # Plot
        ax.imshow(img)
        ax.axis("off")

        # Title with true and predicted labels
        true_name = class_names[true_label]
        pred_name = class_names[pred_label]

        if true_label == pred_label:
            color = "green"
            title_text = f"✓ {true_name}\nConf: {confidence:.2%}"
        else:
            color = "red"
            title_text = (
                f"✗ True: {true_name}\nPred: {pred_name}\nConf: {confidence:.2%}"
            )

        ax.set_title(title_text, fontsize=10, color=color)

    plt.tight_layout()
    plt.show()

In [None]:
# Plot 5 most confident CORRECT predictions
plot_predictions(top5_correct_indices, "5 Most Confident CORRECT Predictions")

In [None]:
# Plot 5 most confident INCORRECT predictions
plot_predictions(top5_incorrect_indices, "5 Most Confident INCORRECT Predictions")

In [None]:
# Confidence statistics
print("=" * 60)
print("CONFIDENCE STATISTICS")
print("=" * 60)

print(f"\nCorrect Predictions:")
print(f"  Count: {len(correct_indices)}")
print(f"  Mean Confidence: {correct_confidences.mean():.2%}")
print(f"  Max Confidence: {correct_confidences.max():.2%}")
print(f"  Min Confidence: {correct_confidences.min():.2%}")

print(f"\nIncorrect Predictions:")
print(f"  Count: {len(incorrect_indices)}")
print(f"  Mean Confidence: {incorrect_confidences.mean():.2%}")
print(f"  Max Confidence: {incorrect_confidences.max():.2%}")
print(f"  Min Confidence: {incorrect_confidences.min():.2%}")

# Plot confidence distributions
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

axes[0].hist(
    correct_confidences.numpy(), bins=50, alpha=0.7, color="green", edgecolor="black"
)
axes[0].set_title("Confidence Distribution - Correct Predictions")
axes[0].set_xlabel("Confidence")
axes[0].set_ylabel("Count")
axes[0].axvline(
    correct_confidences.mean(),
    color="darkgreen",
    linestyle="--",
    linewidth=2,
    label=f"Mean: {correct_confidences.mean():.2%}",
)
axes[0].legend()

axes[1].hist(
    incorrect_confidences.numpy(), bins=50, alpha=0.7, color="red", edgecolor="black"
)
axes[1].set_title("Confidence Distribution - Incorrect Predictions")
axes[1].set_xlabel("Confidence")
axes[1].set_ylabel("Count")
axes[1].axvline(
    incorrect_confidences.mean(),
    color="darkred",
    linestyle="--",
    linewidth=2,
    label=f"Mean: {incorrect_confidences.mean():.2%}",
)
axes[1].legend()

plt.tight_layout()
plt.show()