In [1]:
import torch

print("Torch version:", torch.__version__)
print("CUDA available:", torch.cuda.is_available())

if torch.cuda.is_available():
    print("GPU:", torch.cuda.get_device_name(0))


Torch version: 2.9.0+cu126
CUDA available: True
GPU: Tesla T4


In [2]:
!pip install transformers torchvision einops



In [3]:
! pip install datasets



In [4]:
from datasets import load_dataset

dataset = load_dataset("daniel3303/StoryReasoning")
print(dataset)


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


README.md: 0.00B [00:00, ?B/s]

data/train-00000-of-00002.parquet:   0%|          | 0.00/327M [00:00<?, ?B/s]

data/train-00001-of-00002.parquet:   0%|          | 0.00/331M [00:00<?, ?B/s]

data/test-00000-of-00001.parquet:   0%|          | 0.00/115M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/3552 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/626 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['story_id', 'images', 'frame_count', 'chain_of_thought', 'story'],
        num_rows: 3552
    })
    test: Dataset({
        features: ['story_id', 'images', 'frame_count', 'chain_of_thought', 'story'],
        num_rows: 626
    })
})


In [5]:
print("A sample story entry:\n")

sample = dataset["train"][0]
print("Story ID:", sample["story_id"])
print("Frame count:", sample["frame_count"])
print("Images (list of URLs):", sample["images"])
print("Text:", sample["story"])
print("Chain-of-thought:", sample["chain_of_thought"])

A sample story entry:

