In [None]:
import os
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModel
from datasets import load_dataset
from tqdm.auto import tqdm
import numpy as np

# Universal Device Selector
# On your Mac M4, this MUST print "Using Device: mps"
if torch.cuda.is_available():
    device = "cuda"
elif torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cpu"

print(f"Using Device: {device}")

# Set seeds for reproducibility
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

In [None]:
!pip install matplotlib

In [None]:
# Setup Paths
PROJECT_ROOT = os.getcwd()
FEAT_DIR = os.path.join(PROJECT_ROOT, "features") 

print(f"Looking for features in: {FEAT_DIR}")

try:
    # Load the Global Features
    # map_location=device ensures they load directly to the Mac's unified memory
    img_feats_train = torch.load(os.path.join(FEAT_DIR, "flickr30k_train_global.pt"), map_location=device)
    img_feats_val   = torch.load(os.path.join(FEAT_DIR, "flickr30k_val_global.pt"), map_location=device)
    img_feats_test  = torch.load(os.path.join(FEAT_DIR, "flickr30k_test_global.pt"), map_location=device)
    
    print("Features Loaded Successfully!")
    print(f"Train Shape: {img_feats_train.shape}")
    print(f"Val Shape:   {img_feats_val.shape}")
    print(f"Test Shape:  {img_feats_test.shape}")

except FileNotFoundError:
    print("\nERROR: Files not found.")
    print(f"I looked in: {FEAT_DIR}")
    print("Please create a folder named 'features' and put your .pt files inside it.")

In [None]:
print("Downloading Flickr30k text data from Hugging Face...")

DATA_FILES = [
    "https://huggingface.co/datasets/nlphuji/flickr30k/resolve/refs%2Fconvert%2Fparquet/TEST/test/0000.parquet",
    "https://huggingface.co/datasets/nlphuji/flickr30k/resolve/refs%2Fconvert%2Fparquet/TEST/test/0001.parquet",
    "https://huggingface.co/datasets/nlphuji/flickr30k/resolve/refs%2Fconvert%2Fparquet/TEST/test/0002.parquet",
    "https://huggingface.co/datasets/nlphuji/flickr30k/resolve/refs%2Fconvert%2Fparquet/TEST/test/0003.parquet",
    "https://huggingface.co/datasets/nlphuji/flickr30k/resolve/refs%2Fconvert%2Fparquet/TEST/test/0004.parquet",
    "https://huggingface.co/datasets/nlphuji/flickr30k/resolve/refs%2Fconvert%2Fparquet/TEST/test/0005.parquet",
    "https://huggingface.co/datasets/nlphuji/flickr30k/resolve/refs%2Fconvert%2Fparquet/TEST/test/0006.parquet",
    "https://huggingface.co/datasets/nlphuji/flickr30k/resolve/refs%2Fconvert%2Fparquet/TEST/test/0007.parquet",
    "https://huggingface.co/datasets/nlphuji/flickr30k/resolve/refs%2Fconvert%2Fparquet/TEST/test/0008.parquet",
]

# Load and Split
raw_dataset = load_dataset("parquet", data_files=DATA_FILES, cache_dir="./hf_cache")["train"]

flickr = {
    "train": raw_dataset.filter(lambda x: x["split"] == "train"),
    "validation": raw_dataset.filter(lambda x: x["split"] == "val"),
    "test": raw_dataset.filter(lambda x: x["split"] == "test")
}

print(f"Text Loaded! Train: {len(flickr['train'])}, Val: {len(flickr['validation'])}, Test: {len(flickr['test'])}")

In [None]:
# Initialize Tokenizer
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

class FlickrDataset(Dataset):
    def __init__(self, hf_dataset, img_feats, tokenizer, max_len=32, random_cap=True):
        self.ds = hf_dataset
        self.img_feats = img_feats
        self.tokenizer = tokenizer
        self.max_len = max_len
        self.random_cap = random_cap

    def __len__(self):
        return len(self.ds)

    def __getitem__(self, idx):
        img = self.img_feats[idx]
        captions = self.ds[idx]['caption']
        
        # Random caption during training helps the model learn better
        txt = random.choice(captions) if self.random_cap else captions[0]
        
        tok = self.tokenizer(
            txt, 
            padding="max_length", 
            truncation=True, 
            max_length=self.max_len, 
            return_tensors="pt"
        )
        
        return {
            "img_feat": img,
            "input_ids": tok["input_ids"].squeeze(0),
            "attention_mask": tok["attention_mask"].squeeze(0)
        }

