In [None]:
# ============================================================
# Mount Google Drive (Colab Only)
# ============================================================

from google.colab import drive
drive.mount('/content/drive')

# Change this path only if your dataset location changes
DRIVE_ROOT = '/content/drive/MyDrive/SIH/ML/flickr8k'
print("Using DRIVE_ROOT =", DRIVE_ROOT)

In [None]:
# ============================================================
# Install OpenAI CLIP (Required for Colab)
# ============================================================

!pip install --upgrade pip
!pip install ftfy regex tqdm
!pip install git+https://github.com/openai/CLIP.git

In [None]:
# ============================================================
# Imports, Device Setup, and Configuration
# ============================================================

import os
import time
import csv
import re
import math
from pathlib import Path
from typing import List

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import clip
from tqdm import tqdm
import numpy as np

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print("Device:", device)

# ============================================================
# Configuration (Matches Your Original Notebook Exactly)
# ============================================================

class CFG:
    model_name = 'ViT-B/32'
    image_size = 224
    batch_size = 64
    epochs = 12
    lr = 1e-5
    weight_decay = 1e-4
    temperature = 0.07
    freeze_backbone_layers = False
    num_workers = 4
    pin_memory = True
    seed = 42
    device = device

cfg = CFG()
torch.manual_seed(cfg.seed)

In [None]:
# ============================================================
# Flickr8k Dataset (Robust CSV-Aware Parser)
# ============================================================

class Flickr8kDataset(Dataset):
    """
    Robust parser for Flickr8k captions.txt.
    Handles:
    - CSV format with header
    - .jpg#0 format
    - Tab separated
    - Space separated fallback
    """

    def __init__(self, root: str, split='train', transform=None,
                 max_captions_per_image=5, verbose=True):

        self.root = Path(root)
        captions_path = self.root / 'captions.txt'
        images_dir = self.root / 'Images'

        if not captions_path.exists():
            raise FileNotFoundError(f"Captions file not found at {captions_path}")

        if not images_dir.exists():
            raise FileNotFoundError(f"Images folder not found at {images_dir}")

        self.image2caps = {}
        self.transform = transform

        # ---------- CSV parsing ----------
        with open(captions_path, 'r', encoding='utf-8') as f:
            reader = csv.reader(f)
            rows = list(reader)

        start_idx = 0
        if rows and rows[0][0].strip().lower() in ('image','filename'):
            start_idx = 1

        for row in rows[start_idx:]:
            if len(row) < 2:
                continue
            imgname = row[0].split('#')[0].strip()
            caption = ','.join(row[1:]).strip()
            self.image2caps.setdefault(imgname, []).append(caption)

        # ---------- Load images ----------
        self.images = sorted(list(images_dir.glob('*.jpg')))

        # ---------- Deterministic split ----------
        n = len(self.images)
        train_end = int(0.8 * n)
        val_end = int(0.9 * n)

        if split == 'train':
            images_subset = self.images[:train_end]
        elif split == 'val':
            images_subset = self.images[train_end:val_end]
        elif split == 'test':
            images_subset = self.images[val_end:]
        else:
            images_subset = self.images

        # ---------- Expand into (image, caption) pairs ----------
        self.pairs = []
        for p in images_subset:
            caps = self.image2caps.get(p.name, [])
            for c in caps[:max_captions_per_image]:
                self.pairs.append((str(p), c))

        print(f"{split} pairs:", len(self.pairs))

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        img_path, caption = self.pairs[idx]
        img = Image.open(img_path).convert('RGB')
        if self.transform:
            img = self.transform(img)
        return img, caption, os.path.basename(img_path)

In [None]:
# ============================================================
# Image Preprocessing and DataLoaders
# ============================================================

preprocess = transforms.Compose([
    transforms.Resize((cfg.image_size, cfg.image_size)),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=(0.48145466, 0.4578275, 0.40821073),
        std=(0.26862954, 0.26130258, 0.27577711)
    )
])

train_ds = Flickr8kDataset(DRIVE_ROOT, split='train', transform=preprocess)
val_ds   = Flickr8kDataset(DRIVE_ROOT, split='val', transform=preprocess)

train_loader = DataLoader(train_ds,
                          batch_size=cfg.batch_size,
                          shuffle=True,
                          num_workers=cfg.num_workers,
                          pin_memory=cfg.pin_memory)