Story ID: 3920
Frame count: 17
Images (list of URLs): [<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=575x240 at 0x7D9AA5BA24E0>, <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=575x240 at 0x7D9AA5BA2810>, <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=575x240 at 0x7D9AA5BA2900>, <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=575x240 at 0x7D9AA5BA29F0>, <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=575x240 at 0x7D9AA5BA2990>, <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=575x240 at 0x7D9AA5BA2BD0>, <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=575x240 at 0x7D9AA5BA2B40>, <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=575x240 at 0x7D9AA5BA2DB0>, <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=575x240 at 0x7D9AA5BA2F60>, <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=575x240 at 0x7D9AA5BA30B0>, <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=575x240 at 0x7D9AA5BA31D0>, <PIL.Jpeg

## Dataset Pre-Processing

In [6]:
# IMPORTANT: limit dataset size to avoid RAM crash
small_train_dataset = dataset["train"].select(range(200))
print("Using", len(small_train_dataset), "stories for training")


Using 200 stories for training


In [7]:
import torch
from torch.utils.data import Dataset
from torchvision import transforms
import re

def clean_story_text(text):
    text = re.sub(r"<.*?>", "", text)
    text = re.sub(r"\s+", " ", text).strip()
    return text


class StoryReasoningDataset(Dataset):
    def __init__(self, hf_dataset, K=4):
        self.dataset = hf_dataset
        self.K = K
        self.transform = transforms.Compose([
            transforms.Resize((224,224)),
            transforms.ToTensor()
        ])

        # build index map (story_idx, window_start)
        self.index_map = []
        for i, ex in enumerate(self.dataset):
            frame_count = ex["frame_count"]
            for j in range(frame_count - K):
                self.index_map.append((i, j))

    def __len__(self):
        return len(self.index_map)

    def __getitem__(self, idx):
        story_idx, start = self.index_map[idx]
        example = self.dataset[story_idx]

        images = example["images"]
        story_text = clean_story_text(example["story"])

        input_images = images[start:start+self.K]
        target_image = images[start+self.K]

        imgs = [self.transform(img) for img in input_images]
        tgt_img = self.transform(target_image)

        return (
            torch.stack(imgs),     # (K, 3, 224, 224)
            story_text,
            tgt_img,
            story_text
        )


In [8]:
from torch.utils.data import DataLoader

train_dataset = StoryReasoningDataset(small_train_dataset, K=4)

train_loader = DataLoader(
    train_dataset,
    batch_size=2,
    shuffle=True,
    num_workers=2,
    pin_memory=True
)

print("Total training samples:", len(train_dataset))

# Test batch
images, text, tgt_img, tgt_text = next(iter(train_loader))
print("Images:", images.shape)
print("Target image:", tgt_img.shape)
print("Text sample:", text[0][:150])


Total training samples: 1742
Images: torch.Size([2, 4, 3, 224, 224])
Target image: torch.Size([2, 3, 224, 224])
Text sample: The day was overcast, but Mr. Thompson felt the weight of his mission as he approached the quaint house in the suburban area. Adjusting his tie, he st


### Image Encoder (CNN with Transfer Learning)

#### Importing Required Libraries

In [9]:
import torch
import torch.nn as nn
from torchvision import models

#### Define Image Encoder (ResNet18)

In [10]:
class ImageEncoder(nn.Module):
    def __init__(self, embed_dim=512):
        super().__init__()

        # Load pretrained ResNet18
        resnet = models.resnet18(pretrained=True)

        # Remove classification head
        self.backbone = nn.Sequential(*list(resnet.children())[:-1])

        # Project CNN features into shared embedding space
        self.fc = nn.Linear(512, embed_dim)

    def forward(self, x):
        """
        x shape: (B, K, 3, 224, 224)
        """
        B, K, C, H, W = x.shape

        # Merge batch and time for CNN processing
        x = x.view(B * K, C, H, W)

        feats = self.backbone(x)           # (B*K, 512, 1, 1)
        feats = feats.squeeze(-1).squeeze(-1)  # (B*K, 512)

        feats = self.fc(feats)             # (B*K, embed_dim)

        # Restore temporal dimension
        feats = feats.view(B, K, -1)       # (B, K, embed_dim)

        return feats


In [11]:
# Create encoder
image_encoder = ImageEncoder(embed_dim=512)

# Move to GPU if available
device = "cuda" if torch.cuda.is_available() else "cpu"
image_encoder = image_encoder.to(device)

# Get one batch from DataLoader
images, text, tgt_img, tgt_text = next(iter(train_loader))
images = images.to(device)

# Forward pass
with torch.no_grad():
    image_features = image_encoder(images)

print("Image features shape:", image_features.shape)




Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth


100%|██████████| 44.7M/44.7M [00:00<00:00, 105MB/s]


Image features shape: torch.Size([2, 4, 512])


### Text Encoder (Transformer / BERT)

In [12]:
from transformers import BertTokenizer, BertModel

#### Load Tokenizer + BERT

In [13]:
# Load tokenizer and model
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
bert_model = BertModel.from_pretrained("bert-base-uncased")

# Move to GPU if available
bert_model = bert_model.to(device)
bert_model.eval()


tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(30522, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0-11): 12 x BertLayer(
        (attention): BertAttention(
          (self): BertSdpaSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False

#### Text Encoder Module

In [14]:
class TextEncoder(nn.Module):
    def __init__(self, embed_dim=512):
        super().__init__()
        self.bert = bert_model
        self.fc = nn.Linear(768, embed_dim)

    def forward(self, input_ids, attention_mask):
        with torch.no_grad():  # freeze BERT for stability
            outputs = self.bert(
                input_ids=input_ids,
                attention_mask=attention_mask
            )

        pooled = outputs.pooler_output    # (B, 768)
        return self.fc(pooled)            # (B, 512)


In [15]:
# Take text batch from DataLoader
_, text_batch, _, _ = next(iter(train_loader))

# Tokenize
encoded = tokenizer(
    list(text_batch),
    padding=True,
    truncation=True,
    max_length=128,
    return_tensors="pt"
)

input_ids = encoded["input_ids"].to(device)
attention_mask = encoded["attention_mask"].to(device)

print("Input IDs shape:", input_ids.shape)
print("Attention mask shape:", attention_mask.shape)


Input IDs shape: torch.Size([2, 128])
Attention mask shape: torch.Size([2, 128])


In [16]:
text_encoder = TextEncoder(embed_dim=512).to(device)

with torch.no_grad():
    text_features = text_encoder(input_ids, attention_mask)

print("Text features shape:", text_features.shape)


Text features shape: torch.Size([2, 512])


### Multimodal Fusion with Cross-Modal Attention

#### Cross-Modal Attention Module

In [17]:
class CrossModalAttention(nn.Module):
    def __init__(self, embed_dim=512, num_heads=4):
        super().__init__()
        self.attn = nn.MultiheadAttention(
            embed_dim=embed_dim,
            num_heads=num_heads,
            batch_first=True
        )

    def forward(self, image_feats, text_feats):
        """
        image_feats: (B, K, 512)
        text_feats: (B, 512)
        """
        # Expand text to act as key & value for each timestep
        text_feats = text_feats.unsqueeze(1)  # (B, 1, 512)

        # Image queries attend over text
        fused, attn_weights = self.attn(
            query=image_feats,
            key=text_feats,
            value=text_feats
        )

        return fused, attn_weights


#### Test Cross-Modal Attention

In [18]:
cross_attn = CrossModalAttention(embed_dim=512).to(device)

with torch.no_grad():
    fused_feats, attn_weights = cross_attn(image_features, text_features)

print("Fused features shape:", fused_feats.shape)
print("Attention weights shape:", attn_weights.shape)

Fused features shape: torch.Size([2, 4, 512])
Attention weights shape: torch.Size([2, 4, 1])


### Temporal Sequence Model (LSTM)

#### Define Sequence Model (LSTM)

In [19]:
class TemporalSequenceModel(nn.Module):
    def __init__(self, embed_dim=512):
        super().__init__()
        self.lstm = nn.LSTM(
            input_size=embed_dim,
            hidden_size=embed_dim,
            num_layers=1,
            batch_first=True
        )

    def forward(self, x):
        """
        x: (B, K, 512)
        """
        outputs, (h_n, c_n) = self.lstm(x)

        # Use last hidden state as story representation
        story_repr = h_n[-1]  # (B, 512)

        return story_repr


In [20]:
temporal_model = TemporalSequenceModel(embed_dim=512).to(device)

with torch.no_grad():
    story_representation = temporal_model(fused_feats)

print("Story representation shape:", story_representation.shape)


Story representation shape: torch.Size([2, 512])


### Dual Decoders (Image + Text)

#### Image Decoder (Embedding Prediction)

In [21]:
class ImageDecoder(nn.Module):
    def __init__(self, embed_dim=512):
        super().__init__()
        self.fc = nn.Linear(embed_dim, embed_dim)

    def forward(self, x):
        """
        x: (B, 512)
        """
        return self.fc(x)   # (B, 512)


#### Text Decoder (Vocabulary Prediction)

In [22]:
class TextDecoder(nn.Module):
    def __init__(self, embed_dim=512, vocab_size=30522):
        super().__init__()
        self.fc = nn.Linear(embed_dim, vocab_size)

    def forward(self, x):
        """
        x: (B, 512)
        """
        return self.fc(x)   # (B, vocab_size)


#### Test Both Decoders

In [23]:
image_decoder = ImageDecoder(embed_dim=512).to(device)
text_decoder = TextDecoder(embed_dim=512).to(device)

with torch.no_grad():
    predicted_image_embed = image_decoder(story_representation)
    predicted_text_logits = text_decoder(story_representation)

print("Predicted image embedding shape:", predicted_image_embed.shape)
print("Predicted text logits shape:", predicted_text_logits.shape)


Predicted image embedding shape: torch.Size([2, 512])
Predicted text logits shape: torch.Size([2, 30522])


### Full End-to-End Model

#### Define the Full Model

In [24]:
class VisualStoryModel(nn.Module):
    def __init__(self, embed_dim=512, vocab_size=30522):
        super().__init__()

        self.image_encoder = ImageEncoder(embed_dim)
        self.text_encoder = TextEncoder(embed_dim)
        self.cross_attention = CrossModalAttention(embed_dim)
        self.temporal_model = TemporalSequenceModel(embed_dim)
        self.image_decoder = ImageDecoder(embed_dim)
        self.text_decoder = TextDecoder(embed_dim, vocab_size)

    def forward(self, images, input_ids, attention_mask):
        # Encode modalities
        img_feats = self.image_encoder(images)                     # (B, K, 512)
        txt_feats = self.text_encoder(input_ids, attention_mask)   # (B, 512)

        # Fuse modalities
        fused_feats, _ = self.cross_attention(img_feats, txt_feats)

        # Temporal modeling
        story_repr = self.temporal_model(fused_feats)              # (B, 512)

        # Decode outputs
        pred_img_embed = self.image_decoder(story_repr)
        pred_txt_logits = self.text_decoder(story_repr)

        return pred_img_embed, pred_txt_logits


#### Instantiate Full Model

In [25]:
model = VisualStoryModel(embed_dim=512, vocab_size=30522).to(device)
model.eval()



VisualStoryModel(
  (image_encoder): ImageEncoder(
    (backbone): Sequential(
      (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (4): Sequential(
        (0): BasicBlock(
          (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (1): BasicBlock(
          (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): Batc

#### Full Forward Pass Test

In [26]:
# Get one batch
images, text_batch, tgt_img, tgt_text = next(iter(train_loader))
images = images.to(device)

# Tokenize text
encoded = tokenizer(
    list(text_batch),
    padding=True,
    truncation=True,
    max_length=128,
    return_tensors="pt"
)

input_ids = encoded["input_ids"].to(device)
attention_mask = encoded["attention_mask"].to(device)

# Forward pass
with torch.no_grad():
    pred_img_embed, pred_txt_logits = model(
        images, input_ids, attention_mask
    )

print("Final predicted image embedding:", pred_img_embed.shape)
print("Final predicted text logits:", pred_txt_logits.shape)


Final predicted image embedding: torch.Size([2, 512])
Final predicted text logits: torch.Size([2, 30522])


### Training Objective, Loss Functions & Evaluation

#### Define Loss Functions

In [27]:
image_loss_fn = nn.MSELoss()
text_loss_fn = nn.CrossEntropyLoss()


#### Optimizer

In [28]:
optimizer = torch.optim.Adam(
    model.parameters(),
    lr=1e-4
)

In [29]:
%%writefile visualize.py
# ==========================================
# Visualization Utilities for Results
# ==========================================

import os
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import torch.nn.functional as F
import pandas as pd


# --------------------------------------------------
# Ensure result directories exist
# --------------------------------------------------

def ensure_result_dirs():
    os.makedirs("results/figures", exist_ok=True)
    os.makedirs("results/tables", exist_ok=True)


# --------------------------------------------------
# Plot 1: Training Loss Curve
# --------------------------------------------------

def plot_training_loss(epochs, losses):
    ensure_result_dirs()

    plt.figure(figsize=(6, 4))
    plt.plot(epochs, losses, marker="o")
    plt.xlabel("Epoch")
    plt.ylabel("Average Training Loss")
    plt.title("Training Loss Across Epochs")
    plt.grid(True)

    plt.savefig(
        "results/figures/training_loss.png",
        dpi=300,
        bbox_inches="tight"
    )
    plt.close()


# --------------------------------------------------
# Plot 2: Cross-Modal Attention Heatmap
# --------------------------------------------------

def plot_attention_heatmap(attn_weights):
    """
    attn_weights: (B, K, 1)
    """
    ensure_result_dirs()

    attn = attn_weights[0].cpu().numpy()

    plt.figure(figsize=(4, 3))
    sns.heatmap(
        attn,
        annot=True,
        cmap="viridis",
        cbar=True,
        yticklabels=[f"Timestep {i+1}" for i in range(attn.shape[0])],
        xticklabels=["Text"]
    )

    plt.title("Cross-Modal Attention Weights")

    plt.savefig(
        "results/figures/cross_modal_attention.png",
        dpi=300,
        bbox_inches="tight"
    )
    plt.close()


# --------------------------------------------------
# Plot 3: Image Embedding Similarity Histogram
# --------------------------------------------------

def plot_image_similarity(pred_embed, target_embed):
    ensure_result_dirs()

    similarity = F.cosine_similarity(pred_embed, target_embed).cpu().numpy()

    plt.figure(figsize=(6, 4))
    plt.hist(similarity, bins=20)
    plt.xlabel("Cosine Similarity")
    plt.ylabel("Frequency")
    plt.title("Image Embedding Similarity Distribution")

    plt.savefig(
        "results/figures/image_embedding_similarity.png",
        dpi=300,
        bbox_inches="tight"
    )
    plt.close()


# --------------------------------------------------
# Plot 4: Text Top-5 Confidence
# --------------------------------------------------

def plot_text_topk_confidence(logits, k=5):
    ensure_result_dirs()

    probs = torch.softmax(logits, dim=1)
    topk_probs, topk_indices = probs[0].topk(k)

    plt.figure(figsize=(6, 4))
    plt.bar(range(k), topk_probs.cpu().numpy())
    plt.xticks(range(k), [f"Token {i}" for i in topk_indices.cpu().numpy()])
    plt.ylabel("Probability")
    plt.title(f"Top-{k} Text Prediction Confidence")

    plt.savefig(
        "results/figures/text_top5_confidence.png",
        dpi=300,
        bbox_inches="tight"
    )
    plt.close()


# --------------------------------------------------
# Save Metrics Table
# --------------------------------------------------

def save_metrics_table(metrics_dict):
    ensure_result_dirs()

    df = pd.DataFrame([metrics_dict])
    df.to_csv(
        "results/tables/metrics_summary.csv",
        index=False
    )


Writing visualize.py


### Model Training

In [30]:
num_epochs = 20
epoch_losses = []   # 🔹 for plotting

model.train()

for epoch in range(num_epochs):
    total_loss = 0.0

    for images, text_batch, tgt_img, tgt_text in train_loader:
        images = images.to(device)
        tgt_img = tgt_img.to(device)

        # Tokenize input text
        encoded = tokenizer(
            list(text_batch),
            padding=True,
            truncation=True,
            max_length=128,
            return_tensors="pt"
        )
        input_ids = encoded["input_ids"].to(device)
        attention_mask = encoded["attention_mask"].to(device)

        # Forward pass (UNCHANGED)
        pred_img_embed, pred_txt_logits = model(
            images, input_ids, attention_mask
        )

        # Target image embedding (reuse model encoder)
        with torch.no_grad():
            target_img_embed = model.image_encoder(
                tgt_img.unsqueeze(1)
            ).squeeze(1)

        # Target text token (first token)
        target_token = input_ids[:, 0]

        # Loss
        loss_img = image_loss_fn(pred_img_embed, target_img_embed)
        loss_txt = text_loss_fn(pred_txt_logits, target_token)
        loss = loss_img + loss_txt

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    avg_loss = total_loss / len(train_loader)
    epoch_losses.append(avg_loss)

    print(f"Epoch [{epoch+1}/{num_epochs}] - Avg Loss: {avg_loss:.4f}")


Epoch [1/20] - Avg Loss: 0.3154
Epoch [2/20] - Avg Loss: 0.0677
Epoch [3/20] - Avg Loss: 0.0677
Epoch [4/20] - Avg Loss: 0.0677
Epoch [5/20] - Avg Loss: 0.0671
Epoch [6/20] - Avg Loss: 0.0674
Epoch [7/20] - Avg Loss: 0.0673
Epoch [8/20] - Avg Loss: 0.0677
Epoch [9/20] - Avg Loss: 0.0675
Epoch [10/20] - Avg Loss: 0.0675
Epoch [11/20] - Avg Loss: 0.0675
Epoch [12/20] - Avg Loss: 0.0675
Epoch [13/20] - Avg Loss: 0.0674
Epoch [14/20] - Avg Loss: 0.0673
Epoch [15/20] - Avg Loss: 0.0672
Epoch [16/20] - Avg Loss: 0.0675
Epoch [17/20] - Avg Loss: 0.0723
Epoch [18/20] - Avg Loss: 0.2802
Epoch [19/20] - Avg Loss: 0.0677
Epoch [20/20] - Avg Loss: 0.0675


### Saving Visualizations and Metrix

In [34]:
%%writefile utils.py
import torch
import torch.nn.functional as F
from datasets import load_dataset
from torch.utils.data import Dataset
from torchvision import transforms
from PIL import Image
import random
import numpy as np

# Disable EXIF handling (Windows-safe)
Image.MAX_IMAGE_PIXELS = None


# -----------------------------
# Reproducibility
# -----------------------------

def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


# -----------------------------
# Text cleaning
# -----------------------------

def clean_story_text(text):
    return text.replace("<gdi>", "").replace("</gdi>", "")


# -----------------------------
# Dataset
# -----------------------------

class StoryReasoningDataset(Dataset):
    def __init__(
        self,
        hf_dataset_name,
        split="train",
        max_stories=200,
        sequence_length=4,
        image_size=(224, 224),
        text_max_length=128
    ):
        self.sequence_length = sequence_length

        self.dataset = load_dataset(hf_dataset_name, split=split)
        if max_stories:
            self.dataset = self.dataset.select(range(max_stories))

        self.transform = transforms.Compose([
            transforms.Resize(image_size),
            transforms.ToTensor()
        ])

        self.index_map = []
        for i in range(len(self.dataset)):
            frame_count = self.dataset[i]["frame_count"]
            for j in range(frame_count - sequence_length):
                self.index_map.append((i, j))

    def __len__(self):
        return len(self.index_map)

    def __getitem__(self, idx):
        story_idx, start = self.index_map[idx]
        example = self.dataset[story_idx]

        images = [img.convert("RGB") for img in example["images"]]
        story_text = clean_story_text(example["story"])

        input_imgs = images[start:start+self.sequence_length]
        target_img = images[start+self.sequence_length]

        input_imgs = [self.transform(img) for img in input_imgs]
        target_img = self.transform(target_img)

        return (
            torch.stack(input_imgs),
            story_text,
            target_img,
            story_text
        )


# -----------------------------
# Metrics
# -----------------------------

def text_topk_accuracy(logits, targets, k=1):
    _, topk = logits.topk(k, dim=1)
    correct = topk.eq(targets.unsqueeze(1)).any(dim=1)
    return correct.float().mean().item()


def image_cosine_similarity(pred, target):
    return F.cosine_similarity(pred, target).mean().item()


Writing utils.py


In [35]:
from utils import text_topk_accuracy, image_cosine_similarity

top1_acc = text_topk_accuracy(pred_txt_logits, input_ids[:, 0], k=1)
top5_acc = text_topk_accuracy(pred_txt_logits, input_ids[:, 0], k=5)
img_similarity = image_cosine_similarity(pred_img_embed, target_img_embed)

print("Text Top-1 Accuracy:", top1_acc)
print("Text Top-5 Accuracy:", top5_acc)
print("Image Cosine Similarity:", img_similarity)


Text Top-1 Accuracy: 1.0
Text Top-5 Accuracy: 1.0
Image Cosine Similarity: 0.7830475568771362


In [36]:
from visualize import (
    plot_training_loss,
    plot_image_similarity,
    plot_text_topk_confidence,
    save_metrics_table
)

# Training loss curve
plot_training_loss(
    epochs=list(range(1, num_epochs + 1)),
    losses=epoch_losses
)

# Image embedding similarity
plot_image_similarity(pred_img_embed, target_img_embed)

# Text prediction confidence
plot_text_topk_confidence(pred_txt_logits, k=5)

# Save metrics table
save_metrics_table({
    "epochs": num_epochs,
    "final_training_loss": epoch_losses[-1],
    "text_top1_accuracy": top1_acc,
    "text_top5_accuracy": top5_acc,
    "image_cosine_similarity": img_similarity
})