# Create DataLoaders (Batch Size 128 is good for M4)
BATCH_SIZE = 128

train_loader = DataLoader(
    FlickrDataset(flickr['train'], img_feats_train, tokenizer, random_cap=True),
    batch_size=BATCH_SIZE, shuffle=True
)
val_loader = DataLoader(
    FlickrDataset(flickr['validation'], img_feats_val, tokenizer, random_cap=False),
    batch_size=BATCH_SIZE, shuffle=False
)
test_loader = DataLoader(
    FlickrDataset(flickr['test'], img_feats_test, tokenizer, random_cap=False),
    batch_size=BATCH_SIZE, shuffle=False
)

print("DataLoaders Ready.")

In [None]:
import torch
import torch.nn as nn
from transformers import AutoModel

class CrossAttentionInteractionModel(nn.Module):
    def __init__(self, img_dim=768, hidden_dim=768):
        super().__init__()
        
        # 1. Text Encoder (Frozen)
        self.bert = AutoModel.from_pretrained("bert-base-uncased")
        for p in self.bert.parameters(): 
            p.requires_grad = False 
        
        # 2. Image Projection
        self.img_proj = nn.Linear(img_dim, hidden_dim)
        
        # 3. Cross Attention
        # batch_first=True is CRITICAL for shape (Batch, Seq, Dim)
        self.cross_attn = nn.MultiheadAttention(embed_dim=hidden_dim, num_heads=4, batch_first=True)
        
        # 4. Classifier
        # Input is hidden_dim * 2 because we CONCATENATE (Image + Text_Context)
        self.classifier = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.GELU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_dim, 1)
        )

    def forward(self, img_feat, input_ids, mask):
        # --- A. Text Embeddings ---
        with torch.no_grad():
            # Shape: (Batch, Seq_Len, 768)
            txt_seq = self.bert(input_ids, mask).last_hidden_state 
            
        # --- B. Image Embeddings ---
        # Shape: (Batch, 1, 768) - We treat image as a sequence of length 1
        img_hidden = self.img_proj(img_feat).unsqueeze(1) 
        
        # --- C. Cross Attention (The Fix) ---
        # Query = Image (We are asking: "What part of the text matches this image?")
        # Key/Value = Text (The source information)
        
        # Mask logic: PyTorch expects True where we want to IGNORE. 
        # BERT mask is 0 for padding. So we ignore where mask == 0.
        key_padding_mask = (mask == 0)

        attn_out, _ = self.cross_attn(
            query=img_hidden,    # (Batch, 1, Dim)
            key=txt_seq,         # (Batch, Seq, Dim)
            value=txt_seq,       # (Batch, Seq, Dim)
            key_padding_mask=key_padding_mask
        )
        # attn_out is (Batch, 1, Dim) -> The weighted average of text tokens that matched the image
        
        # --- D. Interaction ---
        # CRITICAL: Concatenate the original Image with the Context from Text
        # This lets the classifier compare "What I have" (Image) vs "What I found" (Text Context)
        combined = torch.cat([img_hidden, attn_out], dim=-1).squeeze(1) # Shape: (Batch, 768*2)
        
        # Score
        return self.classifier(combined).squeeze()

In [None]:
from tqdm import tqdm
import torch.nn.functional as F

# 1. Setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = CrossAttentionInteractionModel().to(device)

# 2. Optimizer (Fresh start)
optimizer = torch.optim.AdamW(
    filter(lambda p: p.requires_grad, model.parameters()), 
    lr=1e-4  # Slightly safer LR for attention
)
criterion = nn.BCEWithLogitsLoss()

print("Starting Training (Corrected Logic)...")

