# Train Diffusion Policy for SO-101 Robot

This notebook trains a **Diffusion Policy** on the SO-101 robot arm for pick-and-place tasks using [LeRobot](https://github.com/huggingface/lerobot).

## What is Diffusion Policy?
Diffusion Policy is a state-of-the-art imitation learning method that uses diffusion models to generate robot actions. It has shown excellent performance on manipulation tasks.

## Dataset
We use the official `lerobot/svla_so101_pickplace` dataset which contains 50 episodes (11,939 frames) of pick-and-place demonstrations on the SO-101 arm with **dual cameras**:
- `observation.images.up` — global/top-down view (480x640)
- `observation.images.side` — side view (480x640)

The policy auto-adapts to the dataset's camera setup — both cameras will be used as visual inputs.

## Requirements
- Google Colab with GPU runtime (T4 or better)
- Hugging Face account for uploading trained models

## 1. Install Dependencies

In [None]:
!pip install -U lerobot
!pip install wandb  # optional, for logging

# Reinstall torch+torchvision that match Colab's CUDA version
# (lerobot can pull in an incompatible torchvision)
!pip install torch torchvision --index-url https://download.pytorch.org/whl/cu124

# Fix transformers <-> huggingface_hub version mismatch
# (lerobot + torch reinstall can leave incompatible versions)
!pip install -U transformers huggingface_hub

## 2. Login to Hugging Face

Uses your Colab secrets `HF_USER` and `HF_TOKEN` (set in the key icon in the left sidebar).

In [None]:
import os
from google.colab import userdata

HF_USER = userdata.get('HF_USER')
HF_TOKEN = userdata.get('HF_TOKEN')

os.environ["HF_USER"] = HF_USER
os.environ["HF_TOKEN"] = HF_TOKEN

# Login to Hugging Face CLI
!huggingface-cli login --token {HF_TOKEN}

print(f"Logged in as: {HF_USER}")

## 3. Verify GPU Availability

In [None]:
import torch

if torch.cuda.is_available():
    print(f"GPU available: {torch.cuda.get_device_name(0)}")
    print(f"GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
else:
    print("WARNING: No GPU detected! Go to Runtime > Change runtime type > GPU")

## 4. Explore the Dataset

Let's inspect the dataset to confirm the camera setup and data shape before training.

In [None]:
from lerobot.datasets.lerobot_dataset import LeRobotDataset

dataset = LeRobotDataset("lerobot/svla_so101_pickplace")
print(f"Number of episodes: {dataset.num_episodes}")
print(f"Number of frames: {dataset.num_frames}")
print(f"FPS: {dataset.fps}")
print(f"\nFeatures:")
for key, feat in dataset.features.items():
    print(f"  {key}: {feat}")

# Show a sample frame
sample = dataset[0]
print(f"\nSample keys: {list(sample.keys())}")
for k, v in sample.items():
    if hasattr(v, 'shape'):
        print(f"  {k}: shape={v.shape}, dtype={v.dtype}")

## 5. Test Run (100 steps)

Quick sanity check to make sure everything works before committing to a full training run.

In [None]:
# Short test run — 100 steps, batch_size=8 (safe for T4 16GB)
!python -m lerobot.scripts.lerobot_train \
    --policy.type=diffusion \
    --dataset.repo_id=lerobot/svla_so101_pickplace \
    --output_dir=outputs/train/diffusion_so101_test \
    --job_name=diffusion_so101_test \
    --policy.device=cuda \
    --policy.push_to_hub=false \
    --batch_size=8 \
    --steps=100 \
    --save_freq=50 \
    --log_freq=10 \
    --wandb.enable=false

## 6. Full Training Run

Once the test run completes successfully, run the full training. Adjust `steps` and `batch_size` based on your Colab GPU.

| GPU | Recommended batch_size | ~Time for 100k steps |
|-----|----------------------|---------------------|
| T4 (16GB) | 8 | ~8-10 hours |
| A100 (40GB) | 32 | ~2-3 hours |

In [None]:
!python -m lerobot.scripts.lerobot_train \
    --policy.type=diffusion \
    --dataset.repo_id=lerobot/svla_so101_pickplace \
    --output_dir=outputs/train/diffusion_so101_pickplace \
    --job_name=diffusion_so101_pickplace \
    --policy.device=cuda \
    --policy.push_to_hub=false \
    --batch_size=8 \
    --steps=100000 \
    --save_freq=10000 \
    --log_freq=100 \
    --wandb.enable=false

## 7. (Optional) Train on Your Own Dataset

If you have recorded your own SO-101 dataset with wrist + global cameras, train on it instead. The policy will auto-detect your camera names from the dataset.

In [None]:
# Uncomment and modify to train on your own dataset
# Your dataset should have camera keys like:
#   observation.images.wrist  (wrist-mounted camera)
#   observation.images.top    (global/top-down camera)
# The policy will auto-adapt to whatever camera names your dataset uses.

# CUSTOM_DATASET = f"{HF_USER}/your_so101_dataset"

# !python -m lerobot.scripts.lerobot_train \
#     --policy.type=diffusion \
#     --dataset.repo_id={CUSTOM_DATASET} \
#     --output_dir=outputs/train/diffusion_so101_custom \
#     --job_name=diffusion_so101_custom \
#     --policy.device=cuda \
#     --policy.push_to_hub=false \
#     --batch_size=8 \
#     --steps=100000 \
#     --save_freq=10000 \
#     --log_freq=100 \
#     --wandb.enable=false

## 8. Upload Trained Model to Hugging Face Hub

After training completes, upload your model to share it or use it for inference on your robot.

In [None]:
MODEL_NAME = "diffusion_so101_pickplace"
CHECKPOINT_PATH = "outputs/train/diffusion_so101_pickplace/checkpoints/last/pretrained_model"

print(f"Uploading model to: {HF_USER}/{MODEL_NAME}")
!huggingface-cli upload {HF_USER}/{MODEL_NAME} {CHECKPOINT_PATH}

## 9. Test Inference (Optional)

Quick check that the trained model loads correctly.

In [None]:
from lerobot.policies.diffusion.modeling_diffusion import DiffusionPolicy
import torch

checkpoint_path = "outputs/train/diffusion_so101_pickplace/checkpoints/last/pretrained_model"

try:
    policy = DiffusionPolicy.from_pretrained(checkpoint_path)
    policy.eval()
    print("Model loaded successfully!")
    print(f"Input features: {policy.config.input_features}")
    print(f"Output features: {policy.config.output_features}")
except Exception as e:
    print(f"Could not load model: {e}")
    print("This is expected if training hasn't completed yet.")

## References

- [LeRobot GitHub](https://github.com/huggingface/lerobot)
- [Diffusion Policy Paper](https://arxiv.org/abs/2303.04137)
- [SO-101 Pick-Place Dataset](https://huggingface.co/datasets/lerobot/svla_so101_pickplace)
- [LeRobot Colab Notebooks](https://huggingface.co/docs/lerobot/notebooks)
- [LeRobot Docs — Training on Real Robots](https://huggingface.co/docs/lerobot/en/il_robots)