**Run this to download the dataset if not using kaggle**

In [None]:
import os

# CHANGE THIS to where you want the dataset
COCO_ROOT = "/kaggle/input/coco-2017-dataset"  

IMAGES_DIR = COCO_ROOT  # train2017/ and val2017/ will live directly under this
ANN_DIR = os.path.join(COCO_ROOT, "annotations")

os.makedirs(COCO_ROOT, exist_ok=True)
os.makedirs(ANN_DIR, exist_ok=True)

COCO_ROOT, IMAGES_DIR, ANN_DIR

# This cell uses IPython's ! to run shell commands.
# It will:
#  - download train2017.zip
#  - download val2017.zip
#  - download annotations_trainval2017.zip

print("Downloading MS COCO 2017 train/val + annotations to", COCO_ROOT)

# Train images
!cd "$COCO_ROOT" && wget -c http://images.cocodataset.org/zips/train2017.zip

# Val images
!cd "$COCO_ROOT" && wget -c http://images.cocodataset.org/zips/val2017.zip

# Train/Val annotations (includes captions)
!cd "$COCO_ROOT" && wget -c http://images.cocodataset.org/annotations/annotations_trainval2017.zip


# Unzip train and val images into COCO_ROOT
!cd "$COCO_ROOT" && unzip -q train2017.zip
!cd "$COCO_ROOT" && unzip -q val2017.zip

# Unzip annotations into COCO_ROOT/annotations
!cd "$COCO_ROOT" && unzip -q annotations_trainval2017.zip -d "$ANN_DIR"


import glob
for z in glob.glob(os.path.join(COCO_ROOT, "*.zip")):
    print("Removing", z)
    os.remove(z)



# === Cell 1: Install dependencies and set global config ===


In [1]:

!pip install timm open_clip_torch nltk pycocotools -q

import os, json, re, random
from collections import Counter, defaultdict
from typing import List, Dict, Any, Tuple

import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from tqdm.auto import tqdm
from PIL import Image
from torchvision import transforms
import timm
import open_clip
import nltk
nltk.download('punkt', quiet=True)

print("Torch:", torch.__version__)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# ==== PATHS & CONFIG (adjust as needed) ====
# COCO 2017 dataset root on Kaggle (change if different)
COCO_ROOT    = "/kaggle/input/coco-2017-dataset/coco2017"
PROCESSED_DIR = "/kaggle/working/processed"
os.makedirs(PROCESSED_DIR, exist_ok=True)

VOCAB_PATH   = os.path.join(PROCESSED_DIR, "vocab.json")

MAX_LEN      = 30          # max caption length (incl <bos>/<eos>)
BATCH_SIZE   = 64
FREQ_THRESH  = 4           # min word frequency to keep in vocab
DEBUG_LIMIT  = 100000       # set to None for full train; smaller for faster dev
VAL_DEBUG    = 5000        # small val subset for faster eval

# Beam search config
BEAM_SIZE    = 3           # you can increase to 5 for better quality (but slower)
NO_REPEAT_NGRAM_SIZE = 3   # e.g., 3 to avoid repeating same 3-gram (0 disables)




Torch: 2.6.0+cu124
Using device: cuda


# === Cell 2: Vocabulary, tokenization, COCO dataset, collate ===


In [2]:

SPECIAL_TOKENS = ["<pad>", "<bos>", "<eos>", "<unk>"]

class Vocabulary:
    def __init__(self, freq_threshold: int = 5):
        self.freq_threshold = freq_threshold
        self.itos: List[str] = []
        self.stoi: Dict[str, int] = {}

    def __len__(self):
        return len(self.itos)

    def build_vocab(self, counter: Counter):
        self.itos = SPECIAL_TOKENS.copy()
        for word, freq in counter.items():
            if freq >= self.freq_threshold:
                self.itos.append(word)
        self.stoi = {w: i for i, w in enumerate(self.itos)}
        print(f"Vocab built: {len(self.itos)} tokens (freq ≥ {self.freq_threshold})")

    def numericalize(self, tokens: List[str]) -> List[int]:
        return [self.stoi.get(tok, self.stoi["<unk>"]) for tok in tokens]

    def save(self, path: str):
        obj = {"itos": self.itos, "freq_threshold": self.freq_threshold}
        with open(path, "w") as f:
            json.dump(obj, f)
        print(f"Saved vocab to {path}")

    @classmethod
    def load(cls, path: str) -> "Vocabulary":
        with open(path, "r") as f:
            obj = json.load(f)
        vocab = cls(freq_threshold=obj.get("freq_threshold", 5))
        vocab.itos = obj["itos"]
        vocab.stoi = {w: i for i, w in enumerate(vocab.itos)}
        print(f"Loaded vocab from {path}, size={len(vocab)}")
        return vocab

