In [11]:
import torch
import torch.nn as nn
import clip
from torchvision import datasets
from torch.utils.data import DataLoader, SubsetRandomSampler
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
import os
import random
from collections import OrderedDict # To handle DataParallel state dict keys

In [47]:
# --- Parameters to set ---
# Path to the saved checkpoint file
# Make sure this matches the output of your training script (e.g., ending in _linear.pth or _finetuned.pth)
CHECKPOINT_PATH = '/fast/pmayilvahanan/aesthetic-predictor/clip_model_trained_finetuned.pth' # OR './clip_model_trained_finetuned.pth'
#CHECKPOINT_PATH = '/fast/pmayilvahanan/aesthetic-predictor/clip_model_trained_linear_readout.pth' # OR './clip_model_trained_finetuned.pth'

# CLIP model architecture used during training (must match the checkpoint)
CLIP_MODEL_NAME = "ViT-L/14" # e.g., "ViT-B/32", "ViT-L/14"

# Directory containing the test images (organized by class: good, neutral, bad)
DATA_DIR = '/fast/pmayilvahanan/aesthetic-predictor/data/images_by_rating'
TEST_DIR = os.path.join(DATA_DIR, 'test')

# Number of classes
NUM_CLASSES = 3

# Number of random images to visualize
NUM_IMAGES_TO_PLOT = 75

# Batch size for inference (can be larger than training)
BATCH_SIZE = 64

# Device to use
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# ---

## Load CLIP Model and Preprocessor

In [48]:
print(f"Loading CLIP model: {CLIP_MODEL_NAME} on device: {DEVICE}")
try:
    # Load base CLIP model - always load this to the target device
    base_model, preprocess = clip.load(CLIP_MODEL_NAME, device=DEVICE)
except Exception as e:
    print(f"Error loading CLIP model: {e}")
    print("Please ensure internet connectivity and 'clip-by-openai' installation.")
    # Stop execution if CLIP fails to load
    raise

print("CLIP model loaded successfully.")
feature_dim = base_model.visual.output_dim
print(f"CLIP image feature dimension: {feature_dim}")

Loading CLIP model: ViT-L/14 on device: cuda
CLIP model loaded successfully.
CLIP image feature dimension: 768


In [49]:
# Determine if the checkpoint is from full fine-tuning or linear probe
is_finetuned = '_finetuned' in os.path.basename(CHECKPOINT_PATH)

# Define the structure where weights will be loaded
if is_finetuned:
    print("Loading checkpoint for a fully fine-tuned model.")
    # Recreate the combined model structure used during fine-tuning
    class CLIPWithClassifier(nn.Module):
        def __init__(self, clip_model, num_classes):
            super().__init__()
            self.clip_model = clip_model
            # Recreate the classifier head
            self.classifier = nn.Linear(clip_model.visual.output_dim, num_classes)

        def forward(self, images):
            # Ensure model components are on the correct device inside forward if needed
            # Although they should be moved beforehand
            features = self.clip_model.encode_image(images).float()
            return self.classifier(features)

    # Instantiate the combined model structure, loading the base CLIP model into it
    eval_model = CLIPWithClassifier(base_model, NUM_CLASSES)
else:
    print("Loading checkpoint for a linear readout model.")
    # Only need the linear classifier head
    eval_model = nn.Linear(feature_dim, NUM_CLASSES)
    # Base model will be used separately to get features
    base_model.eval() # Ensure base model is in eval mode

# Load the state dictionary
if not os.path.exists(CHECKPOINT_PATH):
    print(f"Error: Checkpoint file not found at {CHECKPOINT_PATH}")
    raise FileNotFoundError

print(f"Loading state dict from: {CHECKPOINT_PATH}")
try:
    state_dict = torch.load(CHECKPOINT_PATH, map_location='cpu') # Load to CPU first

    # Handle potential DataParallel prefix ('module.')
    if all(key.startswith('module.') for key in state_dict.keys()):
        print("Detected DataParallel prefix in state_dict keys. Removing 'module.' prefix.")
        new_state_dict = OrderedDict()
        for k, v in state_dict.items():
            name = k[7:] # remove `module.`
            new_state_dict[name] = v
        state_dict = new_state_dict

    eval_model.load_state_dict(state_dict)
    print("State dict loaded successfully.")
except Exception as e:
    print(f"Error loading state dict: {e}")
    print("Ensure the checkpoint matches the specified CLIP model and training mode (linear/finetuned).")
    raise

# Move the evaluation model to the target device and set to eval mode
eval_model.to(DEVICE)
eval_model.eval()

Loading checkpoint for a fully fine-tuned model.
Loading state dict from: /fast/pmayilvahanan/aesthetic-predictor/clip_model_trained_finetuned.pth


  state_dict = torch.load(CHECKPOINT_PATH, map_location='cpu') # Load to CPU first


State dict loaded successfully.


