In [2]:
!pip -q install torch torchvision torchmetrics==1.3.2 pillow tqdm nltk einops numpy==1.26.4

In [3]:
import os, json, math, random
from dataclasses import dataclass, asdict
from collections import Counter
from typing import List, Dict, Tuple
from contextlib import nullcontext
from tqdm import tqdm

import numpy as np
from PIL import Image
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from torch.nn.utils.rnn import pack_padded_sequence, pad_sequence
from torchmetrics.text.bleu import BLEUScore
import nltk
nltk.download('punkt', quiet=True)

True

#  Device & AMP setup  


In [5]:
from contextlib import nullcontext
import torch

def pick_device():
    if torch.backends.mps.is_available():
        return torch.device("mps")
    if torch.cuda.is_available():
        return torch.device("cuda")
    return torch.device("cpu")

DEVICE = pick_device()
print("Using device:", DEVICE)

# new-style AMP (CUDA only)
amp_ctx = (lambda: torch.amp.autocast(device_type="cuda")) if DEVICE.type == "cuda" else (lambda: nullcontext())
scaler = torch.amp.GradScaler(device_type="cuda") if DEVICE.type == "cuda" else None
use_scaler = scaler is not None



Using device: mps


In [6]:
@dataclass
class Config:
    DATA_ROOT: str = "data"
    IMAGES_DIR: str = "data/Images"
    CAPTIONS_FILE: str = "data/captions.txt"  
    OUTPUT_DIR: str = "artifacts"

    IMG_SIZE: int = 224
    BATCH_SIZE: int = 64
    NUM_WORKERS: int = 0  # If you see Mac issues, set to 0
    EPOCHS: int = 12
    LR: float = 3e-4

    EMBED_DIM: int = 256
    HIDDEN_DIM: int = 512
    NUM_LAYERS: int = 1
    DROPOUT: float = 0.1

    MIN_FREQ: int = 3
    MAX_LEN: int = 28
    FREEZE_CNN: bool = True

    SEED: int = 42

cfg = Config()
os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)
with open(os.path.join(cfg.OUTPUT_DIR, "config.json"), "w") as f:
    json.dump(asdict(cfg), f, indent=2)
cfg


Config(DATA_ROOT='data', IMAGES_DIR='data/Images', CAPTIONS_FILE='data/captions.txt', OUTPUT_DIR='artifacts', IMG_SIZE=224, BATCH_SIZE=64, NUM_WORKERS=0, EPOCHS=12, LR=0.0003, EMBED_DIM=256, HIDDEN_DIM=512, NUM_LAYERS=1, DROPOUT=0.1, MIN_FREQ=3, MAX_LEN=28, FREEZE_CNN=True, SEED=42)

# Data + Vocab

In [8]:
SPECIALS = {"<pad>":0, "<bos>":1, "<eos>":2, "<unk>":3}

def tokenize(text: str) -> List[str]:
    return nltk.word_tokenize(text.lower())

def load_captions(captions_file: str) -> Dict[str, List[str]]:
    """
    Supports:
      - Flickr8k.token.txt       -> 'image.jpg#i\\tcaption'
      - captions.csv/txt (Kaggle)-> 'image,caption' (with optional header)
    Returns: { "image.jpg": [cap1, cap2, ...], ... }
    """
    img2caps: Dict[str, List[str]] = {}
    with open(captions_file, "r", encoding="utf-8") as f:
        lines = [l.strip() for l in f if l.strip()]

    # Detect token format vs CSV and CSV header
    is_token = ("\t" in lines[0]) and ("#" in lines[0].split("\t")[0])
    has_header = (not is_token) and lines[0].lower().startswith("image,")
    start_idx = 1 if has_header else 0

    for line in lines[start_idx:]:
        if is_token:
            left, cap = line.split("\t", 1)
            img = left.split("#")[0].strip()
        else:
            # CSV-like: split only on the first comma
            parts = line.split(",", 1)
            if len(parts) != 2:
                continue
            img, cap = parts[0].strip(), parts[1].strip()

        # remove wrapping quotes if present
        if img.startswith('"') and img.endswith('"'): img = img[1:-1]
        if cap.startswith('"') and cap.endswith('"'): cap = cap[1:-1]

        if img:
            img2caps.setdefault(img, []).append(cap)
    return img2caps

def filter_existing(img2caps: Dict[str, List[str]], images_dir: str) -> Dict[str, List[str]]:
    """Keep only entries whose image file exists in images_dir."""
    out = {}
    missing = 0
    for img, caps in img2caps.items():
        path = os.path.join(images_dir, img)
        if os.path.exists(path):
            out[img] = caps
        else:
            missing += 1
    if missing:
        print(f"[warn] Skipped {missing} caption entries with missing image files.")
    return out

def build_vocab(img2caps: Dict[str, List[str]], min_freq: int):
    counter = Counter()
    for caps in img2caps.values():
        for c in caps:
            counter.update(tokenize(c))
    # specials first
    itos = [None]*len(SPECIALS)
    for tok, idx in SPECIALS.items():
        itos[idx] = tok
    # add tokens by frequency
    for tok, freq in counter.items():
        if freq >= min_freq:
            itos.append(tok)
    stoi = {tok:i for i, tok in enumerate(itos)}
    return stoi, itos

# ---- Run the pipeline ----
raw_caps = load_captions(cfg.CAPTIONS_FILE)
raw_caps = filter_existing(raw_caps, cfg.IMAGES_DIR)

# (Optional sanity checks)
print("Images dir exists:", os.path.isdir(cfg.IMAGES_DIR))
example_key = next(iter(raw_caps)) if raw_caps else None
print("Example image key:", example_key)
if example_key:
    print("Example file exists:", os.path.exists(os.path.join(cfg.IMAGES_DIR, example_key)))

stoi, itos = build_vocab(raw_caps, cfg.MIN_FREQ)

# save vocab
os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)
with open(os.path.join(cfg.OUTPUT_DIR, "vocab.json"), "w", encoding="utf-8") as f:
    json.dump({"stoi": stoi, "itos": itos}, f)

print(f"Vocab size: {len(stoi)}")


Images dir exists: True
Example image key: 1000268201_693b08cb0e.jpg
Example file exists: True
Vocab size: 4108


# Dataset & Loaders

