In [None]:
import os

from google.colab import drive
drive.mount("/content/drive")

PROJECT_ROOT = "/content/drive/MyDrive/co_attention_flickr30k_new/features_vit_b16"
os.makedirs(PROJECT_ROOT, exist_ok=True)

HF_CACHE_DIR = os.path.join(PROJECT_ROOT, "hf_cache")
os.makedirs(HF_CACHE_DIR, exist_ok=True)


In [None]:
!pip install -q datasets transformers torchvision tqdm


In [None]:
from datasets import load_dataset

# List of all Parquet shards for the TEST subset
DATA_FILES = [
    "https://huggingface.co/datasets/nlphuji/flickr30k/resolve/refs%2Fconvert%2Fparquet/TEST/test/0000.parquet",
    "https://huggingface.co/datasets/nlphuji/flickr30k/resolve/refs%2Fconvert%2Fparquet/TEST/test/0001.parquet",
    "https://huggingface.co/datasets/nlphuji/flickr30k/resolve/refs%2Fconvert%2Fparquet/TEST/test/0002.parquet",
    "https://huggingface.co/datasets/nlphuji/flickr30k/resolve/refs%2Fconvert%2Fparquet/TEST/test/0003.parquet",
    "https://huggingface.co/datasets/nlphuji/flickr30k/resolve/refs%2Fconvert%2Fparquet/TEST/test/0004.parquet",
    "https://huggingface.co/datasets/nlphuji/flickr30k/resolve/refs%2Fconvert%2Fparquet/TEST/test/0005.parquet",
    "https://huggingface.co/datasets/nlphuji/flickr30k/resolve/refs%2Fconvert%2Fparquet/TEST/test/0006.parquet",
    "https://huggingface.co/datasets/nlphuji/flickr30k/resolve/refs%2Fconvert%2Fparquet/TEST/test/0007.parquet",
    "https://huggingface.co/datasets/nlphuji/flickr30k/resolve/refs%2Fconvert%2Fparquet/TEST/test/0008.parquet",
]


flickr_all = load_dataset(
    "parquet",
    data_files=DATA_FILES,
    cache_dir=HF_CACHE_DIR,
)["train"]  

print(flickr_all)
print(flickr_all[0])


In [None]:
from datasets import DatasetDict

def is_split(example, name):
    return example["split"] == name

flickr_train = flickr_all.filter(lambda ex: is_split(ex, "train"))
flickr_val   = flickr_all.filter(lambda ex: is_split(ex, "val"))
flickr_test  = flickr_all.filter(lambda ex: is_split(ex, "test"))

flickr = DatasetDict({
    "train": flickr_train,
    "validation": flickr_val,
    "test": flickr_test,
})

print(flickr)


In [None]:
print(len(flickr["train"]), len(flickr["validation"]), len(flickr["test"]))
print(flickr["train"][0].keys())
print(flickr["train"][0]["caption"])  # list of 5 captions


In [None]:
import matplotlib.pyplot as plt

example = flickr["train"][0]
img = example["image"]          # PIL Image
captions = example["caption"]   # list of 5 strings

plt.imshow(img)
plt.axis("off")
plt.title(captions[0])
plt.show()


In [None]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
MAX_LEN = 32  


In [None]:
import torch
import torch.nn as nn
from torchvision import models

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", device)

vit_weights = models.ViT_B_16_Weights.IMAGENET1K_V1
vit = models.vit_b_16(weights=vit_weights)

# Replacing classification head with identity so vit(x) returns features
vit.heads = nn.Identity()
for p in vit.parameters():
    p.requires_grad = False
vit.to(device)
vit.eval()

# Preprocessing pipeline that matches the ViT weights
vit_preprocess = vit_weights.transforms()


In [None]:
import random
from torch.utils.data import Dataset, DataLoader

