# VideoLLM Architecture Playground

This notebook reconstructs the high-level components of the VideoLLM live streaming architecture in a lightweight, fully self-contained form. It mirrors the key classes in [`models/modeling_live.py`](../models/modeling_live.py) and [`models/live_llama/modeling_live_llama.py`](../models/live_llama/modeling_live_llama.py) by providing:

* a joint text-video embedding step with a connector that projects vision features into the language model hidden size,
* a minimal causal language model core that consumes the fused embeddings, and
* a training loss that up-weights the contribution of video placeholder tokens, just like the original implementation's `stream_loss_weight`.

To make the inner workings visible, the notebook injects extensive print statements so that tensor shapes and intermediate values are surfaced at every major step.

In [None]:
import math
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader

torch.manual_seed(42)


## 1. Configuration scaffold

The real project stores configuration values in `LiveConfigMixin` and `LiveLlamaConfig`.
Here we define a tiny dataclass-like container that captures only the attributes that the model relies on during forward passes.

In [None]:
class TinyLiveConfig:
    def __init__(
        self,
        vocab_size: int = 64,
        hidden_size: int = 32,
        vision_hidden_size: int = 24,
        frame_num_tokens: int = 1,
        stream_loss_weight: float = 3.0,
        v_placeholder_id: int = 63,
    ):
        self.vocab_size = vocab_size
        self.hidden_size = hidden_size
        self.vision_hidden_size = vision_hidden_size
        self.frame_num_tokens = frame_num_tokens
        self.stream_loss_weight = stream_loss_weight
        self.v_placeholder_id = v_placeholder_id


## 2. Minimal causal language backbone

The production model subclasses `LlamaForCausalLM`. To keep things lightweight we implement a tiny GRU-based causal language model that exposes the same helper methods used by the mixin (embedding lookup and a linear language head).

In [None]:
class TinyCausalBackbone(nn.Module):
    def __init__(self, config: TinyLiveConfig):
        super().__init__()
        self.config = config
        self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
        self.rnn = nn.GRU(config.hidden_size, config.hidden_size, batch_first=True)
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size)

    def forward(self, inputs_embeds: torch.Tensor):
        print(f"[TinyCausalBackbone] inputs_embeds shape: {inputs_embeds.shape}")
        hidden_states, _ = self.rnn(inputs_embeds)
        print(f"[TinyCausalBackbone] hidden_states shape: {hidden_states.shape}")
        logits = self.lm_head(hidden_states)
        print(f"[TinyCausalBackbone] logits shape: {logits.shape}")
        return logits

    def get_input_embeddings(self):
        return self.embed_tokens


## 3. Live-style fusion module

The heart of the architecture lives in `LiveMixin` and `LiveLlamaForCausalLM`.
The next cell recreates their behavior:

* `visual_embed` processes video frame features through a connector (two linear layers with GELU) and flattens them into token embeddings.
* `joint_embed` splices the resulting video embeddings into the token stream wherever the video placeholder token appears.
* `forward` calls the backbone, applies the language head, and computes the weighted cross-entropy loss that the original model uses when `stream_loss_weight != 1`.

