
# Live LLaMA Architecture Sandbox

This notebook instantiates the exact `LiveLlamaForCausalLM` architecture from the repository and exercises it with dummy multimodal batches. The goal is to trace how text tokens, frame placeholders, and dense frame descriptors flow through the connector, the joint embedding, and the causal LM head without relying on any auxiliary wrappers.



## 1. Environment setup

We import the production modules directly from the repository and fix the random seed for reproducibility.  Printing happens aggressively throughout the notebook so we can inspect intermediate tensors.


In [None]:

import math
import random
from dataclasses import dataclass

import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader

from models.live_llama import LiveLlamaConfig, LiveLlamaForCausalLM

SEED = 17
random.seed(SEED)
torch.manual_seed(SEED)
print(f"Torch version: {torch.__version__}")
print(f"Running on device: {torch.device('cpu')}")



## 2. Build a tiny but faithful Live LLaMA

We reuse the production configuration class and only shrink hidden sizes / vocabulary to keep the demo lightweight.  All architectural pieces (token embeddings, rotary attention, connector MLP, etc.) remain exactly the same as the real model.


In [None]:

# The config mirrors the real Live LLaMA setup while shrinking dimensions for CPU-friendly experimentation.
config = LiveLlamaConfig(
    vocab_size=160,
    hidden_size=128,
    intermediate_size=256,
    num_hidden_layers=2,
    num_attention_heads=4,
    num_key_value_heads=4,
    max_position_embeddings=512,
    bos_token_id=1,
    eos_token_id=2,
    pad_token_id=0,
    # Live-specific knobs
    vision_hidden_size=96,
    frame_token_cls=True,
    frame_token_pooled=[2, 2],  # 1 cls + 2x2 spatial tokens => 5 placeholders per frame
    frame_num_tokens=5,
    v_placeholder_id=158,
    frame_token_interval_id=159,
    stream_loss_weight=3.0,
)
model = LiveLlamaForCausalLM(config)
print(model)
print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"Trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")



## 3. Dummy dataset that matches the production contract

Each sample returns:

* `input_ids`: token ids containing regular text, interval separators, and visual placeholders.
* `frames`: dense frame descriptors with shape `(num_frames, frame_num_tokens, vision_hidden_size)`.
* `labels`: targets for causal LM training where system / user prompt tokens are masked out with `-100`.

The logic mirrors how the training pipeline expects data to be structured.


In [None]:

@dataclass
class DummyConversation:
    text_span: list[int]
    num_frames: int


class LiveLikeDataset(Dataset):
    def __init__(self, *, conversations, max_prompt_tokens=12):
        self.conversations = conversations
        self.max_prompt_tokens = max_prompt_tokens

    def __len__(self):
        return len(self.conversations)

    def _build_sequence(self, convo: DummyConversation):
        # System prompt + user prompt tokens (masked in the loss)
        prompt_tokens = [config.bos_token_id]
        prompt_tokens += random.sample(convo.text_span, k=min(len(convo.text_span), self.max_prompt_tokens))
        # Assistant response prefix before frames arrive (will be trained)
        response_prefix = [random.choice(convo.text_span) for _ in range(3)]

        sequence = prompt_tokens + response_prefix
        for frame_idx in range(convo.num_frames):
            sequence.append(config.frame_token_interval_id)
            frame_placeholder_slice = [config.v_placeholder_id] * config.frame_num_tokens
            sequence.extend(frame_placeholder_slice)
            # Optionally interleave actual response tokens after each frame chunk
            sequence.append(random.choice(convo.text_span))
        sequence.append(config.eos_token_id)
        return sequence, prompt_tokens

    def __getitem__(self, idx):
        convo = self.conversations[idx]
        token_sequence, prompt_tokens = self._build_sequence(convo)
        input_ids = torch.tensor(token_sequence, dtype=torch.long)
        labels = input_ids.clone()
        # Mask out prompt tokens so the model only learns on assistant responses + frame slots
        labels[: len(prompt_tokens)] = -100
        frames = torch.randn(convo.num_frames, config.frame_num_tokens, config.vision_hidden_size)
        return {
            "input_ids": input_ids,
            "labels": labels,
            "frames": frames,
        }