class Flickr30kDataset(Dataset):
    def __init__(self, hf_dataset, image_transform, tokenizer, max_length=32,
                 random_caption=False):
        self.ds = hf_dataset
        self.image_transform = image_transform
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.random_caption = random_caption

    def __len__(self):
        return len(self.ds)

    def __getitem__(self, idx):
        ex = self.ds[idx]

        # image 
        img = ex["image"].convert("RGB")  # HF Image -> PIL
        pixel_values = self.image_transform(img)  # tensor [3, H, W]

        # caption 
        captions = ex["caption"]  # list of 5 strings
        if self.random_caption:
            caption = random.choice(captions)
        else:
            caption = captions[0]  # first caption only

        tok = self.tokenizer(
            caption,
            padding="max_length",
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt",
        )

        return {
            "pixel_values": pixel_values,                       # [3, H, W]
            "input_ids": tok["input_ids"].squeeze(0),           # [max_len]
            "attention_mask": tok["attention_mask"].squeeze(0), # [max_len]
        }


In [None]:
BATCH_SIZE = 32

train_pt = Flickr30kDataset(
    flickr["train"], image_transform=vit_preprocess,
    tokenizer=tokenizer, max_length=MAX_LEN, random_caption=True
)
val_pt = Flickr30kDataset(
    flickr["validation"], image_transform=vit_preprocess,
    tokenizer=tokenizer, max_length=MAX_LEN, random_caption=False
)
test_pt = Flickr30kDataset(
    flickr["test"], image_transform=vit_preprocess,
    tokenizer=tokenizer, max_length=MAX_LEN, random_caption=False
)

train_loader = DataLoader(train_pt, batch_size=BATCH_SIZE,
                          shuffle=True, num_workers=2, pin_memory=True)
val_loader   = DataLoader(val_pt, batch_size=BATCH_SIZE,
                          shuffle=False, num_workers=2, pin_memory=True)


In [None]:
batch = next(iter(train_loader))
print(batch["pixel_values"].shape)   # [B, 3, 224, 224] (or similar)
print(batch["input_ids"].shape)      # [B, MAX_LEN]
print(batch["attention_mask"].shape) # [B, MAX_LEN]



In [None]:
from torch.utils.data import DataLoader

BATCH_SIZE = 64  

train_loader_feats = DataLoader(
    train_pt, batch_size=BATCH_SIZE,
    shuffle=False, num_workers=2, pin_memory=True
)
val_loader_feats = DataLoader(
    val_pt, batch_size=BATCH_SIZE,
    shuffle=False, num_workers=2, pin_memory=True
)
test_loader_feats = DataLoader(
    test_pt, batch_size=BATCH_SIZE,
    shuffle=False, num_workers=2, pin_memory=True
)


In [None]:
from torchvision import models
import torch.nn as nn
import torch

device = "cuda" if torch.cuda.is_available() else "cpu"

vit_weights = models.ViT_B_16_Weights.IMAGENET1K_V1
vit = models.vit_b_16(weights=vit_weights)

vit.heads = nn.Identity()        # removing classification head
for p in vit.parameters():
    p.requires_grad = False
vit.to(device)
vit.eval()


In [None]:


import torch
import os
from tqdm.auto import tqdm


OUTPUT_DIR = '/content/drive/MyDrive/co_attention_flickr30k_new/features_vit_b16'
os.makedirs(OUTPUT_DIR, exist_ok=True)