In [10]:
class Flickr8kDataset(Dataset):
    def __init__(self, images_dir, img2caps, transform, stoi, max_len=28, split="train", seed=42):
        self.images_dir = images_dir
        self.transform = transform
        self.stoi = stoi
        self.max_len = max_len

        imgs = list(img2caps.keys())
        random.Random(seed).shuffle(imgs)
        n = len(imgs)
        self.imgs = imgs[:int(0.8*n)] if split=="train" else imgs[int(0.8*n):int(0.9*n)] if split=="val" else imgs[int(0.9*n):]
        self.img2caps = img2caps

    def __len__(self): return len(self.imgs)

    def __getitem__(self, idx):
        img_name = self.imgs[idx]
        image = Image.open(os.path.join(self.images_dir, img_name)).convert("RGB")
        cap = random.choice(self.img2caps[img_name])

        image = self.transform(image) if self.transform else image
        tokens = ["<bos>"] + tokenize(cap)[: self.max_len - 2] + ["<eos>"]
        ids = [self.stoi.get(t, SPECIALS["<unk>"]) for t in tokens]
        return image, torch.tensor(ids, dtype=torch.long), img_name

def collate_fn(batch):
    imgs, seqs, names = zip(*batch)
    imgs = torch.stack(imgs)
    lengths = torch.tensor([len(s) for s in seqs], dtype=torch.long)
    seqs_padded = pad_sequence(seqs, batch_first=True, padding_value=SPECIALS["<pad>"])
    return imgs, seqs_padded, lengths, names

train_tfms = transforms.Compose([
    transforms.Resize((cfg.IMG_SIZE, cfg.IMG_SIZE)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]),
])
eval_tfms = transforms.Compose([
    transforms.Resize((cfg.IMG_SIZE, cfg.IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]),
])

train_ds = Flickr8kDataset(cfg.IMAGES_DIR, raw_caps, train_tfms, stoi, cfg.MAX_LEN, "train", cfg.SEED)
val_ds   = Flickr8kDataset(cfg.IMAGES_DIR, raw_caps, eval_tfms,  stoi, cfg.MAX_LEN, "val",   cfg.SEED)
test_ds  = Flickr8kDataset(cfg.IMAGES_DIR, raw_caps, eval_tfms,  stoi, cfg.MAX_LEN, "test",  cfg.SEED)

train_loader = DataLoader(train_ds, batch_size=cfg.BATCH_SIZE, shuffle=True,  num_workers=cfg.NUM_WORKERS, collate_fn=collate_fn)
val_loader   = DataLoader(val_ds,   batch_size=cfg.BATCH_SIZE, shuffle=False, num_workers=cfg.NUM_WORKERS, collate_fn=collate_fn)
test_loader  = DataLoader(test_ds,  batch_size=cfg.BATCH_SIZE, shuffle=False, num_workers=cfg.NUM_WORKERS, collate_fn=collate_fn)

len(train_ds), len(val_ds), len(test_ds)


(6472, 809, 810)

In [11]:
class EncoderCNN(nn.Module):
    def __init__(self, embed_dim=256, freeze=True):
        super().__init__()
        base = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
        self.backbone = nn.Sequential(*list(base.children())[:-1])  # remove FC
        self.proj = nn.Linear(base.fc.in_features, embed_dim)
        if freeze:
            for p in self.backbone.parameters(): p.requires_grad = False

    def forward(self, x):
        feats = self.backbone(x).flatten(1)  # (B,C)
        return self.proj(feats)              # (B,E)

class DecoderRNN(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim, num_layers=1, dropout=0.1):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, embed_dim, padding_idx=SPECIALS["<pad>"])
        self.lstm  = nn.LSTM(embed_dim, hidden_dim, num_layers=num_layers, batch_first=True,
                             dropout=(dropout if num_layers>1 else 0.0))
        self.fc    = nn.Linear(hidden_dim, vocab_size)

    def forward(self, features, captions, lengths):
        x_tok = self.embed(captions)                    # (B,T,E)
        x_img = features.unsqueeze(1)                   # (B,1,E)
        x     = torch.cat([x_img, x_tok[:, :-1, :]], 1) # prepend image, shift caps
        packed = pack_padded_sequence(x, lengths, batch_first=True, enforce_sorted=False)
        out, _ = self.lstm(packed)
        logits = self.fc(out.data)
        return logits

    @torch.no_grad()
    def greedy(self, features, max_len=28):
        B = features.size(0)
        h = None
        token = torch.full((B,1), SPECIALS["<bos>"], dtype=torch.long, device=features.device)
        x = self.embed(token)                           # (B,1,E)
        x = torch.cat([features.unsqueeze(1), x], dim=1)[:, -1:, :]
        outs = []
        for _ in range(max_len):
            y, h = self.lstm(x, h)
            logits = self.fc(y.squeeze(1))
            nxt = torch.argmax(logits, -1, keepdim=True)  # (B,1)
            outs.append(nxt)
            x = self.embed(nxt)
        return torch.cat(outs, 1)

class CaptionNet(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim, num_layers, dropout, freeze_cnn=True):
        super().__init__()
        self.enc = EncoderCNN(embed_dim, freeze=freeze_cnn)
        self.dec = DecoderRNN(vocab_size, embed_dim, hidden_dim, num_layers, dropout)

    def forward(self, images, captions, lengths):
        feats = self.enc(images)
        return self.dec(feats, captions, lengths)

    @torch.no_grad()
    def generate(self, images, max_len=28):
        feats = self.enc(images)
        return self.dec.greedy(feats, max_len)


In [12]:
def ids_to_text(ids, itos):
    words = []
    for i in ids:
        w = itos[int(i)]
        if w == "<eos>": break
        if w in {"<bos>","<pad>"}: continue
        words.append(w)
    return " ".join(words)
    
def compute_bleu(preds: list[str], refs: list[list[str]]) -> float:
    """
    preds: ['a dog runs...', ...]
    refs:  [['a dog running...', 'a canine...', ...], [...], ...]
    BLEUScore will tokenize internally, so pass strings.
    """
    bleu = BLEUScore(n_gram=4, smooth=True)
    score = bleu(preds, refs)  # returns a 0-dim tensor
    return float(score.cpu().item())

def evaluate(model, loader, criterion, itos, device, max_len):
    model.eval()
    total_loss, n_tok = 0.0, 0
    preds, refs = [], []

    with torch.no_grad():
        for images, caps, lengths, names in loader:
            images, caps = images.to(device), caps.to(device)

            # compute loss
            logits = model(images, caps, lengths)
            targets = pack_padded_sequence(caps, lengths, batch_first=True, enforce_sorted=False)[0]
            loss = criterion(logits, targets)
            total_loss += loss.item() * targets.size(0)
            n_tok += targets.size(0)

            # decode predictions
            gen = model.generate(images, max_len=max_len)
            for i in range(images.size(0)):
                pred_txt = ids_to_text(gen[i], itos)      # string
                # use ALL refs for this image as strings (no tokenization here)
                all_refs = raw_caps[names[i]]             # list[str], e.g., 5 captions
                preds.append(pred_txt)
                refs.append(all_refs)

    bleu = compute_bleu(preds, refs) if preds else 0.0
    ppl  = math.exp(total_loss / max(1, n_tok))
    return total_loss/max(1,n_tok), bleu, ppl