In [None]:
class DummyLiveLlama(nn.Module):
    def __init__(self, config: TinyLiveConfig):
        super().__init__()
        self.config = config
        self.backbone = TinyCausalBackbone(config)
        self.connector = nn.Sequential(
            nn.Linear(config.vision_hidden_size, config.hidden_size),
            nn.GELU(),
            nn.Linear(config.hidden_size, config.hidden_size),
        )

    def visual_embed(self, frames: torch.Tensor) -> torch.Tensor:
        print(f"[visual_embed] raw frames shape: {frames.shape}")
        batch, num_frames, feat_dim = frames.shape
        flat_frames = frames.view(-1, feat_dim)
        projected = self.connector(flat_frames)
        print(f"[visual_embed] projected shape: {projected.shape}")
        token_embeddings = projected.view(batch, num_frames * self.config.frame_num_tokens, -1)
        print(f"[visual_embed] token_embeddings shape: {token_embeddings.shape}")
        return token_embeddings

    def joint_embed(self, input_ids: torch.Tensor, frames: torch.Tensor | None) -> torch.Tensor:
        text_embeds = self.backbone.get_input_embeddings()(input_ids)
        print(f"[joint_embed] text_embeds shape: {text_embeds.shape}")
        if frames is None:
            return text_embeds
        video_tokens = self.visual_embed(frames)
        batch, seq_len, hidden = text_embeds.shape
        video_tokens = video_tokens.view(batch, -1, hidden)
        v_mask = input_ids == self.config.v_placeholder_id
        print(f"[joint_embed] video token count per sample: {v_mask.sum(dim=1)}")
        if not torch.all(v_mask.sum(dim=1) == video_tokens.size(1)):
            raise ValueError("Number of video placeholders must match provided frame tokens")
        expanded = text_embeds.clone()
        expanded[v_mask] = video_tokens.view(-1, hidden)
        print(f"[joint_embed] expanded embeddings shape: {expanded.shape}")
        return expanded

    def forward(self, input_ids: torch.Tensor, frames: torch.Tensor | None = None, labels: torch.Tensor | None = None):
        print("
========== Forward Pass ==========")
        print(f"input_ids shape: {input_ids.shape}")
        if frames is not None:
            print(f"frames shape: {frames.shape}")
        embeddings = self.joint_embed(input_ids, frames)
        logits = self.backbone(embeddings)

        loss = None
        if labels is not None:
            print(f"labels shape: {labels.shape}")
            flat_logits = logits.view(-1, self.config.vocab_size)
            flat_labels = labels.view(-1)
            valid_mask = flat_labels != -100
            per_token_loss = nn.functional.cross_entropy(
                flat_logits, flat_labels, ignore_index=-100, reduction='none'
            )
            weight = torch.ones_like(per_token_loss)
            placeholder_mask = (input_ids.view(-1) == self.config.v_placeholder_id) & valid_mask
            weight[placeholder_mask] = self.config.stream_loss_weight
            weighted_loss = (per_token_loss * weight * valid_mask.float()).sum() / valid_mask.sum().clamp_min(1)
            print(f"[loss] mean loss: {weighted_loss.item():.4f}")
            print(f"[loss] placeholder_mask sum: {placeholder_mask.sum().item()}")
            loss = weighted_loss

        return {"logits": logits, "loss": loss}


## 4. Dummy dataset and dataloader

We now create synthetic samples that contain both text token IDs and slots for video frames.
Each sample mimics a `[text, <v>, <v>, text]` pattern with corresponding labels and two frame feature vectors.

In [None]:
class DummyVideoTextDataset(Dataset):
    def __init__(self, config: TinyLiveConfig, num_samples: int = 4, seq_len: int = 8, num_frames: int = 2):
        self.config = config
        self.num_samples = num_samples
        self.seq_len = seq_len
        self.num_frames = num_frames

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        input_ids = torch.randint(3, self.config.vocab_size - 2, (self.seq_len,), dtype=torch.long)
        placeholder_positions = torch.randperm(self.seq_len)[: self.num_frames]
        input_ids[placeholder_positions] = self.config.v_placeholder_id

        labels = input_ids.roll(-1)
        labels[-1] = -100

        frames = torch.randn(self.num_frames, self.config.vision_hidden_size)
        sample = {
            "input_ids": input_ids,
            "frames": frames,
            "labels": labels,
        }
        print(f"[Dataset] Sample {idx}: placeholder positions {placeholder_positions.tolist()}")
        print(f"[Dataset] input_ids: {input_ids.tolist()}")
        print(f"[Dataset] labels: {labels.tolist()}")
        print(f"[Dataset] frames shape: {frames.shape}")
        return sample

def collate_batch(examples):
    input_ids = torch.stack([ex["input_ids"] for ex in examples])
    frames = torch.stack([ex["frames"] for ex in examples])
    labels = torch.stack([ex["labels"] for ex in examples])
    print(f"[collate] input_ids batch shape: {input_ids.shape}")
    print(f"[collate] frames batch shape: {frames.shape}")
    print(f"[collate] labels batch shape: {labels.shape}")
    return {"input_ids": input_ids, "frames": frames, "labels": labels}


## 5. Instantiate model and dataloader

In [None]:
config = TinyLiveConfig()
model = DummyLiveLlama(config)

dataset = DummyVideoTextDataset(config)
dataloader = DataLoader(dataset, batch_size=2, collate_fn=collate_batch, shuffle=False)


## 6. Single forward pass

We run one batch through the model to observe how the embeddings and loss behave.

In [None]:
batch = next(iter(dataloader))
outputs = model(**batch)
print(f"Loss from single batch: {outputs['loss'].item():.4f}")


## 7. Simple training loop

Finally, we perform a short training loop over the dummy data.
Gradient norms and loss values are printed each step to make the optimization dynamics explicit.

In [None]:
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-3)
num_epochs = 2

for epoch in range(num_epochs):
    print(f"
===== Epoch {epoch + 1}/{num_epochs} =====")
    for step, batch in enumerate(dataloader):
        optimizer.zero_grad()
        outputs = model(**batch)
        loss = outputs["loss"]
        loss.backward()
        total_norm = 0.0
        for p in model.parameters():
            if p.grad is not None:
                param_norm = p.grad.data.norm(2)
                total_norm += param_norm.item() ** 2
        total_norm = math.sqrt(total_norm)
        print(f"[train] step={step} loss={loss.item():.4f} grad_norm={total_norm:.4f}")
        optimizer.step()
