# SmolVLM2 Training on Google Colab

Train SmolVLM2 256M/500M vision-language models on Google Colab.

**GPU Requirements:**
- T4 (free tier): 256M model with gradient checkpointing
- A100 (Pro/Pro+): Both 256M and 500M models

**Runtime:** Go to `Runtime > Change runtime type` and select GPU.

## 1. Setup Environment

In [None]:
# Check GPU
!nvidia-smi

import torch

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(
        f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB"
    )

In [None]:
# Clone repository
!git clone https://github.com/YOUR_USERNAME/smolvlm_sandbox.git
%cd smolvlm_sandbox

In [None]:
# Install dependencies
%pip install -q -e .

# Install flash attention (optional, improves speed)
%pip install -q flash-attn --no-build-isolation 2>/dev/null || echo "Flash attention not installed (optional)"

In [None]:
# Mount Google Drive for saving checkpoints
from google.colab import drive

drive.mount("/content/drive")

# Create checkpoint directory
import os

CHECKPOINT_DIR = "/content/drive/MyDrive/smolvlm2_checkpoints"
os.makedirs(CHECKPOINT_DIR, exist_ok=True)
print(f"Checkpoints will be saved to: {CHECKPOINT_DIR}")

## 2. Configuration

Adjust settings based on your GPU:

In [None]:
# ============================================
# CONFIGURATION - Adjust based on your GPU
# ============================================

# Model size: "256m" or "500m"
# - T4 (16GB): Use "256m"
# - A100 (40GB): Can use "500m"
MODEL_SIZE = "256m"

# Batch size (reduce if OOM)
# - T4: batch_size=1-2
# - A100: batch_size=4-8
BATCH_SIZE = 2
GRADIENT_ACCUMULATION_STEPS = 8  # Effective batch = BATCH_SIZE * GRAD_ACCUM

# Training steps (reduce for quick test)
MAX_STEPS = 1000  # Use 50000 for full training
WARMUP_STEPS = 100
SAVE_STEPS = 250
LOGGING_STEPS = 10

# Learning rate
LEARNING_RATE = 1e-4

# Dataset samples (reduce for quick test)
MAX_TRAIN_SAMPLES = 10000  # Set to None for full dataset

# Memory optimizations
USE_GRADIENT_CHECKPOINTING = True
USE_BF16 = True  # Use mixed precision

# Weights & Biases logging (optional)
USE_WANDB = False
WANDB_PROJECT = "smolvlm2-colab"

print(f"Model: SmolVLM2-{MODEL_SIZE.upper()}")
print(
    f"Batch size: {BATCH_SIZE} x {GRADIENT_ACCUMULATION_STEPS} = {BATCH_SIZE * GRADIENT_ACCUMULATION_STEPS}"
)
print(f"Max steps: {MAX_STEPS}")

## 3. Initialize Model

In [None]:
import torch

# Import from our package
from src.model import get_config, initialize_smolvlm_model

# Get config
config = get_config(MODEL_SIZE)
print(f"Vision encoder: {config.vision_encoder_name}")
print(f"Text decoder: {config.text_decoder_name}")

# Initialize model
print("\nInitializing model...")
model, processor, tokenizer = initialize_smolvlm_model(
    model_size=MODEL_SIZE,
    torch_dtype=torch.bfloat16 if USE_BF16 else torch.float32,
    use_flash_attention=True,  # Falls back to eager if not available
)

# Enable gradient checkpointing
if USE_GRADIENT_CHECKPOINTING:
    if hasattr(model, "gradient_checkpointing_enable"):
        model.gradient_checkpointing_enable()
    elif hasattr(model.text_decoder, "gradient_checkpointing_enable"):
        model.text_decoder.gradient_checkpointing_enable()
    print("Gradient checkpointing enabled")

# Move to GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

# Print model info
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"\nTotal parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")

## 4. Load Dataset

In [None]:
import io

from datasets import load_dataset
from PIL import Image
from torch.utils.data import Dataset

# Load The Cauldron dataset (streaming to avoid downloading everything)
print("Loading dataset...")
raw_dataset = load_dataset(
    "HuggingFaceM4/the_cauldron",
    split="train",
    streaming=True,
)

# Take subset for testing
if MAX_TRAIN_SAMPLES:
    raw_dataset = raw_dataset.take(MAX_TRAIN_SAMPLES)
    print(f"Using {MAX_TRAIN_SAMPLES} samples")

