In [None]:
import os

from google.colab import drive
drive.mount("/content/drive")

# Choose a folder in your Drive where everything for this project will live
PROJECT_ROOT = "/content/drive/MyDrive/co_attention_flickr30k_new"
os.makedirs(PROJECT_ROOT, exist_ok=True)

HF_CACHE_DIR = os.path.join(PROJECT_ROOT, "hf_cache")
os.makedirs(HF_CACHE_DIR, exist_ok=True)


In [None]:
!pip install -q datasets transformers torchvision tqdm


In [None]:
from datasets import load_dataset

# List of all Parquet shards for the TEST subset
DATA_FILES = [
    "https://huggingface.co/datasets/nlphuji/flickr30k/resolve/refs%2Fconvert%2Fparquet/TEST/test/0000.parquet",
    "https://huggingface.co/datasets/nlphuji/flickr30k/resolve/refs%2Fconvert%2Fparquet/TEST/test/0001.parquet",
    "https://huggingface.co/datasets/nlphuji/flickr30k/resolve/refs%2Fconvert%2Fparquet/TEST/test/0002.parquet",
    "https://huggingface.co/datasets/nlphuji/flickr30k/resolve/refs%2Fconvert%2Fparquet/TEST/test/0003.parquet",
    "https://huggingface.co/datasets/nlphuji/flickr30k/resolve/refs%2Fconvert%2Fparquet/TEST/test/0004.parquet",
    "https://huggingface.co/datasets/nlphuji/flickr30k/resolve/refs%2Fconvert%2Fparquet/TEST/test/0005.parquet",
    "https://huggingface.co/datasets/nlphuji/flickr30k/resolve/refs%2Fconvert%2Fparquet/TEST/test/0006.parquet",
    "https://huggingface.co/datasets/nlphuji/flickr30k/resolve/refs%2Fconvert%2Fparquet/TEST/test/0007.parquet",
    "https://huggingface.co/datasets/nlphuji/flickr30k/resolve/refs%2Fconvert%2Fparquet/TEST/test/0008.parquet",
]

# When you pass a *list* of files to the "parquet" builder, it makes a single split called "train"
flickr_all = load_dataset(
    "parquet",
    data_files=DATA_FILES,
    cache_dir=HF_CACHE_DIR,
)["train"]  # this is the only split name the parquet builder uses in this case

print(flickr_all)
print(flickr_all[0])


In [None]:
from datasets import DatasetDict

def is_split(example, name):
    return example["split"] == name

flickr_train = flickr_all.filter(lambda ex: is_split(ex, "train"))
flickr_val   = flickr_all.filter(lambda ex: is_split(ex, "val"))
flickr_test  = flickr_all.filter(lambda ex: is_split(ex, "test"))

flickr = DatasetDict({
    "train": flickr_train,
    "validation": flickr_val,
    "test": flickr_test,
})

print(flickr)


In [None]:
print(len(flickr["train"]), len(flickr["validation"]), len(flickr["test"]))
print(flickr["train"][0].keys())
print(flickr["train"][0]["caption"])  # list of 5 captions


In [None]:
import matplotlib.pyplot as plt

example = flickr["train"][0]
img = example["image"]          # PIL Image
captions = example["caption"]   # list of 5 strings

plt.imshow(img)
plt.axis("off")
plt.title(captions[0])
plt.show()


In [None]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
MAX_LEN = 32


In [None]:
import torch
import torch.nn as nn
from torchvision import models

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", device)

vit_weights = models.ViT_B_16_Weights.IMAGENET1K_V1
vit = models.vit_b_16(weights=vit_weights)

# Replace classification head with identity so vit(x) returns features
vit.heads = nn.Identity()
for p in vit.parameters():
    p.requires_grad = False
vit.to(device)
vit.eval()

# Preprocessing pipeline that matches the ViT weights
vit_preprocess = vit_weights.transforms()


In [None]:
import random
from torch.utils.data import Dataset, DataLoader

