In [2]:
import os
import json
import re
from collections import Counter
from typing import List, Dict, Any

import numpy as np
from PIL import Image

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from tqdm.auto import tqdm

import torch_xla
import torch_xla.core.xla_model as xm

# ---- Device ----
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = xm.xla_device()
# torch.set_default_tensor_type('torch.FloatTensor')
print("Using device:", device)

# ---- Paths (change COCO_ROOT if needed) ----
COCO_ROOT = "/kaggle/input/coco-2017-dataset/coco2017"  # <- adjust if your dataset is elsewhere
ANN_DIR   = os.path.join(COCO_ROOT, "annotations")

TRAIN_IMAGES_DIR = os.path.join(COCO_ROOT, "train2017")
VAL_IMAGES_DIR   = os.path.join(COCO_ROOT, "val2017")
TRAIN_JSON       = os.path.join(ANN_DIR, "captions_train2017.json")
VAL_JSON         = os.path.join(ANN_DIR, "captions_val2017.json")

print("Train images dir:", TRAIN_IMAGES_DIR)
print("Val images dir  :", VAL_IMAGES_DIR)
print("Train captions  :", TRAIN_JSON)
print("Val captions    :", VAL_JSON)

# ---- Training config ----
MAX_LEN      = 30
FREQ_THRESH  = 5
BATCH_SIZE   = 32
DEBUG_LIMIT  = 10000     # use a subset for speed; set to None for full
VAL_DEBUG    = 1000

WORK_DIR   = "/kaggle/working"
VOCAB_PATH = os.path.join(WORK_DIR, "vocab_sat.json")
os.makedirs(WORK_DIR, exist_ok=True)




Using device: cpu
Train images dir: /kaggle/input/coco-2017-dataset/coco2017/train2017
Val images dir  : /kaggle/input/coco-2017-dataset/coco2017/val2017
Train captions  : /kaggle/input/coco-2017-dataset/coco2017/annotations/captions_train2017.json
Val captions    : /kaggle/input/coco-2017-dataset/coco2017/annotations/captions_val2017.json


In [3]:
# === Cell B: Vocabulary, tokenization, and build vocab ===

SPECIAL_TOKENS = ["<pad>", "<bos>", "<eos>", "<unk>"]

class Vocabulary:
    def __init__(self, freq_threshold: int = 5):
        self.freq_threshold = freq_threshold
        self.itos: List[str] = []
        self.stoi: Dict[str, int] = {}

    def __len__(self):
        return len(self.itos)

    def build_vocab(self, counter: Counter):
        self.itos = SPECIAL_TOKENS.copy()
        for word, freq in counter.items():
            if freq >= self.freq_threshold:
                self.itos.append(word)
        self.stoi = {w: i for i, w in enumerate(self.itos)}
        print(f"Vocab built: {len(self.itos)} tokens (freq ≥ {self.freq_threshold})")

    def numericalize(self, tokens: List[str]) -> List[int]:
        return [self.stoi.get(tok, self.stoi["<unk>"]) for tok in tokens]

    def save(self, path: str):
        obj = {"itos": self.itos, "freq_threshold": self.freq_threshold}
        with open(path, "w") as f:
            json.dump(obj, f)
        print(f"Saved vocab to {path}")

    @classmethod
    def load(cls, path: str) -> "Vocabulary":
        with open(path, "r") as f:
            obj = json.load(f)
        vocab = cls(freq_threshold=obj.get("freq_threshold", 5))
        vocab.itos = obj["itos"]
        vocab.stoi = {w: i for i, w in enumerate(vocab.itos)}
        print(f"Loaded vocab from {path}, size={len(vocab)}")
        return vocab

def tokenize_caption(text: str) -> List[str]:
    text = text.lower().strip()
    text = re.sub(r"[.?!]+$", "", text)
    return text.split()

# ---- Build or load vocab ----
if os.path.exists(VOCAB_PATH):
    vocab = Vocabulary.load(VOCAB_PATH)
else:
    with open(TRAIN_JSON, "r") as f:
        train_ann = json.load(f)
    print("Num training captions:", len(train_ann["annotations"]))

    counter = Counter()
    for ann in tqdm(train_ann["annotations"], desc="Counting words for vocab"):
        tokens = tokenize_caption(ann["caption"])
        counter.update(tokens)

    vocab = Vocabulary(freq_threshold=FREQ_THRESH)
    vocab.build_vocab(counter)
    vocab.save(VOCAB_PATH)

pad_id = vocab.stoi["<pad>"]
bos_id = vocab.stoi["<bos>"]
eos_id = vocab.stoi["<eos>"]

print("pad/bos/eos ids:", pad_id, bos_id, eos_id)
print("Vocab size:", len(vocab))


Num training captions: 591753


Counting words for vocab:   0%|          | 0/591753 [00:00<?, ?it/s]

Vocab built: 11405 tokens (freq ≥ 5)
Saved vocab to /kaggle/working/vocab_sat.json
pad/bos/eos ids: 0 1 2
Vocab size: 11405


In [4]:
# === Cell C: Dataset and DataLoaders ===

class COCODataset2017(Dataset):
    def __init__(
        self,
        images_root: str,
        captions_json: str,
        vocab: Vocabulary,
        max_len: int = 30,
        transform=None,
        debug_limit: int = None,
    ):
        self.images_root = images_root
        self.vocab = vocab
        self.max_len = max_len
        self.transform = transform or self._default_transform()

        with open(captions_json, "r") as f:
            ann = json.load(f)

        self.imgs = {img["id"]: img for img in ann["images"]}

        self.samples = []
        for a in ann["annotations"]:
            img_id = a["image_id"]
            caption = a["caption"]
            tokens = tokenize_caption(caption)
            self.samples.append((self.imgs[img_id]["file_name"], tokens))

        if debug_limit is not None:
            self.samples = self.samples[:debug_limit]
            print(f"[COCODataset2017] Debug limit: {len(self.samples)} samples")

        print(f"Loaded {len(self.samples)} (image, caption) pairs from {captions_json}")

    def _default_transform(self):
        return transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225],
            ),
        ])

    def __len__(self) -> int:
        return len(self.samples)

    def __getitem__(self, idx: int) -> Dict[str, Any]:
        file_name, tokens = self.samples[idx]
        img_path = os.path.join(self.images_root, file_name)

        image = Image.open(img_path).convert("RGB")
        image = self.transform(image)

        seq_ids = [bos_id] + self.vocab.numericalize(tokens) + [eos_id]
        if len(seq_ids) < self.max_len:
            seq_ids += [pad_id] * (self.max_len - len(seq_ids))
        else:
            seq_ids = seq_ids[:self.max_len]

        caption = torch.tensor(seq_ids, dtype=torch.long)

        return {
            "image": image,
            "caption": caption,
            "file_name": file_name,
        }

