In [16]:
# ───────────────────────────── 00‑setup ─────────────────────────────
%pip install -r requirements.txt --quiet
import sys, pathlib, os
sys.path.append(str(pathlib.Path.cwd() / "code"))

PROJECT_DIR   = pathlib.Path().cwd()
DATA_DIR      = PROJECT_DIR / "data"         
NLTK_DIR      = DATA_DIR / "nltk_data"
KAGGLE_DIR    = DATA_DIR / "kaggle_cache"

os.environ["NLTK_DATA"] = str(NLTK_DIR)
os.environ["KAGGLEHUB_CACHE"] = str(KAGGLE_DIR)

NLTK_DIR.mkdir(parents=True, exist_ok=True)
KAGGLE_DIR.mkdir(parents=True, exist_ok=True)

import nltk
if str(NLTK_DIR) not in nltk.data.path:
    nltk.data.path.append(str(NLTK_DIR))

Note: you may need to restart the kernel to use updated packages.


In [10]:
# ───────────────────────────── 01‑imports ────────────────────────────
from models import ShowAttendTell
from datasets import get_flickr8k_splits, collate_fn
from training import TrainingConfig, train_and_evaluate
import torch, torchvision.transforms as T

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", device)

Device: cpu


In [19]:
# ───────────────────────────── 02‑data ──────────────────────────────
import kagglehub, nltk
data_dir = kagglehub.dataset_download("ashish2001/original-flickr8k-dataset")
nltk.download("punkt_tab",   download_dir=str(NLTK_DIR), quiet=True)
nltk.download("wordnet", download_dir=str(NLTK_DIR), quiet=True)

transform = T.Compose([
    T.Resize((256,256)),
    T.RandomCrop(224),
    T.RandomHorizontalFlip(),
    T.ToTensor(),
    T.Normalize(mean=[0.485,0.456,0.406],
                std=[0.229,0.224,0.225])
])

train_ds, val_ds, test_ds, vocab = get_flickr8k_splits(data_dir, transform)
print("Vocab size:", len(vocab))

[Info] Skipped 5 captions with missing images.
Vocabulary size: 8921
Train: 30000, Val: 5000, Test: 5000
Vocab size: 8921


In [20]:
# ───────────────────────────── 03‑model ─────────────────────────────
model = ShowAttendTell(vocab_size=len(vocab),
                       backbone="resnet50",
                       finetune=True,
                       embed_dim=512,
                       hidden_dim=512,
                       attn_dim=512,
                       dropout=0.2,
                       use_double_attention=True,
                       use_hard_attention=False)
criterion = torch.nn.CrossEntropyLoss(label_smoothing=0.1, ignore_index=0)

cfg = TrainingConfig(epochs=5,     # quick smoke‑test
                     batch_size=64,
                     lr=1e-4,
                     save_every=2)

In [None]:
# ───────────────────────────── 04‑train ─────────────────────────────
model, *_ = train_and_evaluate(model, train_ds, val_ds,
                               criterion, cfg, len(vocab),
                               save_dir="results")

In [None]:
# ───────────────────────────── 05‑visualise ─────────────────────────
from visualize import generate_sample_captions, visualize_attention_paper_style
import random, os

indices = random.sample(range(len(test_ds)), 3)
samples = generate_sample_captions(model, test_ds, vocab, device,
                                   indices, beam_size=8)
os.makedirs("results/vis", exist_ok=True)
for i, s in enumerate(samples):
    img_path = os.path.join(data_dir, "Flickr8k_Dataset",
                            "Flicker8k_Dataset", test_ds.dataset.image_names[
                                test_ds.indices[s['dataset_idx']]])
    visualize_attention_paper_style(
        img_path, s['prediction_indices'], s['attention_maps'],
        vocab, reference_caption=s['reference'],
        save_path=f"results/vis/sample_{i+1}.png")