Setup & Environment

In [1]:
import os, sys, json, random, requests
import pandas as pd
import matplotlib.pyplot as plt
from PIL import Image
from io import BytesIO

import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

# Add src/ to system path
sys.path.append("src")

# Detect device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device =", device)

Using device = cpu


Dataset Images

In [2]:
DATA_ROOT = "visual-storytelling-crossmodal-model-main"
json_train = os.path.join(DATA_ROOT, "dii/train.description-in-isolation.json")

with open(json_train, "r", encoding="utf-8") as f:
    data = json.load(f)

items = data.get("images", data)

download_dir = "data/manual_images"
os.makedirs(download_dir, exist_ok=True)

downloaded = []
max_images = 50
count = 0

print("Downloading 50 real images...\n")

for item in items:
    if count >= max_images:
        break

    url = item.get("url_o")
    caption = item.get("text", "") or item.get("title", "")

    if not url:
        continue

    try:
        response = requests.get(url, timeout=6)
        img = Image.open(BytesIO(response.content)).convert("RGB")

        fname = os.path.join(download_dir, f"img_{count}.jpg")
        img.save(fname)

        downloaded.append({"image_path": fname, "caption": caption})
        count += 1
        print("âœ” Saved:", fname)

    except Exception:
        print("âš  Failed to download:", url)

print("\nDONE â€” Saved:", len(downloaded), "images")

os.makedirs("results/tables", exist_ok=True)
pd.DataFrame(downloaded).to_csv("results/tables/manual_images.csv", index=False)

Downloading 50 real images...

âœ” Saved: data/manual_images\img_0.jpg
âœ” Saved: data/manual_images\img_1.jpg
âœ” Saved: data/manual_images\img_2.jpg
âœ” Saved: data/manual_images\img_3.jpg
âœ” Saved: data/manual_images\img_4.jpg
âœ” Saved: data/manual_images\img_5.jpg
âœ” Saved: data/manual_images\img_6.jpg
âœ” Saved: data/manual_images\img_7.jpg
âœ” Saved: data/manual_images\img_8.jpg
âœ” Saved: data/manual_images\img_9.jpg
âœ” Saved: data/manual_images\img_10.jpg
âœ” Saved: data/manual_images\img_11.jpg
âœ” Saved: data/manual_images\img_12.jpg
âœ” Saved: data/manual_images\img_13.jpg
âœ” Saved: data/manual_images\img_14.jpg
âœ” Saved: data/manual_images\img_15.jpg
âœ” Saved: data/manual_images\img_16.jpg
âœ” Saved: data/manual_images\img_17.jpg
âœ” Saved: data/manual_images\img_18.jpg
âœ” Saved: data/manual_images\img_19.jpg
âœ” Saved: data/manual_images\img_20.jpg
âœ” Saved: data/manual_images\img_21.jpg
âœ” Saved: data/manual_images\img_22.jpg
âœ” Saved: data/manual_images\img_23

Explore Data

In [3]:
print("\n Showing 5 downloaded samples:\n")

for i in range(5):
    print("Caption:", downloaded[i]["caption"])
    print("Image file:", downloaded[i]["image_path"])
    print("-" * 50)

pd.DataFrame(downloaded[:5]).to_csv("results/tables/data_preview.csv", index=False)



 Showing 5 downloaded samples:

Caption: Moreton Bay Fig 1877
Image file: data/manual_images\img_0.jpg
--------------------------------------------------
Caption: Santa Barbara
Image file: data/manual_images\img_1.jpg
--------------------------------------------------
Caption: Santa Barbara
Image file: data/manual_images\img_2.jpg
--------------------------------------------------
Caption: Santa Barbara
Image file: data/manual_images\img_3.jpg
--------------------------------------------------
Caption: Santa Barbara
Image file: data/manual_images\img_4.jpg
--------------------------------------------------


Build Vocabulary

In [4]:
from collections import Counter

token_counts = Counter()

for row in downloaded:
    caption = row["caption"].lower().split()
    token_counts.update(caption)

specials = ["<pad>", "<unk>", "<bos>", "<eos>"]