def tokenize_caption(text: str) -> List[str]:
    text = text.lower().strip()
    text = re.sub(r"[.?!]+$", "", text)
    return text.split()

class COCODataset(Dataset):
    def __init__(
        self,
        images_root: str,
        captions_json: str,
        vocab: Vocabulary,
        max_len: int = 30,
        transform=None,
        debug_limit: int = None,
    ):
        self.images_root = images_root
        self.vocab = vocab
        self.max_len = max_len
        self.transform = transform or self._default_transform()

        with open(captions_json, "r") as f:
            ann = json.load(f)

        self.imgs = {img["id"]: img for img in ann["images"]}

        self.samples = []
        for a in ann["annotations"]:
            img_id = a["image_id"]
            caption = a["caption"]
            tokens = tokenize_caption(caption)
            self.samples.append((self.imgs[img_id]["file_name"], tokens))

        if debug_limit is not None:
            self.samples = self.samples[:debug_limit]
            print(f"[COCODataset] Debug limit: {len(self.samples)} samples")

        print(f"Loaded {len(self.samples)} (image, caption) pairs from {captions_json}")

    def _default_transform(self):
        return transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225],
            ),
        ])

    def __len__(self) -> int:
        return len(self.samples)

    def __getitem__(self, idx: int) -> Dict[str, Any]:
        file_name, tokens = self.samples[idx]
        img_path = os.path.join(self.images_root, file_name)

        image = Image.open(img_path).convert("RGB")
        image = self.transform(image)

        bos_id = self.vocab.stoi["<bos>"]
        eos_id = self.vocab.stoi["<eos>"]
        pad_id = self.vocab.stoi["<pad>"]

        seq_ids = [bos_id] + self.vocab.numericalize(tokens) + [eos_id]
        if len(seq_ids) < self.max_len:
            seq_ids += [pad_id] * (self.max_len - len(seq_ids))
        else:
            seq_ids = seq_ids[:self.max_len]

        caption = torch.tensor(seq_ids, dtype=torch.long)

        return {
            "image": image,
            "caption": caption,
            "file_name": file_name,
        }

def coco_collate_fn(batch):
    images = torch.stack([b["image"] for b in batch], dim=0)
    captions = torch.stack([b["caption"] for b in batch], dim=0)
    file_names = [b["file_name"] for b in batch]
    return {
        "image": images,
        "caption": captions,
        "file_name": file_names,
    }

def simple_tokenize(text: str):
    return re.findall(r"[a-z]+", text.lower())


# === Cell 3: Build or load vocabulary from COCO train captions 


In [3]:

TRAIN_CAPTIONS_JSON = os.path.join(COCO_ROOT, "annotations", "captions_train2017.json")

if os.path.exists(VOCAB_PATH):
    vocab = Vocabulary.load(VOCAB_PATH)
else:
    with open(TRAIN_CAPTIONS_JSON, "r") as f:
        train_ann = json.load(f)
    print("Num captions:", len(train_ann["annotations"]))

    counter = Counter()
    for ann in tqdm(train_ann["annotations"], desc="Counting words for vocab"):
        tokens = tokenize_caption(ann["caption"])
        counter.update(tokens)

    vocab = Vocabulary(freq_threshold=FREQ_THRESH)
    vocab.build_vocab(counter)
    vocab.save(VOCAB_PATH)

pad_id = vocab.stoi["<pad>"]
bos_id = vocab.stoi["<bos>"]
eos_id = vocab.stoi["<eos>"]
print("pad/bos/eos ids:", pad_id, bos_id, eos_id)
print("Vocab size:", len(vocab))


Num captions: 591753