def extract_and_save_new_patch_features(dataloader, split_name):
    print(f"\nStarting extraction for: {split_name} (Patch Features and Text IDs)")

    patch_feats_list = []
    input_ids_list = []
    attention_mask_list = []

    global vit, device
    vit.eval()

    with torch.no_grad():
        for batch in tqdm(dataloader, desc=f"Extracting {split_name}"):
            imgs = batch["pixel_values"].to(device)

            # 1. Patchify (Conv Projection)
            x = vit.conv_proj(imgs)
            x = x.flatten(2).transpose(1, 2) # [B, 196, D]

            # 2. Prepend CLS Token
            batch_size = x.shape[0]
            batch_class_token = vit.class_token.expand(batch_size, -1, -1)
            x = torch.cat([batch_class_token, x], dim=1) # [B, 197, D]

            # 3. Add Positional Embedding (REQUIRED for 197 size)
            if hasattr(vit, 'pos_embedding'):
                x = x + vit.pos_embedding
            elif hasattr(vit, 'positional_embedding'):
                x = x + vit.positional_embedding

            # 5. Run through the Encoder (Transformer Layers)
            patch_tokens = vit.encoder(x)

            patch_feats_list.append(patch_tokens.cpu().half())

            # 2. Collect Text Data
            input_ids_list.append(batch["input_ids"].cpu())
            attention_mask_list.append(batch["attention_mask"].cpu())

    # Concatenate Results and Save Dictionary 
    all_patches = torch.cat(patch_feats_list, dim=0)
    all_ids = torch.cat(input_ids_list, dim=0)
    all_masks = torch.cat(attention_mask_list, dim=0)
    N = all_patches.shape[0]

    data_to_save = {
        "img": all_patches,
        "ids": all_ids,
        "mask": all_masks,
        "N": N
    }

    if N != all_ids.shape[0] or all_patches.shape[1] != 197:
        raise ValueError(f"CRITICAL ERROR: Data mismatch. N={N}, Patches={all_patches.shape[1]}. Expected 197.")

    patch_path_new = os.path.join(OUTPUT_DIR, f"flickr30k_{split_name}_FINAL.pt")

    print(f"Saving FINAL Data Dictionary (N={N}) to {patch_path_new}")
    torch.save(data_to_save, patch_path_new)
    print("-" * 40)

    return patch_path_new

# Run for all 3 splits 
new_train_path = extract_and_save_new_patch_features(train_loader_feats, "train")
new_val_path   = extract_and_save_new_patch_features(val_loader_feats, "val")
new_test_path  = extract_and_save_new_patch_features(test_loader_feats, "test")

print("\n✅ All FINAL data dictionaries extracted successfully and saved!")

###CO-ATTENTION TRAINING

In [None]:
gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
  print('Not connected to a GPU instance.')
else:
  print('GPU is available:\n' + gpu_info)

# Install necessary libraries
!pip install -q transformers torch numpy tqdm

from google.colab import drive
drive.mount('/content/drive')

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import AutoModel, AutoTokenizer
from tqdm.auto import tqdm
import numpy as np
import os

# Set the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

In [None]:

BASE_DATA_PATH = '/content/drive/MyDrive/co_attention_flickr30k_new/features_vit_b16/'

TEST_FILE = BASE_DATA_PATH + 'flickr30k_test_FINAL.pt'
TRAIN_FILE = BASE_DATA_PATH + 'flickr30k_train_FINAL.pt'
VAL_FILE = BASE_DATA_PATH + 'flickr30k_val_FINAL.pt'

def load_and_verify_data(file_path, split_name):
    try:
        data_dict = torch.load(file_path, map_location='cpu')

        img_shape = data_dict["img"].shape
        if len(img_shape) != 3:
            raise ValueError(f"Image shape for {split_name} is {img_shape}. Expected 3 dimensions [N, Patches, Dim] for Co-Attention.")

        print(f"{split_name} data loaded. N={data_dict['N']}, Image Patches Shape: {img_shape}")
        return data_dict
    except FileNotFoundError:
        print(f"ERROR: File not found at {file_path}. Please check the path and file names.")
        raise
    except Exception as e:
        print(f"An error occurred loading {split_name} data: {e}")
        raise

