# VidChapter-7M: Chapter Generation with TinyLlama

This notebook implements the training and inference pipeline for video chapter generation using the `lucas-ventura/chapter-llama` dataset and `TinyLlama/TinyLlama-1.1B-Chat-v1.0` model.

In [1]:
import os
import json
from dataclasses import dataclass
from typing import List, Dict, Any

import torch
from torch.utils.data import Dataset
from huggingface_hub import hf_hub_download
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    Trainer,
    TrainingArguments,
    set_seed,
)




In [2]:
DATASET_REPO = "lucas-ventura/chapter-llama"

# Small model for local experimentation
MODEL_NAME = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"

# IMPORTANT: must be <= model.config.max_position_embeddings (2048 for TinyLlama)
MAX_INPUT_TOKENS = 1000   # total sequence length (prompt + targets)

# Keep transcripts modest so chapters fit too
MAX_TRANSCRIPT_CHARS = 1500

SEED = 42
set_seed(SEED)

In [3]:
def seconds_to_hhmmss(t: float) -> str:
    """Convert seconds to HH:MM:SS (floor)."""
    t = max(0, int(t))
    h = t // 3600
    m = (t % 3600) // 60
    s = t % 60
    return f"{h:02d}:{m:02d}:{s:02d}"


def safe_get(d: Dict, key: str, default=None):
    v = d.get(key, default)
    return v if v is not None else default


def load_json_from_hf(path_in_repo: str) -> Dict[str, Any]:
    """Download a JSON file from the HF dataset repo and load it."""
    local_path = hf_hub_download(repo_id=DATASET_REPO, filename=path_in_repo, repo_type="dataset")
    with open(local_path, "r", encoding="utf-8") as f:
        return json.load(f)