CLIPWithClassifier(
  (clip_model): CLIP(
    (visual): VisionTransformer(
      (conv1): Conv2d(3, 1024, kernel_size=(14, 14), stride=(14, 14), bias=False)
      (ln_pre): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
      (transformer): Transformer(
        (resblocks): Sequential(
          (0): ResidualAttentionBlock(
            (attn): MultiheadAttention(
              (out_proj): NonDynamicallyQuantizableLinear(in_features=1024, out_features=1024, bias=True)
            )
            (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
            (mlp): Sequential(
              (c_fc): Linear(in_features=1024, out_features=4096, bias=True)
              (gelu): QuickGELU()
              (c_proj): Linear(in_features=4096, out_features=1024, bias=True)
            )
            (ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
          )
          (1): ResidualAttentionBlock(
            (attn): MultiheadAttention(
              (out_proj): NonD

In [50]:
if not os.path.isdir(TEST_DIR):
    print(f"Error: Test directory not found at {TEST_DIR}")
    # You might want to stop execution here
    raise FileNotFoundError

# We need two versions of the dataset:
# 1. With CLIP preprocessing for model inference
# 2. Without preprocessing (or just basic ToTensor) for visualization

# Dataset for inference
test_dataset_processed = datasets.ImageFolder(TEST_DIR, transform=preprocess)

# Dataset for visualization (load original images)
# We'll load images directly by path later for plotting
test_dataset_viz = datasets.ImageFolder(TEST_DIR, transform=None) # No transform initially

# Check if dataset loaded correctly
if len(test_dataset_processed) == 0:
    print(f"Error: No images found in {TEST_DIR}")
    raise ValueError

print(f"Loaded {len(test_dataset_processed)} images from {TEST_DIR}.")
# Get class names
class_names = test_dataset_processed.classes
print(f"Classes: {class_names}")

Loaded 141 images from /fast/pmayilvahanan/aesthetic-predictor/data/images_by_rating/test.
Classes: ['bad', 'good', 'neutral']


In [51]:
# Select random indices
num_test_images = len(test_dataset_processed)
num_to_select = min(NUM_IMAGES_TO_PLOT, num_test_images)

random_indices = random.sample(range(num_test_images), num_to_select)

selected_images_processed = []
selected_images_original = []
true_labels = []
image_paths = []

print(f"Selecting {num_to_select} random images for evaluation...")
for idx in random_indices:
    # Get processed image and label for model input
    processed_img, label_idx = test_dataset_processed[idx]
    selected_images_processed.append(processed_img)
    true_labels.append(class_names[label_idx])

    # Get original image path for visualization
    original_path, _ = test_dataset_viz.samples[idx]
    image_paths.append(original_path)
    try:
        # Load the original image using PIL
        original_img = Image.open(original_path).convert("RGB")
        selected_images_original.append(original_img)
    except Exception as e:
        print(f"Warning: Could not load original image {original_path} for visualization: {e}")
        selected_images_original.append(None) # Placeholder if loading fails

# Stack processed images into a batch
image_batch = torch.stack(selected_images_processed).to(DEVICE)

predicted_labels = []
print("Running inference...")
with torch.no_grad():
    if is_finetuned:
        # Pass directly through the combined fine-tuned model
        outputs = eval_model(image_batch)
    else:
        # Linear probe: get features first, then classify
        image_features = base_model.encode_image(image_batch).float()
        outputs = eval_model(image_features)

    # Get predicted class indices
    _, predicted_indices = torch.max(outputs, 1)
    predicted_labels = [class_names[i] for i in predicted_indices.cpu().numpy()]

print("Inference complete.")

Selecting 75 random images for evaluation...
Running inference...
Inference complete.


In [52]:
# Determine grid size for plotting
num_cols = 5
num_rows = (num_to_select + num_cols - 1) // num_cols

# Create the figure object
fig = plt.figure(figsize=(num_cols * 3, num_rows * 3.5)) # Adjust figsize as needed

for i in range(num_to_select):
    plt.subplot(num_rows, num_cols, i + 1)
    img_to_show = selected_images_original[i]

    if img_to_show:
        plt.imshow(img_to_show)
        title = f"True: {true_labels[i]}\nPred: {predicted_labels[i]}"
        # Color title red if prediction is wrong
        title_color = 'green' if true_labels[i] == predicted_labels[i] else 'red'
        plt.title(title, color=title_color, fontsize=9)
    else:
        # Handle cases where original image failed to load
        plt.text(0.5, 0.5, 'Image Load Failed', horizontalalignment='center', verticalalignment='center')
        plt.title(f"True: {true_labels[i]}\nPred: {predicted_labels[i]}", fontsize=9)

    plt.xticks([])
    plt.yticks([])

plt.tight_layout()

# --- Save the figure ---
output_filename = "/fast/pmayilvahanan/aesthetic-predictor/predictions_visualization_finetuned.png"
try:
    plt.savefig(output_filename, dpi=150) # Save the figure to a file (adjust dpi as needed)
    print(f"Plot saved to {output_filename}")
except Exception as e:
    print(f"Error saving plot: {e}")

# --- Close the plot figure ---
plt.close(fig) # Close the figure to free memory

# If you still want to display it inline in a notebook after saving, uncomment the next line
# plt.show()

Plot saved to /fast/pmayilvahanan/aesthetic-predictor/predictions_visualization_finetuned.png


In [53]:
np.sum(np.array(predicted_labels) == np.array(true_labels)) / len(true_labels)

0.7866666666666666