def coco_collate_fn(batch):
    images = torch.stack([b["image"] for b in batch], dim=0)
    captions = torch.stack([b["caption"] for b in batch], dim=0)
    file_names = [b["file_name"] for b in batch]
    return {
        "image": images,
        "caption": captions,
        "file_name": file_names,
    }

train_dataset = COCODataset2017(
    images_root=TRAIN_IMAGES_DIR,
    captions_json=TRAIN_JSON,
    vocab=vocab,
    max_len=MAX_LEN,
    debug_limit=DEBUG_LIMIT,
)

val_dataset = COCODataset2017(
    images_root=VAL_IMAGES_DIR,
    captions_json=VAL_JSON,
    vocab=vocab,
    max_len=MAX_LEN,
    debug_limit=VAL_DEBUG,
)

train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=0,
    collate_fn=coco_collate_fn,
)

val_loader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=0,
    collate_fn=coco_collate_fn,
)

print("Train samples:", len(train_dataset))
print("Val samples  :", len(val_dataset))


[COCODataset2017] Debug limit: 10000 samples
Loaded 10000 (image, caption) pairs from /kaggle/input/coco-2017-dataset/coco2017/annotations/captions_train2017.json
[COCODataset2017] Debug limit: 1000 samples
Loaded 1000 (image, caption) pairs from /kaggle/input/coco-2017-dataset/coco2017/annotations/captions_val2017.json
Train samples: 10000
Val samples  : 1000


In [5]:
# === Cell 2: EncoderCNN (ResNet-based encoder) ===

class EncoderCNN(nn.Module):
    """
    CNN encoder for Show, Attend and Tell:
    - Use a pretrained ResNet (e.g., ResNet-50)
    - Remove the final pooling + fc
    - Apply AdaptiveAvgPool2d to get a fixed spatial size
    - Output shape: [B, enc_image_size, enc_image_size, encoder_dim]
    """
    def __init__(self, encoded_image_size=14):
        super().__init__()
        self.enc_image_size = encoded_image_size

        resnet = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2)
        # Remove the final fully connected layer & pooling
        modules = list(resnet.children())[:-2]  # everything until last conv
        self.cnn = nn.Sequential(*modules)

        self.adaptive_pool = nn.AdaptiveAvgPool2d((encoded_image_size, encoded_image_size))
        self.fine_tune(False)

    def forward(self, images):
        """
        images: [B, 3, 224, 224]
        returns: [B, enc_image_size, enc_image_size, encoder_dim]
        """
        out = self.cnn(images)                      # [B, 2048, H, W]
        out = self.adaptive_pool(out)               # [B, 2048, enc_image_size, enc_image_size]
        out = out.permute(0, 2, 3, 1)               # [B, enc_image_size, enc_image_size, 2048]
        return out

    def fine_tune(self, fine_tune=True):
        """
        Allow or prevent the computation of gradients for convolutional blocks.
        By default we freeze everything for stability; you can unfreeze later.
        """
        for p in self.cnn.parameters():
            p.requires_grad = False

        # Unfreeze some layers if fine_tune=True (e.g., last 2 blocks)
        if fine_tune:
            for c in list(self.cnn.children())[-2:]:
                for p in c.parameters():
                    p.requires_grad = True


In [6]:
# === Cell 3: Attention module and DecoderWithAttention ===

class Attention(nn.Module):
    """
    Soft 'additive' attention:
    Given encoder_out (image features) and decoder hidden state,
    produce attention weights over image locations and a context vector.
    """
    def __init__(self, encoder_dim, decoder_dim, attention_dim):
        super().__init__()
        self.encoder_att = nn.Linear(encoder_dim, attention_dim)  # encoder features -> att
        self.decoder_att = nn.Linear(decoder_dim, attention_dim)  # decoder hidden -> att
        self.full_att = nn.Linear(attention_dim, 1)               # combine and score
        self.relu = nn.ReLU()
        self.softmax = nn.Softmax(dim=1)                          # over pixels

    def forward(self, encoder_out, decoder_hidden):
        """
        encoder_out: [B, num_pixels, encoder_dim]
        decoder_hidden: [B, decoder_dim]
        returns:
            attention_weighted_encoding: [B, encoder_dim]
            alpha: [B, num_pixels]
        """
        att1 = self.encoder_att(encoder_out)                       # [B, num_pixels, att_dim]
        att2 = self.decoder_att(decoder_hidden).unsqueeze(1)       # [B, 1, att_dim]
        att = self.full_att(self.relu(att1 + att2)).squeeze(2)     # [B, num_pixels]
        alpha = self.softmax(att)                                  # [B, num_pixels]
        attention_weighted_encoding = (encoder_out * alpha.unsqueeze(2)).sum(dim=1)  # [B, enc_dim]
        return attention_weighted_encoding, alpha