In [21]:
torch.manual_seed(cfg.SEED); random.seed(cfg.SEED); np.random.seed(cfg.SEED)

vocab_size = len(stoi)
model = CaptionNet(
    vocab_size=vocab_size,
    embed_dim=cfg.EMBED_DIM,
    hidden_dim=cfg.HIDDEN_DIM,
    num_layers=cfg.NUM_LAYERS,
    dropout=cfg.DROPOUT,
    freeze_cnn=cfg.FREEZE_CNN,
).to(DEVICE)

criterion = nn.CrossEntropyLoss(ignore_index=SPECIALS["<pad>"]).to(DEVICE)
optimizer = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=cfg.LR)
scaler = torch.amp.GradScaler(device_type="cuda") if DEVICE.type == "cuda" else None
use_scaler = scaler is not None

best_bleu, save_path = 0.0, os.path.join(cfg.OUTPUT_DIR, "weights.pt")

for epoch in range(1, cfg.EPOCHS+1):
    print(f"\n=== Epoch {epoch}/{cfg.EPOCHS} ===")
    model.train()
    pbar = tqdm(train_loader, desc=f"Training")
    run_loss, tok = 0.0, 0

    for batch_idx, (images, caps, lengths, names) in enumerate(pbar, start=1):
        images, caps = images.to(DEVICE), caps.to(DEVICE)
        optimizer.zero_grad(set_to_none=True)

        with amp_ctx():
            logits = model(images, caps, lengths)
            targets = pack_padded_sequence(caps, lengths, batch_first=True, enforce_sorted=False)[0]
            loss = criterion(logits, targets)

        if use_scaler:
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            loss.backward()
            optimizer.step()

        run_loss += loss.item() * targets.size(0)
        tok += targets.size(0)
        avg_loss = run_loss / max(1, tok)

        # update tqdm bar
        pbar.set_postfix(loss=f"{avg_loss:.4f}")

        # print batch info every 10 batches
        if batch_idx % 10 == 0 or batch_idx == 1:
            print(f"  [Epoch {epoch} | Batch {batch_idx}/{len(train_loader)}] "
                  f"Loss: {loss.item():.4f} | Running Avg: {avg_loss:.4f}")

    # ---- End of epoch validation ----
    vloss, vbleu, vppl = evaluate(model, val_loader, criterion, itos, DEVICE, cfg.MAX_LEN)
    print(f"End of Epoch {epoch}: Val loss {vloss:.4f} | BLEU4 {vbleu:.4f} | PPL {vppl:.2f}")

    if vbleu > best_bleu:
        best_bleu = vbleu
        torch.save({"model": model.state_dict(), "vocab_size": vocab_size}, save_path)
        print(f"  ✓ New best BLEU {best_bleu:.4f} — saved checkpoint to {save_path}")



=== Epoch 1/12 ===


Training:   1%|▏                   | 1/102 [00:00<01:17,  1.30it/s, loss=8.3245]

  [Epoch 1 | Batch 1/102] Loss: 8.3245 | Running Avg: 8.3245


Training:  10%|█▊                 | 10/102 [00:05<00:41,  2.24it/s, loss=8.1278]

  [Epoch 1 | Batch 10/102] Loss: 7.8802 | Running Avg: 8.1278


Training:  20%|███▋               | 20/102 [00:09<00:32,  2.51it/s, loss=7.5662]

  [Epoch 1 | Batch 20/102] Loss: 5.8884 | Running Avg: 7.5662


Training:  29%|█████▌             | 30/102 [00:13<00:28,  2.53it/s, loss=6.7380]

  [Epoch 1 | Batch 30/102] Loss: 4.7872 | Running Avg: 6.7380


Training:  39%|███████▍           | 40/102 [00:17<00:24,  2.48it/s, loss=6.1945]

  [Epoch 1 | Batch 40/102] Loss: 4.5575 | Running Avg: 6.1945


Training:  49%|█████████▎         | 50/102 [00:21<00:20,  2.52it/s, loss=5.8564]

  [Epoch 1 | Batch 50/102] Loss: 4.4175 | Running Avg: 5.8564


Training:  59%|███████████▏       | 60/102 [00:24<00:16,  2.58it/s, loss=5.6118]

  [Epoch 1 | Batch 60/102] Loss: 4.4641 | Running Avg: 5.6118


Training:  69%|█████████████      | 70/102 [00:28<00:12,  2.50it/s, loss=5.4288]

  [Epoch 1 | Batch 70/102] Loss: 4.4855 | Running Avg: 5.4288


Training:  78%|██████████████▉    | 80/102 [00:32<00:08,  2.48it/s, loss=5.2824]

  [Epoch 1 | Batch 80/102] Loss: 4.2060 | Running Avg: 5.2824


Training:  88%|████████████████▊  | 90/102 [00:36<00:04,  2.57it/s, loss=5.1621]

  [Epoch 1 | Batch 90/102] Loss: 4.2277 | Running Avg: 5.1621


Training:  98%|█████████████████▋| 100/102 [00:40<00:00,  2.50it/s, loss=5.0554]

  [Epoch 1 | Batch 100/102] Loss: 4.0849 | Running Avg: 5.0554


Training: 100%|██████████████████| 102/102 [00:41<00:00,  2.47it/s, loss=5.0450]


End of Epoch 1: Val loss 4.0776 | BLEU4 0.0359 | PPL 59.01
  ✓ New best BLEU 0.0359 — saved checkpoint to artifacts/weights.pt

=== Epoch 2/12 ===


Training:   1%|▏                   | 1/102 [00:00<00:57,  1.76it/s, loss=3.9850]

  [Epoch 2 | Batch 1/102] Loss: 3.9850 | Running Avg: 3.9850


Training:  10%|█▊                 | 10/102 [00:04<00:35,  2.59it/s, loss=4.0531]

  [Epoch 2 | Batch 10/102] Loss: 4.0116 | Running Avg: 4.0531