Counting words for vocab:   0%|          | 0/591753 [00:00<?, ?it/s]

Vocab built: 12822 tokens (freq ≥ 4)
Saved vocab to /kaggle/working/processed/vocab.json
pad/bos/eos ids: 0 1 2
Vocab size: 12822


# Cell 4: Vision encoder (ViT), Transformer decoder, Captioner 


In [4]:

class ViTEncoder(nn.Module):
    def __init__(
        self,
        model_name: str = "vit_base_patch16_224",
        pretrained: bool = True,
        trainable: bool = False,
        d_model: int = 512,
    ):
        super().__init__()
        self.vit = timm.create_model(
            model_name,
            pretrained=pretrained,
        )
        self.vit.reset_classifier(0)
        vit_dim = self.vit.num_features
        if vit_dim != d_model:
            self.proj = nn.Linear(vit_dim, d_model)
        else:
            self.proj = nn.Identity()

        for p in self.vit.parameters():
            p.requires_grad = trainable

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        feats = self.vit.forward_features(x)  # [B, S, C] or [B, C] depending on timm version
        if feats.dim() == 2:
            feats = feats.unsqueeze(1)        # [B, 1, C]
        feats = self.proj(feats)             # [B, S, d_model]
        return feats

class TransformerCaptionDecoder(nn.Module):
    def __init__(
        self,
        vocab_size: int,
        d_model: int = 512,
        num_layers: int = 6,
        num_heads: int = 8,
        dim_feedforward: int = 2048,
        max_len: int = 30,
        dropout: float = 0.1,
    ):
        super().__init__()
        self.vocab_size = vocab_size
        self.d_model = d_model
        self.max_len = max_len

        self.token_embed = nn.Embedding(vocab_size, d_model)
        self.pos_embed = nn.Embedding(max_len, d_model)
        self.dropout = nn.Dropout(dropout)

        decoder_layer = nn.TransformerDecoderLayer(
            d_model=d_model,
            nhead=num_heads,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            batch_first=True,
        )
        self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_layers)
        self.out_proj = nn.Linear(d_model, vocab_size)

    def forward(self, tgt, memory, tgt_key_padding_mask=None):
        B, T = tgt.shape
        positions = torch.arange(0, T, device=tgt.device).unsqueeze(0).expand(B, T)
        x = self.token_embed(tgt) * (self.d_model ** 0.5)
        x = x + self.pos_embed(positions)
        x = self.dropout(x)

        causal_mask = torch.triu(
            torch.ones(T, T, device=tgt.device, dtype=torch.bool),
            diagonal=1,
        )

        x = self.decoder(
            tgt=x,
            memory=memory,
            tgt_mask=causal_mask,
            tgt_key_padding_mask=tgt_key_padding_mask,
        )
        logits = self.out_proj(x)
        return logits

class Captioner(nn.Module):
    def __init__(self, vocab_size: int, max_len: int = 30, d_model: int = 512, vit_trainable: bool = False):
        super().__init__()
        self.encoder = ViTEncoder(d_model=d_model, trainable=vit_trainable)
        self.decoder = TransformerCaptionDecoder(
            vocab_size=vocab_size,
            d_model=d_model,
            max_len=max_len,
        )
        self.max_len = max_len

    def forward(self, images: torch.Tensor, captions_in: torch.Tensor) -> torch.Tensor:
        memory = self.encoder(images)
        logits = self.decoder(captions_in, memory)
        return logits


# === Cell 5: Greedy decoding and Beam search decoding ===


In [5]:

@torch.no_grad()
def greedy_decode_batch(model, images, bos_id, eos_id, max_len, device):
    model.eval()
    B = images.size(0)
    images = images.to(device)
    with torch.no_grad():
        memory = model.encoder(images)
        ys = torch.full((B, 1), bos_id, device=device, dtype=torch.long)
        finished = torch.zeros(B, dtype=torch.bool, device=device)

        for _ in range(max_len - 1):
            logits = model.decoder(ys, memory)  # [B, t, V]
            next_token = logits[:, -1, :].argmax(dim=-1, keepdim=True)
            ys = torch.cat([ys, next_token], dim=1)
            finished |= (next_token.squeeze(-1) == eos_id)
            if finished.all():
                break
    return ys  # [B, T]