# Load the 3 Splits
try:
    train_data = load_and_verify_data(TRAIN_FILE, "Train")
    val_data = load_and_verify_data(VAL_FILE, "Validation")
    test_data = load_and_verify_data(TEST_FILE, "Test")

    # Create DataLoaders
    class FixedSubsetDataset(Dataset):
        def __init__(self, data_dict):
            self.imgs = data_dict["img"]
            self.ids = data_dict["ids"]
            self.masks = data_dict["mask"]
            self.N = data_dict["N"]

        def __len__(self):
            return self.N

        def __getitem__(self, idx):
            return {
                "img_feat": self.imgs[idx],
                "input_ids": self.ids[idx],
                "attention_mask": self.masks[idx],
            }

    # Create datasets
    train_dataset = FixedSubsetDataset(train_data)
    val_dataset = FixedSubsetDataset(val_data)

    # batch size for training
    BATCH_SIZE = 32

    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, pin_memory=True)


    # Creating the 'test_subset' for N x N Evaluation 
    TEST_SUBSET_SIZE = 200

    test_subset = {
        "img": test_data["img"][:TEST_SUBSET_SIZE].to(device),
        "ids": test_data["ids"][:TEST_SUBSET_SIZE].to(device),
        "mask": test_data["mask"][:TEST_SUBSET_SIZE].to(device),
        "N": TEST_SUBSET_SIZE
    }

    print(f"\n{'='*60}")
    print(f"Training Loader created: {len(train_loader)} batches")
    print(f"Test Subset for Evaluation: N={test_subset['N']} (first {TEST_SUBSET_SIZE} samples)")
    print(f"This matches your other baseline models (Cross-Attention, CLIP)")
    print(f"{'='*60}")

except NameError:
    print("Please ensure the 'device' variable is defined by running Cell 1 first.")
except Exception:
    # Stop execution if data loading fails
    pass

In [None]:
class BiDirectionalCoAttentionModel(nn.Module):
    def __init__(self, patch_dim=768, hidden_dim=768, num_heads=8, dropout=0.1):
        super().__init__()

        # Text Encoder (Frozen BERT)
        self.bert = AutoModel.from_pretrained("bert-base-uncased")
        for p in self.bert.parameters():
            p.requires_grad = False

        # Image Projection
        self.patch_proj = nn.Linear(patch_dim, hidden_dim)

        # Co-Attention Layers 
        # Image attends to Text (ALL patches attend to ALL text tokens)
        self.i2t_attn = nn.MultiheadAttention(
            embed_dim=hidden_dim,
            num_heads=num_heads,
            dropout=dropout,
            batch_first=True
        )

        # Text attends to Image (ALL text tokens attend to ALL patches)
        self.t2i_attn = nn.MultiheadAttention(
            embed_dim=hidden_dim,
            num_heads=num_heads,
            dropout=dropout,
            batch_first=True
        )

        # Layer Normalization (for residual connections)
        self.ln_img = nn.LayerNorm(hidden_dim)
        self.ln_txt = nn.LayerNorm(hidden_dim)

        # Pooling projections (AFTER co-attention)
        self.img_pool = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout)
        )

        self.txt_pool = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout)
        )

        # Final similarity head
        self.similarity_head = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, 1)
        )

    def forward(self, patch_feat, input_ids, attention_mask):
        if patch_feat.dtype == torch.float16:
            patch_feat = patch_feat.float()

        # Get text embeddings: [B, seq_len, hidden_dim]
        with torch.no_grad():
            txt_seq = self.bert(input_ids, attention_mask).last_hidden_state

        # Project image patches: [B, num_patches, hidden_dim]
        img_seq = self.patch_proj(patch_feat)

        # BIDIRECTIONAL CO-ATTENTION (Full Sequence) 

        # Image-to-Text Attention
        key_padding_mask_txt = (attention_mask == 0)
        img_attended, _ = self.i2t_attn(
            query=img_seq,           # [B, num_patches, D] ← FULL SEQUENCE
            key=txt_seq,             # [B, seq_len, D]
            value=txt_seq,           # [B, seq_len, D]
            key_padding_mask=key_padding_mask_txt
        )
        # Residual connection + LayerNorm
        img_attended = self.ln_img(img_attended + img_seq)

        # Text-to-Image Attention
        txt_attended, _ = self.t2i_attn(
            query=txt_seq,           # [B, seq_len, D] 
            key=img_seq,             # [B, num_patches, D]
            value=img_seq            # [B, num_patches, D]
        )
        # Residual connection + LayerNorm
        txt_attended = self.ln_txt(txt_attended + txt_seq)

        # POOLING (After Co-Attention to preserve information)

        # Pool image patches (mean pooling)
        img_pooled = img_attended.mean(dim=1)  # [B, D]
        img_pooled = self.img_pool(img_pooled)

        # Pool text tokens (masked mean pooling)
        mask_expanded = attention_mask.unsqueeze(-1).float()  # [B, seq_len, 1]
        txt_sum = (txt_attended * mask_expanded).sum(dim=1)   # [B, D]
        txt_count = mask_expanded.sum(dim=1).clamp(min=1)     # [B, 1]
        txt_pooled = txt_sum / txt_count
        txt_pooled = self.txt_pool(txt_pooled)

        # FUSION AND SIMILARITY 

        # Concatenate modalities
        fused = torch.cat([img_pooled, txt_pooled], dim=-1)  # [B, 2D]

        # Compute similarity score
        score = self.similarity_head(fused).squeeze(-1)  # [B]

        return score