Training:  20%|███▋               | 20/102 [00:08<00:32,  2.56it/s, loss=4.0245]

  [Epoch 2 | Batch 20/102] Loss: 3.7658 | Running Avg: 4.0245


Training:  29%|█████▌             | 30/102 [00:11<00:27,  2.60it/s, loss=3.9867]

  [Epoch 2 | Batch 30/102] Loss: 3.8778 | Running Avg: 3.9867


Training:  39%|███████▍           | 40/102 [00:15<00:24,  2.58it/s, loss=3.9788]

  [Epoch 2 | Batch 40/102] Loss: 3.8914 | Running Avg: 3.9788


Training:  49%|█████████▎         | 50/102 [00:19<00:20,  2.55it/s, loss=3.9551]

  [Epoch 2 | Batch 50/102] Loss: 3.8512 | Running Avg: 3.9551


Training:  59%|███████████▏       | 60/102 [00:23<00:16,  2.51it/s, loss=3.9417]

  [Epoch 2 | Batch 60/102] Loss: 3.9173 | Running Avg: 3.9417


Training:  69%|█████████████      | 70/102 [00:27<00:12,  2.56it/s, loss=3.9250]

  [Epoch 2 | Batch 70/102] Loss: 3.7833 | Running Avg: 3.9250


Training:  78%|██████████████▉    | 80/102 [00:31<00:08,  2.50it/s, loss=3.9065]

  [Epoch 2 | Batch 80/102] Loss: 3.7383 | Running Avg: 3.9065


Training:  88%|████████████████▊  | 90/102 [00:35<00:04,  2.55it/s, loss=3.8921]

  [Epoch 2 | Batch 90/102] Loss: 3.7605 | Running Avg: 3.8921


Training:  98%|█████████████████▋| 100/102 [00:39<00:00,  2.54it/s, loss=3.8778]

  [Epoch 2 | Batch 100/102] Loss: 3.6084 | Running Avg: 3.8778


Training: 100%|██████████████████| 102/102 [00:39<00:00,  2.56it/s, loss=3.8769]


End of Epoch 2: Val loss 3.7562 | BLEU4 0.0481 | PPL 42.78
  ✓ New best BLEU 0.0481 — saved checkpoint to artifacts/weights.pt

=== Epoch 3/12 ===


Training:   1%|▏                   | 1/102 [00:00<00:41,  2.41it/s, loss=3.6639]

  [Epoch 3 | Batch 1/102] Loss: 3.6639 | Running Avg: 3.6639


Training:  10%|█▊                 | 10/102 [00:03<00:35,  2.57it/s, loss=3.7019]

  [Epoch 3 | Batch 10/102] Loss: 3.6934 | Running Avg: 3.7019


Training:  20%|███▋               | 20/102 [00:07<00:31,  2.61it/s, loss=3.6712]

  [Epoch 3 | Batch 20/102] Loss: 3.6372 | Running Avg: 3.6712


Training:  29%|█████▌             | 30/102 [00:11<00:27,  2.61it/s, loss=3.6600]

  [Epoch 3 | Batch 30/102] Loss: 3.5983 | Running Avg: 3.6600


Training:  39%|███████▍           | 40/102 [00:15<00:23,  2.63it/s, loss=3.6568]

  [Epoch 3 | Batch 40/102] Loss: 3.7621 | Running Avg: 3.6568


Training:  49%|█████████▎         | 50/102 [00:19<00:20,  2.59it/s, loss=3.6529]

  [Epoch 3 | Batch 50/102] Loss: 3.6939 | Running Avg: 3.6529


Training:  59%|███████████▏       | 60/102 [00:23<00:16,  2.55it/s, loss=3.6516]

  [Epoch 3 | Batch 60/102] Loss: 3.5665 | Running Avg: 3.6516


Training:  69%|█████████████      | 70/102 [00:27<00:12,  2.63it/s, loss=3.6353]

  [Epoch 3 | Batch 70/102] Loss: 3.4560 | Running Avg: 3.6353


Training:  78%|██████████████▉    | 80/102 [00:30<00:08,  2.65it/s, loss=3.6319]

  [Epoch 3 | Batch 80/102] Loss: 3.5718 | Running Avg: 3.6319


Training:  88%|████████████████▊  | 90/102 [00:35<00:04,  2.43it/s, loss=3.6259]

  [Epoch 3 | Batch 90/102] Loss: 3.5147 | Running Avg: 3.6259


Training:  98%|█████████████████▋| 100/102 [00:38<00:00,  2.58it/s, loss=3.6129]

  [Epoch 3 | Batch 100/102] Loss: 3.4630 | Running Avg: 3.6129


Training: 100%|██████████████████| 102/102 [00:39<00:00,  2.59it/s, loss=3.6116]


End of Epoch 3: Val loss 3.5206 | BLEU4 0.0654 | PPL 33.81
  ✓ New best BLEU 0.0654 — saved checkpoint to artifacts/weights.pt

=== Epoch 4/12 ===


Training:   1%|▏                   | 1/102 [00:00<00:40,  2.46it/s, loss=3.5093]

  [Epoch 4 | Batch 1/102] Loss: 3.5093 | Running Avg: 3.5093


Training:  10%|█▊                 | 10/102 [00:03<00:36,  2.52it/s, loss=3.5000]

  [Epoch 4 | Batch 10/102] Loss: 3.3191 | Running Avg: 3.5000


Training:  20%|███▋               | 20/102 [00:07<00:32,  2.50it/s, loss=3.4979]

  [Epoch 4 | Batch 20/102] Loss: 3.5673 | Running Avg: 3.4979


Training:  29%|█████▌             | 30/102 [00:11<00:28,  2.57it/s, loss=3.4858]

  [Epoch 4 | Batch 30/102] Loss: 3.6449 | Running Avg: 3.4858


Training:  39%|███████▍           | 40/102 [00:15<00:24,  2.50it/s, loss=3.4700]

  [Epoch 4 | Batch 40/102] Loss: 3.2420 | Running Avg: 3.4700


Training:  49%|█████████▎         | 50/102 [00:19<00:20,  2.57it/s, loss=3.4801]

  [Epoch 4 | Batch 50/102] Loss: 3.6864 | Running Avg: 3.4801


Training:  59%|███████████▏       | 60/102 [00:23<00:16,  2.49it/s, loss=3.4808]

  [Epoch 4 | Batch 60/102] Loss: 3.4715 | Running Avg: 3.4808


Training:  69%|█████████████      | 70/102 [00:27<00:12,  2.55it/s, loss=3.4747]

  [Epoch 4 | Batch 70/102] Loss: 3.4067 | Running Avg: 3.4747