def live_like_collate(batch):
    input_ids = nn.utils.rnn.pad_sequence([sample["input_ids"] for sample in batch], batch_first=True, padding_value=config.pad_token_id)
    labels = nn.utils.rnn.pad_sequence([sample["labels"] for sample in batch], batch_first=True, padding_value=-100)
    frames_per_sample = [sample["frames"] for sample in batch]
    if frames_per_sample:
        frames = torch.cat(frames_per_sample, dim=0)
        frame_counts = torch.tensor([item.shape[0] for item in frames_per_sample], dtype=torch.long)
    else:
        frames = torch.zeros(0, config.frame_num_tokens, config.vision_hidden_size)
        frame_counts = torch.zeros(0, dtype=torch.long)
    return {
        "input_ids": input_ids,
        "labels": labels,
        "frames": frames,
        "frame_counts": frame_counts,
    }


vocab_pool = list(range(10, 150))
conversations = [
    DummyConversation(text_span=vocab_pool, num_frames=2),
    DummyConversation(text_span=vocab_pool, num_frames=3),
    DummyConversation(text_span=vocab_pool, num_frames=1),
]

dataset = LiveLikeDataset(conversations=conversations)
dataloader = DataLoader(dataset, batch_size=2, shuffle=False, collate_fn=live_like_collate)
print(f"Dataset size: {len(dataset)}")
for idx in range(len(dataset)):
    sample = dataset[idx]
    print(f"Sample {idx} -> tokens: {sample['input_ids'].tolist()}")
    print(f"Sample {idx} -> labels: {sample['labels'].tolist()}")
    print(f"Sample {idx} -> frames shape: {tuple(sample['frames'].shape)}")



## 4. Inspect a batch flowing through the connector and joint embedding

We take the first batch from the loader, push it through `joint_embed`, and show every mask / tensor shape the real training step would touch.


In [None]:

batch = next(iter(dataloader))
print(f"input_ids shape: {tuple(batch['input_ids'].shape)}")
print(f"labels shape: {tuple(batch['labels'].shape)}")
print(f"frames shape: {tuple(batch['frames'].shape)} (total_frames, tokens_per_frame, hidden)")
print(f"frame_counts per sample: {batch['frame_counts'].tolist()}")
print(f"Total frames represented: {int(batch['frame_counts'].sum())}")

v_mask = batch['input_ids'] == config.v_placeholder_id
print(f"Visual placeholder mask (per token): {v_mask}")
num_visual_tokens = int(v_mask.sum())
print(f"Total visual tokens expected from placeholders: {num_visual_tokens}")
print(f"Total tokens supplied by frames tensor: {batch['frames'].shape[0] * config.frame_num_tokens}")

with torch.no_grad():
    embedded = model.joint_embed(batch['input_ids'], batch['frames'])
print(f"joint_embed output shape: {tuple(embedded.shape)}")
print(f"First 5 embedded token norms: {embedded.norm(dim=-1)[:5]}")

reshaped_frames = batch['frames'].view(-1, config.frame_num_tokens, config.vision_hidden_size)
print(f"Connector input view shape (per frame): {tuple(reshaped_frames.shape)}")
print(f"Connector MLP weights summary: {[tuple(layer.weight.shape) for layer in model.connector if hasattr(layer, 'weight')]}")



## 5. Verbose training loop

The loop runs a couple of iterations, printing:

* Token-level loss weights (text vs. stream tokens)
* The causal LM logits tensor shape
* Gradient statistics for the connector and language head


In [None]:

optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)

for step, batch in enumerate(dataloader):
    frame_counts = batch.pop('frame_counts')
    total_frame_tokens = int(frame_counts.sum().item() * config.frame_num_tokens) if frame_counts.numel() else 0
    print(f"\n=== Step {step} ===")
    print(f"Frame counts in batch: {frame_counts.tolist()} -> {total_frame_tokens} frame placeholders")

    optimizer.zero_grad()
    outputs = model(**batch)
    loss = outputs.loss
    logits = outputs.logits

    v_mask = (batch['input_ids'] == config.v_placeholder_id)
    learn_mask = batch['labels'] != -100
    stream_mask = v_mask & learn_mask
    text_mask = learn_mask & ~v_mask

    print(f"Loss: {loss.item():.4f}")
    print(f"Logits shape: {tuple(logits.shape)} (batch, seq, vocab)")
    print(f"Num learnable tokens -> text: {int(text_mask.sum())}, stream: {int(stream_mask.sum())}")
    if stream_mask.any():
        print(f"Example stream token logits (first stream token): {logits[stream_mask][0][:5]}")

    loss.backward()
    connector_grad_norm = torch.norm(torch.stack([p.grad.norm() for p in model.connector.parameters() if p.grad is not None]))
    lm_head_grad_norm = model.lm_head.weight.grad.norm()
    print(f"Connector grad norm: {connector_grad_norm:.4f}")
    print(f"LM head grad norm: {lm_head_grad_norm:.4f}")

    optimizer.step()

    if step >= 1:
        break
