In [None]:
print("\nVerifying uploaded files in /content/:")
!ls -lh /content/*.pt

In [None]:
import os

from google.colab import drive
drive.mount("/content/drive")

# Choosing a folder
PROJECT_ROOT = "/content/drive/MyDrive/co_attention_flickr30k_new/features_vit_b16"
os.makedirs(PROJECT_ROOT, exist_ok=True)


In [None]:
import torch
import os
from datasets import load_dataset
from transformers import AutoTokenizer

# Loading Captions using Direct Parquet Links
print("Downloading Flickr30k data directly from Parquet files...")

#URLs point to preprocessed data files directly bypassing broken script
DATA_FILES = [
    "https://huggingface.co/datasets/nlphuji/flickr30k/resolve/refs%2Fconvert%2Fparquet/TEST/test/0000.parquet",
    "https://huggingface.co/datasets/nlphuji/flickr30k/resolve/refs%2Fconvert%2Fparquet/TEST/test/0001.parquet",
    "https://huggingface.co/datasets/nlphuji/flickr30k/resolve/refs%2Fconvert%2Fparquet/TEST/test/0002.parquet",
    "https://huggingface.co/datasets/nlphuji/flickr30k/resolve/refs%2Fconvert%2Fparquet/TEST/test/0003.parquet",
    "https://huggingface.co/datasets/nlphuji/flickr30k/resolve/refs%2Fconvert%2Fparquet/TEST/test/0004.parquet",
    "https://huggingface.co/datasets/nlphuji/flickr30k/resolve/refs%2Fconvert%2Fparquet/TEST/test/0005.parquet",
    "https://huggingface.co/datasets/nlphuji/flickr30k/resolve/refs%2Fconvert%2Fparquet/TEST/test/0006.parquet",
    "https://huggingface.co/datasets/nlphuji/flickr30k/resolve/refs%2Fconvert%2Fparquet/TEST/test/0007.parquet",
    "https://huggingface.co/datasets/nlphuji/flickr30k/resolve/refs%2Fconvert%2Fparquet/TEST/test/0008.parquet",
]

# Loading all data as one big chunk first
raw_dataset = load_dataset(
    "parquet",
    data_files=DATA_FILES,
    cache_dir="./hf_cache"
)["train"]

# Filtering them into splits using 'split' column provided in data
flickr = {
    "train": raw_dataset.filter(lambda x: x["split"] == "train"),
    "validation": raw_dataset.filter(lambda x: x["split"] == "val"),
    "test": raw_dataset.filter(lambda x: x["split"] == "test")
}

print(f"Data Loaded! Train: {len(flickr['train'])}, Val: {len(flickr['validation'])}, Test: {len(flickr['test'])}")

# Loading Image Features from local upload
print("\nLoading Image Features into RAM...")
base_path = "/content"

# Ensuring catching errors if files aren't uploaded
try:
    print("Loading Global Features...")
    img_feats_train = torch.load(os.path.join(PROJECT_ROOT, "flickr30k_train_global.pt"))
    img_feats_val   = torch.load(os.path.join(PROJECT_ROOT, "flickr30k_val_global.pt"))
    img_feats_test  = torch.load(os.path.join(PROJECT_ROOT, "flickr30k_test_global.pt"))

    print(f"Train: {img_feats_train.shape}") # [29000, 768]
    print(f"Val:   {img_feats_val.shape}")   # [1014, 768]
    print(f"Test:  {img_feats_test.shape}")  # [1000, 768]

    print(f"Features Loaded. Train: {img_feats_train.shape}, Val: {img_feats_val.shape}")

except FileNotFoundError:
    print("ERROR: Could not find the .pt files. Please make sure you uploaded them in Cell 1.")

# Preparing tokenizer
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
MAX_LEN = 32

In [None]:
import random
from torch.utils.data import Dataset, DataLoader

class Flickr30kRetrievalDataset(Dataset):
    def __init__(self, hf_dataset, img_feats, tokenizer, max_length=32, random_caption=True):
        # Ensuring we have as many image vectors as we have captions
        
        self.ds = hf_dataset
        self.img_feats = img_feats
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.random_caption = random_caption

    def __len__(self):
        return len(self.ds)

    def __getitem__(self, idx):
        # getting Image feature
        img_feat = self.img_feats[idx] # Shape [768]

        # getting caption
        
        captions = self.ds[idx]["caption"]

        if self.random_caption:
            caption = random.choice(captions)
        else:
            caption = captions[0]

        # Tokenizing
        tok = self.tokenizer(
            caption,
            padding="max_length",
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt",
        )

        return {
            "img_feat": img_feat,
            "input_ids": tok["input_ids"].squeeze(0),
            "attention_mask": tok["attention_mask"].squeeze(0),
        }

# Creating Loaders
BATCH_SIZE = 128

train_ret = Flickr30kRetrievalDataset(
    flickr["train"], 
    img_feats_train,
    tokenizer,
    max_length=MAX_LEN,
    random_caption=True
)

val_ret = Flickr30kRetrievalDataset(
    hf_dataset=flickr["validation"],  #Passing object directly
    img_feats=img_feats_val,          # Must match size of flickr["validation"]
    tokenizer=tokenizer,
    max_length=MAX_LEN,
    random_caption=False
)

test_ret = Flickr30kRetrievalDataset(
    flickr["test"], 
    img_feats_test,
    tokenizer,
    max_length=MAX_LEN,
    random_caption=True
)

# train=29k, val=1k, test=1k
# ensuring img_feats_train length matches dataset split length
print(f"Dataset Size: {len(train_ret)}")

train_loader = DataLoader(train_ret, batch_size=BATCH_SIZE, shuffle=True)
val_loader   = DataLoader(val_ret, batch_size=BATCH_SIZE, shuffle=False)
test_loader  = DataLoader(test_ret, batch_size=BATCH_SIZE, shuffle=False)

In [None]:
import torch.nn as nn
import torch.optim as optim
from transformers import BertModel
from tqdm.auto import tqdm

# defining model
class Baseline1_Concat(nn.Module):
    def __init__(self, hidden_dim=256):
        super(Baseline1_Concat, self).__init__()

        # Frozen BERT
        self.bert = BertModel.from_pretrained("bert-base-uncased")
        for p in self.bert.parameters():
            p.requires_grad = False

        # Fusion MLP, (768 img + 768 text) --> 1 score
        self.fusion_mlp = nn.Sequential(
            nn.Linear(768 + 768, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(hidden_dim, 1)
        )

    def forward(self, img_feats, input_ids, attention_mask):
        # Text Embeddings [B, 768]
        with torch.no_grad():
            text_out = self.bert(input_ids=input_ids, attention_mask=attention_mask)
            text_vec = text_out.pooler_output

        # Concat [B,1536]
        fused = torch.cat((img_feats, text_vec), dim=1)

        # Score [B]
        return self.fusion_mlp(fused).squeeze(1)

# Setting up training
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Training on: {DEVICE}")

model = Baseline1_Concat().to(DEVICE)
optimizer = optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.MarginRankingLoss(margin=0.2)

EPOCHS = 5

# Running Loop
for epoch in range(EPOCHS):
    model.train()
    total_loss = 0

    for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}"):
        # Moving to GPU
        img = batch["img_feat"].to(DEVICE)
        txt = batch["input_ids"].to(DEVICE)
        msk = batch["attention_mask"].to(DEVICE)

        # Positive Scores
        pos_scores = model(img, txt, msk)

        # Negative Scores. Shifting captions by 1 to create mismatches
        txt_neg = torch.roll(txt, shifts=1, dims=0)
        msk_neg = torch.roll(msk, shifts=1, dims=0)
        neg_scores = model(img, txt_neg, msk_neg)

        # Loss
        ones = torch.ones(img.size(0)).to(DEVICE)
        loss = criterion(pos_scores, neg_scores, ones)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print(f"Epoch {epoch+1} Loss: {total_loss/len(train_loader):.4f}")

print("Training Complete!")

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Training on: {DEVICE}")

In [None]:
def get_fixed_test_subset(loader, device, num_samples=200):
    """
    This extracts first N samples from loader to create fixed evaluation set
    """
    all_imgs = []
    all_ids = []
    all_masks = []

    collected = 0
    print(f" Extracting fixed subset of {num_samples} samples...")

    with torch.no_grad():
        for batch in loader:
            # Handling Dictionary vs Tuple
            if isinstance(batch, dict):
                img = batch['img_feat']
                ids = batch['input_ids']
                mask = batch['attention_mask']
            else:
                img, ids, mask = batch

            all_imgs.append(img)
            all_ids.append(ids)
            all_masks.append(mask)

            collected += img.size(0)
            if collected >= num_samples:
                break

    # Concatenating and triming exactly to num_samples
    subset = {
        "img": torch.cat(all_imgs)[:num_samples].to(device),
        "ids": torch.cat(all_ids)[:num_samples].to(device),
        "mask": torch.cat(all_masks)[:num_samples].to(device),
        "N": num_samples
    }

    print(f"Fixed Test Subset Ready. (N={num_samples})")
    return subset

test_subset = get_fixed_test_subset(test_loader, device, num_samples=200)

In [None]:
import torch
import torch.nn.functional as F
import numpy as np
from tqdm.auto import tqdm

def evaluate_baseline_concat(model, test_subset, device):
    model.eval()

    print("--- Step 1: Extracting all features ---")
    all_img_feats = []
    all_text_feats = []

    with torch.no_grad():
        for batch in tqdm(test_subset, desc="Extracting Embeddings"):
            # Storing image features, which are already pre computed by ViT
            img = batch["img_feat"].to(device)
            all_img_feats.append(img.cpu())

            # computing text features
            txt_input = batch["input_ids"].to(device)
            txt_mask = batch["attention_mask"].to(device)

            # BERT inside model directly
            text_out = model.bert(input_ids=txt_input, attention_mask=txt_mask)


            text_vec = text_out.pooler_output
            all_text_feats.append(text_vec.cpu())


    img_tensor = torch.cat(all_img_feats, dim=0).to(device)
    txt_tensor = torch.cat(all_text_feats, dim=0).to(device)

    n_samples = img_tensor.size(0)
    print(f"\nEvaluated sizes: Images {img_tensor.shape}, Texts {txt_tensor.shape}")

    print("\n--- Step 2: Computing 1M Scores (The Slow Part) ---")
    # need matrix [N_images, N_texts]
    sim_matrix = torch.zeros((n_samples, n_samples))

    with torch.no_grad():
        # Looping over every image, query
        for i in tqdm(range(n_samples), desc="Scoring All Pairs"):
            # Getting single image vector [1,768]
            img_vec = img_tensor[i].unsqueeze(0)

            #Repeating it to match all text vectors [N,768]
            img_repeated = img_vec.repeat(n_samples, 1)

            #Concatenating [N,68] + [N,768] --> [N,1536]
            #pairs current image with every caption in dataset
            fused_input = torch.cat((img_repeated, txt_tensor), dim=1)

            # Runing MLP, fusion layer
            # calling fusion_mlp directly to bypass BERT step
            scores = model.fusion_mlp(fused_input) #[N,1]
            scores = scores.squeeze(1)             # [N]

            #Storing in matrix
            sim_matrix[i] = scores.cpu()

    print("\n--- Step 3: Calculating Metrics ---")
    #calculating recall for image to Text
    print(">> Image-to-Text (Given Image, find Caption)")
    i2t_r1, i2t_r5, i2t_r10 = calculate_recall(sim_matrix, direction="i2t")
    print(f"R@1: {i2t_r1:.2f}% | R@5: {i2t_r5:.2f}% | R@10: {i2t_r10:.2f}%")

    # calculating recall for text to image
    print("\n>> Text-to-Image (Given Caption, find Image)")
    t2i_r1, t2i_r5, t2i_r10 = calculate_recall(sim_matrix, direction="t2i")
    print(f"R@1: {t2i_r1:.2f}% | R@5: {t2i_r5:.2f}% | R@10: {t2i_r10:.2f}%")

    return sim_matrix

def calculate_recall(sim_matrix, direction="i2t"):
    # standard Recall@K calculation
    # matrix rows = queries,cols = targets

    if direction == "t2i":
        # if query is text,we transpose matrix
        sim_matrix = sim_matrix.t()

    n = sim_matrix.size(0)
    ranks = []

    for i in range(n):
        # Score of correct pair, diagonal element
        target_score = sim_matrix[i, i]

        # Scores of all candidates for this query
        row_scores = sim_matrix[i, :]


        rank = (row_scores > target_score).sum().item() + 1
        ranks.append(rank)

    ranks = np.array(ranks)

    r1 = 100.0 * np.sum(ranks <= 1) / n
    r5 = 100.0 * np.sum(ranks <= 5) / n
    r10 = 100.0 * np.sum(ranks <= 10) / n

    return r1, r5, r10


sim_matrix_test = evaluate_baseline_concat(model, test_loader, DEVICE)

In [None]:
import matplotlib.pyplot as plt
import textwrap

def visualize_retrieval(dataset, sim_matrix_test, idx=0):
    """
    dataset: The Dataset object (val_ret) - NOT the loader
    sim_matrix: The computed score matrix from Step 1
    idx: The index of the image in the validation set you want to test
    """

    #Setting up data
    # Getting raw image from underlying HF dataset
    # need PIL image, not tensor
    example = dataset.ds[idx]
    pil_img = example["image"]

    # Getting Ground Truth caption, the first one in the list
    gt_caption = example["caption"][0]

    #Getting Predictions
    # Grabbing scores for this image
    scores = sim_matrix_test[idx]

    # Getting top 3 indices
    # argsort sorts ascending, so we take last 3 and reverse them
    topk_indices = torch.argsort(scores, descending=True)[:3]

    #Visualization
    plt.figure(figsize=(12, 6))

    # Plotting Image
    plt.subplot(1, 2, 1)
    plt.imshow(pil_img)
    plt.axis("off")
    plt.title(f"Query Image #{idx}")

    # Plotting Text
    plt.subplot(1, 2, 2)
    plt.axis("off")

    text_display = f"GROUND TRUTH:\n{textwrap.fill(gt_caption, width=50)}\n\n"
    text_display += "PREDICTIONS:\n"

    for rank, pred_idx in enumerate(topk_indices):
        pred_idx = pred_idx.item()

        # Getting text for that index from dataset
        pred_caption = dataset.ds[pred_idx]["caption"][0]

        # Checking if it is correct one
        is_correct = (pred_idx == idx)
        if is_correct:
            marker = "[ CORRECT ]"
        else:
            marker = f"[ WRONG - Img #{pred_idx} ]"

        text_display += f"{rank+1}. {marker} {textwrap.fill(pred_caption, width=50)}\n"
        text_display += f"   (Score: {scores[pred_idx]:.4f})\n\n"

    plt.text(0, 0.5, text_display, fontsize=12, va='center')
    plt.show()

# executing visualization
# Passing dataset (val_ret), not loader
# Changing 'idx' to see different examples
visualize_retrieval(val_ret, sim_matrix_test, idx=5)
visualize_retrieval(val_ret, sim_matrix_test, idx=7)
visualize_retrieval(val_ret, sim_matrix_test, idx=63)