# Video Tokenization Using [NVIDIA Cosmos Tokenizer](https://github.com/NVIDIA/Cosmos-Tokenizer) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/nvidia/Cosmos-Tokenizer/blob/main/notebook/Video_Tokenization.ipynb)

The Jupyter Notebook example utilizes the **Cosmos-Tokenizer** pretrained models, which include Continuous Video (CV) tokenizers that transform videos into continuous spatio-temporal latents and Discrete Video (DI) tokenizers that transform videos into discrete tokens. Both CV and DV tokenizers are available with compression rates of (`TxHxW` format) 4x8x8 and 8x8x8, and 8x16x16. For instance, **CV4x8x8** effectively downsizes the number of frames by a factor of 4 and both height and width by a factor of 8.

Within the notebook, the `VideoTokenizer` class from the `cosmos_tokenizer.video_lib` module is employed to manage the encoder and decoder components of this model. The encoder compresses the input video into a condensed latent representation or discrete integers, while the decoder reconstructs the video from this latent representation or discrete integers.

This instance of the Cosmos Tokenizer demonstrates its autoencoding capability: compressing a video into a smaller latent space and subsequently reconstructing it to its original form. This showcases the efficiency of video tokenization for tasks involving significant spatial compression during video reconstruction, a highly desirable feature for generative modeling.


This tutorial follows a simple, step-by-step approach, making it easy to understand and adapt.

## Step 1: Clone the Cosmos Tokenizer Repository

In [138]:
!git clone https://github.com/NVIDIA/Cosmos-Tokenizer.git