In [None]:
class SmolVLMDataset(Dataset):
    """Dataset for SmolVLM2 training."""

    def __init__(self, data, processor, tokenizer, max_length=1024):
        self.data = list(data)  # Convert streaming dataset to list
        self.processor = processor
        self.tokenizer = tokenizer
        self.max_length = max_length
        print(f"Loaded {len(self.data)} samples")

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = self.data[idx]

        try:
            # Get image
            image = None
            if "image" in sample and sample["image"] is not None:
                image = sample["image"]
                if isinstance(image, bytes):
                    image = Image.open(io.BytesIO(image))
            elif "images" in sample and sample["images"]:
                image = sample["images"][0]
                if isinstance(image, bytes):
                    image = Image.open(io.BytesIO(image))

            if image is None:
                # Return dummy sample
                return self.__getitem__((idx + 1) % len(self))

            # Ensure RGB
            if image.mode != "RGB":
                image = image.convert("RGB")

            # Get text
            if "conversations" in sample:
                convs = sample["conversations"]
                text = "\n".join(
                    [f"{c.get('from', 'user')}: {c.get('value', '')}" for c in convs]
                )
            elif "messages" in sample:
                msgs = sample["messages"]
                text = "\n".join(
                    [f"{m.get('role', 'user')}: {m.get('content', '')}" for m in msgs]
                )
            elif "question" in sample and "answer" in sample:
                text = f"user: {sample['question']}\nassistant: {sample['answer']}"
            else:
                text = str(sample.get("text", ""))

            # Add image token
            if "<image>" not in text:
                text = "<image>" + text

            # Process image
            pixel_values = self.processor(
                images=image,
                return_tensors="pt",
            )["pixel_values"].squeeze(0)

            # Tokenize text
            encoding = self.tokenizer(
                text,
                truncation=True,
                max_length=self.max_length,
                padding="max_length",
                return_tensors="pt",
            )

            return {
                "input_ids": encoding["input_ids"].squeeze(0),
                "attention_mask": encoding["attention_mask"].squeeze(0),
                "pixel_values": pixel_values,
                "labels": encoding["input_ids"].squeeze(0).clone(),
            }

        except Exception as e:
            print(f"Error processing sample {idx}: {e}")
            # Return next sample
            return self.__getitem__((idx + 1) % len(self))


# Create dataset
print("Processing dataset (this may take a few minutes)...")
train_dataset = SmolVLMDataset(
    data=raw_dataset,
    processor=processor,
    tokenizer=tokenizer,
    max_length=1024,
)

In [None]:
# Test dataset
sample = train_dataset[0]
print("Sample keys:", sample.keys())
print(f"input_ids shape: {sample['input_ids'].shape}")
print(f"pixel_values shape: {sample['pixel_values'].shape}")

## 5. Setup Training

In [None]:
from dataclasses import dataclass
from typing import Any, Dict, List

import torch
from transformers import Trainer, TrainingArguments


@dataclass
class SmolVLMDataCollator:
    """Data collator for SmolVLM2."""

    tokenizer: Any

    def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
        batch = {
            "input_ids": torch.stack([f["input_ids"] for f in features]),
            "attention_mask": torch.stack([f["attention_mask"] for f in features]),
            "labels": torch.stack([f["labels"] for f in features]),
            "pixel_values": torch.stack([f["pixel_values"] for f in features]),
        }
        return batch


data_collator = SmolVLMDataCollator(tokenizer=tokenizer)

In [None]:
# Setup Weights & Biases (optional)
if USE_WANDB:
    import wandb

    wandb.login()
    report_to = ["wandb"]
else:
    report_to = ["tensorboard"]

# Training arguments
training_args = TrainingArguments(
    output_dir=CHECKPOINT_DIR,
    # Batch size
    per_device_train_batch_size=BATCH_SIZE,
    gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,
    # Learning rate
    learning_rate=LEARNING_RATE,
    lr_scheduler_type="cosine",
    warmup_steps=WARMUP_STEPS,
    # Training duration
    max_steps=MAX_STEPS,
    # Precision
    bf16=USE_BF16,
    # Memory optimization
    gradient_checkpointing=USE_GRADIENT_CHECKPOINTING,
    optim="adamw_torch_fused",
    # Logging
    logging_steps=LOGGING_STEPS,
    logging_first_step=True,
    report_to=report_to,
    run_name=f"smolvlm2-{MODEL_SIZE}-colab",
    # Checkpointing
    save_steps=SAVE_STEPS,
    save_total_limit=2,
    # Other
    remove_unused_columns=False,
    dataloader_pin_memory=True,
    dataloader_num_workers=2,
)