class DecoderWithAttention(nn.Module):
    """
    LSTM decoder with attention (Show, Attend and Tell style).
    We use fixed-length captions with BOS/EOS/PAD.
    """
    def __init__(
        self,
        attention_dim,
        embed_dim,
        decoder_dim,
        vocab_size,
        encoder_dim=2048,
        dropout=0.5,
        max_len=30,
    ):
        super().__init__()

        self.attention = Attention(encoder_dim, decoder_dim, attention_dim)
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.dropout = nn.Dropout(dropout)
        self.decode_step = nn.LSTMCell(embed_dim + encoder_dim, decoder_dim, bias=True)
        self.init_h = nn.Linear(encoder_dim, decoder_dim)
        self.init_c = nn.Linear(encoder_dim, decoder_dim)
        self.f_beta = nn.Linear(decoder_dim, encoder_dim)  # gating scalar
        self.sigmoid = nn.Sigmoid()
        self.fc = nn.Linear(decoder_dim, vocab_size)

        self.max_len = max_len
        self.vocab_size = vocab_size
        self.encoder_dim = encoder_dim
        self.decoder_dim = decoder_dim

    def init_hidden_state(self, encoder_out):
        """
        encoder_out: [B, num_pixels, encoder_dim]
        """
        mean_encoder_out = encoder_out.mean(dim=1)   # [B, encoder_dim]
        h = self.init_h(mean_encoder_out)            # [B, decoder_dim]
        c = self.init_c(mean_encoder_out)
        return h, c

    def forward(self, encoder_out, encoded_captions):
        """
        Forward pass during training (teacher forcing).
        encoder_out: [B, enc_image_size, enc_image_size, encoder_dim]
        encoded_captions: [B, T] (with BOS, EOS, PAD)
        Returns:
            predictions: [B, T-1, vocab_size]
            alphas: [B, T-1, num_pixels]
        We predict tokens for positions 1..T-1 (targets are 1..T, i.e. shifted).
        """
        B = encoder_out.size(0)
        enc_image_size = encoder_out.size(1)
        encoder_dim = encoder_out.size(-1)
        num_pixels = enc_image_size * enc_image_size

        # Flatten spatial dims
        encoder_out = encoder_out.view(B, -1, encoder_dim)  # [B, num_pixels, encoder_dim]

        # Prepare embeddings (we'll ignore the last token when feeding)
        embeddings = self.embedding(encoded_captions)       # [B, T, embed_dim]
        T = encoded_captions.size(1)

        h, c = self.init_hidden_state(encoder_out)

        preds = []
        alphas = []

        # Decode from t=0..T-2 (We use caption[t] as input, target is caption[t+1])
        for t in range(T - 1):
            batch_emb_t = embeddings[:, t, :]                # [B, embed_dim]
            context, alpha = self.attention(encoder_out, h)  # [B, enc_dim], [B, num_pixels]
            gate = self.sigmoid(self.f_beta(h))              # gating [B, enc_dim]
            context = gate * context

            # LSTMCell input is [embed, context]
            lstm_input = torch.cat([batch_emb_t, context], dim=1)  # [B, embed_dim+enc_dim]
            h, c = self.decode_step(lstm_input, (h, c))            # both [B, decoder_dim]

            output = self.fc(self.dropout(h))                      # [B, vocab_size]
            preds.append(output)
            alphas.append(alpha)

        preds = torch.stack(preds, dim=1)   # [B, T-1, vocab_size]
        alphas = torch.stack(alphas, dim=1) # [B, T-1, num_pixels]

        return preds, alphas


In [7]:
# === Cell 4: EncoderCNN, Attention, DecoderWithAttention, SAT wrapper ===

class EncoderCNN(nn.Module):
    """
    CNN encoder: ResNet-50, output spatial feature map.
    """
    def __init__(self, encoded_image_size=14):
        super().__init__()
        self.enc_image_size = encoded_image_size

        resnet = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2)
        modules = list(resnet.children())[:-2]  # remove avgpool & fc
        self.cnn = nn.Sequential(*modules)

        self.adaptive_pool = nn.AdaptiveAvgPool2d((encoded_image_size, encoded_image_size))
        self.fine_tune(False)

    def forward(self, images):
        """
        images: [B, 3, 224, 224]
        returns: [B, enc_image_size, enc_image_size, encoder_dim]
        """
        out = self.cnn(images)                      # [B, 2048, H, W]
        out = self.adaptive_pool(out)               # [B, 2048, enc_image_size, enc_image_size]
        out = out.permute(0, 2, 3, 1)               # [B, enc_image_size, enc_image_size, 2048]
        return out

    def fine_tune(self, fine_tune=True):
        """
        Unfreeze last conv blocks if fine_tune=True.
        """
        for p in self.cnn.parameters():
            p.requires_grad = False

        if fine_tune:
            for c in list(self.cnn.children())[-2:]:
                for p in c.parameters():
                    p.requires_grad = True

class Attention(nn.Module):
    """
    Additive attention over spatial image features.
    """
    def __init__(self, encoder_dim, decoder_dim, attention_dim):
        super().__init__()
        self.encoder_att = nn.Linear(encoder_dim, attention_dim)
        self.decoder_att = nn.Linear(decoder_dim, attention_dim)
        self.full_att    = nn.Linear(attention_dim, 1)
        self.relu        = nn.ReLU()
        self.softmax     = nn.Softmax(dim=1)

    def forward(self, encoder_out, decoder_hidden):
        """
        encoder_out: [B, num_pixels, encoder_dim]
        decoder_hidden: [B, decoder_dim]
        """
        att1 = self.encoder_att(encoder_out)             # [B, num_pixels, att_dim]
        att2 = self.decoder_att(decoder_hidden).unsqueeze(1)  # [B, 1, att_dim]
        att  = self.full_att(self.relu(att1 + att2)).squeeze(2)  # [B, num_pixels]
        alpha = self.softmax(att)                        # [B, num_pixels]
        attention_weighted_encoding = (encoder_out * alpha.unsqueeze(2)).sum(dim=1)  # [B, enc_dim]
        return attention_weighted_encoding, alpha

class DecoderWithAttention(nn.Module):
    """
    LSTM decoder with attention (Show, Attend and Tell).
    """
    def __init__(
        self,
        attention_dim,
        embed_dim,
        decoder_dim,
        vocab_size,
        encoder_dim=2048,
        dropout=0.5,
        max_len=30,
    ):
        super().__init__()

        self.attention = Attention(encoder_dim, decoder_dim, attention_dim)
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.dropout   = nn.Dropout(dropout)
        self.decode_step = nn.LSTMCell(embed_dim + encoder_dim, decoder_dim, bias=True)
        self.init_h = nn.Linear(encoder_dim, decoder_dim)
        self.init_c = nn.Linear(encoder_dim, decoder_dim)
        self.f_beta = nn.Linear(decoder_dim, encoder_dim)
        self.sigmoid = nn.Sigmoid()
        self.fc = nn.Linear(decoder_dim, vocab_size)

        self.max_len    = max_len
        self.vocab_size = vocab_size
        self.encoder_dim = encoder_dim
        self.decoder_dim = decoder_dim

    def init_hidden_state(self, encoder_out):
        """
        encoder_out: [B, num_pixels, encoder_dim]
        """
        mean_encoder_out = encoder_out.mean(dim=1)  # [B, encoder_dim]
        h = self.init_h(mean_encoder_out)           # [B, decoder_dim]
        c = self.init_c(mean_encoder_out)
        return h, c

    def forward(self, encoder_out, encoded_captions):
        """
        encoder_out: [B, H, W, encoder_dim]
        encoded_captions: [B, T]  (with <bos>, ..., <eos>/<pad>)
        Returns:
            preds:  [B, T-1, vocab_size]
            alphas: [B, T-1, num_pixels]
        """
        B = encoder_out.size(0)
        enc_image_size = encoder_out.size(1)
        encoder_dim = encoder_out.size(-1)
        num_pixels = enc_image_size * enc_image_size

        # Flatten spatial dims
        encoder_out = encoder_out.view(B, -1, encoder_dim)  # [B, num_pixels, encoder_dim]

        # Embeddings
        embeddings = self.embedding(encoded_captions)       # [B, T, embed_dim]
        T = encoded_captions.size(1)

        h, c = self.init_hidden_state(encoder_out)

        preds = []
        alphas = []

        # Teacher forcing: input caption[t], target caption[t+1]
        for t in range(T - 1):
            batch_emb_t = embeddings[:, t, :]                # [B, embed_dim]
            context, alpha = self.attention(encoder_out, h)  # [B, enc_dim], [B, num_pixels]
            gate = self.sigmoid(self.f_beta(h))              # [B, enc_dim]
            context = gate * context

            lstm_input = torch.cat([batch_emb_t, context], dim=1)  # [B, embed_dim + enc_dim]
            h, c = self.decode_step(lstm_input, (h, c))            # [B, decoder_dim]

            output = self.fc(self.dropout(h))                      # [B, vocab_size]
            preds.append(output)
            alphas.append(alpha)

        preds = torch.stack(preds, dim=1)    # [B, T-1, vocab_size]
        alphas = torch.stack(alphas, dim=1)  # [B, T-1, num_pixels]

        return preds, alphas