# 3. Training Loop
for epoch in range(5):
    epoch_loss = 0
    model.train()
    
    progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}")
    
    for i, batch in enumerate(progress_bar):
        img_feat = batch['img_feat'].to(device)
        input_ids = batch['input_ids'].to(device)
        mask = batch['attention_mask'].to(device)
        
        optimizer.zero_grad()
        
        # --- Forward Pass (Positive) ---
        # Target: 1.0
        pos_scores = model(img_feat, input_ids, mask)
        loss_pos = criterion(pos_scores, torch.ones_like(pos_scores))
        
        # --- Forward Pass (Negative) ---
        # Target: 0.0
        # Shuffle images to create mismatches (Image A + Text B)
        neg_img_feat = img_feat[torch.randperm(img_feat.size(0))]
        neg_scores = model(neg_img_feat, input_ids, mask)
        loss_neg = criterion(neg_scores, torch.zeros_like(neg_scores))
        
        # --- Update ---
        loss = loss_pos + loss_neg
        loss.backward()
        
        # Gradient Clipping (Prevents exploding gradients in Attention)
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        
        optimizer.step()
        
        epoch_loss += loss.item()
        
        # LIVE MONITORING
        progress_bar.set_postfix({'loss': f"{loss.item():.4f}"})

    avg_loss = epoch_loss / len(train_loader)
    print(f"✅ Epoch {epoch+1} Done. Avg Loss: {avg_loss:.4f}")

In [None]:
# --- INVESTIGATION CELL ---
batch = next(iter(train_loader))
img_feat = batch['img_feat']
input_ids = batch['input_ids']

print("--- DATA INSPECTION ---")
print(f"Image Feature Stats: Min={img_feat.min().item():.4f}, Max={img_feat.max().item():.4f}, Mean={img_feat.mean().item():.4f}")
print(f"Are Image Features all Zero? {torch.all(img_feat == 0).item()}")

print(f"Input IDs Stats: Min={input_ids.min().item()}, Max={input_ids.max().item()}")
print(f"Are Input IDs all Zero? {torch.all(input_ids == 0).item()}")

In [None]:
def get_fixed_test_subset(loader, device, num_samples=200):
    """
    Extracts the first N samples from the loader to create a fixed evaluation set.
    """
    all_imgs = []
    all_ids = []
    all_masks = []
    
    collected = 0
    print(f" Extracting fixed subset of {num_samples} samples...")
    
    with torch.no_grad():
        for batch in loader:
            # Handle Dictionary vs Tuple
            if isinstance(batch, dict):
                img = batch['img_feat']
                ids = batch['input_ids']
                mask = batch['attention_mask']
            else:
                img, ids, mask = batch 
            
            all_imgs.append(img)
            all_ids.append(ids)
            all_masks.append(mask)
            
            collected += img.size(0)
            if collected >= num_samples:
                break
            
    # Concatenate and trim exactly to num_samples
    subset = {
        "img": torch.cat(all_imgs)[:num_samples].to(device),
        "ids": torch.cat(all_ids)[:num_samples].to(device),
        "mask": torch.cat(all_masks)[:num_samples].to(device),
        "N": num_samples
    }
    
    print(f"Fixed Test Subset Ready. (N={num_samples})")
    return subset

test_subset = get_fixed_test_subset(test_loader, device, num_samples=200)

In [None]:
def evaluate_cross_encoder(model, subset):
    model.eval()
    imgs = subset["img"]
    ids = subset["ids"]
    masks = subset["mask"]
    N = subset["N"]
    
    print(f"\n--- EVALUATING CROSS-ENCODER (Baseline 3) ---")
    
    r1, r5, r10, mrr = 0, 0, 0, 0
    
    # Iterate through each text query
    for i in tqdm(range(N), desc="Ranking"):
        
        # 1. Prepare Inputs: Repeat the i-th text N times
        query_ids = ids[i].unsqueeze(0).repeat(N, 1)
        query_mask = masks[i].unsqueeze(0).repeat(N, 1)
        
        with torch.no_grad():
            # 2. Predict Match Score for (Text_i, All_Images)
            scores = model(imgs, query_ids, query_mask)
        
        # 3. Rank
        # The correct image is at index 'i'. 
        # We sort scores descending (High score = good match)
        sorted_indices = torch.argsort(scores, descending=True)
        
        # Find where 'i' is in the list
        rank = (sorted_indices == i).nonzero(as_tuple=True)[0].item()
        true_rank = rank + 1 # 1-based rank
        
        # Update Metrics
        if true_rank == 1: r1 += 1
        if true_rank <= 5: r5 += 1
        if true_rank <= 10: r10 += 1
        mrr += 1.0 / true_rank

    print(f"RESULTS (N={N}):")
    print(f"R@1: {r1/N*100:.2f}% | R@5: {r5/N*100:.2f}% | MRR: {mrr/N:.4f}")