Cloning into 'Cosmos-Tokenizer'...
remote: Enumerating objects: 197, done.[K
remote: Counting objects: 100% (76/76), done.[K
remote: Compressing objects: 100% (59/59), done.[K
remote: Total 197 (delta 37), reused 33 (delta 17), pack-reused 121 (from 1)[K
Receiving objects: 100% (197/197), 11.37 MiB | 6.33 MiB/s, done.
Resolving deltas: 100% (99/99), done.


## Step 2: Install **Cosmos-Tokenizer**
Before proceeding, ensure you have the **Cosmos Tokenizer** installed. If you cloned the repository in Step 1, use the following command to install it in editable mode:

In [139]:
# Step 2: # Install Cosmos-Tokenizer and its Python dependencies.
import os
if os.path.exists("Cosmos-Tokenizer"):
    os.chdir("Cosmos-Tokenizer")
    !apt-get update
    !apt-get install -y git-lfs
    !git lfs pull
    %pip install -e .
else:
    print('Cosmos-Tokenizer is already installed.')

Reading package lists... Done
E: Could not open lock file /var/lib/apt/lists/lock - open (13: Permission denied)
E: Unable to lock directory /var/lib/apt/lists/
W: Problem unlinking the file /var/cache/apt/pkgcache.bin - RemoveCaches (13: Permission denied)
W: Problem unlinking the file /var/cache/apt/srcpkgcache.bin - RemoveCaches (13: Permission denied)
E: Could not open lock file /var/lib/dpkg/lock-frontend - open (13: Permission denied)
E: Unable to acquire the dpkg frontend lock (/var/lib/dpkg/lock-frontend), are you root?
Obtaining file:///home/jason/Desktop/chrono-world-model/Cosmos-Tokenizer/Cosmos-Tokenizer/Cosmos-Tokenizer/Cosmos-Tokenizer/Cosmos-Tokenizer/Cosmos-Tokenizer/Cosmos-Tokenizer/Cosmos-Tokenizer/Cosmos-Tokenizer/Cosmos-Tokenizer/Cosmos-Tokenizer/Cosmos-Tokenizer
  Installing build dependencies ... [?25ldone
[?25h  Checking if build backend supports build_editable ... [?25ldone
[?25h  Getting requirements to build editable ... [?25ldone
[?25h  Preparing editab

## Step 3: Set Up Hugging Face API Token and Download Pretrained Models

In this step, you'll configure the Hugging Face API token and download the pretrained model weights required for the **Cosmos Tokenizer**.

1. **Ensure You Have a Hugging Face Account**  
   If you do not already have a Hugging Face account, follow these steps to create one and generate an API token:
   - Go to the [Hugging Face website](https://huggingface.co/) and sign up for a free account.
   - After logging in, navigate to your [Settings → Access Tokens](https://huggingface.co/settings/tokens).
   - Click on "New Token" to generate an API token with the required permissions.

2. **Set the Hugging Face Token**  
   Check if the Hugging Face token is already set in the environment variables. If not, you will be prompted to enter it manually. The token is essential to authenticate and access the Hugging Face models.



In [140]:
# Check if the token is already set
if "HUGGINGFACE_TOKEN" not in os.environ:
    os.environ["HUGGINGFACE_TOKEN"] = input("Please enter your Hugging Face API token: ")
!git config --global credential.helper store

In [141]:
from huggingface_hub import login, snapshot_download
import os
HUGGINGFACE_TOKEN = os.environ.get("HUGGINGFACE_TOKEN")
login(token=HUGGINGFACE_TOKEN, add_to_git_credential=True)
model_names = [
        "Cosmos-0.1-Tokenizer-CV4x8x8",
        "Cosmos-0.1-Tokenizer-CV8x8x8",
        "Cosmos-0.1-Tokenizer-CV8x16x16",
        #"Cosmos-0.1-Tokenizer-DV4x8x8",
        #"Cosmos-0.1-Tokenizer-DV8x8x8",
        #"Cosmos-0.1-Tokenizer-DV8x16x16",
]
for model_name in model_names:
    hf_repo = "nvidia/" + model_name
    local_dir = "pretrained_ckpts/" + model_name
    os.makedirs(local_dir, exist_ok=True)
    print(f"downloading {model_name}...")
    snapshot_download(repo_id=hf_repo, local_dir=local_dir)

downloading Cosmos-0.1-Tokenizer-CV4x8x8...


Fetching 8 files: 100%|██████████| 8/8 [00:01<00:00,  4.69it/s]


downloading Cosmos-0.1-Tokenizer-CV8x8x8...


Fetching 7 files: 100%|██████████| 7/7 [00:00<00:00,  7.01it/s]


downloading Cosmos-0.1-Tokenizer-CV8x16x16...


Fetching 7 files: 100%|██████████| 7/7 [00:00<00:00,  7.45it/s]


In [142]:
!pip install opencv-python
!pip install torch torchvision --index-url https://download.pytorch.org/whl/cu126

Looking in indexes: https://download.pytorch.org/whl/cu126


## Step 4: Use Cosmos Tokenizer for Video Reconstruction



In [143]:
# @title In this step, load the required checkpoints, and perform video reconstruction. {"run":"auto"}
import cv2
import numpy as np
import torch

import importlib
import cosmos_tokenizer.video_lib
import mediapy as media

importlib.reload(cosmos_tokenizer.video_lib)
from cosmos_tokenizer.video_lib import CausalVideoTokenizer

# 1) Specify the model name, and the paths to the encoder/decoder checkpoints.
model_name = 'Cosmos-0.1-Tokenizer-CV4x8x8' # @param ["Cosmos-0.1-Tokenizer-CV4x8x8", "Cosmos-0.1-Tokenizer-CV8x8x8", "Cosmos-0.1-Tokenizer-CV8x16x16", "Cosmos-0.1-Tokenizer-DV4x8x8", "Cosmos-0.1-Tokenizer-DV8x8x8", "Cosmos-0.1-Tokenizer-DV8x16x16", "Cosmos-1.0-Tokenizer-CV8x8x8", "Cosmos-1.0-Tokenizer-DV8x16x16"]
temporal_window = 49 # @param {type:"slider", min:1, max:121, step:8}

encoder_ckpt = f"pretrained_ckpts/{model_name}/encoder.jit"
decoder_ckpt = f"pretrained_ckpts/{model_name}/decoder.jit"

# 2) Load or provide the video filename you want to tokenize & reconstruct.
input_filepath = "/home/jason/Desktop/chrono-world-model/ezgif-resize.mp4"

# 3) Read the video from disk (shape = T x H x W x 3 in BGR).
input_video = media.read_video(input_filepath)[..., :3]
assert input_video.ndim == 4 and input_video.shape[-1] == 3, "Frames must have shape T x H x W x 3"

# 4) Expand dimensions to B x T x H x W x C, since the CausalVideoTokenizer expects a batch dimension
#    in the input. (Batch size = 1 in this example.)
batched_input_video = np.expand_dims(input_video, axis=0)
# 4) Expand dimensions to B x T x H x W x C, since the CausalVideoTokenizer expects a batch dimension
#    in the input. (Batch size = 1 in this example.)
batched_input_video = np.expand_dims(input_video, axis=0)

# Convert to tensor and rearrange dimensions from [B, T, H, W, C] to [B, C, T, H, W]
batched_input_tensor = torch.from_numpy(np.array(batched_input_video)).float()
batched_input_tensor = batched_input_tensor / 255.0  # Normalize to [0, 1]

# Rearrange dimensions: [B, T, H, W, C] -> [B, C, T, H, W]
batched_input_tensor = batched_input_tensor.permute(0, 4, 1, 2, 3)

# Move to GPU and convert to bfloat16
batched_input_tensor = batched_input_tensor.to(device="cuda", dtype=torch.bfloat16)

print(f"Input tensor shape: {batched_input_tensor.shape}")  # Should be [1, 3, T, H, W]
print(f"Input tensor dtype: {batched_input_tensor.dtype}")

# 5) Create the CausalVideoTokenizer instance with the encoder & decoder.
tokenizer = CausalVideoTokenizer(
    checkpoint_enc=encoder_ckpt,
    checkpoint_dec=decoder_ckpt,
    device="cuda",
    dtype="bfloat16",
)



# 6) Use the tokenizer to autoencode (encode & decode) the video.
#    The output is a NumPy array with shape = B x T x H x W x C, range [0..255].
batched_output_video = tokenizer(batched_input_video,
                                 temporal_window=temporal_window)

# 7) Extract the single video from the batch (index 0).
output_video = batched_output_video[0]

# 9) Save the reconstructed video to disk.
input_dir, input_filename = os.path.split(input_filepath)
filename, ext = os.path.splitext(input_filename)
output_filepath = f"{input_dir}/{filename}_{model_name.split('-')[-1]}{ext}"
media.write_video(output_filepath, output_video)
print("Input video read from:\t", f"{os.getcwd()}/{input_filepath}")
print("Reconstruction saved:\t", f"{os.getcwd()}/{output_filepath}")

# 10) Visualization of the input video (left) and the reconstruction (right).
media.show_videos([input_video, output_video], ["Input Video", "Reconstructed Video"], height=720)

Input tensor shape: torch.Size([1, 3, 304, 480, 640])
Input tensor dtype: torch.bfloat16


100%|██████████| 7/7 [00:08<00:00,  1.21s/it]


Input video read from:	 /home/jason/Desktop/chrono-world-model/Cosmos-Tokenizer/Cosmos-Tokenizer/Cosmos-Tokenizer/Cosmos-Tokenizer/Cosmos-Tokenizer/Cosmos-Tokenizer/Cosmos-Tokenizer/Cosmos-Tokenizer/Cosmos-Tokenizer/Cosmos-Tokenizer/Cosmos-Tokenizer/Cosmos-Tokenizer//home/jason/Desktop/chrono-world-model/ezgif-resize.mp4
Reconstruction saved:	 /home/jason/Desktop/chrono-world-model/Cosmos-Tokenizer/Cosmos-Tokenizer/Cosmos-Tokenizer/Cosmos-Tokenizer/Cosmos-Tokenizer/Cosmos-Tokenizer/Cosmos-Tokenizer/Cosmos-Tokenizer/Cosmos-Tokenizer/Cosmos-Tokenizer/Cosmos-Tokenizer/Cosmos-Tokenizer//home/jason/Desktop/chrono-world-model/ezgif-resize_CV4x8x8.mp4


0,1
Input Video  This browser does not support the video tag.,Reconstructed Video  This browser does not support the video tag.


In [144]:
# Now use the properly formatted tensor for encoding
(latent, ) = tokenizer.encode(batched_input_tensor)
print(f"Latent shape: {latent.shape}")

# 7) Save the latents for later use
torch.save(latent, f"{input_dir}/{filename}_latents_{model_name.split('-')[-1]}.pt")
print(f"Latents saved to: {input_dir}/{filename}_latents_{model_name.split('-')[-1]}.pt")


Latent shape: torch.Size([1, 16, 77, 60, 80])
Latents saved to: /home/jason/Desktop/chrono-world-model/ezgif-resize_latents_CV4x8x8.pt


In [145]:
# Load the saved latents instead of using the current latent variable
latent_filepath = f"{input_dir}/{filename}_latents_{model_name.split('-')[-1]}.pt"
print(f"Loading latents from: {latent_filepath}")
loaded_latent = torch.load(latent_filepath)

print(f"Loaded latent shape: {loaded_latent.shape}")
print(f"Loaded latent dtype: {loaded_latent.dtype}")

# Ensure the loaded latent is on the correct device
loaded_latent = loaded_latent.to(device="cuda", dtype=torch.bfloat16)

with torch.no_grad():
    reconstructed_tensor = tokenizer.decode(loaded_latent)

print(f"Original tensor shape: {batched_input_tensor.shape}")
print(f"Reconstructed tensor shape: {reconstructed_tensor.shape}")

# Handle shape mismatch by trimming the reconstructed tensor if needed
if reconstructed_tensor.shape != batched_input_tensor.shape:
    print(f"⚠️  Shape mismatch detected, trimming reconstructed tensor...")
    # Trim to match original shape (usually just the temporal dimension)
    min_frames = min(reconstructed_tensor.shape[2], batched_input_tensor.shape[2])
    reconstructed_tensor = reconstructed_tensor[:, :, :min_frames, :, :]
    batched_input_tensor_trimmed = batched_input_tensor[:, :, :min_frames, :, :]
    print(f"Trimmed shapes - Original: {batched_input_tensor_trimmed.shape}, Reconstructed: {reconstructed_tensor.shape}")
else:
    batched_input_tensor_trimmed = batched_input_tensor

print("✓ Reconstruction completed!")

# 9) Convert reconstructed tensor back to numpy video format for visualization
# Convert from [B, C, T, H, W] to [B, T, H, W, C] and denormalize
reconstructed_numpy = reconstructed_tensor.permute(0, 2, 3, 4, 1).cpu().float().numpy()
reconstructed_numpy = (reconstructed_numpy * 255.0).clip(0, 255).astype(np.uint8)

# Extract single video from batch
reconstructed_video = reconstructed_numpy[0]  # Shape: [T, H, W, C]

# Also convert original for comparison (trim if needed)
original_numpy = batched_input_tensor_trimmed.permute(0, 2, 3, 4, 1).cpu().float().numpy()
original_numpy = (original_numpy * 255.0).clip(0, 255).astype(np.uint8)
original_video_trimmed = original_numpy[0]

print(f"Original video shape: {original_video_trimmed.shape}")
print(f"Reconstructed video shape: {reconstructed_video.shape}")

# 10) Save reconstructed video
reconstructed_filepath = f"{input_dir}/{filename}_reconstructed_{model_name.split('-')[-1]}{ext}"
media.write_video(reconstructed_filepath, reconstructed_video)
print(f"Reconstructed video saved to: {reconstructed_filepath}")

# 11) Display side-by-side comparison
print("Displaying original vs reconstructed video...")
media.show_videos(
    [original_video_trimmed, reconstructed_video], 
    ["Original Video", "Reconstructed Video"], 
    height=400
)


Loading latents from: /home/jason/Desktop/chrono-world-model/ezgif-resize_latents_CV4x8x8.pt
Loaded latent shape: torch.Size([1, 16, 77, 60, 80])
Loaded latent dtype: torch.bfloat16
Original tensor shape: torch.Size([1, 3, 304, 480, 640])
Reconstructed tensor shape: torch.Size([1, 3, 305, 480, 640])
⚠️  Shape mismatch detected, trimming reconstructed tensor...
Trimmed shapes - Original: torch.Size([1, 3, 304, 480, 640]), Reconstructed: torch.Size([1, 3, 304, 480, 640])
✓ Reconstruction completed!
Original video shape: (304, 480, 640, 3)
Reconstructed video shape: (304, 480, 640, 3)
Reconstructed video saved to: /home/jason/Desktop/chrono-world-model/ezgif-resize_reconstructed_CV4x8x8.mp4
Displaying original vs reconstructed video...


0,1
Original Video  This browser does not support the video tag.,Reconstructed Video  This browser does not support the video tag.


In [146]:
# 12) Now let's prepare for transformer training on the latent space
print("\n" + "="*50)
print("TRANSFORMER TRAINING PREPARATION")
print("="*50)

# Analyze the latent space for transformer training
print(f"Latent tensor shape: {loaded_latent.shape}")  # [B, C, T, H, W] = [1, 16, 39, 30, 40]
print(f"Compression ratio - Temporal: {batched_input_tensor.shape[2] / loaded_latent.shape[2]:.1f}x")
print(f"Compression ratio - Spatial: {batched_input_tensor.shape[3] / loaded_latent.shape[3]:.1f}x")

# Reshape latent for sequence modeling: [B, C, T, H, W] -> [B, T, C*H*W]
B, C, T, H, W = loaded_latent.shape
latent_flattened = loaded_latent.permute(0, 2, 1, 3, 4).reshape(B, T, C*H*W)
print(f"Flattened latent for transformer: {latent_flattened.shape}")  # [1, 39, 19200]

# Define sequence parameters for autoregressive training
context_length = 8  # Number of past frames to condition on
prediction_length = 4  # Number of future frames to predict
sequence_dim = C * H * W  # 16 * 30 * 40 = 19200

print(f"Context length: {context_length} frames")
print(f"Prediction length: {prediction_length} frames")
print(f"Sequence dimension: {sequence_dim}")

# Create training sequences (context -> target)
if T >= context_length + prediction_length:
    num_sequences = T - context_length - prediction_length + 1
    print(f"Can create {num_sequences} training sequences from this video")
    
    # Example: create one training sample
    start_idx = 0
    context_frames = latent_flattened[:, start_idx:start_idx+context_length, :]  # [1, 8, 19200]
    target_frames = latent_flattened[:, start_idx+context_length:start_idx+context_length+prediction_length, :]  # [1, 4, 19200]
    
    print(f"Context frames shape: {context_frames.shape}")
    print(f"Target frames shape: {target_frames.shape}")
    
    # Save training data
    training_data = {
        'context': context_frames.cpu(),
        'target': target_frames.cpu(),
        'latent_shape': (C, H, W),
        'original_shape': batched_input_tensor.shape[2:],  # (T, H, W)
    }
    
    torch.save(training_data, f"{input_dir}/{filename}_training_data_{model_name.split('-')[-1]}.pt")
    print(f"Training data saved to: {input_dir}/{filename}_training_data_{model_name.split('-')[-1]}.pt")
else:
    print(f"⚠️  Video too short for training sequences. Need at least {context_length + prediction_length} frames.")

print("\n✓ Ready for transformer training!")
print("Next steps:")
print("1. Collect more videos and extract their latents")
print("2. Create a dataset class for loading training sequences")
print("3. Build a transformer model for autoregressive prediction")
print("4. Train the model to predict future latent frames")


TRANSFORMER TRAINING PREPARATION
Latent tensor shape: torch.Size([1, 16, 77, 60, 80])
Compression ratio - Temporal: 3.9x
Compression ratio - Spatial: 8.0x
Flattened latent for transformer: torch.Size([1, 77, 76800])
Context length: 8 frames
Prediction length: 4 frames
Sequence dimension: 76800
Can create 66 training sequences from this video
Context frames shape: torch.Size([1, 8, 76800])
Target frames shape: torch.Size([1, 4, 76800])
Training data saved to: /home/jason/Desktop/chrono-world-model/ezgif-resize_training_data_CV4x8x8.pt

✓ Ready for transformer training!
Next steps:
1. Collect more videos and extract their latents
2. Create a dataset class for loading training sequences
3. Build a transformer model for autoregressive prediction
4. Train the model to predict future latent frames


In [147]:
# 13) Build and train a simple transformer for autoregressive video prediction
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import math

# Simple Transformer Model for Video Prediction
class VideoTransformer(nn.Module):
    def __init__(self, d_model, nhead, num_layers, sequence_dim, max_seq_length=512):
        super(VideoTransformer, self).__init__()
        self.d_model = d_model
        self.sequence_dim = sequence_dim
        
        # Input projection
        self.input_projection = nn.Linear(sequence_dim, d_model)
        
        # Positional encoding
        self.pos_encoding = nn.Parameter(torch.randn(max_seq_length, d_model))
        
        # Transformer layers
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model, 
            nhead=nhead, 
            dim_feedforward=d_model*4,
            dropout=0.1,
            batch_first=True
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        
        # Output projection
        self.output_projection = nn.Linear(d_model, sequence_dim)
        
    def forward(self, x):
        # x shape: [batch_size, seq_len, sequence_dim]
        seq_len = x.size(1)
        
        # Project input to model dimension
        x = self.input_projection(x)  # [batch_size, seq_len, d_model]
        
        # Add positional encoding
        x = x + self.pos_encoding[:seq_len, :].unsqueeze(0)
        
        # Create causal mask
        mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool().to(x.device)
        
        # Apply transformer
        x = self.transformer(x, mask=mask)
        
        # Project back to sequence dimension
        x = self.output_projection(x)
        
        return x

# Create model
d_model = 512
nhead = 8
num_layers = 6

model = VideoTransformer(
    d_model=d_model,
    nhead=nhead, 
    num_layers=num_layers,
    sequence_dim=sequence_dim
).cuda()

# Convert model to bfloat16 to match input dtype
model = model.to(dtype=torch.bfloat16)

print(f"Model parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")

# Prepare training data from multiple sequences
all_contexts = []
all_targets = []

# Generate all possible training sequences from the video
for start_idx in range(num_sequences):
    context = latent_flattened[:, start_idx:start_idx+context_length, :]
    target = latent_flattened[:, start_idx+context_length:start_idx+context_length+prediction_length, :]
    all_contexts.append(context)
    all_targets.append(target)

# Combine into tensors
train_contexts = torch.cat(all_contexts, dim=0)  # [num_sequences, context_length, sequence_dim]
train_targets = torch.cat(all_targets, dim=0)    # [num_sequences, prediction_length, sequence_dim]

print(f"Training contexts shape: {train_contexts.shape}")
print(f"Training targets shape: {train_targets.shape}")

# Training setup
optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0.01)
criterion = nn.MSELoss()
num_epochs = 100

# Training loop
# Training loop - Pure autoregressive, predicting 4 frames per step
model.train()
for epoch in range(num_epochs):
    total_loss = 0
    
    optimizer.zero_grad()
    
    # Create longer sequences for multi-step training
    num_rollout_steps = 10  # 10 steps of 4 frames each = 40 frames total
    horizon_length = num_rollout_steps * prediction_length  # 40 frames
    
    # Generate training sequences with longer horizons
    multi_step_losses = []
    
    for start_idx in range(min(num_sequences, train_contexts.shape[0])):
        if start_idx + horizon_length <= latent_flattened.shape[1] - context_length:
            # Get initial context
            initial_context = latent_flattened[:, start_idx:start_idx+context_length, :].cuda().to(dtype=torch.bfloat16)
            
            # Get full target sequence (40 frames)
            full_targets = latent_flattened[:, start_idx+context_length:start_idx+context_length+horizon_length, :].cuda().to(dtype=torch.bfloat16)
            
            current_sequence = initial_context.clone()
            step_losses = []
            
            # Pure autoregressive unroll - predict 4 frames at a time
            for step in range(num_rollout_steps):
                # Current context (last 8 frames)
                context = current_sequence[:, -context_length:, :]
                
                # Ground truth for this step (next 4 frames)
                gt_start = step * prediction_length
                gt_end = gt_start + prediction_length
                ground_truth = full_targets[:, gt_start:gt_end, :]  # [1, 4, 19200]
                
                # Pure autoregressive: only use context, no teacher forcing
                output = model(context)
                predicted = output[:, -prediction_length:, :]  # Take last 4 predictions [1, 4, 19200]
                
                # Compute loss for these 4 frames
                step_loss = criterion(predicted, ground_truth)
                step_losses.append(step_loss)
                
                # Use predicted frames for next iteration (pure autoregressive)
                current_sequence = torch.cat([current_sequence, predicted.detach()], dim=1)
            
            # Average loss over all steps
            if step_losses:
                sequence_loss = torch.stack(step_losses).mean()
                multi_step_losses.append(sequence_loss)
    
    # Average loss over all sequences
    if multi_step_losses:
        total_loss = torch.stack(multi_step_losses).mean()
        
        # Backward pass
        total_loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
    
    if epoch % 20 == 0:
        print(f"Epoch {epoch}, Pure Autoregressive Loss (4 frames/step): {total_loss.item():.6f}")

print("✓ Pure autoregressive training completed!")

# 14) Autoregressive prediction rollout
model.eval()
print("\n" + "="*50)
print("AUTOREGRESSIVE PREDICTION ROLLOUT")
print("="*50)

# Use first context_length frames as seed
seed_frames = latent_flattened[:, :context_length, :].cuda().to(dtype=torch.bfloat16)  # [1, 8, 19200]
rollout_length = 16  # Generate 16 frames into the future

print(f"Seed frames shape: {seed_frames.shape}")
print(f"Generating {rollout_length} frames autoregressively...")

# Autoregressive generation - predict 4 frames at a time (matching training)
with torch.no_grad():
    generated_sequence = seed_frames.clone()
    
    # Generate in chunks of 4 frames
    num_generation_steps = rollout_length // 4  # 16 frames / 4 = 4 steps
    
    for step in range(num_generation_steps):
        # Take last context_length frames as input
        current_context = generated_sequence[:, -context_length:, :]
        
        # Predict next 4 frames (matching training)
        output = model(current_context)
        next_4_frames = output[:, -prediction_length:, :]  # Last 4 predictions
        
        # Append to sequence
        generated_sequence = torch.cat([generated_sequence, next_4_frames], dim=1)
        
        print(f"Generated 4 frames, step {step+1}/{num_generation_steps}")

print(f"Final generated sequence shape: {generated_sequence.shape}")

# 15) Convert predictions back to video format and visualize
print("\n" + "="*50)
print("CONVERTING PREDICTIONS TO VIDEO")
print("="*50)

# Reshape back to latent format: [B, T, C*H*W] -> [B, C, T, H, W]
generated_latent = generated_sequence.reshape(1, -1, C, H, W).permute(0, 2, 1, 3, 4)
print(f"Generated latent shape: {generated_latent.shape}")

# Decode to video using the tokenizer
with torch.no_grad():
    predicted_video_tensor = tokenizer.decode(generated_latent)

print(f"Predicted video tensor shape: {predicted_video_tensor.shape}")

# Convert to numpy for visualization
predicted_numpy = predicted_video_tensor.permute(0, 2, 3, 4, 1).cpu().float().numpy()
predicted_numpy = (predicted_numpy * 255.0).clip(0, 255).astype(np.uint8)
predicted_video = predicted_numpy[0]

print(f"Predicted video shape: {predicted_video.shape}")

# Split into seed and generated parts
seed_length = context_length
total_frames = predicted_video.shape[0]

# Account for potential padding in decoder output
actual_seed_frames = min(seed_length, total_frames)
actual_generated_frames = total_frames - actual_seed_frames

seed_video = predicted_video[:actual_seed_frames]
generated_video = predicted_video[actual_seed_frames:]

print(f"Seed video frames: {seed_video.shape[0]}")
print(f"Generated video frames: {generated_video.shape[0]}")

# Save videos
seed_filepath = f"{input_dir}/{filename}_seed_{model_name.split('-')[-1]}.mp4"
generated_filepath = f"{input_dir}/{filename}_generated_{model_name.split('-')[-1]}.mp4"
full_prediction_filepath = f"{input_dir}/{filename}_full_prediction_{model_name.split('-')[-1]}.mp4"

media.write_video(seed_filepath, seed_video)
media.write_video(generated_filepath, generated_video)
media.write_video(full_prediction_filepath, predicted_video)

print(f"Seed video saved to: {seed_filepath}")
print(f"Generated video saved to: {generated_filepath}")
print(f"Full prediction saved to: {full_prediction_filepath}")

# Display comparison
print("\nDisplaying seed vs generated video...")
if generated_video.shape[0] > 0:
    media.show_videos(
        [seed_video, generated_video], 
        [f"Seed ({seed_video.shape[0]} frames)", f"Generated ({generated_video.shape[0]} frames)"], 
        height=300
    )
else:
    print("⚠️  No generated frames to display")

print("\n✓ Autoregressive video prediction completed!")
print("The model learned to predict future video frames in the compressed latent space!")

Model parameters: 97,896,960
Training contexts shape: torch.Size([66, 8, 76800])
Training targets shape: torch.Size([66, 4, 76800])
Epoch 0, Pure Autoregressive Loss (4 frames/step): 2.156250
Epoch 20, Pure Autoregressive Loss (4 frames/step): 1.015625
Epoch 40, Pure Autoregressive Loss (4 frames/step): 0.474609
Epoch 60, Pure Autoregressive Loss (4 frames/step): 0.253906
Epoch 80, Pure Autoregressive Loss (4 frames/step): 0.162109
✓ Pure autoregressive training completed!

AUTOREGRESSIVE PREDICTION ROLLOUT
Seed frames shape: torch.Size([1, 8, 76800])
Generating 16 frames autoregressively...
Generated 4 frames, step 1/4
Generated 4 frames, step 2/4
Generated 4 frames, step 3/4
Generated 4 frames, step 4/4
Final generated sequence shape: torch.Size([1, 24, 76800])

CONVERTING PREDICTIONS TO VIDEO
Generated latent shape: torch.Size([1, 16, 24, 60, 80])
Predicted video tensor shape: torch.Size([1, 3, 93, 480, 640])
Predicted video shape: (93, 480, 640, 3)
Seed video frames: 8
Generated vi

0,1
Seed (8 frames)  This browser does not support the video tag.,Generated (85 frames)  This browser does not support the video tag.



✓ Autoregressive video prediction completed!
The model learned to predict future video frames in the compressed latent space!