In [4]:
class VidChaptersAsrDataset(Dataset):
    """
    ASR-only dataset for Chapter-Llama style training.

    Assumptions (based on the released subset JSONs):
    - chapters_*.json:
        {
            "<video_id>": {
             "duration": float,
             "title": str,
             "description": str,
             "channel_id": str,
             "view_count": int,
             "chapters": {
                "<start_sec>": "<chapter title>"
                ...
                }
            },
          ...
        }

    - asrs_*.json: supports TWO shapes:

        (1) Dict-of-lists (as you observed):
            {
                "<video_id>": {
                "text":  [str, ...],
                "start": [float, ...],
                "end":   [float, ...]
            },
          ...
        }

        (2) List-of-dicts (original assumption):
            {
                "<video_id>": [
                    {"start": float, "end": float, "text": str},
                    ...
                    ],
                ...
            }

    Each __getitem__ returns a dict with tokenized
    input_ids, attention_mask, labels for a *single video*.
    """

    def __init__(
        self,
        tokenizer,
        chapters_json: Dict[str, Any],
        asrs_json: Dict[str, Any],
        max_input_tokens: int = 2048,
        max_transcript_chars: int = 3000,
        max_videos: int | None = None,
    ):
        self.tokenizer = tokenizer
        self.chapters_json = chapters_json
        self.asrs_json = asrs_json
        self.max_input_tokens = max_input_tokens
        self.max_transcript_chars = max_transcript_chars

        all_vids = sorted(set(chapters_json.keys()) & set(asrs_json.keys()))
        if max_videos is not None:
            all_vids = all_vids[:max_videos]
        self.video_ids = all_vids

        # Pre-tokenize static instruction skeleton to avoid recomputing its ids
        self.base_instruction = (
            "You are an expert system that segments long YouTube videos into "
            "semantically meaningful chapters and names each chapter.\n\n"
            "Given the video title, description, and a subset of the ASR transcript "
            "with timestamps, output ALL chapters for the video as lines of the form:\n"
            "HH:MM:SS - Chapter title\n\n"
            "Be faithful to the content and avoid hallucinating chapters that are "
            "not supported by the transcript.\n\n"
        )

    def __len__(self):
        return len(self.video_ids)

    def _build_transcript_text(self, asr_segments) -> str:
        """
            Build a textual representation of ASR:

            Handles:
            - dict-of-lists: {"text": [...], "start": [...], "end": [...]}
            - list-of-dicts: [{"text": ..., "start": ..., "end": ...}, ...]

            Returns a single string like:
              [HH:MM:SS] utterance
              [HH:MM:SS] next utterance
              ...
            truncated by self.max_transcript_chars.
        """

        triplets = []

        if isinstance(asr_segments, dict) and all(
            k in asr_segments for k in ("text", "start", "end")
        ):
            texts = asr_segments["text"]
            starts = asr_segments["start"]
            ends = asr_segments["end"]
            for t, s, e in zip(texts, starts, ends):
                if not t:
                    continue
                try:
                    s_float = float(s)
                except Exception:
                    s_float = 0.0
                triplets.append((s_float, e, t))
        else:
            # Assume list[dict]
            for seg in asr_segments:
                text = seg.get("text", "")
                if not text:
                    continue
                try:
                    s_float = float(seg.get("start", 0.0))
                except Exception:
                    s_float = 0.0
                triplets.append((s_float, seg.get("end", 0.0), text))

        triplets.sort(key=lambda x: x[0])  # sort by start time

        chunks = []
        total_chars = 0

        for start, end, text in triplets:
            ts = seconds_to_hhmmss(start)
            line = f"[{ts}] {text.strip()}"
            new_len = total_chars + len(line) + 1
            if new_len > self.max_transcript_chars:
                break
            chunks.append(line)
            total_chars = new_len

        return "\n".join(chunks)

    def _build_chapter_target(self, chapters_dict: Dict[str, str]) -> str:
        """
        Turn chapters dict into canonical text:
        HH:MM:SS - Title
        one per line, sorted by time.
        """
        items = []
        for start_str, title in chapters_dict.items():
            try:
                t = float(start_str)
            except Exception:
                # Sometimes keys might already be numeric-ish strings; fallback
                try:
                    t = float(str(start_str).replace(",", ""))
                except Exception:
                    t = 0.0
            items.append((t, title))
        items.sort(key=lambda x: x[0])
        lines = [f"{seconds_to_hhmmss(t)} - {title}" for t, title in items]
        return "\n".join(lines)

    def __getitem__(self, idx):
        """
        Critical invariants:
        - Total sequence length <= self.max_input_tokens
        - There is at least ONE label token != -100
        """

        # Safety loop: try a few different videos if current one is unusable
        for _attempt in range(5):
            vid = self.video_ids[idx]
            chap_entry = self.chapters_json[vid]
            asr_segments = self.asrs_json[vid]

            video_title = safe_get(chap_entry, "title", "")
            video_desc = safe_get(chap_entry, "description", "")
            chapters_dict = safe_get(chap_entry, "chapters", {})

            # Build transcript and target chapter text
            transcript_text = self._build_transcript_text(asr_segments)
            chapters_text = self._build_chapter_target(chapters_dict)

            # ---- Prompt & target strings ----
            prompt = (
                self.base_instruction
                + f"Video title: {video_title}\n"
                + f"Video description: {video_desc}\n\n"
                + "### Transcript (partial, chronologically ordered):\n"
                + transcript_text
                + "\n\n### Chapters\n"
            )
            target = chapters_text

            # ---- Separate tokenization for prompt/target ----
            max_total = self.max_input_tokens
            max_prompt = int(max_total * 0.75)   # allow prompt to take up to 75%
            if max_prompt < 1:
                max_prompt = max_total // 2

            # Prompt
            prompt_enc = self.tokenizer(
                prompt,
                add_special_tokens=False,
                truncation=True,
                max_length=max_prompt,
            )
            prompt_ids = prompt_enc["input_ids"]

            # Remaining room for target
            max_target = max_total - len(prompt_ids)
            if max_target <= 0:
                # Prompt alone too long: keep only last chunk of prompt
                prompt_ids = prompt_ids[-(max_total // 2):]
                max_target = max_total - len(prompt_ids)

            # Target (chapter lines)
            target_enc = self.tokenizer(
                target,
                add_special_tokens=False,
                truncation=True,
                max_length=max_target,
            )
            target_ids = target_enc["input_ids"]

            # If no target tokens → this sample is useless; try another video
            if len(target_ids) == 0:
                idx = (idx + 1) % len(self.video_ids)
                continue

            input_ids = prompt_ids + target_ids
            attention_mask = [1] * len(input_ids)
            labels = [-100] * len(prompt_ids) + target_ids

            # Safety: truncate everything to max_total
            if len(input_ids) > max_total:
                input_ids = input_ids[:max_total]
                attention_mask = attention_mask[:max_total]
                labels = labels[:max_total]

            # Final check: ensure at least one token has a real label
            if all(l == -100 for l in labels):
                idx = (idx + 1) % len(self.video_ids)
                continue

            return {
                "input_ids": input_ids,
                "attention_mask": attention_mask,
                "labels": labels,
                "video_id": vid,
            }

        # If we somehow failed 5 times in a row:
        raise RuntimeError("Failed to construct a valid training example with at least one label token.")

In [5]:
@dataclass
class ChapteringCollator:
    tokenizer: Any

    def __call__(self, batch: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
        # Variable-length → pad
        max_len = max(len(ex["input_ids"]) for ex in batch)
        pad_id = self.tokenizer.pad_token_id

        input_ids = []
        attention_mask = []
        labels = []

        for ex in batch:
            n = len(ex["input_ids"])
            pad_len = max_len - n

            input_ids.append(ex["input_ids"] + [pad_id] * pad_len)
            attention_mask.append(ex["attention_mask"] + [0] * pad_len)
            labels.append(ex["labels"] + [-100] * pad_len)

        return {
            "input_ids": torch.tensor(input_ids, dtype=torch.long),
            "attention_mask": torch.tensor(attention_mask, dtype=torch.long),
            "labels": torch.tensor(labels, dtype=torch.long),
        }

## Display Example Data
Here we load the training data and display one example to understand the structure.

In [6]:
print("Loading subset JSONs from Hugging Face...")

chapters_train = load_json_from_hf("docs/subset_data/chapters/chapters_sml1k_train.json")
asrs_train = load_json_from_hf("docs/subset_data/asrs/asrs_sml1k_train.json")

# Display one example
example_vid_id = list(chapters_train.keys())[0]
print(f"Example Video ID: {example_vid_id}")
print("Chapter Data:")
print(json.dumps(chapters_train[example_vid_id], indent=2))
print("\nASR Data (first 2 segments):")
if isinstance(asrs_train[example_vid_id], dict):
    print(json.dumps({k: v[:2] for k, v in asrs_train[example_vid_id].items()}, indent=2))
else:
    print(json.dumps(asrs_train[example_vid_id][:2], indent=2))

Loading subset JSONs from Hugging Face...
Example Video ID: -1vmZJ-EWtI
Chapter Data:
{
  "duration": 2595.906,
  "title": "If You Ever Experience Anxiety, Try These Tips to Overcome It | Seane Corn on Women of Impact",
  "description": "Hey guys, Lisa here! If you didn\u2019t already know, I am super frikin excited to share that I\u2019m writing a book! To be the FIRST to get sneak peeks about my book and other exclusive content go to: http://lisabilyeu.com/ and be sure to sign up for my newsletter.\n\nSeane Corn overcame OCD and severe anxiety, not to mention drug abuse, through yoga and spiritual practice. Now she is part of the revolution of the soul, practicing radical healing, conscious action, and guiding people to self-awareness. On this episode of Women of Impact with Lisa Bilyeu, Seane Corn shares her history with trauma, explains how to identify and change the stories we tell ourselves, and describes how she healed herself with yoga.\n\nSHOW NOTES:\n\nSeane shares her histor

## Model Setup and Training
Initialize the model, tokenizer, and trainer.

In [7]:
# Device (MPS on Apple Silicon if available)
if torch.backends.mps.is_available():
    device = torch.device("mps")
elif torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
print(f"Using device: {device}")

# Load tokenizer & model
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
tokenizer.model_max_length = MAX_INPUT_TOKENS

model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
model.resize_token_embeddings(len(tokenizer))
model.config.use_cache = False  # important for training

max_ctx = model.config.max_position_embeddings
print("Model max position embeddings:", max_ctx)
assert MAX_INPUT_TOKENS <= max_ctx

model.to(device)

Using device: cpu
Model max position embeddings: 2048


LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32000, 2048)
    (layers): ModuleList(
      (0-21): 22 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (k_proj): Linear(in_features=2048, out_features=256, bias=False)
          (v_proj): Linear(in_features=2048, out_features=256, bias=False)
          (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=2048, out_features=5632, bias=False)
          (up_proj): Linear(in_features=2048, out_features=5632, bias=False)
          (down_proj): Linear(in_features=5632, out_features=2048, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
      )
    )
    (norm): LlamaRMSNorm((2048,), eps=1e-05)
    (rotary_emb): 

In [8]:
chapters_val = load_json_from_hf("docs/subset_data/chapters/chapters_sml300_val.json")
asrs_val = load_json_from_hf("docs/subset_data/asrs/asrs_sml300_val.json")

# Build datasets
train_dataset = VidChaptersAsrDataset(
    tokenizer=tokenizer,
    chapters_json=chapters_train,
    asrs_json=asrs_train,
    max_input_tokens=MAX_INPUT_TOKENS,
    max_transcript_chars=MAX_TRANSCRIPT_CHARS,
    max_videos=50,  # or e.g. 200 for very quick smoke test
)

val_dataset = VidChaptersAsrDataset(
    tokenizer=tokenizer,
    chapters_json=chapters_val,
    asrs_json=asrs_val,
    max_input_tokens=MAX_INPUT_TOKENS,
    max_transcript_chars=MAX_TRANSCRIPT_CHARS,
    max_videos=50,
)

print(f"#train videos: {len(train_dataset)}")
print(f"#val videos:   {len(val_dataset)}")

collator = ChapteringCollator(tokenizer=tokenizer)

#train videos: 50
#val videos:   50


In [9]:
# -------------------------
# Training args
# -------------------------
output_dir = "outputs/chapter_llama_asr_sml1k_tiny"
os.makedirs(output_dir, exist_ok=True)

# Conservative settings for stability on M4 + TinyLlama
training_args = TrainingArguments(
    output_dir=output_dir,
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    gradient_accumulation_steps=2,
    num_train_epochs=1.0,
    learning_rate=2e-5,
    warmup_ratio=0.03,
    weight_decay=0.01,
    logging_steps=10,
    eval_strategy="steps",
    eval_steps=100,
    save_steps=500,
    save_total_limit=2,
    report_to="none",
    fp16=False,
    bf16=False,
    gradient_checkpointing=False,  # keep off for now for stability
    remove_unused_columns=False,
    max_grad_norm=1.0,             # gradient clipping
)

# Simple metric: validation loss (Trainer logs eval_loss) + we can add perplexity later
def compute_metrics(eval_pred):
    # logits, labels = eval_pred
    return {}

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    data_collator=collator,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
)

print("Starting training...")
trainer.train()
print("Done.")

# Save final model & tokenizer
trainer.save_model(output_dir)
tokenizer.save_pretrained(output_dir)
print(f"Model saved to: {output_dir}")

Starting training...


  trainer = Trainer(


Step,Training Loss,Validation Loss


Done.
Model saved to: outputs/chapter_llama_asr_sml1k_tiny


## Inference
Run inference on a held-out video.

In [10]:
def run_single_video_inference(
    model_dir: str,
    subset_tag: str = "s",
    max_new_tokens: int = 256,
):
    """
    Load a fine-tuned checkpoint and run chaptering on ONE held-out video
    from chapters_{subset_tag}_test.json + asrs_{subset_tag}_test.json.

    subset_tag="s" → use:
      docs/subset_data/chapters/chapters_s_test.json
      docs/subset_data/asrs/asrs_s_test.json
    """

    # --------------------------
    # Device & model
    # --------------------------
    if torch.backends.mps.is_available():
        device = torch.device("mps")
    elif torch.cuda.is_available():
        device = torch.device("cuda")
    else:
        device = torch.device("cpu")
    print(f"[INFER] Using device: {device}")

    tokenizer = AutoTokenizer.from_pretrained(model_dir)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    model = AutoModelForCausalLM.from_pretrained(model_dir)
    model.to(device)
    model.eval()

    # --------------------------
    # Load test JSONs
    # --------------------------
    chapters_test = load_json_from_hf(
        f"docs/subset_data/chapters/chapters_{subset_tag}_test.json"
    )
    asrs_test = load_json_from_hf(
        f"docs/subset_data/asrs/asrs_{subset_tag}_test.json"
    )

    # Tiny dataset wrapper to reuse formatting helpers
    test_dataset = VidChaptersAsrDataset(
        tokenizer=tokenizer,
        chapters_json=chapters_test,
        asrs_json=asrs_test,
        max_input_tokens=MAX_INPUT_TOKENS,
        max_transcript_chars=MAX_TRANSCRIPT_CHARS,
        max_videos=None,
    )

    if len(test_dataset) == 0:
        raise ValueError("[INFER] No test videos found.")

    vid0 = test_dataset.video_ids[0]
    chap_entry = chapters_test[vid0]
    asr_segments = asrs_test[vid0]

    video_title = chap_entry.get("title", "")
    video_desc = chap_entry.get("description", "")
    chapters_dict = chap_entry.get("chapters", {})

    # Reuse dataset helpers for text format
    transcript_text = test_dataset._build_transcript_text(asr_segments)
    gt_chapters_text = test_dataset._build_chapter_target(chapters_dict)
    base_instruction = test_dataset.base_instruction

    # Build prompt exactly like during training (minus target)
    prompt = (
        base_instruction
        + f"Video title: {video_title}\n"
        + f"Video description: {video_desc}\n\n"
        + "### Transcript (partial, chronologically ordered):\n"
        + transcript_text
        + "\n\n### Chapters\n"
    )

    inputs = tokenizer(
        prompt,
        return_tensors="pt",
        truncation=True,
        max_length=MAX_INPUT_TOKENS,
    )
    inputs = {k: v.to(device) for k, v in inputs.items()}

    with torch.no_grad():
        gen_ids = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            do_sample=False,
            num_beams=1,
            eos_token_id=tokenizer.eos_token_id,
        )

    # Cut off the prompt, keep only new tokens
    input_len = inputs["input_ids"].shape[1]
    gen_only = gen_ids[0][input_len:]
    pred_text = tokenizer.decode(gen_only, skip_special_tokens=True)

    # Pretty-print
    print("=" * 80)
    print(f"[INFER] VIDEO ID: {vid0}")
    print("=" * 80)

    print("\n[INFER] TRANSCRIPT (first ~15 lines):")
    t_lines = transcript_text.splitlines()
    for line in t_lines[:15]:
        print(line)
    if len(t_lines) > 15:
        print("... (transcript truncated) ...")

    print("\n[INFER] GROUND-TRUTH CHAPTERS:")
    print(gt_chapters_text)

    print("\n[INFER] PREDICTED CHAPTERS:")
    print(pred_text)
    print("=" * 80)

# Run inference
print("\n=== Running quick inference on one held-out test video ===")
run_single_video_inference(
    model_dir=output_dir,
    subset_tag="s",         # uses chapters_s_test / asrs_s_test
    max_new_tokens=256,
)


=== Running quick inference on one held-out test video ===
[INFER] Using device: cpu


docs/subset_data/chapters/chapters_s_tes(…):   0%|          | 0.00/12.0M [00:00<?, ?B/s]

To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development


docs/subset_data/asrs/asrs_s_test.json:   0%|          | 0.00/50.4M [00:00<?, ?B/s]

[INFER] VIDEO ID: --7febZVhIs

[INFER] TRANSCRIPT (first ~15 lines):
[00:00:02] Hey, what's up everybody?
[00:00:07] My name is Ray and I'm the creator of Ray So Silly, a YouTube channel dedicated to creating videos with green screen technology.
[00:00:14] I'm excited to be here on KineMaster's channel to show you how to use the KineMaster app
[00:00:19] to achieve the green screen effect.
[00:00:21] This is a popular effect used in many television shows, movies, and more.
[00:00:26] First, I'm going to show you how to achieve this effect, and then you can use your imagination to create a world of your own.
[00:00:32] Let's open up the app and get started.
[00:00:34] Once you're in KineMaster, tap the plus icon here to create new project.
[00:00:39] In here, choose the aspect ratio.
[00:00:41] I'm going to choose 16 by 9.
[00:00:44] Now click on media to go into your camera roll to select the video.
[00:00:48] Now from the video, switch to photo.
[00:00:50] And here we need a clear gre