class ShowAttendTell(nn.Module):
    def __init__(
        self,
        vocab_size,
        attention_dim=512,
        embed_dim=512,
        decoder_dim=512,
        encoder_dim=2048,
        dropout=0.5,
        max_len=30,
    ):
        super().__init__()
        self.encoder = EncoderCNN()
        self.decoder = DecoderWithAttention(
            attention_dim=attention_dim,
            embed_dim=embed_dim,
            decoder_dim=decoder_dim,
            vocab_size=vocab_size,
            encoder_dim=encoder_dim,
            dropout=dropout,
            max_len=max_len,
        )

    def forward(self, images, captions):
        encoder_out = self.encoder(images)
        preds, alphas = self.decoder(encoder_out, captions)
        return preds, alphas

model_sat = ShowAttendTell(
    vocab_size=len(vocab),
    attention_dim=512,
    embed_dim=512,
    decoder_dim=512,
    encoder_dim=2048,
    dropout=0.5,
    max_len=MAX_LEN,
).to(device)

criterion_sat = nn.CrossEntropyLoss(ignore_index=pad_id)
optimizer_sat = torch.optim.AdamW(
    filter(lambda p: p.requires_grad, model_sat.parameters()),
    lr=3e-4,
    weight_decay=1e-4,
)

print("Trainable params:",
      sum(p.numel() for p in model_sat.parameters() if p.requires_grad))


Downloading: "https://download.pytorch.org/models/resnet50-11ad3fa6.pth" to /root/.cache/torch/hub/checkpoints/resnet50-11ad3fa6.pth


100%|██████████| 97.8M/97.8M [00:00<00:00, 232MB/s] 


Trainable params: 22446734


In [8]:
# === Cell 5: Training loop for SAT (CE only) ===

def train_one_epoch_sat(model, loader, optimizer, criterion, device):
    model.train()
    total_loss = 0.0

    for batch in tqdm(loader, desc="SAT train (1 epoch)"):
        images = batch["image"].to(device)
        captions = batch["caption"].to(device)

        preds, alphas = model(images, captions)      # preds: [B, T-1, V]
        B, Tm1, V = preds.size()
        targets = captions[:, 1:]                    # [B, T-1]

        loss = criterion(
            preds.reshape(B * Tm1, V),
            targets.reshape(B * Tm1),
        )

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()

        total_loss += loss.item()

    return total_loss / len(loader)

EPOCHS_SAT = 5  # start small; you can increase later

for e in range(EPOCHS_SAT):
    avg_loss = train_one_epoch_sat(model_sat, train_loader, optimizer_sat, criterion_sat, device)
    print(f"[SAT] Epoch {e+1}/{EPOCHS_SAT} - loss: {avg_loss:.4f}")
    torch.save(model_sat.state_dict(), os.path.join(WORK_DIR, f"sat_epoch{e+1}.pt"))
    print("Saved checkpoint:", f"sat_epoch{e+1}.pt")


SAT train (1 epoch):   0%|          | 0/313 [00:00<?, ?it/s]

[SAT] Epoch 1/5 - loss: 4.8408
Saved checkpoint: sat_epoch1.pt


SAT train (1 epoch):   0%|          | 0/313 [00:00<?, ?it/s]

[SAT] Epoch 2/5 - loss: 3.8074
Saved checkpoint: sat_epoch2.pt


SAT train (1 epoch):   0%|          | 0/313 [00:00<?, ?it/s]

[SAT] Epoch 3/5 - loss: 3.4414
Saved checkpoint: sat_epoch3.pt


SAT train (1 epoch):   0%|          | 0/313 [00:00<?, ?it/s]

[SAT] Epoch 4/5 - loss: 3.1788
Saved checkpoint: sat_epoch4.pt


SAT train (1 epoch):   0%|          | 0/313 [00:00<?, ?it/s]

[SAT] Epoch 5/5 - loss: 2.9697
Saved checkpoint: sat_epoch5.pt


In [17]:
# === Cell 6: Greedy decoding for qualitative check ===

def ids_to_words(ids, vocab, pad_id, eos_id, drop_bos=True):
    words = []
    for idx in ids:
        if idx == pad_id:
            break
        tok = vocab.itos[idx]
        if drop_bos and tok == "<bos>":
            continue
        if tok == "<eos>":
            break
        words.append(tok)
    return words

@torch.no_grad()
def sat_greedy_decode(model, image, max_len=30):
    """
    Greedy decoding for a single image tensor [3, 224, 224].
    Returns list of token ids (includes <bos>, ...).
    """
    model.eval()
    image = image.unsqueeze(0).to(device)    # [1, 3, 224, 224]

    encoder_out = model.encoder(image)       # [1, H, W, enc_dim]
    B, H, W, enc_dim = encoder_out.shape
    num_pixels = H * W
    encoder_out = encoder_out.view(1, num_pixels, enc_dim)   # [1, num_pixels, enc_dim]

    h, c = model.decoder.init_hidden_state(encoder_out)

    seq = [bos_id]
    prev_word = torch.tensor([bos_id], device=device, dtype=torch.long)
    prev_emb = model.decoder.embedding(prev_word)            # [1, embed_dim]

    for _ in range(max_len - 1):
        context, alpha = model.decoder.attention(encoder_out, h)
        gate = model.decoder.sigmoid(model.decoder.f_beta(h))
        context = gate * context

        lstm_input = torch.cat([prev_emb, context], dim=1)
        h, c = model.decoder.decode_step(lstm_input, (h, c))

        output = model.decoder.fc(model.decoder.dropout(h))  # [1, vocab_size]
        _, next_word = output.max(dim=1)
        next_id = next_word.item()
        seq.append(next_id)

        if next_id == eos_id:
            break

        prev_emb = model.decoder.embedding(next_word)

    return seq