class Flickr30kDataset(Dataset):
    def __init__(self, hf_dataset, image_transform, tokenizer, max_length=32,
                 random_caption=False):
        self.ds = hf_dataset
        self.image_transform = image_transform
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.random_caption = random_caption

    def __len__(self):
        return len(self.ds)

    def __getitem__(self, idx):
        ex = self.ds[idx]

        # ---- image ----
        img = ex["image"].convert("RGB")  # HF Image -> PIL
        pixel_values = self.image_transform(img)  # tensor [3, H, W]

        # ---- caption ----
        captions = ex["caption"]  # list of 5 strings
        if self.random_caption:
            caption = random.choice(captions)
        else:
            caption = captions[0]  # first caption only

        tok = self.tokenizer(
            caption,
            padding="max_length",
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt",
        )

        return {
            "pixel_values": pixel_values,                       # [3, H, W]
            "input_ids": tok["input_ids"].squeeze(0),           # [max_len]
            "attention_mask": tok["attention_mask"].squeeze(0), # [max_len]
        }


In [None]:
BATCH_SIZE = 32

train_pt = Flickr30kDataset(
    flickr["train"], image_transform=vit_preprocess,
    tokenizer=tokenizer, max_length=MAX_LEN, random_caption=True
)
val_pt = Flickr30kDataset(
    flickr["validation"], image_transform=vit_preprocess,
    tokenizer=tokenizer, max_length=MAX_LEN, random_caption=False
)
test_pt = Flickr30kDataset(
    flickr["test"], image_transform=vit_preprocess,
    tokenizer=tokenizer, max_length=MAX_LEN, random_caption=False
)

train_loader = DataLoader(train_pt, batch_size=BATCH_SIZE,
                          shuffle=True, num_workers=2, pin_memory=True)
val_loader   = DataLoader(val_pt, batch_size=BATCH_SIZE,
                          shuffle=False, num_workers=2, pin_memory=True)


In [None]:
batch = next(iter(train_loader))
print(batch["pixel_values"].shape)   # [B, 3, 224, 224] (or similar)
print(batch["input_ids"].shape)      # [B, MAX_LEN]
print(batch["attention_mask"].shape) # [B, MAX_LEN]


In [None]:
from torch.utils.data import DataLoader

BATCH_SIZE = 64

train_loader_feats = DataLoader(
    train_pt, batch_size=BATCH_SIZE,
    shuffle=False, num_workers=2, pin_memory=True
)
val_loader_feats = DataLoader(
    val_pt, batch_size=BATCH_SIZE,
    shuffle=False, num_workers=2, pin_memory=True
)
test_loader_feats = DataLoader(
    test_pt, batch_size=BATCH_SIZE,
    shuffle=False, num_workers=2, pin_memory=True
)


In [None]:
from torchvision import models
import torch.nn as nn
import torch

device = "cuda" if torch.cuda.is_available() else "cpu"

vit_weights = models.ViT_B_16_Weights.IMAGENET1K_V1
vit = models.vit_b_16(weights=vit_weights)

vit.heads = nn.Identity()        # removing classification head
for p in vit.parameters():
    p.requires_grad = False
vit.to(device)
vit.eval()


In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import torch
import os
from tqdm.auto import tqdm

# Ensure the output directory exists
feat_dir = os.path.join(PROJECT_ROOT, "features_vit_b16")
os.makedirs(feat_dir, exist_ok=True)

def extract_and_save_features(dataloader, split_name):
    print(f"Starting extraction for: {split_name}")

    global_feats_list = []
    patch_feats_list = []

    vit.eval()

    with torch.no_grad():
        for batch in tqdm(dataloader, desc=f"Extracting {split_name}"):
            imgs = batch["pixel_values"].to(device)

            # Input: [B, 3, 224, 224] -> Output: [B, 768, 14, 14]
            x = vit.conv_proj(imgs)

            # [B, 768, 196] -> [B, 196, 768]
            x = x.flatten(2).transpose(1, 2)

            # 2. Append CLS Token
            batch_size = x.shape[0]
            # Create a [B, 1, 768] CLS token
            batch_class_token = vit.class_token.expand(batch_size, -1, -1)
            # Concatenate: [B, 197, 768]
            x = torch.cat([batch_class_token, x], dim=1)

            # 3. Encoder Pass (Transformer Layers)
            # This applies self-attention across patches
            x = vit.encoder(x)

            # --- SEPARATE OUTPUTS ---

            # Global Feature: Just the 1st token [CLS]
            # Shape: [B, 768]
            cls_token = x[:, 0]

            # Patch Features: All 197 tokens (including CLS)
            # Shape: [B, 197, 768]
            patch_tokens = x

            # --- STORE RESULTS ---
            global_feats_list.append(cls_token.cpu())

            patch_feats_list.append(patch_tokens.cpu().half())

    # Concatenate all batches into one large tensor
    all_global = torch.cat(global_feats_list, dim=0)
    all_patches = torch.cat(patch_feats_list, dim=0)

    # Save filenames
    global_path = os.path.join(feat_dir, f"flickr30k_{split_name}_global.pt")
    patch_path = os.path.join(feat_dir, f"flickr30k_{split_name}_patch.pt")

    print(f"Saving Global Features {all_global.shape} to {global_path}")
    torch.save(all_global, global_path)

    print(f"Saving Patch Features {all_patches.shape} to {patch_path}")
    torch.save(all_patches, patch_path)
    print("-" * 40)