print("✓ CORRECTED Co-Attention Model defined (Full Sequence Attention)")

In [None]:


import torch.nn as nn
from tqdm.auto import tqdm

# InfoNCE Loss 
class InfoNCELoss(nn.Module):
    def __init__(self, temperature=0.07):
        super().__init__()
        self.temperature = temperature
        self.cross_entropy = nn.CrossEntropyLoss()

    def forward(self, scores_matrix):
        B = scores_matrix.size(0)
        # Scale by temperature
        logits = scores_matrix / self.temperature

        # Labels: positive pairs are on the diagonal
        labels = torch.arange(B, device=scores_matrix.device)

        # Symmetric loss (Image-to-Text + Text-to-Image)
        loss_i2t = self.cross_entropy(logits, labels)
        loss_t2i = self.cross_entropy(logits.t(), labels)

        return (loss_i2t + loss_t2i) / 2

# Initialize Model, Loss, and Optimizer
model = BiDirectionalCoAttentionModel(
    patch_dim=768,
    hidden_dim=768,
    num_heads=8,
    dropout=0.1
).to(device)

optimizer = torch.optim.AdamW(
    filter(lambda p: p.requires_grad, model.parameters()),
    lr=5e-5,  # Conservative learning rate
    weight_decay=0.01,
    betas=(0.9, 0.98),
    eps=1e-6
)

# Learning rate scheduler
from torch.optim.lr_scheduler import CosineAnnealingLR
scheduler = CosineAnnealingLR(optimizer, T_max=10, eta_min=1e-6)

criterion = InfoNCELoss(temperature=0.07)

GRAD_CLIP_VALUE = 1.0

