# Bittle Quadruped Locomotion Training

Train a locomotion policy for the Bittle quadruped robot using Brax PPO, then export to ONNX and record a video.

**Modes:**
- **Test mode** (`TEST_MODE = True`): ~6 min on A100, minimal training for pipeline verification
- **Full mode** (`TEST_MODE = False`): ~30 min on A100, production-quality policy

**Outputs:** `outputs/policy.onnx` + `outputs/videos/latest_video.mp4`

**Requirements:** Google Colab with a GPU runtime (A100 recommended). Go to Runtime > Change runtime type > GPU.

In [None]:
# GPU check & install dependencies
import subprocess, os

result = subprocess.run(["nvidia-smi"], capture_output=True, text=True)
if result.returncode != 0:
    raise RuntimeError(
        "No GPU detected! Go to Runtime > Change runtime type > GPU, then restart."
    )
print(result.stdout)

GITHUB_TOKEN = ""  # @param {type:"string"}
REPO = "triton-droids/pupper-simulations"

if not os.path.exists("pupper-simulations"):
    if GITHUB_TOKEN:
        !git clone https://{GITHUB_TOKEN}@github.com/{REPO}.git
    else:
        !git clone https://github.com/{REPO}.git
    if not os.path.exists("pupper-simulations"):
        raise RuntimeError(
            "Clone failed. If the repo is private, set GITHUB_TOKEN above to a "
            "GitHub personal access token with repo scope."
        )
else:
    print("Repository already cloned, skipping.")

%cd pupper-simulations
!pip install -e .

print("\n--- Installation complete ---")
print("If you see dependency errors above, restart the runtime (Runtime > Restart runtime) and re-run this cell.")

In [None]:
# Environment setup
import os, sys, warnings

os.environ["MUJOCO_GL"] = "egl"
warnings.filterwarnings("ignore", category=RuntimeWarning, message="overflow encountered in cast")

# Add locomotion/ to sys.path for bare imports (e.g. bittle_env, training_config)
locomotion_dir = os.path.join(os.getcwd(), "locomotion")
if locomotion_dir not in sys.path:
    sys.path.insert(0, locomotion_dir)

import jax
print(f"JAX version: {jax.__version__}")
print(f"JAX devices: {jax.devices()}")
assert any(d.platform == "gpu" for d in jax.devices()), "No GPU visible to JAX!"

## Training Configuration

In [None]:
TEST_MODE = True  # @param {type:"boolean"}

ONNX_OUTPUT = "outputs/policy.onnx"
VIDEO_OUTPUT = "outputs/videos/latest_video.mp4"
XML_PATH = "locomotion/bittle_adapted_scene.xml"

## Train the Policy

In [None]:
from brax import envs
from brax.training.agents.ppo import train as ppo
from bittle_env import BittleEnv
from training_config import TrainingConfig

config = TrainingConfig(test_mode=TEST_MODE)
mode = "TEST" if TEST_MODE else "FULL"
print(f"Training Bittle ({mode} mode) | {config.to_dict()}")

# Register environment (idempotent)
if "bittle" not in envs._envs:
    envs.register_environment("bittle", BittleEnv)

env = envs.get_environment("bittle", xml_path=XML_PATH)

# Collect rewards for plotting
training_rewards = []
training_steps = []

def progress(step, metrics):
    reward = float(metrics["eval/episode_reward"])
    training_steps.append(step)
    training_rewards.append(reward)
    print(f"  Step {step:>10,} | Reward: {reward:.4f}")

make_policy, params, _ = ppo.train(
    environment=env,
    progress_fn=progress,
    num_timesteps=config.num_timesteps,
    num_evals=config.num_evals,
    episode_length=config.episode_length,
    num_envs=config.num_envs,
    batch_size=config.batch_size,
    unroll_length=config.unroll_length,
    num_minibatches=config.num_minibatches,
    num_updates_per_batch=config.num_updates_per_batch,
)

print(f"\nTraining complete! Final reward: {training_rewards[-1]:.4f}")

In [None]:
# Plot training curve
import matplotlib.pyplot as plt

plt.figure(figsize=(10, 4))
plt.plot(training_steps, training_rewards, marker="o")
plt.xlabel("Timesteps")
plt.ylabel("Episode Reward")
plt.title(f"Training Curve ({mode} mode)")
plt.grid(True)
plt.tight_layout()
plt.show()

## Export to ONNX

In [None]:
from onnx_export import export_policy_to_onnx

os.makedirs(os.path.dirname(ONNX_OUTPUT), exist_ok=True)
export_policy_to_onnx(params, ONNX_OUTPUT)

size_kb = os.path.getsize(ONNX_OUTPUT) / 1024
print(f"ONNX model saved to {ONNX_OUTPUT} ({size_kb:.1f} KB)")

## Record Video

In [None]:
from video_recorder import record_video

record_video(env, make_policy, params, VIDEO_OUTPUT)

size_kb = os.path.getsize(VIDEO_OUTPUT) / 1024
print(f"Video saved to {VIDEO_OUTPUT} ({size_kb:.1f} KB)")

## View Results

In [None]:
# Display video inline
import base64
from IPython.display import HTML

with open(VIDEO_OUTPUT, "rb") as f:
    video_b64 = base64.b64encode(f.read()).decode("utf-8")

HTML(f"""
<video width="640" autoplay loop controls>
  <source src="data:video/mp4;base64,{video_b64}" type="video/mp4">
</video>
""")

## Download Files

In [None]:
from google.colab import files

files.download(ONNX_OUTPUT)
files.download(VIDEO_OUTPUT)