# --- Test on a small batch from val_loader ---
batch = next(iter(val_loader))
images = batch["image"]
captions = batch["caption"]
file_names = batch["file_name"]

for i in range(10):
    img = images[i]
    gt_ids = captions[i].tolist()
    pred_ids = sat_greedy_decode(model_sat, img, max_len=MAX_LEN)

    gt_words   = ids_to_words(gt_ids, vocab, pad_id, eos_id)
    pred_words = ids_to_words(pred_ids, vocab, pad_id, eos_id)

    print("\nFile:", file_names[i])
    print("PRED:", " ".join(pred_words) if pred_words else "[EMPTY]")
    print("GT  :", " ".join(gt_words))



File: 000000179765.jpg
PRED: a motorcycle parked in a parking lot
GT  : a black honda motorcycle parked in front of a garage

File: 000000179765.jpg
PRED: a motorcycle parked in a parking lot
GT  : a honda motorcycle parked in a grass driveway

File: 000000190236.jpg
PRED: a kitchen with a stove and a computer
GT  : an office cubicle with four different types of computers

File: 000000331352.jpg
PRED: a bathroom with a toilet and a sink
GT  : a small closed toilet in a cramped space

File: 000000517069.jpg
PRED: a man sitting on a street with a city street
GT  : two women waiting at a bench next to a street

File: 000000179765.jpg
PRED: a motorcycle parked in a parking lot
GT  : a black honda motorcycle with a dark burgundy seat

File: 000000331352.jpg
PRED: a bathroom with a toilet and a sink
GT  : a tan toilet and sink combination in a small room

File: 000000190236.jpg
PRED: a kitchen with a stove and a computer
GT  : the home office space seems to be very cluttered

File: 00000018

In [12]:
# === Cell 7: Beam search decoding for Show, Attend and Tell ===

import math

@torch.no_grad()
def sat_beam_search_single(
    model,
    image,          # [3, 224, 224]
    bos_id,
    eos_id,
    pad_id,
    beam_size=3,
    max_len=30,
):
    """
    Beam search for a single image.
    Returns: list of token ids (including <bos>, ..., <eos>).
    """
    model.eval()

    # Encode image
    img = image.unsqueeze(0).to(device)           # [1, 3, H, W]
    encoder_out = model.encoder(img)              # [1, H, W, enc_dim]
    B, H, W, enc_dim = encoder_out.shape
    num_pixels = H * W
    encoder_out = encoder_out.view(1, num_pixels, enc_dim)  # [1, num_pixels, enc_dim]

    # Init hidden state
    h, c = model.decoder.init_hidden_state(encoder_out)     # [1, dec_dim]

    # Beam: list of (log_prob, seq, h, c)
    # seq is a list of token ids
    start_seq = [bos_id]
    beam = [(0.0, start_seq, h, c)]   # log_prob = 0

    completed = []

    for _ in range(max_len - 1):
        new_beam = []

        for log_p, seq, h_prev, c_prev in beam:
            # If already ended, keep in completed
            if seq[-1] == eos_id:
                completed.append((log_p, seq))
                continue

            prev_word_id = seq[-1]
            prev_word = torch.tensor([prev_word_id], device=device, dtype=torch.long)
            prev_emb = model.decoder.embedding(prev_word)             # [1, embed_dim]

            # Attention
            context, alpha = model.decoder.attention(encoder_out, h_prev)
            gate = model.decoder.sigmoid(model.decoder.f_beta(h_prev))
            context = gate * context

            lstm_input = torch.cat([prev_emb, context], dim=1)        # [1, emb+enc]
            h_new, c_new = model.decoder.decode_step(lstm_input, (h_prev, c_prev))

            output = model.decoder.fc(model.decoder.dropout(h_new))   # [1, vocab_size]
            log_probs = torch.log_softmax(output, dim=-1).squeeze(0)  # [vocab_size]

            # Top-k expansions
            topk_logp, topk_ids = torch.topk(log_probs, beam_size)

            for lp, idx in zip(topk_logp.tolist(), topk_ids.tolist()):
                new_seq = seq + [idx]
                new_log_p = log_p + lp
                new_beam.append((new_log_p, new_seq, h_new, c_new))

        if not new_beam:
            break

        # Keep top beam_size
        new_beam.sort(key=lambda x: x[0], reverse=True)
        beam = new_beam[:beam_size]

        # If we already have enough completed sequences and all beams ended, break
        all_ended = all(seq[-1] == eos_id for _, seq, _, _ in beam)
        if all_ended:
            break

    # If we found completed sequences, pick the best
    if completed:
        completed.sort(key=lambda x: x[0], reverse=True)
        best_logp, best_seq = completed[0]
    else:
        # Otherwise, use the best in the current beam
        beam.sort(key=lambda x: x[0], reverse=True)
        best_logp, best_seq, _, _ = beam[0]

    return best_seq

@torch.no_grad()
def sat_beam_search_batch(
    model,
    images,     # [B, 3, H, W]
    bos_id,
    eos_id,
    pad_id,
    beam_size=3,
    max_len=30,
):
    """
    Run beam search for each image in a batch independently.
    Returns: list of lists of ids.
    """
    preds = []
    B = images.size(0)
    for i in range(B):
        seq = sat_beam_search_single(
            model,
            images[i],
            bos_id,
            eos_id,
            pad_id,
            beam_size=beam_size,
            max_len=max_len,
        )
        preds.append(seq)
    return preds


In [18]:
# Quick sanity: compare greedy vs beam for a few val images
batch = next(iter(val_loader))
images = batch["image"]
captions = batch["caption"]
file_names = batch["file_name"]

for i in range(10):
    img = images[i]
    gt_ids = captions[i].tolist()

    greedy_ids = sat_greedy_decode(model_sat, img, max_len=MAX_LEN)
    beam_ids   = sat_beam_search_single(model_sat, img, bos_id, eos_id, pad_id, beam_size=3, max_len=MAX_LEN)

    greedy_words = ids_to_words(greedy_ids, vocab, pad_id, eos_id)
    beam_words   = ids_to_words(beam_ids, vocab, pad_id, eos_id)
    gt_words     = ids_to_words(gt_ids, vocab, pad_id, eos_id)

    print("\nFile:", file_names[i])
    print("GREEDY:", " ".join(greedy_words))
    print("BEAM  :", " ".join(beam_words))
    print("GT    :", " ".join(gt_words))