val_loader = DataLoader(val_ds,
                        batch_size=cfg.batch_size,
                        shuffle=False,
                        num_workers=cfg.num_workers,
                        pin_memory=cfg.pin_memory)

In [None]:
# ============================================================
# CLIP Wrapper with Projection Heads
# ============================================================

class FineTuneCLIP(nn.Module):

    def __init__(self, clip_model_name='ViT-B/32',
                 embed_dim=512, proj_dim=256):
        super().__init__()

        self.clip_model, _ = clip.load(clip_model_name, device='cpu')
        self.embed_dim = embed_dim

        # Projection head for image embeddings
        self.img_proj = nn.Sequential(
            nn.Linear(self.clip_model.visual.output_dim, proj_dim),
            nn.LayerNorm(proj_dim),
            nn.GELU(),
            nn.Linear(proj_dim, embed_dim)
        )

        # Projection head for text embeddings
        self.txt_proj = nn.Sequential(
            nn.Linear(self.clip_model.transformer.width, proj_dim),
            nn.LayerNorm(proj_dim),
            nn.GELU(),
            nn.Linear(proj_dim, embed_dim)
        )

    def forward(self, images, tokenized_text):

        img_features = self.clip_model.encode_image(images)
        txt_features = self.clip_model.encode_text(tokenized_text)

        img_emb = F.normalize(self.img_proj(img_features), dim=-1)
        txt_emb = F.normalize(self.txt_proj(txt_features), dim=-1)

        return img_emb, txt_emb


model = FineTuneCLIP(cfg.model_name).to(device)
model.clip_model.to(device)

In [None]:
# ============================================================
# Contrastive Loss (NT-Xent)
# ============================================================

class NTXentLoss(nn.Module):

    def __init__(self, temperature=0.07):
        super().__init__()
        self.temperature = temperature

    def forward(self, img_emb, txt_emb):

        logits = img_emb @ txt_emb.t() / self.temperature
        labels = torch.arange(logits.size(0), device=logits.device)

        loss_i2t = F.cross_entropy(logits, labels)
        loss_t2i = F.cross_entropy(logits.t(), labels)

        return (loss_i2t + loss_t2i) / 2


criterion = NTXentLoss(cfg.temperature)
optimizer = torch.optim.AdamW(model.parameters(),
                              lr=cfg.lr,
                              weight_decay=cfg.weight_decay)

In [None]:
# ============================================================
# Stage 1: Global Contrastive Fine-Tuning
# (Exactly same NT-Xent training as your original notebook)
# ============================================================

epochs = cfg.epochs
best_val_r1 = -1.0

for epoch in range(1, epochs + 1):

    model.train()
    epoch_loss = 0.0
    pbar = tqdm(train_loader, desc=f"Epoch {epoch}/{epochs}")

    for images, captions, _ in pbar:

        images = images.to(device)
        tokenized = clip.tokenize(list(captions), truncate=True).to(device)

        img_emb, txt_emb = model(images, tokenized)
        loss = criterion(img_emb, txt_emb)

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()

        epoch_loss += loss.item()
        pbar.set_postfix({'loss': f"{epoch_loss/(pbar.n+1):.4f}"})

    print(f"Epoch {epoch} Avg Loss:", epoch_loss/len(train_loader))

    # Save checkpoint
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'cfg': vars(cfg)
    }, f"clip_finetuned_epoch{epoch}.pth")

print("Stage 1 training complete.")

In [None]:
# ============================================================
# Retrieval Evaluation (Exact Same Logic as Original)
# ============================================================