@torch.no_grad()
def beam_search_decode_batch(
    model,
    images,
    bos_id,
    eos_id,
    max_len,
    device,
    beam_size=3,
    no_repeat_ngram_size=0,
):
    """
    Beam search for each image in the batch independently.
    Returns tensor [B, T] of best sequences.
    """
    model.eval()
    images = images.to(device)
    B = images.size(0)

    with torch.no_grad():
        memory = model.encoder(images)   # [B, S, D]

    sequences = []
    for i in range(B):
        mem_i = memory[i:i+1]           # [1, S, D]
        seq = beam_search_single(
            model,
            mem_i,
            bos_id,
            eos_id,
            max_len,
            device,
            beam_size,
            no_repeat_ngram_size,
        )
        sequences.append(seq)

    max_T = max(len(s) for s in sequences)
    out = torch.full((B, max_T), eos_id, device=device, dtype=torch.long)
    for i, seq in enumerate(sequences):
        out[i, :len(seq)] = torch.tensor(seq, device=device, dtype=torch.long)
    return out

def _has_repeat_ngram(candidate: List[int], no_repeat_ngram_size: int) -> bool:
    if no_repeat_ngram_size <= 0:
        return False
    if len(candidate) < 2 * no_repeat_ngram_size:
        return False
    seen = set()
    for i in range(len(candidate) - no_repeat_ngram_size + 1):
        ngram = tuple(candidate[i:i+no_repeat_ngram_size])
        if ngram in seen:
            return True
        seen.add(ngram)
    return False

@torch.no_grad()
def beam_search_single(
    model,
    memory,                # [1, S, D] for a single image
    bos_id,
    eos_id,
    max_len,
    device,
    beam_size=3,
    no_repeat_ngram_size=0,
):
    """
    Beam search for a single image.
    Returns a list of token ids (best sequence).
    """
    V = model.decoder.vocab_size

    beam = [(0.0, [bos_id])]
    completed = []

    for _ in range(max_len - 1):
        new_beam = []
        for logp, seq in beam:
            if seq[-1] == eos_id:
                completed.append((logp, seq))
                continue

            tgt = torch.tensor(seq, device=device, dtype=torch.long).unsqueeze(0)  # [1, t]
            logits = model.decoder(tgt, memory)  # [1, t, V]
            next_logits = logits[:, -1, :]       # [1, V]
            probs = torch.log_softmax(next_logits, dim=-1).squeeze(0)  # [V]

            topk_logp, topk_ids = probs.topk(beam_size)

            for lp, idx in zip(topk_logp.tolist(), topk_ids.tolist()):
                candidate = seq + [idx]
                if no_repeat_ngram_size > 0 and _has_repeat_ngram(candidate, no_repeat_ngram_size):
                    continue
                new_beam.append((logp + lp, candidate))

        if not new_beam:
            break

        new_beam.sort(key=lambda x: x[0], reverse=True)
        beam = new_beam[:beam_size]

        if len(completed) >= beam_size:
            break

    if completed:
        completed.sort(key=lambda x: x[0], reverse=True)
        best_seq = completed[0][1]
    else:
        beam.sort(key=lambda x: x[0], reverse=True)
        best_seq = beam[0][1]

    return best_seq


# === Cell 6: CLIP grounding loss and ID->text helpers ===


In [6]:

class CLIPGroundingLoss(nn.Module):
    def __init__(
        self,
        clip_model,
        clip_tokenizer,
        device,
        lambda_clip: float = 0.5,
        input_mean=(0.485, 0.456, 0.406),
        input_std=(0.229, 0.224, 0.225),
    ):
        super().__init__()
        self.clip_model = clip_model.eval()
        self.clip_tokenizer = clip_tokenizer
        self.device = device
        self.lambda_clip = lambda_clip

        self.input_mean = torch.tensor(input_mean).view(1, 3, 1, 1).to(device)
        self.input_std  = torch.tensor(input_std).view(1, 3, 1, 1).to(device)

        self.clip_mean = torch.tensor(
            [0.48145466, 0.4578275, 0.40821073]
        ).view(1, 3, 1, 1).to(device)
        self.clip_std = torch.tensor(
            [0.26862954, 0.26130258, 0.27577711]
        ).view(1, 3, 1, 1).to(device)

        for p in self.clip_model.parameters():
            p.requires_grad = False

    @torch.no_grad()
    def _encode_image(self, images: torch.Tensor) -> torch.Tensor:
        imgs = images * self.input_std + self.input_mean
        imgs = imgs.clamp(0, 1)
        imgs = (imgs - self.clip_mean) / self.clip_std
        img_feats = self.clip_model.encode_image(imgs)
        img_feats = img_feats / img_feats.norm(dim=-1, keepdim=True)
        return img_feats

    @torch.no_grad()
    def _encode_text(self, captions: List[str]) -> torch.Tensor:
        tokens = self.clip_tokenizer(captions).to(self.device)
        txt_feats = self.clip_model.encode_text(tokens)
        txt_feats = txt_feats / txt_feats.norm(dim=-1, keepdim=True)
        return txt_feats

    def forward(self, images: torch.Tensor, captions_text: List[str]) -> torch.Tensor:
        img_feats = self._encode_image(images)
        txt_feats = self._encode_text(captions_text)
        sim = (img_feats * txt_feats).sum(dim=-1)
        loss = (1.0 - sim).mean()
        return self.lambda_clip * loss

def ids_to_words(ids, vocab, pad_id, eos_id, drop_bos=True):
    words = []
    for i in ids:
        if i == pad_id:
            break
        w = vocab.itos[i]
        if drop_bos and w == "<bos>":
            continue
        if w == "<eos>":
            break
        words.append(w)
    return words

def batch_ids_to_strings(batch_ids, vocab, pad_id, eos_id):
    if isinstance(batch_ids, torch.Tensor):
        batch_ids = batch_ids.tolist()
    sentences = []
    for ids in batch_ids:
        words = ids_to_words(ids, vocab, pad_id, eos_id)
        sentences.append(" ".join(words))
    return sentences


# Cell 7: Build train/val datasets, dataloaders, model, optimizer, CLIP


In [None]:

TRAIN_IMAGES_DIR = os.path.join(COCO_ROOT, "train2017")
VAL_IMAGES_DIR   = os.path.join(COCO_ROOT, "val2017")
ANN_DIR          = os.path.join(COCO_ROOT, "annotations")
TRAIN_JSON       = os.path.join(ANN_DIR, "captions_train2017.json")
VAL_JSON         = os.path.join(ANN_DIR, "captions_val2017.json")

train_dataset = COCODataset(
    images_root=TRAIN_IMAGES_DIR,
    captions_json=TRAIN_JSON,
    vocab=vocab,
    max_len=MAX_LEN,
    debug_limit=DEBUG_LIMIT,
)

val_dataset = COCODataset(
    images_root=VAL_IMAGES_DIR,
    captions_json=VAL_JSON,
    vocab=vocab,
    max_len=MAX_LEN,
    debug_limit=VAL_DEBUG,
)

train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=0,          # keep 0 to avoid multiprocessing issues
    collate_fn=coco_collate_fn,
)
val_loader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=0,
    collate_fn=coco_collate_fn,
)

model = Captioner(
    vocab_size=len(vocab),
    max_len=MAX_LEN,
    d_model=512,
    vit_trainable=True,    # start frozen; you can set True later to fine-tune ViT
).to(device)

criterion = nn.CrossEntropyLoss(ignore_index=pad_id)
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=1e-4)

# CLIP model for grounding
clip_model, _, _ = open_clip.create_model_and_transforms("ViT-B-32", pretrained="openai")
clip_model = clip_model.to(device)
clip_tokenizer = open_clip.get_tokenizer("ViT-B-32")
clip_loss_fn = CLIPGroundingLoss(clip_model, clip_tokenizer, device, lambda_clip=0.0)

print("Setup done. Train samples:", len(train_dataset), "Val samples:", len(val_dataset))


[COCODataset] Debug limit: 10000 samples
Loaded 10000 (image, caption) pairs from /kaggle/input/coco-2017-dataset/coco2017/annotations/captions_train2017.json
[COCODataset] Debug limit: 1000 samples
Loaded 1000 (image, caption) pairs from /kaggle/input/coco-2017-dataset/coco2017/annotations/captions_val2017.json


