<a href="https://colab.research.google.com/github/ossnat/VSD_foundation_model/blob/main/VideoMAE_LoRA_Fine_Tuning_Notebook.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# A Colab notebook for fine-tuning VideoMAE on a toy video dataset using LoRA.
# This script demonstrates the full pipeline from data loading to model training and evaluation.

# -----------------
# 1. Setup and Installs
# -----------------

# Install the required libraries. This will take a few minutes.
!pip install -q transformers datasets accelerate peft bitsandbytes
!pip install -q decord
!pip install -q torch torchvision torchaudio

# Import necessary libraries
import torch
from datasets import load_dataset
from transformers import VideoMAEImageProcessor, VideoMAEForPreTraining
from peft import get_peft_model, LoraConfig, prepare_model_for_kbit_training
from torch.utils.data import Dataset, DataLoader
import numpy as np
import random
from accelerate import Accelerator
from tqdm.notebook import tqdm
import decord
import os
import io

# Set up the Accelerator for distributed training (useful for larger models/datasets)
accelerator = Accelerator()
device = accelerator.device

print(f"Using device: {device}")
print("Setup complete.")

# -----------------
# 2. Data Preparation
# -----------------

# Load a small toy video dataset from the Hugging Face Hub.
# We'll use a subset of the UCF101 dataset to keep the runtime short.
# It's important to use a small sample for this tutorial.
print("Loading toy video dataset...")
dataset = load_dataset("hf-internal-testing/mrl-test-videos-small", split="train")
print("Dataset loaded successfully.")
print(f"Number of videos in the dataset: {len(dataset)}")

# Load the VideoMAE feature extractor and a pre-trained model for pre-training (masked autoencoding)
model_name = "MCG-NJU/videomae-base"
processor = VideoMAEImageProcessor.from_pretrained(model_name)
model = VideoMAEForPreTraining.from_pretrained(model_name)

# Define the number of frames to subsample and the video path key
num_frames = 16
video_path_key = "video_file"

# Custom dataset class to handle video loading and processing
class VideoDataset(Dataset):
    def __init__(self, dataset, processor, num_frames, video_path_key):
        self.dataset = dataset
        self.processor = processor
        self.num_frames = num_frames
        self.video_path_key = video_path_key
        # Initialize a video reader for each video to optimize loading
        self.video_readers = {
            os.path.basename(video_path): decord.VideoReader(io.BytesIO(video_data))
            for video_path, video_data in zip(dataset[video_path_key], dataset["video_data"])
        }

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        # Get video data and its reader
        video_data = self.dataset[idx]
        video_reader = self.video_readers[os.path.basename(video_data[self.video_path_key])]

        # Determine the video length and subsample frames
        total_frames = len(video_reader)
        # Select num_frames evenly spaced frames
        frame_indices = np.linspace(0, total_frames - 1, self.num_frames, dtype=int)

        # Read the frames
        frames = video_reader.get_batch(frame_indices).asnumpy()
        frames = frames.astype(np.float32)

        # Preprocess the frames using the VideoMAE processor
        pixel_values = self.processor(list(frames), return_tensors="pt").pixel_values.squeeze()

        # The MAE task requires masked image patches. We'll implement a simple collator for this.
        # Here we just return the raw pixel values. Masking will happen in the data collator.
        return {"pixel_values": pixel_values}

# Data collator for the MAE pre-training task
# This collator takes the frames and creates the masked and unmasked patches
def data_collator(examples):
    # Stack all frames from the batch
    pixel_values = torch.stack([e["pixel_values"] for e in examples])

    # The VideoMAE model will handle the masking internally during the forward pass.
    # We just need to ensure the pixel values are in the correct format and shape.
    return {"pixel_values": pixel_values}

# Create the dataset and dataloader
train_dataset = VideoDataset(dataset, processor, num_frames, video_path_key)
train_dataloader = DataLoader(train_dataset, batch_size=4, collate_fn=data_collator)

# -----------------
# 3. LoRA Configuration
# -----------------