def compute_retrieval(model, dataset):

    model.eval()

    image_paths = []
    captions = []
    caption_image_idx = []
    img_name_to_idx = {}

    # Build mapping
    for img_path, cap in dataset.pairs:
        name = os.path.basename(img_path)
        if name not in img_name_to_idx:
            img_name_to_idx[name] = len(image_paths)
            image_paths.append(img_path)
        captions.append(cap)
        caption_image_idx.append(img_name_to_idx[name])

    # Encode images
    image_embs = []
    with torch.no_grad():
        for p in image_paths:
            im = preprocess(Image.open(p).convert('RGB')).unsqueeze(0).to(device)
            img_feat = model.clip_model.encode_image(im)
            img_feat = model.img_proj(img_feat)
            image_embs.append(F.normalize(img_feat, dim=-1).cpu())

    image_embs = torch.cat(image_embs)

    # Encode captions
    text_embs = []
    with torch.no_grad():
        for cap in captions:
            tokenized = clip.tokenize([cap]).to(device)
            txt_feat = model.clip_model.encode_text(tokenized)
            txt_feat = model.txt_proj(txt_feat)
            text_embs.append(F.normalize(txt_feat, dim=-1).cpu())

    text_embs = torch.cat(text_embs)

    sim = image_embs @ text_embs.t()
    sims = sim.numpy()

    # Image -> Text recall
    def recall_i2t(k):
        correct = 0
        for i in range(len(image_paths)):
            ranked = np.argsort(-sims[i])[:k]
            if any(caption_image_idx[j] == i for j in ranked):
                correct += 1
        return correct / len(image_paths) * 100

    # Text -> Image recall
    def recall_t2i(k):
        correct = 0
        for j in range(len(captions)):
            ranked = np.argsort(-sims[:, j])[:k]
            if caption_image_idx[j] in ranked:
                correct += 1
        return correct / len(captions) * 100

    print("Image->Text R@1/5/10:",
          recall_i2t(1), recall_i2t(5), recall_i2t(10))
    print("Text->Image R@1/5/10:",
          recall_t2i(1), recall_t2i(5), recall_t2i(10))


compute_retrieval(model, val_ds)

In [None]:
# ============================================================
# Add Fine-Grained Projection Layers (Patch + Token)
# ============================================================

model.eval()

# Infer token dimensions
with torch.no_grad():
    for imgs, caps, _ in train_loader:
        imgs = imgs[:2].to(device)
        tokenized = clip.tokenize(list(caps[:2]), truncate=True).to(device)

        # Patch tokens
        x = model.clip_model.visual.conv1(imgs)
        x = x.reshape(x.shape[0], x.shape[1], -1).permute(0,2,1)
        cls = model.clip_model.visual.class_embedding.to(x.dtype)\
              .unsqueeze(0).unsqueeze(0).expand(x.shape[0], -1, -1)
        x = torch.cat([cls, x], dim=1)
        x = x + model.clip_model.visual.positional_embedding.to(x.dtype)
        x = model.clip_model.visual.ln_pre(x)
        x = x.permute(1,0,2)
        x = model.clip_model.visual.transformer(x)
        x = x.permute(1,0,2)
        x = model.clip_model.visual.ln_post(x)
        patch_tokens = x[:,1:,:]
        D_patch = patch_tokens.shape[-1]

        # Text tokens
        t = model.clip_model.token_embedding(tokenized).type(torch.float32)
        t = t + model.clip_model.positional_embedding.to(t.dtype)
        t = t.permute(1,0,2)
        t = model.clip_model.transformer(t)
        t = t.permute(1,0,2)
        token_tokens = model.clip_model.ln_final(t)
        D_text = token_tokens.shape[-1]
        break

proj_dim = model.embed_dim

model.patch_proj = nn.Linear(D_patch, proj_dim).to(device)
model.token_proj = nn.Linear(D_text, proj_dim).to(device)

print("Fine-grained projection heads added.")

In [None]:
# ============================================================
# Fine-Grained Similarity Matrix (Exact Original Logic)
# ============================================================

def extract_patch_tokens(clip_model, images):

    x = clip_model.visual.conv1(images)
    B, C, Hf, Wf = x.shape
    x = x.reshape(B, C, -1).permute(0,2,1)
    cls = clip_model.visual.class_embedding.to(x.dtype)\
          .unsqueeze(0).unsqueeze(0).expand(B, -1, -1)
    x = torch.cat([cls, x], dim=1)
    x = x + clip_model.visual.positional_embedding.to(x.dtype)
    x = clip_model.visual.ln_pre(x)
    x = x.permute(1,0,2)
    x = clip_model.visual.transformer(x)
    x = x.permute(1,0,2)
    x = clip_model.visual.ln_post(x)

    return x[:,1:,:]


def extract_text_tokens(clip_model, tokenized):

    x = clip_model.token_embedding(tokenized).type(torch.float32)
    x = x + clip_model.positional_embedding.to(x.dtype)
    x = x.permute(1,0,2)
    x = clip_model.transformer(x)
    x = x.permute(1,0,2)
    x = clip_model.ln_final(x)

    return x