Training:  78%|██████████████▉    | 80/102 [00:31<00:08,  2.47it/s, loss=3.4670]

  [Epoch 4 | Batch 80/102] Loss: 3.3247 | Running Avg: 3.4670


Training:  88%|████████████████▊  | 90/102 [00:35<00:04,  2.54it/s, loss=3.4634]

  [Epoch 4 | Batch 90/102] Loss: 3.3499 | Running Avg: 3.4634


Training:  98%|█████████████████▋| 100/102 [00:39<00:00,  2.56it/s, loss=3.4578]

  [Epoch 4 | Batch 100/102] Loss: 3.5435 | Running Avg: 3.4578


Training: 100%|██████████████████| 102/102 [00:39<00:00,  2.56it/s, loss=3.4574]


End of Epoch 4: Val loss 3.4270 | BLEU4 0.0474 | PPL 30.78

=== Epoch 5/12 ===


Training:   1%|▏                   | 1/102 [00:00<00:40,  2.51it/s, loss=3.2801]

  [Epoch 5 | Batch 1/102] Loss: 3.2801 | Running Avg: 3.2801


Training:  10%|█▊                 | 10/102 [00:03<00:35,  2.58it/s, loss=3.3902]

  [Epoch 5 | Batch 10/102] Loss: 3.4541 | Running Avg: 3.3902


Training:  20%|███▋               | 20/102 [00:08<00:33,  2.44it/s, loss=3.3933]

  [Epoch 5 | Batch 20/102] Loss: 3.2866 | Running Avg: 3.3933


Training:  29%|█████▌             | 30/102 [00:11<00:28,  2.49it/s, loss=3.3840]

  [Epoch 5 | Batch 30/102] Loss: 3.4399 | Running Avg: 3.3840


Training:  39%|███████▍           | 40/102 [00:15<00:24,  2.54it/s, loss=3.3703]

  [Epoch 5 | Batch 40/102] Loss: 3.2036 | Running Avg: 3.3703


Training:  49%|█████████▎         | 50/102 [00:19<00:20,  2.52it/s, loss=3.3711]

  [Epoch 5 | Batch 50/102] Loss: 3.6476 | Running Avg: 3.3711


Training:  59%|███████████▏       | 60/102 [00:23<00:16,  2.52it/s, loss=3.3612]

  [Epoch 5 | Batch 60/102] Loss: 3.2489 | Running Avg: 3.3612


Training:  69%|█████████████      | 70/102 [00:27<00:12,  2.49it/s, loss=3.3611]

  [Epoch 5 | Batch 70/102] Loss: 3.3968 | Running Avg: 3.3611


Training:  78%|██████████████▉    | 80/102 [00:31<00:08,  2.48it/s, loss=3.3638]

  [Epoch 5 | Batch 80/102] Loss: 3.3850 | Running Avg: 3.3638


Training:  88%|████████████████▊  | 90/102 [00:35<00:04,  2.48it/s, loss=3.3551]

  [Epoch 5 | Batch 90/102] Loss: 3.2976 | Running Avg: 3.3551


Training:  98%|█████████████████▋| 100/102 [00:39<00:00,  2.56it/s, loss=3.3440]

  [Epoch 5 | Batch 100/102] Loss: 3.2123 | Running Avg: 3.3440


Training: 100%|██████████████████| 102/102 [00:40<00:00,  2.52it/s, loss=3.3440]


End of Epoch 5: Val loss 3.3522 | BLEU4 0.0391 | PPL 28.57

=== Epoch 6/12 ===


Training:   1%|▏                   | 1/102 [00:00<00:41,  2.44it/s, loss=3.2002]

  [Epoch 6 | Batch 1/102] Loss: 3.2002 | Running Avg: 3.2002


Training:  10%|█▊                 | 10/102 [00:03<00:35,  2.56it/s, loss=3.2979]

  [Epoch 6 | Batch 10/102] Loss: 3.1405 | Running Avg: 3.2979


Training:  20%|███▋               | 20/102 [00:07<00:31,  2.60it/s, loss=3.2747]

  [Epoch 6 | Batch 20/102] Loss: 3.3313 | Running Avg: 3.2747


Training:  29%|█████▌             | 30/102 [00:11<00:27,  2.63it/s, loss=3.2703]

  [Epoch 6 | Batch 30/102] Loss: 3.2264 | Running Avg: 3.2703


Training:  39%|███████▍           | 40/102 [00:15<00:24,  2.56it/s, loss=3.2571]

  [Epoch 6 | Batch 40/102] Loss: 3.0995 | Running Avg: 3.2571


Training:  49%|█████████▎         | 50/102 [00:19<00:20,  2.50it/s, loss=3.2747]

  [Epoch 6 | Batch 50/102] Loss: 3.2133 | Running Avg: 3.2747


Training:  59%|███████████▏       | 60/102 [00:23<00:16,  2.52it/s, loss=3.2683]

  [Epoch 6 | Batch 60/102] Loss: 3.3507 | Running Avg: 3.2683


Training:  69%|█████████████      | 70/102 [00:27<00:12,  2.52it/s, loss=3.2662]

  [Epoch 6 | Batch 70/102] Loss: 3.1943 | Running Avg: 3.2662


Training:  78%|██████████████▉    | 80/102 [00:31<00:08,  2.55it/s, loss=3.2643]

  [Epoch 6 | Batch 80/102] Loss: 3.2940 | Running Avg: 3.2643


Training:  88%|████████████████▊  | 90/102 [00:35<00:04,  2.46it/s, loss=3.2596]

  [Epoch 6 | Batch 90/102] Loss: 3.3877 | Running Avg: 3.2596


Training:  98%|█████████████████▋| 100/102 [00:39<00:00,  2.55it/s, loss=3.2542]

  [Epoch 6 | Batch 100/102] Loss: 3.2737 | Running Avg: 3.2542


Training: 100%|██████████████████| 102/102 [00:39<00:00,  2.55it/s, loss=3.2536]


End of Epoch 6: Val loss 3.2617 | BLEU4 0.0505 | PPL 26.09

=== Epoch 7/12 ===


Training:   1%|▏                   | 1/102 [00:00<00:39,  2.58it/s, loss=2.8940]

  [Epoch 7 | Batch 1/102] Loss: 2.8940 | Running Avg: 2.8940


Training:  10%|█▊                 | 10/102 [00:03<00:36,  2.52it/s, loss=3.1512]

  [Epoch 7 | Batch 10/102] Loss: 3.0166 | Running Avg: 3.1512