print(f"Output directory: {training_args.output_dir}")
print(f"Effective batch size: {BATCH_SIZE * GRADIENT_ACCUMULATION_STEPS}")

In [None]:
# Create trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    data_collator=data_collator,
    tokenizer=tokenizer,
)

print("Trainer ready!")

## 6. Train!

In [None]:
# Clear GPU memory
torch.cuda.empty_cache()

# Start training
print("Starting training...")
print(f"This will run for {MAX_STEPS} steps.")
print("Checkpoints will be saved to Google Drive.\n")

try:
    train_result = trainer.train()

    # Print results
    print("\n" + "=" * 50)
    print("Training Complete!")
    print("=" * 50)
    print(f"Total steps: {train_result.global_step}")
    print(f"Training loss: {train_result.training_loss:.4f}")

except KeyboardInterrupt:
    print("\nTraining interrupted. Saving checkpoint...")
    trainer.save_model()
    print(f"Checkpoint saved to {CHECKPOINT_DIR}")

In [None]:
# Save final model
print("Saving final model...")
trainer.save_model()
trainer.save_state()

print(f"\nModel saved to: {CHECKPOINT_DIR}")
print(
    "\nTo download, right-click the folder in the file browser and select 'Download'."
)

## 7. Test Inference

In [None]:
# Test the trained model
from io import BytesIO

import requests

# Load a test image
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg"
response = requests.get(url)
image = Image.open(BytesIO(response.content))

# Display image
display(image)

In [None]:
# Run inference
model.eval()

prompt = "<image>Describe this image in detail."

# Process inputs
inputs = processor(images=image, return_tensors="pt")
input_ids = tokenizer(prompt, return_tensors="pt")["input_ids"]

# Move to GPU
inputs = {k: v.to(device) for k, v in inputs.items()}
input_ids = input_ids.to(device)

# Generate
with torch.no_grad():
    outputs = model.text_decoder.generate(
        input_ids=input_ids,
        max_new_tokens=100,
        do_sample=False,
    )

# Decode
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(f"Prompt: {prompt}")
print(f"Response: {response}")

## 8. Resume Training (Optional)

If your session disconnected, you can resume from the last checkpoint:

In [None]:
# Find latest checkpoint
import os

checkpoints = [d for d in os.listdir(CHECKPOINT_DIR) if d.startswith("checkpoint-")]
if checkpoints:
    latest = max(checkpoints, key=lambda x: int(x.split("-")[1]))
    resume_path = os.path.join(CHECKPOINT_DIR, latest)
    print(f"Found checkpoint: {resume_path}")

    # Resume training
    # trainer.train(resume_from_checkpoint=resume_path)
else:
    print("No checkpoints found.")

## 9. Export to HuggingFace Hub (Optional)

In [None]:
# Login to HuggingFace
from huggingface_hub import notebook_login

notebook_login()

In [None]:
# Push to Hub
HUB_MODEL_ID = "your-username/smolvlm2-256m-finetuned"  # Change this!

# trainer.push_to_hub(HUB_MODEL_ID)
# print(f"Model pushed to: https://huggingface.co/{HUB_MODEL_ID}")

---

## Tips for Colab Training

### Memory Issues (OOM)
1. Reduce `BATCH_SIZE` to 1
2. Reduce `max_length` in dataset to 512
3. Use 256M model instead of 500M
4. Ensure `gradient_checkpointing` is enabled

### Session Timeout
1. Save checkpoints frequently (`SAVE_STEPS=100`)
2. Mount Google Drive to persist checkpoints
3. Use `resume_from_checkpoint` to continue

### Speed Up Training
1. Use Colab Pro for A100 GPU
2. Increase `BATCH_SIZE` if memory allows
3. Use `bf16=True` for mixed precision

### Full Training
For full 50k step training, consider:
1. Colab Pro+ for longer sessions
2. Cloud providers (RunPod, Lambda Labs, etc.)
3. Multi-GPU setup with the distributed training scripts