In [2]:
import torch

print("Torch version:", torch.__version__)
print("CUDA available:", torch.cuda.is_available())

if torch.cuda.is_available():
    print("GPU:", torch.cuda.get_device_name(0))


Torch version: 2.9.0+cu126
CUDA available: True
GPU: Tesla T4


In [3]:
!pip install transformers torchvision einops



In [4]:
! pip install datasets



In [5]:
from datasets import load_dataset

dataset = load_dataset("daniel3303/StoryReasoning")
print(dataset)


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


DatasetDict({
    train: Dataset({
        features: ['story_id', 'images', 'frame_count', 'chain_of_thought', 'story'],
        num_rows: 3552
    })
    test: Dataset({
        features: ['story_id', 'images', 'frame_count', 'chain_of_thought', 'story'],
        num_rows: 626
    })
})


In [6]:
print("A sample story entry:\n")

sample = dataset["train"][0]
print("Story ID:", sample["story_id"])
print("Frame count:", sample["frame_count"])
print("Images (list of URLs):", sample["images"])
print("Text:", sample["story"])
print("Chain-of-thought:", sample["chain_of_thought"])

A sample story entry:

Story ID: 3920
Frame count: 17
Images (list of URLs): [<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=575x240 at 0x7EB01E57A660>, <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=575x240 at 0x7EB01C2623C0>, <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=575x240 at 0x7EB00BF4DE50>, <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=575x240 at 0x7EB00BF4DDC0>, <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=575x240 at 0x7EB00BF4E2A0>, <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=575x240 at 0x7EB00BF4E360>, <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=575x240 at 0x7EB00BF4E420>, <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=575x240 at 0x7EB00BF4E4E0>, <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=575x240 at 0x7EB00BF4E5A0>, <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=575x240 at 0x7EB00BF4E660>, <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=575x240 at 0x7EB00BF4E720>, <PIL.Jpeg

## Dataset Pre-Processing

In [9]:
# IMPORTANT: limit dataset size to avoid RAM crash
small_train_dataset = dataset["train"].select(range(200))
print("Using", len(small_train_dataset), "stories for training")


Using 200 stories for training


In [8]:
import torch
from torch.utils.data import Dataset
from torchvision import transforms
import re

def clean_story_text(text):
    text = re.sub(r"<.*?>", "", text)
    text = re.sub(r"\s+", " ", text).strip()
    return text


class StoryReasoningDataset(Dataset):
    def __init__(self, hf_dataset, K=4):
        self.dataset = hf_dataset
        self.K = K
        self.transform = transforms.Compose([
            transforms.Resize((224,224)),
            transforms.ToTensor()
        ])

        # build index map (story_idx, window_start)
        self.index_map = []
        for i, ex in enumerate(self.dataset):
            frame_count = ex["frame_count"]
            for j in range(frame_count - K):
                self.index_map.append((i, j))

    def __len__(self):
        return len(self.index_map)

    def __getitem__(self, idx):
        story_idx, start = self.index_map[idx]
        example = self.dataset[story_idx]

        images = example["images"]
        story_text = clean_story_text(example["story"])

        input_images = images[start:start+self.K]
        target_image = images[start+self.K]

        imgs = [self.transform(img) for img in input_images]
        tgt_img = self.transform(target_image)

        return (
            torch.stack(imgs),     # (K, 3, 224, 224)
            story_text,
            tgt_img,
            story_text
        )


In [10]:
from torch.utils.data import DataLoader

train_dataset = StoryReasoningDataset(small_train_dataset, K=4)

train_loader = DataLoader(
    train_dataset,
    batch_size=2,
    shuffle=True,
    num_workers=2,
    pin_memory=True
)

print("Total training samples:", len(train_dataset))

# Test batch
images, text, tgt_img, tgt_text = next(iter(train_loader))
print("Images:", images.shape)
print("Target image:", tgt_img.shape)
print("Text sample:", text[0][:150])


Total training samples: 1742
Images: torch.Size([2, 4, 3, 224, 224])
Target image: torch.Size([2, 3, 224, 224])
Text sample: In the heart of a dense forest, the man in the blue shirt and the man in the white shirt ran desperately, their hearts pounding with fear. The man in 