Training:  20%|███▋               | 20/102 [00:07<00:32,  2.54it/s, loss=3.1723]

  [Epoch 7 | Batch 20/102] Loss: 3.2056 | Running Avg: 3.1723


Training:  29%|█████▌             | 30/102 [00:11<00:29,  2.48it/s, loss=3.1710]

  [Epoch 7 | Batch 30/102] Loss: 3.3074 | Running Avg: 3.1710


Training:  39%|███████▍           | 40/102 [00:15<00:24,  2.52it/s, loss=3.1725]

  [Epoch 7 | Batch 40/102] Loss: 3.2005 | Running Avg: 3.1725


Training:  49%|█████████▎         | 50/102 [00:19<00:20,  2.49it/s, loss=3.1729]

  [Epoch 7 | Batch 50/102] Loss: 3.1904 | Running Avg: 3.1729


Training:  59%|███████████▏       | 60/102 [00:23<00:16,  2.51it/s, loss=3.1649]

  [Epoch 7 | Batch 60/102] Loss: 3.1339 | Running Avg: 3.1649


Training:  69%|█████████████      | 70/102 [00:27<00:12,  2.51it/s, loss=3.1638]

  [Epoch 7 | Batch 70/102] Loss: 3.2508 | Running Avg: 3.1638


Training:  78%|██████████████▉    | 80/102 [00:31<00:08,  2.57it/s, loss=3.1691]

  [Epoch 7 | Batch 80/102] Loss: 3.1659 | Running Avg: 3.1691


Training:  88%|████████████████▊  | 90/102 [00:35<00:04,  2.53it/s, loss=3.1677]

  [Epoch 7 | Batch 90/102] Loss: 3.2558 | Running Avg: 3.1677


Training:  98%|█████████████████▋| 100/102 [00:39<00:00,  2.48it/s, loss=3.1644]

  [Epoch 7 | Batch 100/102] Loss: 2.9845 | Running Avg: 3.1644


Training: 100%|██████████████████| 102/102 [00:40<00:00,  2.53it/s, loss=3.1647]


End of Epoch 7: Val loss 3.2018 | BLEU4 0.0485 | PPL 24.58

=== Epoch 8/12 ===


Training:   1%|▏                   | 1/102 [00:00<00:40,  2.52it/s, loss=3.1193]

  [Epoch 8 | Batch 1/102] Loss: 3.1193 | Running Avg: 3.1193


Training:  10%|█▊                 | 10/102 [00:04<00:36,  2.51it/s, loss=3.0723]

  [Epoch 8 | Batch 10/102] Loss: 3.0346 | Running Avg: 3.0723


Training:  20%|███▋               | 20/102 [00:07<00:33,  2.48it/s, loss=3.1075]

  [Epoch 8 | Batch 20/102] Loss: 3.0115 | Running Avg: 3.1075


Training:  29%|█████▌             | 30/102 [00:11<00:28,  2.51it/s, loss=3.1072]

  [Epoch 8 | Batch 30/102] Loss: 3.1228 | Running Avg: 3.1072


Training:  39%|███████▍           | 40/102 [00:15<00:24,  2.50it/s, loss=3.1038]

  [Epoch 8 | Batch 40/102] Loss: 3.0543 | Running Avg: 3.1038


Training:  49%|█████████▎         | 50/102 [00:19<00:20,  2.53it/s, loss=3.1026]

  [Epoch 8 | Batch 50/102] Loss: 2.9450 | Running Avg: 3.1026


Training:  59%|███████████▏       | 60/102 [00:24<00:17,  2.40it/s, loss=3.1152]

  [Epoch 8 | Batch 60/102] Loss: 3.2507 | Running Avg: 3.1152


Training:  69%|█████████████      | 70/102 [00:28<00:12,  2.52it/s, loss=3.1139]

  [Epoch 8 | Batch 70/102] Loss: 3.1347 | Running Avg: 3.1139


Training:  78%|██████████████▉    | 80/102 [00:31<00:08,  2.51it/s, loss=3.1044]

  [Epoch 8 | Batch 80/102] Loss: 3.0376 | Running Avg: 3.1044


Training:  88%|████████████████▊  | 90/102 [00:35<00:04,  2.57it/s, loss=3.1087]

  [Epoch 8 | Batch 90/102] Loss: 3.3125 | Running Avg: 3.1087


Training:  98%|█████████████████▋| 100/102 [00:39<00:00,  2.56it/s, loss=3.1102]

  [Epoch 8 | Batch 100/102] Loss: 3.4386 | Running Avg: 3.1102


Training: 100%|██████████████████| 102/102 [00:40<00:00,  2.53it/s, loss=3.1092]


End of Epoch 8: Val loss 3.1278 | BLEU4 0.0514 | PPL 22.82

=== Epoch 9/12 ===


Training:   1%|▏                   | 1/102 [00:00<00:45,  2.22it/s, loss=3.1618]

  [Epoch 9 | Batch 1/102] Loss: 3.1618 | Running Avg: 3.1618


Training:  10%|█▊                 | 10/102 [00:04<00:36,  2.49it/s, loss=3.0603]

  [Epoch 9 | Batch 10/102] Loss: 3.1859 | Running Avg: 3.0603


Training:  20%|███▋               | 20/102 [00:08<00:33,  2.42it/s, loss=3.0423]

  [Epoch 9 | Batch 20/102] Loss: 3.2167 | Running Avg: 3.0423


Training:  29%|█████▌             | 30/102 [00:12<00:28,  2.52it/s, loss=3.0241]

  [Epoch 9 | Batch 30/102] Loss: 2.9616 | Running Avg: 3.0241


Training:  39%|███████▍           | 40/102 [00:16<00:25,  2.47it/s, loss=3.0284]

  [Epoch 9 | Batch 40/102] Loss: 3.2114 | Running Avg: 3.0284


Training:  49%|█████████▎         | 50/102 [00:20<00:20,  2.50it/s, loss=3.0308]

  [Epoch 9 | Batch 50/102] Loss: 2.9776 | Running Avg: 3.0308


Training:  59%|███████████▏       | 60/102 [00:24<00:17,  2.45it/s, loss=3.0295]

  [Epoch 9 | Batch 60/102] Loss: 3.0259 | Running Avg: 3.0295


Training:  69%|█████████████      | 70/102 [00:28<00:12,  2.52it/s, loss=3.0351]

  [Epoch 9 | Batch 70/102] Loss: 3.1070 | Running Avg: 3.0351