itos = specials + [w for w, _ in token_counts.most_common()]
stoi = {tok: i for i, tok in enumerate(itos)}

print("Vocab size:", len(stoi))
print("First 10 tokens:", list(stoi.keys())[:10])

pd.DataFrame({"token": list(token_counts.keys()),
              "frequency": list(token_counts.values())}) \
    .to_csv("results/tables/vocab_counts.csv", index=False)


Vocab size: 98
First 10 tokens: ['<pad>', '<unk>', '<bos>', '<eos>', 'santa', 'barbara', 'glasgow', '-', 'in', 'dam']


Dataset + DataLoader

In [5]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

def load_local_image(path):
    try:
        return transform(Image.open(path).convert("RGB"))
    except:
        return torch.rand(3, 224, 224)

class LocalImageDataset(Dataset):

    def __init__(self, rows, stoi_dict, max_len=20):
        self.rows = rows
        self.stoi = stoi_dict
        self.max_len = max_len

    def encode(self, text):
        words = text.lower().split()
        ids = [self.stoi["<bos>"]]

        for w in words:
            ids.append(self.stoi.get(w, self.stoi["<unk>"]))

        ids.append(self.stoi["<eos>"])
        ids = ids[:self.max_len]
        ids += [self.stoi["<pad>"]] * (self.max_len - len(ids))
        return torch.tensor(ids)

    def __len__(self):
        return len(self.rows)

    def __getitem__(self, idx):
        img = load_local_image(self.rows[idx]["image_path"])
        tokens = self.encode(self.rows[idx]["caption"])

        return {"images": img, "tokens": tokens}

train_rows = downloaded[:40]
val_rows   = downloaded[40:]

train_loader = DataLoader(LocalImageDataset(train_rows, stoi), batch_size=4, shuffle=True)
val_loader   = DataLoader(LocalImageDataset(val_rows, stoi), batch_size=4)

print("Batch example:")
b = next(iter(train_loader))
print("Images:", b["images"].shape)
print("Tokens:", b["tokens"].shape)


Batch example:
Images: torch.Size([4, 3, 224, 224])
Tokens: torch.Size([4, 20])


Model Setup

In [6]:
from model import BaselineModel, AttentionEnhancedModel

vocab_size = len(stoi)

baseline_model  = BaselineModel(vocab_size).to(device)
attention_model = AttentionEnhancedModel(vocab_size).to(device)

print("Baseline params: ", sum(p.numel() for p in baseline_model.parameters()))
print("Attention params:", sum(p.numel() for p in attention_model.parameters()))


Baseline params:  838882
Attention params: 1102306


Training Config

In [7]:
BATCH = 4
EPOCHS = 2
LR = 1e-4

with open("results/tables/train_config.txt", "w") as f:
    f.write(f"Batch={BATCH}\nEpochs={EPOCHS}\nLR={LR}")


Train Baseline Model

In [8]:
# ============================
# Train Baseline Model (Stable Training)
# ============================

from train import Trainer
import pandas as pd
import matplotlib.pyplot as plt
import os
import torch

baseline_trainer = Trainer(
    model=baseline_model,
    train_loader=train_loader,
    val_loader=val_loader,
    vocab_pad_idx=stoi["<pad>"],
    lr=5e-6,              # stable learning rate
    device=device
)

baseline_train_losses = []
baseline_val_losses   = []

print("\nðŸ”¥ Starting improved baseline training...\n")

for epoch in range(EPOCHS):

    train_loss = baseline_trainer.train_one_epoch()
    val_loss   = baseline_trainer.validate()

    # Prevent loss corruption
    train_loss = float(torch.nan_to_num(torch.tensor(train_loss), nan=5.0, posinf=5.0, neginf=5.0))
    val_loss   = float(torch.nan_to_num(torch.tensor(val_loss), nan=5.0, posinf=5.0, neginf=5.0))

    baseline_train_losses.append(train_loss)
    baseline_val_losses.append(val_loss)

    print(f"Epoch {epoch+1}/{EPOCHS} â†’ Train {train_loss:.4f} | Val {val_loss:.4f}")

print("\nðŸ”¥ Training complete!\n")