# --- ACTION: Run on Baseline 3 ---
evaluate_cross_encoder(model, test_subset)

In [None]:
import torch
from tqdm.auto import tqdm
import numpy as np

def evaluate_cross_encoder(model, subset, device):
    model.eval()
    
    # Move subset tensors to device ONCE for efficiency
    imgs = subset["img"].to(device)
    ids = subset["ids"].to(device)
    masks = subset["mask"].to(device)
    N = subset["N"]
    
    print(f"\n--- EVALUATING CROSS-ENCODER (Baseline 3, N={N}) ---")
    
    # Initialize metrics for both T2I and I2T
    t2i_r1, t2i_r5, t2i_r10, t2i_mrr = 0, 0, 0, 0
    i2t_r1, i2t_r5, i2t_r10, i2t_mrr = 0, 0, 0, 0
    
    # --- 1. Text-to-Image (T2I) Retrieval ---
    # Query: Text (i) | Database: Images (j=0 to N-1)
    print("\n[T2I] Ranking Text queries against all Images...")
    for i in tqdm(range(N), desc="T2I Ranking"):
        
        # 1. Prepare Inputs: Repeat the i-th text N times
        query_ids = ids[i].unsqueeze(0).repeat(N, 1)
        query_mask = masks[i].unsqueeze(0).repeat(N, 1)
        
        with torch.no_grad():
            # 2. Predict Match Score for (Text_i, All_Images)
            # CRITICAL: We pass ALL images (imgs) and the repeated text
            scores = model(imgs, query_ids, query_mask)
        
        # 3. Rank (Correct image is at index 'i')
        sorted_indices = torch.argsort(scores, descending=True)
        rank = (sorted_indices == i).nonzero(as_tuple=True)[0].item()
        true_rank = rank + 1
        
        # Update T2I Metrics
        if true_rank == 1: t2i_r1 += 1
        if true_rank <= 5: t2i_r5 += 1
        if true_rank <= 10: t2i_r10 += 1
        t2i_mrr += 1.0 / true_rank
        
    # --- 2. Image-to-Text (I2T) Retrieval ---
    # Query: Image (i) | Database: Texts (j=0 to N-1)
    print("\n[I2T] Ranking Image queries against all Texts...")
    for i in tqdm(range(N), desc="I2T Ranking"):
        
        # 1. Prepare Inputs: Repeat the i-th image N times
        query_img = imgs[i].unsqueeze(0).repeat(N, 1)
        
        with torch.no_grad():
            # 2. Predict Match Score for (Image_i, All_Texts)
            # CRITICAL: We pass ALL texts (ids, masks) and the repeated image
            scores = model(query_img, ids, masks)
        
        # 3. Rank (Correct text is at index 'i')
        sorted_indices = torch.argsort(scores, descending=True)
        rank = (sorted_indices == i).nonzero(as_tuple=True)[0].item()
        true_rank = rank + 1
        
        # Update I2T Metrics
        if true_rank == 1: i2t_r1 += 1
        if true_rank <= 5: i2t_r5 += 1
        if true_rank <= 10: i2t_r10 += 1
        i2t_mrr += 1.0 / true_rank

    # --- 3. Final Output ---
    print("\n" + "="*40)
    print(f"     FINAL RETRIEVAL RESULTS (N={N})")
    print("="*40)
    
    print("--- Text to Image (T2I) ---")
    print(f"R@1:  {t2i_r1/N*100:.2f}% | R@5: {t2i_r5/N*100:.2f}% | R@10: {t2i_r10/N*100:.2f}%")
    print(f"MRR:  {t2i_mrr/N:.4f}")

    print("\n--- Image to Text (I2T) ---")
    print(f"R@1:  {i2t_r1/N*100:.2f}% | R@5: {i2t_r5/N*100:.2f}% | R@10: {i2t_r10/N*100:.2f}%")
    print(f"MRR:  {i2t_mrr/N:.4f}")
    print("="*40)


# --- ACTION: Run on Baseline 3 ---
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
evaluate_cross_encoder(model, test_subset, DEVICE)

In [None]:
####Baseline 2: CLIP

In [None]:
import torch.nn.functional as F