Training:  78%|██████████████▉    | 80/102 [00:32<00:08,  2.52it/s, loss=3.0406]

  [Epoch 9 | Batch 80/102] Loss: 2.9553 | Running Avg: 3.0406


Training:  88%|████████████████▊  | 90/102 [00:36<00:04,  2.50it/s, loss=3.0375]

  [Epoch 9 | Batch 90/102] Loss: 3.1133 | Running Avg: 3.0375


Training:  98%|█████████████████▋| 100/102 [00:40<00:00,  2.41it/s, loss=3.0300]

  [Epoch 9 | Batch 100/102] Loss: 2.8439 | Running Avg: 3.0300


Training: 100%|██████████████████| 102/102 [00:40<00:00,  2.50it/s, loss=3.0310]


End of Epoch 9: Val loss 3.0694 | BLEU4 0.0514 | PPL 21.53

=== Epoch 10/12 ===


Training:   1%|▏                   | 1/102 [00:00<00:39,  2.56it/s, loss=3.0183]

  [Epoch 10 | Batch 1/102] Loss: 3.0183 | Running Avg: 3.0183


Training:  10%|█▊                 | 10/102 [00:04<00:36,  2.49it/s, loss=2.9980]

  [Epoch 10 | Batch 10/102] Loss: 3.1161 | Running Avg: 2.9980


Training:  20%|███▋               | 20/102 [00:08<00:33,  2.44it/s, loss=3.0090]

  [Epoch 10 | Batch 20/102] Loss: 3.1196 | Running Avg: 3.0090


Training:  29%|█████▌             | 30/102 [00:12<00:29,  2.43it/s, loss=2.9896]

  [Epoch 10 | Batch 30/102] Loss: 3.0048 | Running Avg: 2.9896


Training:  39%|███████▍           | 40/102 [00:16<00:24,  2.50it/s, loss=2.9969]

  [Epoch 10 | Batch 40/102] Loss: 2.9627 | Running Avg: 2.9969


Training:  49%|█████████▎         | 50/102 [00:20<00:20,  2.53it/s, loss=3.0027]

  [Epoch 10 | Batch 50/102] Loss: 3.0380 | Running Avg: 3.0027


Training:  59%|███████████▏       | 60/102 [00:24<00:16,  2.48it/s, loss=2.9991]

  [Epoch 10 | Batch 60/102] Loss: 3.0322 | Running Avg: 2.9991


Training:  69%|█████████████      | 70/102 [00:28<00:12,  2.49it/s, loss=2.9969]

  [Epoch 10 | Batch 70/102] Loss: 2.9550 | Running Avg: 2.9969


Training:  78%|██████████████▉    | 80/102 [00:32<00:08,  2.48it/s, loss=2.9922]

  [Epoch 10 | Batch 80/102] Loss: 2.8746 | Running Avg: 2.9922


Training:  88%|████████████████▊  | 90/102 [00:36<00:04,  2.51it/s, loss=2.9926]

  [Epoch 10 | Batch 90/102] Loss: 3.0959 | Running Avg: 2.9926


Training:  98%|█████████████████▋| 100/102 [00:40<00:00,  2.47it/s, loss=2.9901]

  [Epoch 10 | Batch 100/102] Loss: 2.9959 | Running Avg: 2.9901


Training: 100%|██████████████████| 102/102 [00:40<00:00,  2.50it/s, loss=2.9898]


End of Epoch 10: Val loss 3.0603 | BLEU4 0.0481 | PPL 21.33

=== Epoch 11/12 ===


Training:   1%|▏                   | 1/102 [00:00<00:40,  2.47it/s, loss=3.0383]

  [Epoch 11 | Batch 1/102] Loss: 3.0383 | Running Avg: 3.0383


Training:  10%|█▊                 | 10/102 [00:04<00:37,  2.47it/s, loss=2.9433]

  [Epoch 11 | Batch 10/102] Loss: 2.8808 | Running Avg: 2.9433


Training:  20%|███▋               | 20/102 [00:08<00:32,  2.52it/s, loss=2.9299]

  [Epoch 11 | Batch 20/102] Loss: 2.8584 | Running Avg: 2.9299


Training:  29%|█████▌             | 30/102 [00:11<00:28,  2.51it/s, loss=2.9284]

  [Epoch 11 | Batch 30/102] Loss: 2.9897 | Running Avg: 2.9284


Training:  39%|███████▍           | 40/102 [00:15<00:24,  2.55it/s, loss=2.9400]

  [Epoch 11 | Batch 40/102] Loss: 3.1202 | Running Avg: 2.9400


Training:  49%|█████████▎         | 50/102 [00:19<00:21,  2.46it/s, loss=2.9377]

  [Epoch 11 | Batch 50/102] Loss: 2.8239 | Running Avg: 2.9377


Training:  59%|███████████▏       | 60/102 [00:24<00:17,  2.44it/s, loss=2.9547]

  [Epoch 11 | Batch 60/102] Loss: 3.1331 | Running Avg: 2.9547


Training:  69%|█████████████      | 70/102 [00:28<00:14,  2.25it/s, loss=2.9566]

  [Epoch 11 | Batch 70/102] Loss: 2.8860 | Running Avg: 2.9566


Training:  78%|██████████████▉    | 80/102 [00:32<00:09,  2.36it/s, loss=2.9604]

  [Epoch 11 | Batch 80/102] Loss: 3.0205 | Running Avg: 2.9604


Training:  88%|████████████████▊  | 90/102 [00:36<00:04,  2.49it/s, loss=2.9581]

  [Epoch 11 | Batch 90/102] Loss: 2.8759 | Running Avg: 2.9581


Training:  98%|█████████████████▋| 100/102 [00:40<00:00,  2.47it/s, loss=2.9546]

  [Epoch 11 | Batch 100/102] Loss: 2.8944 | Running Avg: 2.9546


Training: 100%|██████████████████| 102/102 [00:41<00:00,  2.47it/s, loss=2.9536]


End of Epoch 11: Val loss 2.9915 | BLEU4 0.0366 | PPL 19.92

=== Epoch 12/12 ===


Training:   1%|▏                   | 1/102 [00:00<00:41,  2.41it/s, loss=2.9326]

  [Epoch 12 | Batch 1/102] Loss: 2.9326 | Running Avg: 2.9326


Training:  10%|█▊                 | 10/102 [00:04<00:36,  2.49it/s, loss=2.9017]

  [Epoch 12 | Batch 10/102] Loss: 3.0069 | Running Avg: 2.9017