model.safetensors:   0%|          | 0.00/346M [00:00<?, ?B/s]

# === Cell 8: Training loop with optional CLIP grounding (CE + lambda * CLIP) ===


In [None]:

def train_one_epoch_with_clip(
    model,
    loader,
    optimizer,
    criterion,
    device,
    vocab,
    clip_loss_fn: CLIPGroundingLoss = None,
    lambda_clip_current: float = 0.0,
    max_len: int = 30,
):
    model.train()
    total_loss = total_ce = total_clip = 0.0

    bos_id = vocab.stoi["<bos>"]
    eos_id = vocab.stoi["<eos>"]
    pad_id = vocab.stoi["<pad>"]

    for batch in tqdm(loader, desc="Train"):
        images = batch["image"].to(device)
        captions = batch["caption"].to(device)

        inputs = captions[:, :-1]
        targets = captions[:, 1:]

        logits = model(images, inputs)
        B, T, V = logits.shape
        ce_loss = criterion(
            logits.reshape(B * T, V),
            targets.reshape(B * T),
        )

        if clip_loss_fn is not None and lambda_clip_current > 0:
            with torch.no_grad():
                pred_ids_batch = greedy_decode_batch(
                    model, images, bos_id, eos_id, max_len, device
                )
            captions_text = batch_ids_to_strings(pred_ids_batch, vocab, pad_id, eos_id)
            clip_loss_fn.lambda_clip = lambda_clip_current
            clip_loss = clip_loss_fn(images, captions_text)
        else:
            clip_loss = torch.tensor(0.0, device=device)

        loss = ce_loss + clip_loss

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()

        total_loss += loss.item()
        total_ce += ce_loss.item()
        total_clip += clip_loss.item()

    n = len(loader)
    return total_loss / n, total_ce / n, total_clip / n

# === Run training ===
EPOCHS = 5
LAMBDA_CLIP_TARGET = 0.0   # adjust; try 0.3–0.5
WARMUP_EPOCHS = 2          # ramp CLIP from 0 to target over first few epochs

for epoch in range(EPOCHS):
    if LAMBDA_CLIP_TARGET > 0:
        if epoch < WARMUP_EPOCHS:
            lambda_clip_current = LAMBDA_CLIP_TARGET * (epoch + 1) / WARMUP_EPOCHS
        else:
            lambda_clip_current = LAMBDA_CLIP_TARGET
    else:
        lambda_clip_current = 0.0

    print(f"\n=== Epoch {epoch+1}/{EPOCHS} | lambda_clip={lambda_clip_current:.3f} ===")
    avg_loss, avg_ce, avg_clip = train_one_epoch_with_clip(
        model, train_loader, optimizer, criterion,
        device, vocab, clip_loss_fn, lambda_clip_current, MAX_LEN
    )
    print(f"Train: total={avg_loss:.4f} | CE={avg_ce:.4f} | CLIP={avg_clip:.4f}")

    torch.save(model.state_dict(), f"/kaggle/working/model_epoch{epoch+1}.pt")
    print("Saved checkpoint for epoch", epoch+1)


# === Cell 9: Evaluation – BLEU (beam search) and CLIPScore (beam search) ===


In [None]:

from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction

def evaluate_bleu_on_val_beam(model, val_loader, vocab, device, max_batches=30):
    model.eval()
    smoothie = SmoothingFunction().method4
    scores = []

    for b_idx, batch in enumerate(tqdm(val_loader, desc="BLEU eval (beam)")):
        if b_idx >= max_batches:
            break
        images = batch["image"]
        captions = batch["caption"]

        pred_ids_batch = beam_search_decode_batch(
            model, images, bos_id, eos_id, MAX_LEN, device,
            beam_size=BEAM_SIZE,
            no_repeat_ngram_size=NO_REPEAT_NGRAM_SIZE,
        )

        B = captions.size(0)
        for i in range(B):
            gt_ids   = captions[i].tolist()
            pred_ids = pred_ids_batch[i].tolist()

            gt_words   = ids_to_words(gt_ids, vocab, pad_id, eos_id)
            pred_words = ids_to_words(pred_ids, vocab, pad_id, eos_id)

            if len(pred_words) == 0 or len(gt_words) == 0:
                continue

            score = sentence_bleu(
                [gt_words],
                pred_words,
                smoothing_function=smoothie,
                weights=(0.25, 0.25, 0.25, 0.25),
            )
            scores.append(score)

    return float(np.mean(scores)) if scores else 0.0