def train_loop(dataloader, model, criterion, optimizer, device):
    model.train()
    total_loss = 0
    num_batches = 0

    for batch_idx, batch in enumerate(tqdm(dataloader, desc="Training")):
        img_feat = batch["img_feat"].to(device)
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        B = img_feat.size(0)

        # Compute B x B similarity matrix (all pairs in batch) 

        # Tile image features: [B, P, D] -> [B*B, P, D]
        img_tiled = img_feat.unsqueeze(1).repeat(1, B, 1, 1).reshape(
            B*B, img_feat.size(1), img_feat.size(2)
        )

        # Tile text features: [B, L] -> [B*B, L]
        ids_tiled = input_ids.unsqueeze(0).repeat(B, 1, 1).reshape(B*B, -1)
        mask_tiled = attention_mask.unsqueeze(0).repeat(B, 1, 1).reshape(B*B, -1)

        # Compute all B*B similarity scores
        scores = model(img_tiled, ids_tiled, mask_tiled)

        # Reshape to similarity matrix: [B, B]
        scores_matrix = scores.reshape(B, B)

        # Compute InfoNCE loss
        loss = criterion(scores_matrix)

        # Check for NaN/Inf
        if torch.isnan(loss) or torch.isinf(loss):
            print(f"\nWARNING: NaN/Inf loss at batch {batch_idx}, skipping...")
            continue

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()

        # Gradient clipping
        grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=GRAD_CLIP_VALUE)

        if torch.isnan(grad_norm) or torch.isinf(grad_norm):
            print(f"\nWARNING: Invalid gradients at batch {batch_idx}, skipping...")
            continue

        optimizer.step()

        total_loss += loss.item()
        num_batches += 1

    avg_loss = total_loss / num_batches if num_batches > 0 else float('inf')
    print(f"Epoch Loss: {avg_loss:.4f}")
    return avg_loss

print("✓ InfoNCE Loss and Training Loop defined")
print(f"✓ Model has {sum(p.numel() for p in model.parameters() if p.requires_grad):,} trainable parameters")

In [None]:


import torch
import os

# Configuration
NUM_EPOCHS = 10  # Start with 10 epochs
SAVE_DIR = '/content/drive/MyDrive/co_attention_flickr30k_new/'
SAVE_PATH = os.path.join(SAVE_DIR, 'co_attention_model_corrected.pth')
os.makedirs(SAVE_DIR, exist_ok=True)

print(f"{'='*60}")
print(f"Starting Training: CORRECTED Co-Attention Model")
print(f"{'='*60}")
print(f"Architecture: Full Sequence Bidirectional Co-Attention")
print(f"Loss: InfoNCE (in-batch negatives, like CLIP)")
print(f"Epochs: {NUM_EPOCHS}")
print(f"Batch Size: {BATCH_SIZE}")
print(f"Learning Rate: {optimizer.param_groups[0]['lr']}")
print(f"{'='*60}\n")

best_loss = float('inf')

for epoch in range(1, NUM_EPOCHS + 1):
    print(f"\n{'='*60}")
    print(f"Epoch {epoch}/{NUM_EPOCHS}")
    print(f"{'='*60}")

    # Train
    train_loss = train_loop(train_loader, model, criterion, optimizer, device)

    # Step scheduler
    scheduler.step()
    current_lr = optimizer.param_groups[0]['lr']
    print(f"Learning Rate: {current_lr:.6f}")

    # Save best model
    if train_loss < best_loss:
        best_loss = train_loss
        torch.save(model.state_dict(), SAVE_PATH)
        print(f"✓ Best model saved (loss: {best_loss:.4f})")

    # Save checkpoint every 3 epochs
    if epoch % 3 == 0:
        checkpoint_path = SAVE_PATH.replace('.pth', f'_epoch{epoch}.pth')
        torch.save(model.state_dict(), checkpoint_path)
        print(f"✓ Checkpoint saved: epoch {epoch}")

print(f"\n{'='*60}")
print(f"Training Complete!")
print(f"Best Loss: {best_loss:.4f}")
print(f"Model saved to: {SAVE_PATH}")
print(f"{'='*60}")

In [None]:

def get_fixed_test_subset(loader, device, num_samples=200):
    """
    Extracts the first N samples from the loader to create a fixed evaluation set.
    """
    all_imgs = []
    all_ids = []
    all_masks = []

    collected = 0
    print(f"Extracting fixed subset of {num_samples} samples...")

    with torch.no_grad():
        for batch in loader:
            # Handle Dictionary vs Tuple
            if isinstance(batch, dict):
                img = batch['img_feat']
                ids = batch['input_ids']
                mask = batch['attention_mask']
            else:
                img, ids, mask = batch

            all_imgs.append(img)
            all_ids.append(ids)
            all_masks.append(mask)

            collected += img.size(0)
            if collected >= num_samples:
                break

    # Concatenate and trim exactly to num_samples
    subset = {
        "img": torch.cat(all_imgs)[:num_samples].to(device),
        "ids": torch.cat(all_ids)[:num_samples].to(device),
        "mask": torch.cat(all_masks)[:num_samples].to(device),
        "N": num_samples
    }

    print(f"Fixed Test Subset Ready. (N={num_samples})")
    return subset