# Set up PEFT (Parameter-Efficient Fine-Tuning) with LoRA.
# This will drastically reduce the number of trainable parameters.
# We'll apply it to the query and value projections in the attention layers.
peft_config = LoraConfig(
    r=16, # Rank of the update matrices. A lower rank means fewer trainable parameters.
    lora_alpha=32, # LoRA scaling factor.
    target_modules=["query", "value"], # Modules to apply LoRA to.
    lora_dropout=0.1,
    bias="none",
)

# Prepare the model for PEFT and k-bit training (optional but good practice for memory)
model = prepare_model_for_kbit_training(model)
model = get_peft_model(model, peft_config)

# Print the number of trainable parameters to show the efficiency of LoRA
model.print_trainable_parameters()

# -----------------
# 4. Training Loop
# -----------------

# We'll use a simple manual training loop here to demonstrate the process clearly.
# For a full-scale project, you would use Hugging Face's Trainer API for more features.

# Define training hyperparameters
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)
num_epochs = 3

# Prepare model, optimizer, and dataloader with the accelerator
model, optimizer, train_dataloader = accelerator.prepare(
    model, optimizer, train_dataloader
)

# Training loop
model.train()
print("Starting training...")
for epoch in range(num_epochs):
    epoch_loss = 0
    progress_bar = tqdm(train_dataloader, desc=f"Epoch {epoch+1}/{num_epochs}")
    for batch in progress_bar:
        optimizer.zero_grad()
        outputs = model(**batch)
        loss = outputs.loss
        accelerator.backward(loss)
        optimizer.step()
        epoch_loss += loss.item()
        progress_bar.set_postfix(loss=loss.item())

    avg_loss = epoch_loss / len(train_dataloader)
    print(f"Epoch {epoch+1} finished. Average Loss: {avg_loss:.4f}")

# -----------------
# 5. Visualization and Assessment (Simplified)
# -----------------

print("Training finished. Saving model...")

# Save the LoRA adapters and the full model.
# The adapters are tiny and easy to share.
peft_model_path = "./videomae_lora_adapters"
model.save_pretrained(peft_model_path)
print(f"LoRA adapters saved to: {peft_model_path}")

# To visualize the results, we would normally use a validation set
# and a downstream task like video classification.
# For this tutorial, we will simply demonstrate a qualitative assessment
# of the model's ability to reconstruct a masked video.

print("\n--- Model Assessment (Qualitative) ---")

# The model's primary objective is to reconstruct masked patches.
# A lower loss during training indicates the model is getting better at this.
# You can assess the reconstruction quality visually by grabbing the output
# and visualizing the reconstructed video frames.

# Let's take one example from the dataset and run it through the model.
model.eval()
sample_batch = next(iter(train_dataloader))
sample_batch = {k: v.to(device) for k, v in sample_batch.items()}

# Forward pass to get the reconstructed video patches
with torch.no_grad():
    outputs = model(**sample_batch)

# The reconstructed pixel values are in `outputs.logits`.
# The shapes can be complex, so we'll just check the output shape.
reconstructed_pixels = outputs.logits
print(f"Shape of reconstructed pixels: {reconstructed_pixels.shape}")
print("The reconstructed pixels represent the model's guess for the masked video content.")
print("A well-trained model would produce a plausible reconstruction.")

# You could also add code here to visualize the original vs. reconstructed video frames
# using a library like matplotlib or a custom function.
# For example:
#
# import matplotlib.pyplot as plt
#
# original_frames = sample_batch['pixel_values'][0]
# original_frames = processor.post_process_video_output(original_frames, output_norm=True)
#
# reconstructed_frames = reconstructed_pixels[0]
# reconstructed_frames = processor.post_process_video_output(reconstructed_frames, output_norm=True)
#
# fig, axes = plt.subplots(1, 2, figsize=(10, 5))
# axes[0].imshow(original_frames[0].permute(1, 2, 0))
# axes[0].set_title("Original Frame")
# axes[1].imshow(reconstructed_frames[0].permute(1, 2, 0))
# axes[1].set_title("Reconstructed Frame")
# plt.show()