def compute_clipscore_on_val_beam(model, val_loader, clip_model, clip_tokenizer, vocab, device, max_batches=30):
    model.eval()
    clip_model.eval()
    sims = []

    for b_idx, batch in enumerate(tqdm(val_loader, desc="CLIPScore eval (beam)")):
        if b_idx >= max_batches:
            break

        images = batch["image"].to(device)

        pred_ids_batch = beam_search_decode_batch(
            model, images, bos_id, eos_id, MAX_LEN, device,
            beam_size=BEAM_SIZE,
            no_repeat_ngram_size=NO_REPEAT_NGRAM_SIZE,
        )
        captions_text = batch_ids_to_strings(pred_ids_batch, vocab, pad_id, eos_id)

        with torch.no_grad():
            # Using open_clip's preprocess pipeline
            # Recreate transforms from create_model_and_transforms:
            # we didn't save the preprocess, so just recreate:
            _, _, preprocess = open_clip.create_model_and_transforms("ViT-B-32", pretrained="openai")
            for img, cap in zip(batch["image"], captions_text):
                image = transforms.ToPILImage()(img.cpu())
                image = preprocess(image).unsqueeze(0).to(device)
                tokens = clip_tokenizer([cap]).to(device)

                img_feat = clip_model.encode_image(image)
                txt_feat = clip_model.encode_text(tokens)

                img_feat = img_feat / img_feat.norm(dim=-1, keepdim=True)
                txt_feat = txt_feat / txt_feat.norm(dim=-1, keepdim=True)

                sim = (img_feat * txt_feat).sum(dim=-1).item()
                sims.append(sim)

    sims = np.array(sims)
    return {
        "mean": float(sims.mean()),
        "median": float(np.median(sims)),
        "std": float(sims.std()),
        "low_clip_rate": float((sims < 0.2).mean()),
    }

bleu_val = evaluate_bleu_on_val_beam(model, val_loader, vocab, device, max_batches=32)
clip_stats = compute_clipscore_on_val_beam(model, val_loader, clip_model, clip_tokenizer, vocab, device, max_batches=32)

print("BLEU-4 (beam) approx:", bleu_val)
print("CLIPScore stats (beam):", clip_stats)


# === Cell 10: CHAIR hallucination metrics (beam captions) ===


In [None]:

INST_VAL_JSON = os.path.join(ANN_DIR, "instances_val2017.json")
with open(INST_VAL_JSON, "r") as f:
    inst_val = json.load(f)

imgid_to_file = {img["id"]: img["file_name"] for img in inst_val["images"]}
catid_to_name = {c["id"]: c["name"].lower() for c in inst_val["categories"]}

file_to_objects = defaultdict(set)
for ann in inst_val["annotations"]:
    img_id = ann["image_id"]
    cat_id = ann["category_id"]
    file_name = imgid_to_file[img_id]
    cat_name = catid_to_name[cat_id]
    file_to_objects[file_name].add(cat_name)

category_vocab = []
for cat_name in sorted(set(catid_to_name.values())):
    tokens = simple_tokenize(cat_name)
    category_vocab.append((cat_name, tokens))

def find_mentioned_objects(caption_text: str, category_vocab):
    tokens = simple_tokenize(caption_text)
    mentioned = set()
    for cat_name, cat_tokens in category_vocab:
        L = len(cat_tokens)
        if L == 1:
            if cat_tokens[0] in tokens:
                mentioned.add(cat_name)
        else:
            for i in range(len(tokens) - L + 1):
                if tokens[i:i+L] == cat_tokens:
                    mentioned.add(cat_name)
                    break
    return mentioned