File: 000000179765.jpg
GREEDY: a motorcycle parked in a parking lot
BEAM  : a motorcycle parked next to a motorcycle
GT    : a black honda motorcycle parked in front of a garage

File: 000000179765.jpg
GREEDY: a motorcycle parked in a parking lot
BEAM  : a motorcycle parked next to a motorcycle
GT    : a honda motorcycle parked in a grass driveway

File: 000000190236.jpg
GREEDY: a kitchen with a stove and a computer
BEAM  : a kitchen with a stove and a computer
GT    : an office cubicle with four different types of computers

File: 000000331352.jpg
GREEDY: a bathroom with a toilet and a sink
BEAM  : a bathroom with a toilet and a sink
GT    : a small closed toilet in a cramped space

File: 000000517069.jpg
GREEDY: a man sitting on a street with a city street
BEAM  : a group of people sitting on the street
GT    : two women waiting at a bench next to a street

File: 000000179765.jpg
GREEDY: a motorcycle parked in a parking lot
BEAM  : a motorcycle parked next to a motorcycle
GT    : a 

In [14]:
# === Cell 8: BLEU-4 evaluation (beam search) ===

!pip install nltk -q
import nltk
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
nltk.download('punkt', quiet=True)

def evaluate_bleu_sat_beam(
    model,
    loader,
    vocab,
    pad_id,
    eos_id,
    bos_id,
    device,
    beam_size=3,
    max_batches=50,   # for speed; set to len(loader) for full
):
    model.eval()
    smoothie = SmoothingFunction().method4
    scores = []

    for b_idx, batch in enumerate(tqdm(loader, desc="BLEU eval (SAT beam)")):
        if b_idx >= max_batches:
            break

        images = batch["image"]
        captions = batch["caption"]

        pred_ids_batch = sat_beam_search_batch(
            model,
            images,
            bos_id,
            eos_id,
            pad_id,
            beam_size=beam_size,
            max_len=MAX_LEN,
        )

        B = captions.size(0)
        for i in range(B):
            gt_ids   = captions[i].tolist()
            pred_ids = pred_ids_batch[i]

            gt_words   = ids_to_words(gt_ids, vocab, pad_id, eos_id)
            pred_words = ids_to_words(pred_ids, vocab, pad_id, eos_id)

            if len(pred_words) == 0 or len(gt_words) == 0:
                continue

            score = sentence_bleu(
                [gt_words],
                pred_words,
                smoothing_function=smoothie,
                weights=(0.25, 0.25, 0.25, 0.25),
            )
            scores.append(score)

    if not scores:
        return 0.0
    return float(np.mean(scores))

bleu_beam = evaluate_bleu_sat_beam(
    model_sat,
    val_loader,
    vocab,
    pad_id,
    eos_id,
    bos_id,
    device,
    beam_size=3,
    max_batches=300,   # ~30 batches for a quick estimate
)

print("Approx BLEU-4 (SAT + beam=3):", bleu_beam)