def compute_fine_sim_matrix(patch_tokens, token_tokens):

    Bp, P, Dp = patch_tokens.shape
    Bt, T, Dt = token_tokens.shape

    patch_p = model.patch_proj(patch_tokens.reshape(-1, Dp))\
                .reshape(Bp, P, -1)
    token_p = model.token_proj(token_tokens.reshape(-1, Dt))\
                .reshape(Bt, T, -1)

    patch_p = F.normalize(patch_p, dim=-1)
    token_p = F.normalize(token_p, dim=-1)

    sims = torch.zeros((Bp, Bt), device=patch_p.device)

    for i in range(Bp):
        sim_token_patch = torch.einsum('pd,btd->btp',
                                       patch_p[i], token_p)
        max_over_patches = sim_token_patch.max(dim=2).values
        sims[i] = max_over_patches.mean(dim=1)

    return sims

In [None]:
# ============================================================
# Stage 2: Fine-Grained Alignment Training
# (Exact same λ * fine_loss combination)
# ============================================================

lambda_fine = 1.0
fine_epochs = 6

for epoch in range(1, fine_epochs + 1):

    model.train()
    epoch_loss = 0.0
    pbar = tqdm(train_loader, desc=f"FineEpoch {epoch}")

    for images, captions, _ in pbar:

        images = images.to(device)
        tokenized = clip.tokenize(list(captions), truncate=True).to(device)

        # Global loss
        img_g = model.clip_model.encode_image(images)
        txt_g = model.clip_model.encode_text(tokenized)

        img_proj = F.normalize(model.img_proj(img_g), dim=-1)
        txt_proj = F.normalize(model.txt_proj(txt_g), dim=-1)

        global_loss = criterion(img_proj, txt_proj)

        # Fine-grain loss
        patch_tokens = extract_patch_tokens(model.clip_model, images)
        token_tokens = extract_text_tokens(model.clip_model, tokenized)

        fine_sim = compute_fine_sim_matrix(patch_tokens, token_tokens)
        fine_loss = criterion(fine_sim, torch.eye(fine_sim.size(0)).to(device))

        loss = global_loss + lambda_fine * fine_loss

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()

        epoch_loss += loss.item()
        pbar.set_postfix({'loss': f"{epoch_loss/(pbar.n+1):.4f}"})

    print("Fine Epoch", epoch, "Loss:", epoch_loss/len(train_loader))

print("Fine-grain training complete.")

In [None]:
# ============================================================
# Final Retrieval Evaluation After Fine-Grain Training
# ============================================================

compute_retrieval(model, val_ds)

In [None]:
# ============================================================
# Build and Cache Dataset Embeddings (for Fast Retrieval)
# ============================================================

dataset_for_index = val_ds   # change to train_ds if needed

image_paths = []
captions = []
caption_image_idx = []
img_name_to_idx = {}

for img_path, cap in dataset_for_index.pairs:
    name = os.path.basename(img_path)
    if name not in img_name_to_idx:
        img_name_to_idx[name] = len(image_paths)
        image_paths.append(img_path)
    captions.append(cap)
    caption_image_idx.append(img_name_to_idx[name])

num_images = len(image_paths)
num_captions = len(captions)

print("Index built:")
print("Images:", num_images)
print("Captions:", num_captions)

# Encode images
image_embs = []
model.eval()

with torch.no_grad():
    for p in tqdm(image_paths, desc="Encoding Images"):
        im = preprocess(Image.open(p).convert('RGB')).unsqueeze(0).to(device)
        img_feat = model.clip_model.encode_image(im)
        img_feat = model.img_proj(img_feat)
        img_feat = F.normalize(img_feat, dim=-1).cpu()
        image_embs.append(img_feat)

image_embs = torch.cat(image_embs, dim=0)

# Encode captions
text_embs = []

with torch.no_grad():
    for cap in tqdm(captions, desc="Encoding Captions"):
        tokenized = clip.tokenize([cap], truncate=True).to(device)
        txt_feat = model.clip_model.encode_text(tokenized)
        txt_feat = model.txt_proj(txt_feat)
        txt_feat = F.normalize(txt_feat, dim=-1).cpu()
        text_embs.append(txt_feat)

text_embs = torch.cat(text_embs, dim=0)

image_embs_np = image_embs.numpy()
text_embs_np = text_embs.numpy()