# Create test_loader if it doesn't exist
if 'test_loader' not in locals():
    print("Creating test_loader...")
    test_dataset = FixedSubsetDataset(test_data)
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, pin_memory=True)

# Create the N=200 test subset (same as other baseline models)
test_subset = get_fixed_test_subset(test_loader, device, num_samples=200)

print(f"\n{'='*60}")
print(f"Test subset created with N={test_subset['N']}")
print(f"Ready for evaluation - matches your other baseline models!")
print(f"{'='*60}")

In [None]:
print(f"Test subset size N = {test_subset['N']}")

In [None]:
import torch

# Load the saved model
SAVE_PATH = '/content/drive/MyDrive/co_attention_flickr30k_new/co_attention_model_final.pth'

# Recreate the model architecture
model = BiDirectionalCrossAttentionModel(patch_dim=768, hidden_dim=768).to(device)

# Load the saved weights
model.load_state_dict(torch.load(SAVE_PATH, map_location=device))

# Convert entire model to float32
model = model.float()

# Verify all parameters are float32
print("Checking model dtypes...")
for name, param in model.named_parameters():
    if param.dtype != torch.float32:
        print(f"WARNING: {name} is {param.dtype}")

print(f"Model loaded and converted to float32")
print(f"Model is on device: {next(model.parameters()).device}")

# Set to eval mode
model.eval()

In [None]:
# Clear CUDA memory
import gc
import torch

# Delete any large tensors
if 'img_tiled' in locals():
    del img_tiled
if 'ids_tiled' in locals():
    del ids_tiled
if 'mask_tiled' in locals():
    del mask_tiled

# Clear cache
torch.cuda.empty_cache()
gc.collect()

print("✓ Memory cleared")
print(f"GPU Memory: {torch.cuda.memory_allocated()/1e9:.2f} GB allocated")

In [None]:
# Cell 9: Memory-Efficient Evaluation Script

import numpy as np
from tqdm.auto import tqdm

def calculate_recall_matrix(sim_matrix, direction="i2t"):
    if direction == "t2i":
        sim_matrix = sim_matrix.t()

    n = sim_matrix.size(0)
    ranks = []
    sim_matrix_np = sim_matrix.cpu().numpy()

    for i in range(n):
        target_score = sim_matrix_np[i, i]
        row_scores = sim_matrix_np[i, :]
        rank = (row_scores > target_score).sum() + 1
        ranks.append(rank)

    ranks = np.array(ranks)
    r1 = 100.0 * np.sum(ranks <= 1) / n
    r5 = 100.0 * np.sum(ranks <= 5) / n
    r10 = 100.0 * np.sum(ranks <= 10) / n

    return r1, r5, r10


