In [1]:
# Run this cell to install required libraries
!pip install transformers torch pandas pillow scikit-learn bert_score tqdm


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.2[0m[39;49m -> [0m[32;49m25.3[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


In [2]:
import torch
import pandas as pd
import numpy as np
import os
from PIL import Image
from tqdm.notebook import tqdm
from transformers import CLIPModel, CLIPProcessor
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from sklearn.metrics import average_precision_score
from bert_score import score as bert_scorer
import warnings

# Suppress warnings
warnings.filterwarnings("ignore")

In [3]:
# --- Configuration ---
# Update these paths to match your PSC directory structure
IMAGE_DIR = '/ocean/projects/cis250019p/gandotra/11785-gp-eeg/images'
CAPTIONS_FILE = '/ocean/projects/cis250019p/gandotra/11785-gp-eeg/captions.txt'
MODEL_ID = 'openai/clip-vit-base-patch32'
BATCH_SIZE = 128  # Adjust based on your GPU memory

# --- Device Setup ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [4]:
# Load the processor (handles tokenization and image preprocessing)
processor = CLIPProcessor.from_pretrained(MODEL_ID)

# Load the pretrained model
model = CLIPModel.from_pretrained(MODEL_ID).to(device)

# Set model to evaluation mode (disables dropout, etc.)
model.eval()

print("CLIP model and processor loaded.")

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


CLIP model and processor loaded.


In [5]:
print(f"Loading captions from: {CAPTIONS_FILE}")

all_image_names = []
all_captions = []
all_categories = []

# Use the loading logic from your EEGMultimodalDataset._load_captions
# This parses the tab-separated file correctly.
captions_data = {}
try:
    with open(CAPTIONS_FILE, 'r') as f:
        header = next(f) # Skip header
        
        for line in f:
            parts = line.strip().split('\t')
            if len(parts) >= 4:
                # Based on your _load_captions method:
                category = parts[1]
                img_name = parts[2]
                caption = '\t'.join(parts[3:]).strip()
                
                # Store in a dict to get unique image_name -> (caption, category) pairs
                # This matches the set of ~10,000 unique images
                captions_data[img_name] = (caption, category)
    
    # Now, convert the unique dictionary items into parallel lists
    # The order is fixed and will match for images and text
    for img_name, (caption, category) in captions_data.items():
        all_image_names.append(img_name)
        all_captions.append(caption)
        all_categories.append(category)

    N = len(all_captions)
    unique_categories = sorted(list(set(all_categories)))
    
    # Create mappings for class-aware metrics
    category_to_idx = {cat: i for i, cat in enumerate(unique_categories)}
    all_category_indices = [category_to_idx[cat] for cat in all_categories]
    
    print(f"Loaded {N} unique image/caption pairs.")
    print(f"Found {len(unique_categories)} unique categories.")

except Exception as e:
    print(f"Error loading {CAPTIONS_FILE}: {e}")
    print("Please ensure the file path is correct and it is a TAB-separated file.")

Loading captions from: /ocean/projects/cis250019p/gandotra/11785-gp-eeg/captions.txt
Loaded 9825 unique image/caption pairs.
Found 20 unique categories.


In [6]:
all_text_embeds = []
with torch.no_grad():
    # Process in batches
    for i in tqdm(range(0, N, BATCH_SIZE), desc="Generating text embeddings"):
        batch_captions = all_captions[i : i + BATCH_SIZE]
        
        # Tokenize the text
        inputs = processor(text=batch_captions, return_tensors="pt", padding=True, truncation=True, max_length=77).to(device)
        
        # Get text features
        text_features = model.get_text_features(**inputs)
        
        # Normalize the embeddings (L2-norm)
        text_features = F.normalize(text_features, p=2, dim=-1)
        
        all_text_embeds.append(text_features.cpu())

# Concatenate all batch embeddings into a single tensor
text_embeds = torch.cat(all_text_embeds, dim=0)
print(f"Text embeddings shape: {text_embeds.shape}")

Generating text embeddings:   0%|          | 0/77 [00:00<?, ?it/s]

Text embeddings shape: torch.Size([9825, 512])


In [7]:
class ImageDataset(Dataset):
    """Custom Dataset for loading images."""
    def __init__(self, image_names, image_dir, processor):
        self.image_names = image_names
        self.image_dir = image_dir
        self.processor = processor
        self.extensions = ['.jpg', '.jpeg', '.png', '.JPEG'] # From your _find_image_path

    def __len__(self):
        return len(self.image_names)

    def _find_image_path(self, img_base_name):
        """Find image file with any extension"""
        for ext in self.extensions:
            img_path = os.path.join(self.image_dir, img_base_name + ext)
            if os.path.exists(img_path):
                return img_path
        # If no extension worked, try the original name (in case it had .jpg)
        img_path = os.path.join(self.image_dir, img_base_name)
        if os.path.exists(img_path):
            return img_path
        return None

    def __getitem__(self, idx):
        img_base_name = self.image_names[idx]
        img_path = self._find_image_path(img_base_name)
        
        if img_path is None:
            print(f"Warning: Could not find image file for {img_base_name}")
            return torch.zeros((3, 224, 224))
            
        try:
            image = Image.open(img_path).convert("RGB")
            # Process the image
            processed_image = self.processor(images=image, return_tensors="pt")['pixel_values'].squeeze(0)
            return processed_image
        except Exception as e:
            print(f"Warning: Could not load image {img_path}. Error: {e}")
            # Return a tensor of zeros if image is corrupt
            return torch.zeros((3, 224, 224))

# Create Dataset and DataLoader
image_dataset = ImageDataset(all_image_names, IMAGE_DIR, processor)
image_loader = DataLoader(image_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4, pin_memory=True)

all_image_embeds = []
with torch.no_grad():
    for batch_images in tqdm(image_loader, desc="Generating image embeddings"):
        batch_images = batch_images.to(device)
        
        # Get image features
        image_features = model.get_image_features(pixel_values=batch_images)
        
        # Normalize the embeddings (L2-norm)
        image_features = F.normalize(image_features, p=2, dim=-1)
        
        all_image_embeds.append(image_features.cpu())

# Concatenate all batch embeddings into a single tensor
image_embeds = torch.cat(all_image_embeds, dim=0)
print(f"Image embeddings shape: {image_embeds.shape}")

Generating image embeddings:   0%|          | 0/77 [00:00<?, ?it/s]

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Av

Image embeddings shape: torch.Size([9825, 512])


In [8]:
# Move embeddings to GPU for fast matrix multiplication
image_embeds_gpu = image_embeds.to(device)
text_embeds_gpu = text_embeds.to(device)

# Calculate the cosine similarity matrix (N_images x N_captions)
# sim = (image_embeds @ text_embeds.T)
print("Calculating similarity matrix...")
similarity_matrix = (image_embeds_gpu @ text_embeds_gpu.T).cpu()

print(f"Similarity matrix shape: {similarity_matrix.shape}")

Calculating similarity matrix...
Similarity matrix shape: torch.Size([9825, 9825])


In [9]:
print("--- ðŸ“ˆ Recall@K Evaluation ---")

# Get the indices of the top 5 most similar captions for each image
# dim=1 means we are ranking along the caption dimension
top_k_indices = torch.topk(similarity_matrix, k=5, dim=1).indices

# Ground truth is that image 'i' matches caption 'i'
gt_indices = torch.arange(N).unsqueeze(1)

# --- Instance-Level Recall@K ---
# Check if the correct caption index is in the top-k results
r1_hits = (top_k_indices[:, :1] == gt_indices).any(dim=1).float().sum()
r3_hits = (top_k_indices[:, :3] == gt_indices).any(dim=1).float().sum()
r5_hits = (top_k_indices[:, :5] == gt_indices).any(dim=1).float().sum()

print(f"\nInstance-Level Recall (Exact Match):")
print(f"Recall@1: {100 * r1_hits / N:.2f}%")
print(f"Recall@3: {100 * r3_hits / N:.2f}%")
print(f"Recall@5: {100 * r5_hits / N:.2f}%")

# --- Class-Aware Recall@K ---
# We need to check if any of the top-k retrieved captions belong to the same class as the ground truth
gt_categories = torch.tensor(all_category_indices, dtype=torch.long)
retrieved_categories = torch.tensor(all_category_indices)[top_k_indices]

# Check if the correct category index is in the top-k retrieved categories
class_r1_hits = (retrieved_categories[:, :1] == gt_categories.unsqueeze(1)).any(dim=1).float().sum()
class_r3_hits = (retrieved_categories[:, :3] == gt_categories.unsqueeze(1)).any(dim=1).float().sum()
class_r5_hits = (retrieved_categories[:, :5] == gt_categories.unsqueeze(1)).any(dim=1).float().sum()

print(f"\nClass-Aware Recall (Correct Category):")
print(f"Recall@1: {100 * class_r1_hits / N:.2f}%")
print(f"Recall@3: {100 * class_r3_hits / N:.2f}%")
print(f"Recall@5: {100 * class_r5_hits / N:.2f}%")

--- ðŸ“ˆ Recall@K Evaluation ---

Instance-Level Recall (Exact Match):
Recall@1: 19.77%
Recall@3: 32.71%
Recall@5: 39.49%

Class-Aware Recall (Correct Category):
Recall@1: 96.99%
Recall@3: 98.25%
Recall@5: 98.51%


In [10]:
print("\n--- ðŸ“ˆ Mean Average Precision (MAP) Evaluation ---")

# --- Caption-Level MAP (Instance) ---
# y_true[i, j] = 1 if j is the correct caption for image i, else 0
y_true_instance = np.eye(N, dtype=int)
aps_instance = []
for i in range(N):
    y_true_i = y_true_instance[i]
    y_score_i = similarity_matrix[i].numpy()
    aps_instance.append(average_precision_score(y_true_i, y_score_i))

map_instance = np.mean(aps_instance)
print(f"Caption-Level MAP (Instance): {map_instance:.4f}")

# --- Class-Aware MAP ---
# y_true[i, j] = 1 if caption j is in the same class as image i, else 0
y_true_class = np.zeros((N, N), dtype=int)
gt_categories_np = gt_categories.numpy()
for i in range(N):
    correct_class_idx = gt_categories_np[i]
    # Find all captions with the same class
    relevant_indices = np.where(gt_categories_np == correct_class_idx)[0]
    y_true_class[i, relevant_indices] = 1

aps_class = []
for i in range(N):
    y_true_i = y_true_class[i]
    y_score_i = similarity_matrix[i].numpy()
    aps_class.append(average_precision_score(y_true_i, y_score_i))

map_class_aware = np.mean(aps_class)
print(f"Class-Aware MAP: {map_class_aware:.4f}")

# --- Per-Class MAP ---
df_results = pd.DataFrame({
    'category': all_categories,
    'ap_class': aps_class
})
map_per_class = df_results.groupby('category')['ap_class'].mean()
print(f"\nPer-Class MAP (mean of class MAPs): {map_per_class.mean():.4f}")
print("Per-Class MAP (details):")
print(map_per_class)


--- ðŸ“ˆ Mean Average Precision (MAP) Evaluation ---
Caption-Level MAP (Instance): 0.2978
Class-Aware MAP: 0.8305

Per-Class MAP (mean of class MAPs): 0.8308
Per-Class MAP (details):
category
aeroplane      0.952461
bicycle        0.923959
bird           0.835170
boat           0.871740
bottle         0.906255
bus            0.927476
car            0.690873
cat            0.816477
chair          0.727544
cow            0.760897
diningtable    0.932633
dog            0.721748
flower         0.965426
horse          0.943866
motorbike      0.768518
person         0.532131
sheep          0.721595
sofa           0.741440
train          0.920891
tvmonitor      0.954820
Name: ap_class, dtype: float64


In [11]:
print("\n--- ðŸ“ˆ BERTScore Evaluation ---")

# Get the top-1 retrieved caption for each image
top_1_indices = top_k_indices[:, 0].tolist()
retrieved_captions = [all_captions[i] for i in top_1_indices]
ground_truth_captions = all_captions # Assumes order is matched

# ----- MODIFIED LINE -----
# We specify a smaller, faster model (DistilBERT)
# This will download a ~250MB model instead of the 1.42GB one.
(P, R, F1) = bert_scorer(retrieved_captions, ground_truth_captions,
                       model_type='distilbert-base-uncased', 
                       lang='en', 
                       verbose=True)
# -------------------------

print(f"\nAverage BERTScore F1 (Top-1 vs. GT): {F1.mean():.4f}")

# Report success rate for F1 > 0.7
success_rate = (F1 > 0.7).float().mean()
print(f"Success Rate (F1 > 0.7): {100 * success_rate:.2f}%")


--- ðŸ“ˆ BERTScore Evaluation ---


tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

calculating scores...
computing bert embedding.


  0%|          | 0/138 [00:00<?, ?it/s]

computing greedy matching.


  0%|          | 0/154 [00:00<?, ?it/s]

done in 4.86 seconds, 2021.29 sentences/sec

Average BERTScore F1 (Top-1 vs. GT): 0.8874
Success Rate (F1 > 0.7): 99.97%


In [12]:
print("\n--- ðŸ“ˆ CLIPScore Distribution Analysis ---")

# Matched scores are the diagonal of the similarity matrix
matched_scores = similarity_matrix.diag()

# Mismatched scores are all off-diagonal elements
# Create a mask to select only off-diagonal elements
mask = torch.ones_like(similarity_matrix, dtype=torch.bool)
mask.fill_diagonal_(False)
mismatched_scores = similarity_matrix[mask]

print(f"\nMatched (Correct) Pairs:")
print(f"  Mean: {matched_scores.mean():.4f}")
print(f"  Std:  {matched_scores.std():.4f}")
print(f"  Min:  {matched_scores.min():.4f}")
print(f"  Max:  {matched_scores.max():.4f}")

print(f"\nMismatched (Incorrect) Pairs:")
print(f"  Mean: {mismatched_scores.mean():.4f}")
print(f"  Std:  {mismatched_scores.std():.4f}")
print(f"  Min:  {mismatched_scores.min():.4f}")
print(f"  Max:  {mismatched_scores.max():.4f}")

print("\nAnalysis: A good separation between the matched and mismatched mean scores indicates")
print("that CLIP can effectively distinguish correct pairs from incorrect ones.")


--- ðŸ“ˆ CLIPScore Distribution Analysis ---

Matched (Correct) Pairs:
  Mean: 0.3021
  Std:  0.0304
  Min:  0.1528
  Max:  0.4579

Mismatched (Incorrect) Pairs:
  Mean: 0.1446
  Std:  0.0430
  Min:  -0.0691
  Max:  0.4197

Analysis: A good separation between the matched and mismatched mean scores indicates
that CLIP can effectively distinguish correct pairs from incorrect ones.