[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m25.0.1[0m[39;49m -> [0m[32;49m25.3[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


BLEU eval (SAT beam):   0%|          | 0/32 [00:00<?, ?it/s]

Approx BLEU-4 (SAT + beam=3): 0.07269644203025069


In [15]:
# === Cell 9: CLIPScore-style evaluation for SAT (beam captions) ===

!pip install open_clip_torch -q
import open_clip
from PIL import Image

# Load CLIP model + preprocess once
clip_model, _, clip_preprocess = open_clip.create_model_and_transforms(
    "ViT-B-32",
    pretrained="openai",
)
clip_model = clip_model.to(device).eval()
clip_tokenizer = open_clip.get_tokenizer("ViT-B-32")

@torch.no_grad()
def compute_clipscore_sat_beam(
    model,
    loader,
    vocab,
    pad_id,
    eos_id,
    bos_id,
    device,
    beam_size=3,
    max_batches=10,  # small for speed; increase if you want
):
    sims = []

    for b_idx, batch in enumerate(tqdm(loader, desc="CLIPScore eval (SAT beam)")):
        if b_idx >= max_batches:
            break

        images = batch["image"]   # [B, 3, H, W]
        file_names = batch["file_name"]

        # Decode captions with beam search
        pred_ids_batch = sat_beam_search_batch(
            model,
            images,
            bos_id,
            eos_id,
            pad_id,
            beam_size=beam_size,
            max_len=MAX_LEN,
        )
        pred_texts = []
        for seq in pred_ids_batch:
            pred_texts.append(" ".join(ids_to_words(seq, vocab, pad_id, eos_id)))

        # For each image–caption pair, compute CLIP similarity
        for img_tensor, cap in zip(images, pred_texts):
            # Convert tensor -> PIL, apply CLIP preprocess
            pil = transforms.ToPILImage()(img_tensor)
            image_input = clip_preprocess(pil).unsqueeze(0).to(device)
            text_tokens = clip_tokenizer([cap]).to(device)

            img_feat = clip_model.encode_image(image_input)
            txt_feat = clip_model.encode_text(text_tokens)

            img_feat = img_feat / img_feat.norm(dim=-1, keepdim=True)
            txt_feat = txt_feat / txt_feat.norm(dim=-1, keepdim=True)

            sim = (img_feat * txt_feat).sum(dim=-1).item()
            sims.append(sim)

    sims = np.array(sims)
    stats = {
        "mean": float(sims.mean()),
        "median": float(np.median(sims)),
        "std": float(sims.std()),
        "low_clip_rate": float((sims < 0.2).mean()),
    }
    return stats

clip_stats_sat = compute_clipscore_sat_beam(
    model_sat,
    val_loader,
    vocab,
    pad_id,
    eos_id,
    bos_id,
    device,
    beam_size=3,
    max_batches=10,  # just a sample for speed
)
print("CLIPScore stats (SAT + beam=3):", clip_stats_sat)


[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m25.0.1[0m[39;49m -> [0m[32;49m25.3[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m




open_clip_model.safetensors:   0%|          | 0.00/605M [00:00<?, ?B/s]



CLIPScore eval (SAT beam):   0%|          | 0/32 [00:00<?, ?it/s]

CLIPScore stats (SAT + beam=3): {'mean': 0.22763739149086176, 'median': 0.22704041749238968, 'std': 0.023720598154584698, 'low_clip_rate': 0.096875}


In [1]:
# === Cell 10: CHAIR hallucination metrics for SAT (beam captions) ===

# 1) Load instance annotations (COCO 2017 val)
INST_VAL_JSON = os.path.join(ANN_DIR, "instances_val2017.json")
with open(INST_VAL_JSON, "r") as f:
    inst_val = json.load(f)

# Map image_id -> file_name and category_id -> name
imgid_to_file = {img["id"]: img["file_name"] for img in inst_val["images"]}
catid_to_name = {c["id"]: c["name"].lower() for c in inst_val["categories"]}

# Build mapping from file_name -> set of ground-truth object classes
from collections import defaultdict
file_to_objects = defaultdict(set)
for ann in inst_val["annotations"]:
    img_id = ann["image_id"]
    cat_id = ann["category_id"]
    file_name = imgid_to_file[img_id]
    cat_name = catid_to_name[cat_id]
    file_to_objects[file_name].add(cat_name)

# Category vocab for matching words in captions
def simple_tokenize(text: str):
    return re.findall(r"[a-z]+", text.lower())

category_vocab = []
for cat_name in sorted(set(catid_to_name.values())):
    tokens = simple_tokenize(cat_name)
    category_vocab.append((cat_name, tokens))

def find_mentioned_objects(caption_text: str, category_vocab):
    tokens = simple_tokenize(caption_text)
    mentioned = set()
    for cat_name, cat_tokens in category_vocab:
        L = len(cat_tokens)
        if L == 1:
            if cat_tokens[0] in tokens:
                mentioned.add(cat_name)
        else:
            # multi-word categories like "traffic light"
            for i in range(len(tokens) - L + 1):
                if tokens[i:i+L] == cat_tokens:
                    mentioned.add(cat_name)
                    break
    return mentioned

@torch.no_grad()
def compute_chair_sat_beam(
    model,
    loader,
    vocab,
    pad_id,
    eos_id,
    bos_id,
    device,
    file_to_objects,
    category_vocab,
    beam_size=3,
    max_batches=30,
):
    model.eval()
    all_caps = 0
    caps_with_hallucination = 0
    total_mentions = 0
    hallucinated_mentions = 0

    for b_idx, batch in enumerate(tqdm(loader, desc="CHAIR eval (SAT beam)")):
        if b_idx >= max_batches:
            break

        images = batch["image"]
        file_names = batch["file_name"]

        pred_ids_batch = sat_beam_search_batch(
            model,
            images,
            bos_id,
            eos_id,
            pad_id,
            beam_size=beam_size,
            max_len=MAX_LEN,
        )
        pred_texts = [
            " ".join(ids_to_words(seq, vocab, pad_id, eos_id))
            for seq in pred_ids_batch
        ]

        for fn, cap in zip(file_names, pred_texts):
            all_caps += 1
            gt_objects = file_to_objects.get(fn, set())
            mentioned = find_mentioned_objects(cap, category_vocab)
            if not mentioned:
                continue

            total_mentions += len(mentioned)
            hallucinated = mentioned - gt_objects
            if hallucinated:
                caps_with_hallucination += 1
                hallucinated_mentions += len(hallucinated)

    chair_s = caps_with_hallucination / all_caps if all_caps else 0.0
    chair_i = hallucinated_mentions / total_mentions if total_mentions else 0.0
    return {
        "CHAIRs": chair_s,
        "CHAIRi": chair_i,
        "total_captions": all_caps,
        "total_mentions": total_mentions,
        "hallucinated_mentions": hallucinated_mentions,
    }

chair_stats_sat = compute_chair_sat_beam(
    model_sat,
    val_loader,
    vocab,
    pad_id,
    eos_id,
    bos_id,
    device,
    file_to_objects,
    category_vocab,
    beam_size=3,
    max_batches=32,
)
print("CHAIR stats (SAT + beam=3):", chair_stats_sat)


NameError: name 'os' is not defined

In [19]:
# === NEW CELL: BLEU-4 evaluation (SAT + greedy) ===

def evaluate_bleu_sat_greedy(
    model,
    loader,
    vocab,
    pad_id,
    eos_id,
    device,
    max_batches=50,   # for speed; set to len(loader) for full
):
    model.eval()
    smoothie = SmoothingFunction().method4
    scores = []

    for b_idx, batch in enumerate(tqdm(loader, desc="BLEU eval (SAT greedy)")):
        if b_idx >= max_batches:
            break

        images = batch["image"]
        captions = batch["caption"]

        B = images.size(0)
        for i in range(B):
            img = images[i]
            gt_ids = captions[i].tolist()

            # Greedy decoding for this image
            pred_ids = sat_greedy_decode(model, img, max_len=MAX_LEN)

            gt_words   = ids_to_words(gt_ids, vocab, pad_id, eos_id)
            pred_words = ids_to_words(pred_ids, vocab, pad_id, eos_id)

            if len(pred_words) == 0 or len(gt_words) == 0:
                continue

            score = sentence_bleu(
                [gt_words],
                pred_words,
                smoothing_function=smoothie,
                weights=(0.25, 0.25, 0.25, 0.25),
            )
            scores.append(score)

    if not scores:
        return 0.0
    return float(np.mean(scores))

bleu_greedy = evaluate_bleu_sat_greedy(
    model_sat,
    val_loader,
    vocab,
    pad_id,
    eos_id,
    device,
    max_batches=300,  # adjust as you like
)

print("Approx BLEU-4 (SAT + greedy):", bleu_greedy)


BLEU eval (SAT greedy):   0%|          | 0/32 [00:00<?, ?it/s]

Approx BLEU-4 (SAT + greedy): 0.07183989813031917


In [22]:
# === NEW CELL: CLIPScore evaluation (SAT + greedy) ===

@torch.no_grad()
def compute_clipscore_sat_greedy(
    model,
    loader,
    vocab,
    pad_id,
    eos_id,
    device,
    max_batches=10,  # small for speed; increase if needed
):
    model.eval()
    sims = []

    for b_idx, batch in enumerate(tqdm(loader, desc="CLIPScore eval (SAT greedy)")):
        if b_idx >= max_batches:
            break

        images = batch["image"]   # [B, 3, H, W]

        B = images.size(0)
        for i in range(B):
            img_tensor = images[i]

            # Greedy caption
            pred_ids = sat_greedy_decode(model, img_tensor, max_len=MAX_LEN)
            caption  = " ".join(ids_to_words(pred_ids, vocab, pad_id, eos_id))

            # Convert tensor -> PIL, apply CLIP preprocess
            pil = transforms.ToPILImage()(img_tensor)
            image_input = clip_preprocess(pil).unsqueeze(0).to(device)
            text_tokens = clip_tokenizer([caption]).to(device)

            img_feat = clip_model.encode_image(image_input)
            txt_feat = clip_model.encode_text(text_tokens)

            img_feat = img_feat / img_feat.norm(dim=-1, keepdim=True)
            txt_feat = txt_feat / txt_feat.norm(dim=-1, keepdim=True)

            sim = (img_feat * txt_feat).sum(dim=-1).item()
            sims.append(sim)

    sims = np.array(sims)
    stats = {
        "mean": float(sims.mean()),
        "median": float(np.median(sims)),
        "std": float(sims.std()),
        "low_clip_rate": float((sims < 0.2).mean()),
    }
    return stats

clip_stats_sat_greedy = compute_clipscore_sat_greedy(
    model_sat,
    val_loader,
    vocab,
    pad_id,
    eos_id,
    device,
    max_batches=32,
)
print("CLIPScore stats (SAT + greedy):", clip_stats_sat_greedy)


CLIPScore eval (SAT greedy):   0%|          | 0/32 [00:00<?, ?it/s]

CLIPScore stats (SAT + greedy): {'mean': 0.22726542752981185, 'median': 0.22602428495883942, 'std': 0.02360034516301436, 'low_clip_rate': 0.122}


In [23]:
# === NEW CELL: CHAIR hallucination metrics (SAT + greedy) ===

@torch.no_grad()
def compute_chair_sat_greedy(
    model,
    loader,
    vocab,
    pad_id,
    eos_id,
    device,
    file_to_objects,
    category_vocab,
    max_batches=30,
):
    model.eval()
    all_caps = 0
    caps_with_hallucination = 0
    total_mentions = 0
    hallucinated_mentions = 0

    for b_idx, batch in enumerate(tqdm(loader, desc="CHAIR eval (SAT greedy)")):
        if b_idx >= max_batches:
            break

        images = batch["image"]
        file_names = batch["file_name"]

        B = images.size(0)
        for i in range(B):
            img = images[i]
            fn  = file_names[i]

            pred_ids = sat_greedy_decode(model, img, max_len=MAX_LEN)
            cap_text = " ".join(ids_to_words(pred_ids, vocab, pad_id, eos_id))

            all_caps += 1
            gt_objects = file_to_objects.get(fn, set())
            mentioned = find_mentioned_objects(cap_text, category_vocab)
            if not mentioned:
                continue

            total_mentions += len(mentioned)
            hallucinated = mentioned - gt_objects
            if hallucinated:
                caps_with_hallucination += 1
                hallucinated_mentions += len(hallucinated)

    chair_s = caps_with_hallucination / all_caps if all_caps else 0.0
    chair_i = hallucinated_mentions / total_mentions if total_mentions else 0.0
    return {
        "CHAIRs": chair_s,
        "CHAIRi": chair_i,
        "total_captions": all_caps,
        "total_mentions": total_mentions,
        "hallucinated_mentions": hallucinated_mentions,
    }

chair_stats_sat_greedy = compute_chair_sat_greedy(
    model_sat,
    val_loader,
    vocab,
    pad_id,
    eos_id,
    device,
    file_to_objects,
    category_vocab,
    max_batches=32,
)
print("CHAIR stats (SAT + greedy):", chair_stats_sat_greedy)


CHAIR eval (SAT greedy):   0%|          | 0/32 [00:00<?, ?it/s]

CHAIR stats (SAT + greedy): {'CHAIRs': 0.257, 'CHAIRi': 0.3397727272727273, 'total_captions': 1000, 'total_mentions': 880, 'hallucinated_mentions': 299}


In [9]:
# === Cell: Visualize SAT predictions (greedy + beam) with images ===

import matplotlib.pyplot as plt
import numpy as np

# If you normalized images with ImageNet stats, define unnormalize:
IMAGENET_MEAN = np.array([0.485, 0.456, 0.406])
IMAGENET_STD  = np.array([0.229, 0.224, 0.225])

def unnormalize(img_tensor):
    """
    img_tensor: [3, H, W] tensor in normalized space.
    Returns HxWx3 numpy array in [0,1].
    """
    img = img_tensor.cpu().numpy()
    img = (img * IMAGENET_STD[:, None, None]) + IMAGENET_MEAN[:, None, None]
    img = np.clip(img, 0.0, 1.0)
    img = np.transpose(img, (1, 2, 0))  # CHW -> HWC
    return img

def decode_ids_to_text(ids, vocab, pad_id, eos_id):
    return " ".join(ids_to_words(ids, vocab, pad_id, eos_id))

@torch.no_grad()
def show_sat_examples(
    model,
    loader,
    vocab,
    pad_id,
    eos_id,
    bos_id,
    device,
    num_examples=5,
    use_beam=True,
    beam_size=3,
    max_len=30,
):
    """
    Show a few validation images with:
      - ground truth caption
      - SAT greedy prediction
      - SAT beam prediction (optional)
    """
    model.eval()
    
    batch = next(iter(loader))
    images = batch["image"]
    captions = batch["caption"]
    file_names = batch.get("file_name", None)

    n = min(num_examples, images.size(0))

    for i in range(n):
        img_tensor = images[i]
        gt_ids = captions[i].tolist()

        # Greedy prediction
        pred_ids_greedy = sat_greedy_decode(
            model,
            img_tensor,
            max_len=max_len
        )

        # Beam prediction (optional)
        if use_beam:
            pred_ids_beam = sat_beam_search_single(
                model,
                img_tensor,
                bos_id,
                eos_id,
                pad_id,
                beam_size=beam_size,
                max_len=max_len,
            )
        else:
            pred_ids_beam = None

        gt_text       = decode_ids_to_text(gt_ids, vocab, pad_id, eos_id)
        pred_text_gr  = decode_ids_to_text(pred_ids_greedy, vocab, pad_id, eos_id)
        pred_text_beam = decode_ids_to_text(pred_ids_beam, vocab, pad_id, eos_id) if pred_ids_beam else "[disabled]"

        # Plot image
        plt.figure(figsize=(6, 6))
        plt.imshow(unnormalize(img_tensor))
        plt.axis("off")

        title = f"Example {i+1}"
        if file_names is not None:
            title += f"  ({file_names[i]})"
        plt.title(title, fontsize=12)

        # Print captions under the image
        print("=" * 80)
        print(title)
        print("GT      :", gt_text)
        print("Greedy  :", pred_text_gr if pred_text_gr else "[EMPTY]")
        if use_beam:
            print(f"Beam({beam_size}):", pred_text_beam if pred_text_beam else "[EMPTY]")

        plt.show()


# --- Call it (SAT model) ---
show_sat_examples(
    model=model_sat,
    loader=val_loader,
    vocab=vocab,
    pad_id=pad_id,
    eos_id=eos_id,
    bos_id=bos_id,
    device=device,
    num_examples=10,
    use_beam=True,   # set False if you only want greedy
    beam_size=3,
    max_len=MAX_LEN,
)


NameError: name 'sat_greedy_decode' is not defined