def compute_chair_on_val_beam(
    model,
    val_loader,
    file_to_objects,
    category_vocab,
    vocab,
    device,
    max_batches: int = 30,
):
    model.eval()
    all_caps = 0
    caps_with_hallucination = 0
    total_mentions = 0
    hallucinated_mentions = 0

    for b_idx, batch in enumerate(tqdm(val_loader, desc="CHAIR eval (beam)")):
        if b_idx >= max_batches:
            break

        images = batch["image"]
        file_names = batch["file_name"]
        all_caps += len(file_names)

        pred_ids_batch = beam_search_decode_batch(
            model, images, bos_id, eos_id, MAX_LEN, device,
            beam_size=BEAM_SIZE,
            no_repeat_ngram_size=NO_REPEAT_NGRAM_SIZE,
        )
        pred_texts = batch_ids_to_strings(pred_ids_batch, vocab, pad_id, eos_id)

        for fn, pred_text in zip(file_names, pred_texts):
            gt_objects = file_to_objects.get(fn, set())
            mentioned = find_mentioned_objects(pred_text, category_vocab)
            if not mentioned:
                continue
            total_mentions += len(mentioned)
            hallucinated = mentioned - gt_objects
            if hallucinated:
                caps_with_hallucination += 1
                hallucinated_mentions += len(hallucinated)

    chair_s = caps_with_hallucination / all_caps if all_caps else 0.0
    chair_i = hallucinated_mentions / total_mentions if total_mentions else 0.0
    return {
        "CHAIRs": chair_s,
        "CHAIRi": chair_i,
        "total_captions": all_caps,
        "total_mentions": total_mentions,
        "hallucinated_mentions": hallucinated_mentions,
    }

chair_stats = compute_chair_on_val_beam(
    model, val_loader, file_to_objects, category_vocab, vocab, device, max_batches=32
)
print("CHAIR stats (beam):", chair_stats)


# === Cell 11: Show a few beam-search predictions vs ground truth ===


In [None]:

model.eval()
batch = next(iter(val_loader))
images = batch["image"]
captions = batch["caption"]
file_names = batch["file_name"]

pred_ids_batch = beam_search_decode_batch(
    model, images, bos_id, eos_id, MAX_LEN, device,
    beam_size=BEAM_SIZE,
    no_repeat_ngram_size=NO_REPEAT_NGRAM_SIZE,
)

for i in range(10):
    gt_ids   = captions[i].tolist()
    pred_ids = pred_ids_batch[i].tolist()
    

    gt_words   = ids_to_words(gt_ids, vocab, pad_id, eos_id)
    pred_words = ids_to_words(pred_ids, vocab, pad_id, eos_id)

    print("\nFile:", file_names[i])
    print("PRED:", " ".join(pred_words))
    print("GT  :", " ".join(gt_words))


In [None]:
print("Raw predicted ID sequences for first few examples:")
for i in range(5):
    print(pred_ids_batch[i].tolist())


# === Cell 11 Updated: Show image + predictions + GT ===


In [None]:
import matplotlib.pyplot as plt
import numpy as np

model.eval()
batch = next(iter(val_loader))
images = batch["image"]
captions = batch["caption"]
file_names = batch["file_name"]

pred_ids_batch = beam_search_decode_batch(
    model, images, bos_id, eos_id, MAX_LEN, device,
    beam_size=BEAM_SIZE,
    no_repeat_ngram_size=NO_REPEAT_NGRAM_SIZE,
)

# Show first 10 results with images
for i in range(10):
    img_tensor = images[i].cpu()

    # Un-normalize for viewing (reverse ImageNet normalization)
    mean = torch.tensor([0.485, 0.456, 0.406]).view(3,1,1)
    std  = torch.tensor([0.229, 0.224, 0.225]).view(3,1,1)
    img_np = (img_tensor * std + mean).clamp(0,1).permute(1,2,0).numpy()

    gt_ids   = captions[i].tolist()
    pred_ids = pred_ids_batch[i].tolist()

    gt_words   = ids_to_words(gt_ids, vocab, pad_id, eos_id)
    pred_words = ids_to_words(pred_ids, vocab, pad_id, eos_id)

    plt.figure(figsize=(5,5))
    plt.imshow(img_np)
    plt.axis("off")

    title_str = f"PRED: {' '.join(pred_words)}\nGT: {' '.join(gt_words)}"
    plt.title(title_str, fontsize=9)
    plt.show()
