# VLA Fine-Tuning: SmolVLA on Meta-World

This notebook fine-tunes [SmolVLA](https://huggingface.co/lerobot/smolvla_base) (450M-param Vision-Language-Action model) on [Meta-World](https://github.com/Farama-Foundation/Metaworld) manipulation tasks.

We run a **3-way comparison**:
1. **Random init** — train SmolVLA from scratch (architecture only, no pretrained weights)
2. **Pretrained zero-shot** — evaluate the pretrained SmolVLA directly (no fine-tuning)
3. **Fine-tuned** — adapt pretrained SmolVLA to Meta-World tasks

**Requirements**: Colab Pro+ with A100 GPU. Training takes ~5 hours per run.

## 0. Clone project repo

In [None]:
!git clone https://github.com/vivpra89/VLA_BD.git /content/VLA_BD

## 1. Setup: Install Conda + LeRobot + SmolVLA

In [None]:
!pip install -q condacolab
import condacolab
condacolab.install()
# NOTE: Runtime will restart automatically after this cell. That's expected.

In [None]:
!git clone https://github.com/huggingface/lerobot.git /content/lerobot
!conda install ffmpeg=7.1.1 -c conda-forge -y
%cd /content/lerobot
!pip install -e ".[smolvla]"
!pip install "gymnasium==1.1.0" metaworld wandb matplotlib
%cd /content

## 2. (Optional) Login to Weights & Biases for training curves

In [None]:
import wandb
wandb.login()

## 3. Pick Meta-World tasks

We select 3 diverse tasks from the MT50 suite:
- `assembly-v3` — insert a peg into a hole
- `dial-turn-v3` — turn a dial
- `handle-press-side-v3` — press a handle from the side

These test different manipulation skills: insertion, rotation, and pressing.

In [None]:
import os
os.environ["TASKS"] = "assembly-v3,dial-turn-v3,handle-press-side-v3"
os.environ["DATASET"] = "lerobot/metaworld_mt50"
os.environ["OUTPUT_BASE"] = "/content/outputs"

TASKS = os.environ["TASKS"]
DATASET = os.environ["DATASET"]
OUTPUT_BASE = os.environ["OUTPUT_BASE"]
STEPS_FINETUNE = 20000
STEPS_SCRATCH = 20000
BATCH_SIZE = 64
EVAL_FREQ = 2000
EVAL_EPISODES = 5

## 4. Run 1: Fine-tune pretrained SmolVLA

Start from `lerobot/smolvla_base` (pretrained on real SO100 robot data) and adapt to Meta-World.
This is the standard VLA workflow — take a foundation model, fine-tune on your target domain.

In [None]:
%cd /content/lerobot
!lerobot-train \
    --policy.path=lerobot/smolvla_base \
    --dataset.repo_id=lerobot/metaworld_mt50 \
    --env.type=metaworld \
    --env.task=assembly-v3,dial-turn-v3,handle-press-side-v3 \
    --batch_size=64 \
    --steps=20000 \
    --eval.n_episodes=5 \
    --eval_freq=2000 \
    --save_freq=2000 \
    --output_dir=/content/outputs/finetuned \
    --job_name=smolvla_finetuned \
    --policy.device=cuda \
    --wandb.enable=true

## 5. Run 2: Train from scratch (random init) — OPTIONAL

Same architecture, same training, but **no pretrained weights**. This shows how much the pretrained VLM features help.

**Skip this cell if you want results faster.** You still get a 2-way comparison.

In [None]:
%cd /content/lerobot
!lerobot-train \
    --policy.type=smolvla \
    --dataset.repo_id=lerobot/metaworld_mt50 \
    --env.type=metaworld \
    --env.task=assembly-v3,dial-turn-v3,handle-press-side-v3 \
    --batch_size=64 \
    --steps=20000 \
    --eval.n_episodes=5 \
    --eval_freq=2000 \
    --save_freq=2000 \
    --output_dir=/content/outputs/scratch \
    --job_name=smolvla_scratch \
    --policy.device=cuda \
    --wandb.enable=true

## 6. Evaluate all conditions

Run evaluation on each task for:
- **Pretrained zero-shot** (no training on Meta-World)
- **Fine-tuned** (best checkpoint)
- **From scratch** (best checkpoint, if trained)

In [None]:
import subprocess, os

conditions = {
    "pretrained_zeroshot": "lerobot/smolvla_base",
    "finetuned": "/content/outputs/finetuned/checkpoints/last/pretrained_model",
}

scratch_path = "/content/outputs/scratch/checkpoints/last/pretrained_model"
if os.path.exists(scratch_path):
    conditions["from_scratch"] = scratch_path

tasks = ["assembly-v3", "dial-turn-v3", "handle-press-side-v3"]
results = {}

for cond_name, policy_path in conditions.items():
    results[cond_name] = {}
    for task in tasks:
        print(f"\n--- Evaluating {cond_name} on {task} ---")
        eval_dir = f"/content/outputs/eval/{cond_name}/{task}"
        cmd = f"""cd /content/lerobot && lerobot-eval \
            --policy.path={policy_path} \
            --env.type=metaworld \
            --env.task={task} \
            --eval.batch_size=1 \
            --eval.n_episodes=10 \
            --output_dir={eval_dir}"""
        result = subprocess.run(cmd, shell=True, capture_output=True, text=True)
        print(result.stdout[-500:] if result.stdout else "No stdout")
        if result.returncode != 0:
            print(f"Error: {result.stderr[-500:]}")
            results[cond_name][task] = 0.0
        else:
            results[cond_name][task] = "see eval_dir"

print("\n=== All Results ===")
for cond, task_results in results.items():
    print(f"\n{cond}:")
    for task, score in task_results.items():
        print(f"  {task}: {score}")

## 7. Visualize training curves

In [None]:
import matplotlib.pyplot as plt
import json, glob, os
import numpy as np

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

for idx, (label, output_dir) in enumerate([
    ("Fine-tuned (pretrained init)", "/content/outputs/finetuned"),
    ("From scratch (random init)", "/content/outputs/scratch"),
]):
    log_files = glob.glob(os.path.join(output_dir, "**/*.json"), recursive=True)
    if not log_files:
        print(f"No log files found for {label} in {output_dir}")
        continue

    print(f"Log files for {label}: {log_files}")
    for log_file in log_files:
        try:
            with open(log_file) as f:
                logs = [json.loads(line) for line in f if line.strip()]
            steps = [l["step"] for l in logs if "loss" in l]
            losses = [l["loss"] for l in logs if "loss" in l]
            if steps:
                color = "tab:blue" if idx == 0 else "tab:orange"
                axes[0].plot(steps, losses, label=label, color=color, alpha=0.8)
                break
        except Exception as e:
            print(f"  Could not parse {log_file}: {e}")

axes[0].set_xlabel("Training Steps")
axes[0].set_ylabel("Loss")
axes[0].set_title("Training Loss: Fine-tuned vs From Scratch")
axes[0].legend()
axes[0].grid(True, alpha=0.3)

ax2 = axes[1]
task_names = ["assembly", "dial-turn", "handle-press"]
x = np.arange(len(task_names))
width = 0.25

scratch_rates = [0.0, 0.0, 0.0]
zeroshot_rates = [0.0, 0.0, 0.0]
finetuned_rates = [0.0, 0.0, 0.0]

ax2.bar(x - width, scratch_rates, width, label="From scratch", color="tab:red", alpha=0.8)
ax2.bar(x, zeroshot_rates, width, label="Pretrained (zero-shot)", color="tab:orange", alpha=0.8)
ax2.bar(x + width, finetuned_rates, width, label="Fine-tuned", color="tab:green", alpha=0.8)

ax2.set_xlabel("Task")
ax2.set_ylabel("Success Rate")
ax2.set_title("3-Way Comparison: Success Rate by Task")
ax2.set_xticks(x)
ax2.set_xticklabels(task_names)
ax2.legend()
ax2.set_ylim(0, 1.0)
ax2.grid(True, alpha=0.3, axis="y")

plt.tight_layout()
os.makedirs("/content/outputs", exist_ok=True)
plt.savefig("/content/outputs/training_comparison.png", dpi=150, bbox_inches="tight")
plt.show()
print("Saved to /content/outputs/training_comparison.png")

## 8. Upload best checkpoint to HuggingFace Hub

Save the fine-tuned model so we can load it in the introspection notebook.

In [None]:
!huggingface-cli login

In [None]:
import os
HF_USER = os.environ.get("HF_USER", "YOUR_USERNAME")

!huggingface-cli upload {HF_USER}/smolvla-metaworld-finetuned \
    /content/outputs/finetuned/checkpoints/last/pretrained_model

print(f"\nModel uploaded to:")
print(f"  https://huggingface.co/{HF_USER}/smolvla-metaworld-finetuned")