Training:  20%|███▋               | 20/102 [00:08<00:32,  2.49it/s, loss=2.8904]

  [Epoch 12 | Batch 20/102] Loss: 2.7609 | Running Avg: 2.8904


Training:  29%|█████▌             | 30/102 [00:12<00:29,  2.45it/s, loss=2.8930]

  [Epoch 12 | Batch 30/102] Loss: 2.9255 | Running Avg: 2.8930


Training:  39%|███████▍           | 40/102 [00:16<00:25,  2.41it/s, loss=2.8864]

  [Epoch 12 | Batch 40/102] Loss: 2.8397 | Running Avg: 2.8864


Training:  49%|█████████▎         | 50/102 [00:20<00:21,  2.45it/s, loss=2.8846]

  [Epoch 12 | Batch 50/102] Loss: 2.9755 | Running Avg: 2.8846


Training:  59%|███████████▏       | 60/102 [00:24<00:17,  2.47it/s, loss=2.8880]

  [Epoch 12 | Batch 60/102] Loss: 3.0080 | Running Avg: 2.8880


Training:  69%|█████████████      | 70/102 [00:28<00:12,  2.53it/s, loss=2.8919]

  [Epoch 12 | Batch 70/102] Loss: 2.7469 | Running Avg: 2.8919


Training:  78%|██████████████▉    | 80/102 [00:32<00:08,  2.49it/s, loss=2.8918]

  [Epoch 12 | Batch 80/102] Loss: 2.9285 | Running Avg: 2.8918


Training:  88%|████████████████▊  | 90/102 [00:36<00:04,  2.42it/s, loss=2.8862]

  [Epoch 12 | Batch 90/102] Loss: 3.0035 | Running Avg: 2.8862


Training:  98%|█████████████████▋| 100/102 [00:40<00:00,  2.43it/s, loss=2.8907]

  [Epoch 12 | Batch 100/102] Loss: 3.0938 | Running Avg: 2.8907


Training: 100%|██████████████████| 102/102 [00:41<00:00,  2.48it/s, loss=2.8908]


End of Epoch 12: Val loss 2.9781 | BLEU4 0.0576 | PPL 19.65


In [23]:
# ---- Load best checkpoint and evaluate on test set ----
ckpt = torch.load(save_path, map_location="cpu")
model.load_state_dict(ckpt["model"])
model = model.to(DEVICE).eval()

t_loss, t_bleu, t_ppl = evaluate(model, test_loader, criterion, itos, DEVICE, cfg.MAX_LEN)
print(f"\n=== TEST RESULTS ===")
print(f"Loss: {t_loss:.4f} | BLEU4: {t_bleu:.4f} | Perplexity: {t_ppl:.2f}")

# ---- Show a few generated samples ----
model.eval()
with torch.no_grad():
    for images, caps, lengths, names in test_loader:
        images, caps = images.to(DEVICE), caps.to(DEVICE)
        gen = model.generate(images, max_len=cfg.MAX_LEN)

        for i in range(min(5, images.size(0))):
            pred_caption = ids_to_text(gen[i], itos)
            ref_caption  = ids_to_text(caps[i], itos)
            print(f"[{names[i]}]\n  pred: {pred_caption}\n  ref:  {ref_caption}\n")
        break


  ckpt = torch.load(save_path, map_location="cpu")



=== TEST RESULTS ===
Loss: 3.5325 | BLEU4: 0.0705 | Perplexity: 34.21
[3153067758_53f003b1df.jpg]
  pred: a man in a red shirt is running in the grass .
  ref:  a woman sitting on a bus with a paper bag hanging on a carrier

[3449170348_34dac4a380.jpg]
  pred: a man in a red shirt is running in the grass .
  ref:  a girl dances on a sidewalk .

[3626964430_cb5c7e5acc.jpg]
  pred: a man in a red shirt is running in the grass .
  ref:  people playing cricket in the park , pine trees in the back .

[2286823363_7d554ea740.jpg]
  pred: a man in a red shirt is running in the grass .
  ref:  a young boy jumping from one chair to another in his house

[241347204_007d83e252.jpg]
  pred: a man in a red shirt is running in the grass .
  ref:  football players gather around the <unk> .



In [25]:
# ---- Reload model on CPU for export ----
export_model = CaptionNet(
    vocab_size=vocab_size,
    embed_dim=cfg.EMBED_DIM,
    hidden_dim=cfg.HIDDEN_DIM,
    num_layers=cfg.NUM_LAYERS,
    dropout=cfg.DROPOUT,
    freeze_cnn=cfg.FREEZE_CNN,
).cpu()
export_model.load_state_dict(torch.load(save_path, map_location="cpu")["model"])
export_model.eval()

class InferenceWrapper(nn.Module):
    def __init__(self, net: CaptionNet, max_len: int):
        super().__init__()
        self.net = net
        self.max_len = max_len
    def forward(self, images: torch.Tensor):
        return self.net.generate(images, max_len=self.max_len)

wrapper = InferenceWrapper(export_model, cfg.MAX_LEN).cpu().eval()
dummy = torch.randn(1,3,cfg.IMG_SIZE,cfg.IMG_SIZE, device="cpu")

# ---- TorchScript export ----
ts_path = os.path.join(cfg.OUTPUT_DIR, "model_ts.pt")
ts = torch.jit.trace(wrapper, dummy)
torch.jit.save(ts, ts_path)
print(f"Saved TorchScript: {ts_path}")

# ---- ONNX export ----
onnx_path = os.path.join(cfg.OUTPUT_DIR, "model.onnx")
try:
    torch.onnx.export(
        wrapper, dummy, onnx_path,
        input_names=["images"], output_names=["token_ids"],
        opset_version=14,
        dynamic_axes={"images": {0: "batch"}, "token_ids": {0: "batch"}}
    )
    print(f"Saved ONNX: {onnx_path}")
except Exception as e:
    print("ONNX export skipped:", e)

# ---- Write artifact readme ----
with open(os.path.join(cfg.OUTPUT_DIR, "artifact_README.txt"), "w") as f:
    f.write("Input: float32 (B,3,224,224) normalized to ImageNet mean/std.\n")
    f.write("Output: int64 token IDs (B,T). Decode with vocab.json (itos).\n")

print("Artifacts:", os.listdir(cfg.OUTPUT_DIR))


  export_model.load_state_dict(torch.load(save_path, map_location="cpu")["model"])


Saved TorchScript: artifacts/model_ts.pt




ONNX export skipped: Module onnx is not installed!
Artifacts: ['model_ts.pt', 'config.json', 'weights.pt', 'vocab.json', 'artifact_README.txt']