def evaluate_co_attention_model(model, subset, device):
    model.eval()

    # Move data to CPU first to save GPU memory
    img_tensor = subset["img"].cpu()
    input_ids_tensor = subset["ids"].cpu()
    mask_tensor = subset["mask"].cpu()
    N = subset["N"]

    print(f"\n{'='*60}")
    print(f"EVALUATING Co-Attention Model (N={N})")
    print(f"{'='*60}")

    SCORING_BATCH_SIZE = 8  
    similarity_matrix = torch.zeros((N, N), device='cpu')

    print(f"\nComputing {N}x{N} Similarity Matrix...")

    with torch.no_grad():
        for i in tqdm(range(0, N, SCORING_BATCH_SIZE), desc="Scoring Matrix"):
            # Get batch
            img_batch = img_tensor[i:i + SCORING_BATCH_SIZE].to(device)
            B = img_batch.size(0)

            row_scores = []

            TEXT_CHUNK_SIZE = 20  # Process 20 texts at a time

            for j in range(0, N, TEXT_CHUNK_SIZE):
                text_end = min(j + TEXT_CHUNK_SIZE, N)
                num_texts = text_end - j

                # Tile Image: [B, P, D] -> [B*num_texts, P, D]
                img_tiled = img_batch.unsqueeze(1).repeat(1, num_texts, 1, 1).reshape(
                    B * num_texts, img_batch.size(1), img_batch.size(2)
                )

                # Get text chunk: [num_texts, L] -> [B*num_texts, L]
                ids_chunk = input_ids_tensor[j:text_end].to(device)
                mask_chunk = mask_tensor[j:text_end].to(device)

                ids_tiled = ids_chunk.unsqueeze(0).repeat(B, 1, 1).reshape(B * num_texts, -1)
                mask_tiled = mask_chunk.unsqueeze(0).repeat(B, 1, 1).reshape(B * num_texts, -1)

                # Compute scores for this chunk
                scores = model(img_tiled, ids_tiled, mask_tiled)
                scores_reshaped = scores.reshape(B, num_texts).cpu()
                row_scores.append(scores_reshaped)

                del img_tiled, ids_tiled, mask_tiled, scores, scores_reshaped
                torch.cuda.empty_cache()

            # Concatenate all chunks for this batch of images
            batch_scores = torch.cat(row_scores, dim=1)  # [B, N]
            similarity_matrix[i:i+B, :] = batch_scores

            # Clear memory
            del img_batch, row_scores, batch_scores
            torch.cuda.empty_cache()

    # Calculate metrics
    print(f"\n{'='*60}")
    print(f"FINAL RETRIEVAL RESULTS (N={N})")
    print(f"{'='*60}")

    sim_matrix_gpu = similarity_matrix.to(device)

    # Image-to-Text
    i2t_r1, i2t_r5, i2t_r10 = calculate_recall_matrix(sim_matrix_gpu, direction="i2t")
    print(f"\n--- Image to Text (I2T) ---")
    print(f"R@1:  {i2t_r1:.2f}%")
    print(f"R@5:  {i2t_r5:.2f}%")
    print(f"R@10: {i2t_r10:.2f}%")

    # Text-to-Image
    t2i_r1, t2i_r5, t2i_r10 = calculate_recall_matrix(sim_matrix_gpu, direction="t2i")
    print(f"\n--- Text to Image (T2I) ---")
    print(f"R@1:  {t2i_r1:.2f}%")
    print(f"R@5:  {t2i_r5:.2f}%")
    print(f"R@10: {t2i_r10:.2f}%")

    # Average
    avg_recall = (i2t_r1 + i2t_r5 + i2t_r10 + t2i_r1 + t2i_r5 + t2i_r10) / 6
    print(f"\nAverage Recall: {avg_recall:.2f}%")
    print(f"{'='*60}")

    return {
        "i2t": (i2t_r1, i2t_r5, i2t_r10),
        "t2i": (t2i_r1, t2i_r5, t2i_r10),
        "avg_recall": avg_recall
    }


# Clear memory first
import gc
torch.cuda.empty_cache()
gc.collect()

# Load model
SAVE_PATH = '/content/drive/MyDrive/co_attention_flickr30k_new/co_attention_model_corrected.pth'

model = BiDirectionalCoAttentionModel(
    patch_dim=768,
    hidden_dim=768,
    num_heads=8,
    dropout=0.1
).to(device)

model.load_state_dict(torch.load(SAVE_PATH, map_location=device))
model = model.float()
model.eval()

print(f"✓ Model loaded from: {SAVE_PATH}\n")

results = evaluate_co_attention_model(model, test_subset, device)