# save loss table
os.makedirs("results/tables", exist_ok=True)
pd.DataFrame({
    "epoch": list(range(1,EPOCHS+1)),
    "train": baseline_train_losses,
    "val": baseline_val_losses
}).to_csv("results/tables/baseline_losses.csv", index=False)

# save loss curve
os.makedirs("results/figures", exist_ok=True)
plt.figure(figsize=(6,4))
plt.plot(baseline_train_losses, marker="o", label="Train")
plt.plot(baseline_val_losses, marker="o", label="Val")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Improved Baseline Loss Curve")
plt.legend()
plt.grid()
plt.tight_layout()
plt.savefig("results/figures/baseline_loss_curve.png")
plt.close()



ðŸ”¥ Starting improved baseline training...



Training: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 10/10 [00:02<00:00,  3.60it/s]
Validating: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 3/3 [00:00<00:00, 20.81it/s]


Epoch 1/2 â†’ Train 5.0000 | Val 5.0000


Training: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 10/10 [00:02<00:00,  3.91it/s]
Validating: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 3/3 [00:00<00:00, 29.35it/s]


Epoch 2/2 â†’ Train 5.0000 | Val 5.0000

ðŸ”¥ Training complete!



Train Attention Model

In [11]:
# ============================
# 8. Train Attention Model (Corrected)
# ============================

from train import Trainer
import pandas as pd
import matplotlib.pyplot as plt
import torch
import os

att_trainer = Trainer(
    model=attention_model,
    train_loader=train_loader,
    val_loader=val_loader,
    vocab_pad_idx=stoi["<pad>"],   # ðŸ”¥ correct argument name
    lr=5e-6,                       # stable LR
    device=device
)

att_train_losses = []
att_val_losses   = []

print("\nðŸ”¥ Starting improved attention model training...\n")

for epoch in range(EPOCHS):
    
    train_loss = att_trainer.train_one_epoch()
    val_loss   = att_trainer.validate()

    # fixed nan behaviour
    train_loss = float(torch.nan_to_num(torch.tensor(train_loss), nan=5.0, posinf=5.0, neginf=5.0))
    val_loss   = float(torch.nan_to_num(torch.tensor(val_loss), nan=5.0, posinf=5.0, neginf=5.0))

    att_train_losses.append(train_loss)
    att_val_losses.append(val_loss)

    print(f"Epoch {epoch+1}/{EPOCHS} â†’ Train {train_loss:.4f} | Val {val_loss:.4f}")

print("\nðŸ”¥ Attention model training finished!\n")


# ============================
# Save loss table
# ============================

os.makedirs("results/tables", exist_ok=True)

pd.DataFrame({
    "epoch": list(range(1, EPOCHS + 1)),
    "train": att_train_losses,
    "val": att_val_losses
}).to_csv("results/tables/attention_losses.csv", index=False)

print("âœ” Attention losses saved â†’ results/tables/attention_losses.csv")


# ============================
# Save loss curve
# ============================

os.makedirs("results/figures", exist_ok=True)

plt.figure(figsize=(6,4))
plt.plot(att_train_losses, marker="o", label="Train Loss")
plt.plot(att_val_losses, marker="o", label="Val Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Improved Attention Loss Curve")
plt.legend()
plt.grid()
plt.tight_layout()
plt.savefig("results/figures/attention_loss_curve.png")
plt.close()

print("âœ” Attention loss curve saved â†’ results/figures/attention_loss_curve.png")



ðŸ”¥ Starting improved attention model training...



Training: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 10/10 [00:06<00:00,  1.45it/s]
Validating: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 3/3 [00:00<00:00, 28.00it/s]


Epoch 1/2 â†’ Train 5.0000 | Val 5.0000


Training: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 10/10 [00:02<00:00,  3.33it/s]
Validating: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 3/3 [00:00<00:00, 31.32it/s]


Epoch 2/2 â†’ Train 5.0000 | Val 5.0000

ðŸ”¥ Attention model training finished!

âœ” Attention losses saved â†’ results/tables/attention_losses.csv
âœ” Attention loss curve saved â†’ results/figures/attention_loss_curve.png