class CLIPDualEncoder(nn.Module):
    def __init__(self, img_dim=768, hidden_dim=768, embed_dim=256):
        super().__init__()
        
        # --- TOWER 1: TEXT ---
        # Frozen BERT
        self.bert = AutoModel.from_pretrained("bert-base-uncased")
        for p in self.bert.parameters(): 
            p.requires_grad = False 
            
        # Text Projector: BERT(768) -> Shared Space(256)
        self.txt_proj = nn.Sequential(
            nn.Linear(768, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, embed_dim)
        )
        
        # --- TOWER 2: IMAGE ---
        # Image Projector: ImgFeat(768) -> Shared Space(256)
        self.img_proj = nn.Sequential(
            nn.Linear(img_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, embed_dim)
        )
        
        # Learnable Temperature parameter (starts at ln(1/0.07) ~= 2.65)
        self.logit_scale = nn.Parameter(torch.ones([]) * 2.65)

    def encode_text(self, input_ids, mask):
        # 1. BERT
        with torch.no_grad():
            bert_out = self.bert(input_ids, mask).last_hidden_state
        
        # 2. Mean Pooling
        mask_expanded = mask.unsqueeze(-1)
        sum_embeddings = torch.sum(bert_out * mask_expanded, 1)
        sum_mask = torch.clamp(mask_expanded.sum(1), min=1e-9)
        pooled = sum_embeddings / sum_mask
        
        # 3. Project & Normalize
        return F.normalize(self.txt_proj(pooled), dim=-1)

    def encode_image(self, img_feat):
        # 1. Project & Normalize
        return F.normalize(self.img_proj(img_feat), dim=-1)

    def forward(self, img_feat, input_ids, mask):
        # Get embeddings
        img_emb = self.encode_image(img_feat)
        txt_emb = self.encode_text(input_ids, mask)
        return img_emb, txt_emb

print("CLIPDualEncoder Class Defined.")

In [None]:
# Initialize Model
clip_model = CLIPDualEncoder().to(device)

# Optimizer: Only train the new projection layers
optimizer = torch.optim.AdamW(
    filter(lambda p: p.requires_grad, clip_model.parameters()), 
    lr=1e-3
)

# Loss Function: Standard Cross Entropy
criterion = nn.CrossEntropyLoss()

num_epochs = 5
print("Starting CLIP Training...")

for epoch in range(num_epochs):
    epoch_loss = 0
    clip_model.train()
    
    progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}")
    
    for batch in progress_bar:
        # --- 1. Extract Data ---
        # (Using dictionary extraction like before)
        if isinstance(batch, dict):
            img_feat = batch['img_feat'].to(device)
            input_ids = batch['input_ids'].to(device)
            mask = batch['attention_mask'].to(device)
        else:
            img_feat, _, _ = batch # Adjust if tuple structure differs
            # Assuming you handle tuple correctly or use dict
            
        optimizer.zero_grad()
        
        # --- 2. Forward Pass ---
        # Get normalized embeddings
        img_emb, txt_emb = clip_model(img_feat, input_ids, mask)
        
        # --- 3. Contrastive Loss Calculation ---
        # Matrix Multiplication: [Batch, Dim] @ [Dim, Batch] -> [Batch, Batch]
        # This gives similarity scores for EVERY pair in the batch
        logits = torch.matmul(img_emb, txt_emb.t()) * clip_model.logit_scale.exp()
        
        # The correct match for Image 0 is Text 0, Image 1 is Text 1...
        batch_size = img_feat.size(0)
        labels = torch.arange(batch_size).to(device)
        
        # Symmetric Loss: (Image->Text) + (Text->Image)
        loss_i2t = criterion(logits, labels)
        loss_t2i = criterion(logits.t(), labels)
        loss = (loss_i2t + loss_t2i) / 2
        
        # --- 4. Backprop ---
        loss.backward()
        optimizer.step()
        
        epoch_loss += loss.item()
        progress_bar.set_postfix({'loss': f"{loss.item():.4f}"})
        
    avg_loss = epoch_loss / len(train_loader)
    print(f"✅ Epoch {epoch+1} Complete. Average Loss: {avg_loss:.4f}")

In [None]:
import torch
from tqdm.auto import tqdm
import numpy as np