print("Embedding cache ready.")

In [None]:
# ============================================================
# Retrieval Helper Functions
# ============================================================

from PIL import Image
import matplotlib.pyplot as plt
%matplotlib inline


def show_image(path, title=None, figsize=(6,6)):
    im = Image.open(path).convert('RGB')
    plt.figure(figsize=figsize)
    plt.imshow(im)
    plt.axis('off')
    if title:
        plt.title(title)
    plt.show()


def get_caption(j):
    return captions[j]


def get_image_path(i):
    return image_paths[i]


# ------------------------------------------------------------
# Image → Top-K Captions
# ------------------------------------------------------------
def image_to_topk_captions(image_input, k=5, return_scores=False):

    if isinstance(image_input, str):
        im = Image.open(image_input).convert('RGB')
        im_t = preprocess(im).unsqueeze(0).to(device)
    else:
        im = image_input
        im_t = preprocess(im).unsqueeze(0).to(device)

    with torch.no_grad():
        img_feat = model.clip_model.encode_image(im_t)
        img_feat = model.img_proj(img_feat)
        img_feat = F.normalize(img_feat, dim=-1).cpu().numpy()[0]

    sims = text_embs_np @ img_feat
    topk_idx = np.argsort(-sims)[:k]

    results = [(get_caption(j), float(sims[j])) for j in topk_idx]

    if return_scores:
        return results
    return [r[0] for r in results]


# ------------------------------------------------------------
# Text → Top-K Images
# ------------------------------------------------------------
def text_to_topk_images(text, k=5, return_scores=False):

    tokenized = clip.tokenize([text], truncate=True).to(device)

    with torch.no_grad():
        txt_feat = model.clip_model.encode_text(tokenized)
        txt_feat = model.txt_proj(txt_feat)
        txt_feat = F.normalize(txt_feat, dim=-1).cpu().numpy()[0]

    sims = image_embs_np @ txt_feat
    topk_idx = np.argsort(-sims)[:k]

    results = [(get_image_path(i), float(sims[i])) for i in topk_idx]

    if return_scores:
        return results
    return [r[0] for r in results]

In [None]:
# ============================================================
# Retrieval-Based VQA (Caption Matching)
# ============================================================

def vqa_via_captions(image_path, question, topk_caps=10):

    # Step 1: Retrieve top-k captions for the image
    top_caps = image_to_topk_captions(image_path,
                                      k=topk_caps,
                                      return_scores=False)

    if not top_caps:
        return "No captions available."

    # Step 2: Compare question embedding to each candidate caption
    candidates = [question] + top_caps
    tokenized = clip.tokenize(candidates, truncate=True).to(device)

    with torch.no_grad():
        feats = model.clip_model.encode_text(tokenized)
        feats = model.txt_proj(feats)
        feats = F.normalize(feats, dim=-1).cpu().numpy()

    q_feat = feats[0]
    cap_feats = feats[1:]

    sims = cap_feats @ q_feat
    best_idx = int(np.argmax(sims))

    return {
        "question": question,
        "answer_caption": top_caps[best_idx],
        "score": float(sims[best_idx]),
        "top_candidate_captions": list(zip(top_caps, sims.tolist()))
    }

In [None]:
# ============================================================
# Demo 1: Image → Text Retrieval
# ============================================================

example_index = 100
example_image = image_paths[example_index]

print("Image → Text Retrieval Example")
show_image(example_image)

topcaps = image_to_topk_captions(example_image,
                                 k=5,
                                 return_scores=True)

for cap, score in topcaps:
    print(f"score={score:.4f} -> {cap}")

In [None]:
# ============================================================
# Demo 2: Text → Image Retrieval
# ============================================================

query_text = "person with red hat dancing"
print("Text Query:", query_text)

top_images = text_to_topk_images(query_text,
                                 k=5,
                                 return_scores=True)

for i, (img_path, score) in enumerate(top_images):
    print(f"Rank {i+1} | score={score:.4f}")
    show_image(img_path)

In [None]:
# ============================================================
# Demo 3: VQA
# ============================================================

question = "What is the child doing?"

show_image(example_image)

vqa_res = vqa_via_captions(example_image,
                           question,
                           topk_caps=10)

print("Question:", question)
print("Answer:", vqa_res["answer_caption"])
print("Score:", vqa_res["score"])