# Run for all 3 splits using the existing dataloaders
extract_and_save_features(train_loader_feats, "train")
extract_and_save_features(val_loader_feats, "val")
extract_and_save_features(test_loader_feats, "test")

print("All features extracted successfully!")

In [None]:
"""
three files in Drive:

flickr30k_train_vit_b16.pt – shape [29000, D]

flickr30k_val_vit_b16.pt – shape [1014, D]

flickr30k_test_vit_b16.pt – shape [1000, D]

where D is the ViT embedding size (e.g. 768 for vit_b_16)
"""



In [None]:
import torch
import os

# Define the path where features were saved
feat_dir = os.path.join(PROJECT_ROOT, "features_vit_b16")

# Load the tensors from disk
img_feats_train = torch.load(os.path.join(feat_dir, "flickr30k_train_global.pt"))
img_feats_val   = torch.load(os.path.join(feat_dir, "flickr30k_val_global.pt"))
img_feats_test  = torch.load(os.path.join(feat_dir, "flickr30k_test_global.pt"))

print(f"Train feats shape: {img_feats_train.shape}")
print(f"Val feats shape:   {img_feats_val.shape}")
print(f"Test feats shape:  {img_feats_test.shape}")

In [None]:
from torch.utils.data import Dataset, DataLoader
import random

class Flickr30kRetrievalDataset(Dataset):
    def __init__(self, hf_dataset, img_feats, tokenizer, max_length=32,
                 random_caption=True):
        assert len(hf_dataset) == img_feats.size(0)
        self.ds = hf_dataset
        self.img_feats = img_feats
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.random_caption = random_caption

    def __len__(self):
        return len(self.ds)

    def __getitem__(self, idx):
        ex = self.ds[idx]

        # precomputed image feature
        img_feat = self.img_feats[idx]          # [768]

        # pick a caption
        captions = ex["caption"]                # list of 5 strings
        if self.random_caption:
            caption = random.choice(captions)
        else:
            caption = captions[0]

        tok = self.tokenizer(
            caption,
            padding="max_length",
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt",
        )

        return {
            "img_feat": img_feat,                              # [768]
            "input_ids": tok["input_ids"].squeeze(0),          # [max_len]
            "attention_mask": tok["attention_mask"].squeeze(0),# [max_len]
            "caption": caption,
        }


In [None]:
BATCH_SIZE = 128  # can be larger now because we use features, not images

train_ret = Flickr30kRetrievalDataset(
    flickr["train"], img_feats_train, tokenizer,
    max_length=MAX_LEN, random_caption=True
)
val_ret = Flickr30kRetrievalDataset(
    flickr["validation"], img_feats_val, tokenizer,
    max_length=MAX_LEN, random_caption=False
)

train_loader_ret = DataLoader(train_ret, batch_size=BATCH_SIZE,
                              shuffle=True, num_workers=2, pin_memory=True)
val_loader_ret   = DataLoader(val_ret, batch_size=BATCH_SIZE,
                              shuffle=False, num_workers=2, pin_memory=True)

batch = next(iter(train_loader_ret))
print(batch["img_feat"].shape)        # [B, 768]
print(batch["input_ids"].shape)       # [B, MAX_LEN]
print(batch["attention_mask"].shape)  # [B, MAX_LEN]