def evaluate_clip_on_fixed_set(model, subset, device):
    model.eval()
    
    # Unpack the exact same 200 items used for Cross-Encoder
    imgs = subset["img"].to(device)
    ids = subset["ids"].to(device)
    masks = subset["mask"].to(device)
    N = subset["N"]
    
    print(f"\n--- EVALUATING CLIP (Baseline 2) on Fixed Subset (N={N}) ---")
    
    # Initialize metrics for both T2I and I2T
    t2i_r1, t2i_r5, t2i_r10, t2i_mrr = 0, 0, 0, 0
    i2t_r1, i2t_r5, i2t_r10, i2t_mrr = 0, 0, 0, 0

    with torch.no_grad():
        # 1. Encode All Images [N, Dim]
        img_embs = model.encode_image(imgs)
        
        # 2. Encode All Texts [N, Dim]
        txt_embs = model.encode_text(ids, masks)

        # 3. Compute Similarity Matrix [N, N]
        # S[i, j] = Score for Text_i vs Image_j
        sim_matrix = torch.matmul(txt_embs, img_embs.T)
        
    # --- 4. T2I Metrics (Text to Image) ---
    # Query: Text (Rows) | Database: Images (Columns)
    
    # Sort descending (Best matches first). Sorting over dim=1 (the image database)
    t2i_sorted_indices = torch.argsort(sim_matrix, dim=1, descending=True)
    
    for i in range(N):
        # The correct image for Text i is Image i
        # Find the rank of index 'i' in the sorted list for row i
        rank = (t2i_sorted_indices[i] == i).nonzero(as_tuple=True)[0].item()
        true_rank = rank + 1
        
        if true_rank == 1: t2i_r1 += 1
        if true_rank <= 5: t2i_r5 += 1
        if true_rank <= 10: t2i_r10 += 1
        t2i_mrr += 1.0 / true_rank

    # --- 5. I2T Metrics (Image to Text) ---
    # Query: Image (Columns) | Database: Texts (Rows)
    
    # We transpose the matrix mentally and sort over dim=0 (the text database)
    # Equivalently, we sort the columns of the original matrix.
    i2t_sorted_indices = torch.argsort(sim_matrix, dim=0, descending=True)
    
    for j in range(N):
        # The correct text for Image j is Text j
        # Find the rank of index 'j' in the sorted list for column j
        rank = (i2t_sorted_indices[:, j] == j).nonzero(as_tuple=True)[0].item()
        true_rank = rank + 1
        
        if true_rank == 1: i2t_r1 += 1
        if true_rank <= 5: i2t_r5 += 1
        if true_rank <= 10: i2t_r10 += 1
        i2t_mrr += 1.0 / true_rank
    
    # --- 6. Final Output ---
    print("\n" + "="*40)
    print(f"     CLIP RETRIEVAL RESULTS (N={N})")
    print("="*40)
    
    print("--- Text to Image (T2I) ---")
    print(f"R@1:  {t2i_r1/N*100:.2f}% | R@5: {t2i_r5/N*100:.2f}% | R@10: {t2i_r10/N*100:.2f}%")
    print(f"MRR:  {t2i_mrr/N:.4f}")

    print("\n--- Image to Text (I2T) ---")
    print(f"R@1:  {i2t_r1/N*100:.2f}% | R@5: {i2t_r5/N*100:.2f}% | R@10: {i2t_r10/N*100:.2f}%")
    print(f"MRR:  {i2t_mrr/N:.4f}")
    print("="*40)

# --- RUN EVALUATION ---
# Assuming 'clip_model' is your CLIP model instance
# Assuming 'test_subset' still exists and DEVICE is defined
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
evaluate_clip_on_fixed_set(clip_model, test_subset, DEVICE)

In [None]:
###testing

In [None]:
# --- INVESTIGATION CELL ---
batch = next(iter(train_loader))
img_feat = batch['img_feat']
input_ids = batch['input_ids']

print("--- DATA INSPECTION ---")
print(f"Image Feature Stats: Min={img_feat.min().item():.4f}, Max={img_feat.max().item():.4f}, Mean={img_feat.mean().item():.4f}")
print(f"Are Image Features all Zero? {torch.all(img_feat == 0).item()}")

print(f"Input IDs Stats: Min={input_ids.min().item()}, Max={input_ids.max().item()}")
print(f"Are Input IDs all Zero? {torch.all(input_ids == 0).item()}")