In [1]:
!git clone https://github.com/triton-droids/pupper-simulations.git
%cd pupper-simulations

Cloning into 'pupper-simulations'...
remote: Enumerating objects: 412, done.[K
remote: Counting objects: 100% (412/412), done.[K
remote: Compressing objects: 100% (290/290), done.[K
remote: Total 412 (delta 170), reused 335 (delta 99), pack-reused 0 (from 0)[K
Receiving objects: 100% (412/412), 32.08 MiB | 29.33 MiB/s, done.
Resolving deltas: 100% (170/170), done.
/content/pupper-simulations


In [2]:
!pip install -r training_requirements.txt
%cd locomotion

Collecting mujoco==3.3.4 (from -r training_requirements.txt (line 1))
  Downloading mujoco-3.3.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (44 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/44.4 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.4/44.4 kB[0m [31m2.1 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting ml_collections==1.1.0 (from -r training_requirements.txt (line 2))
  Downloading ml_collections-1.1.0-py3-none-any.whl.metadata (22 kB)
Collecting jax==0.8.0 (from jax[cuda12]==0.8.0->-r training_requirements.txt (line 3))
  Downloading jax-0.8.0-py3-none-any.whl.metadata (13 kB)
Collecting brax==0.13.0 (from -r training_requirements.txt (line 4))
  Downloading brax-0.13.0-py3-none-any.whl.metadata (20 kB)
Collecting mediapy==1.2.4 (from -r training_requirements.txt (line 5))
  Downloading mediapy-1.2.4-py3-none-any.whl.metadata (4.8 kB)
Collecting glfw (from mujoco==3.

In [3]:
import os

#Set EGL as backend to use for rendering
os.environ['MUJOCO_GL'] = 'egl'

In [4]:
from bittle_env import BittleEnv

from brax import envs
import functools
from brax.training.agents.ppo import networks as ppo_networks
from brax.training.agents.ppo import train as ppo
from brax.io import model

from datetime import datetime
import sys
from etils import epath
from matplotlib import pyplot as plt

from flax.training import orbax_utils
from orbax import checkpoint as ocp

import jax
from jax import numpy as jp

import mediapy as media

import numpy as np

In [5]:
#Helper functions
def policy_params_fn(current_step, make_policy, params):
  # save checkpoints
  orbax_checkpointer = ocp.PyTreeCheckpointer()
  save_args = orbax_utils.save_args_from_target(params)
  path = ckpt_path / f'{current_step}'
  orbax_checkpointer.save(path, params, force=True, save_args=save_args)

x_data = []
y_data = []
ydataerr = []
times = [datetime.now()]

max_y, min_y = 40, 0

'''
Callback function to plot training progress.
'''
def progress_x(num_steps, metrics):
    """Enhanced progress callback with detailed metrics and better visualization."""
    print("\n" + "="*60)
    print(f"EVALUATION AT STEP {num_steps}")
    print("="*60)

    # Collect timing info
    times.append(datetime.now())
    time_delta = (times[-1] - times[-2]).total_seconds() if len(times) > 1 else 0

    # Extract key metrics
    episode_reward = metrics['eval/episode_reward']
    episode_reward_std = metrics['eval/episode_reward_std']

    # Print detailed metrics
    print(f"Episode Reward:     {episode_reward:.4f} ± {episode_reward_std:.4f}")
    print(f"Time since last:    {time_delta:.2f}s")

    # Print ALL available metrics for debugging
    print("\nAll available metrics:")
    for key, value in sorted(metrics.items()):
        if isinstance(value, (int, float, np.ndarray)):
            if isinstance(value, np.ndarray):
                value = float(value)
            print(f"  {key:30s}: {value:.6f}")

    # Store data
    x_data.append(num_steps)
    y_data.append(float(episode_reward))
    ydataerr.append(float(episode_reward_std))

    # Calculate statistics
    if len(y_data) > 1:
        improvement = y_data[-1] - y_data[-2]
        print(f"\nReward change: {improvement:+.4f}")
        print(f"Best so far:   {max(y_data):.4f}")
        print(f"Worst so far:  {min(y_data):.4f}")

    print("="*60 + "\n")

    # Improved plotting
    plt.clf()  # Clear the figure
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

    # Left plot: Reward over time
    ax1.errorbar(x_data, y_data, yerr=ydataerr,
                 marker='o', capsize=5, capthick=2,
                 linewidth=2, markersize=8)
    ax1.axhline(y=0, color='r', linestyle='--', alpha=0.3, label='Zero reward')
    ax1.set_xlim([0, train_fn.keywords['num_timesteps'] * 1.1])

    # Auto-adjust y-limits based on data
    if y_data:
        y_min = min(y_data) - max(ydataerr) * 1.2
        y_max = max(y_data) + max(ydataerr) * 1.2
        ax1.set_ylim([y_min, y_max])

    ax1.set_xlabel('# environment steps', fontsize=12)
    ax1.set_ylabel('Reward per episode', fontsize=12)
    ax1.set_title(f'Current: {y_data[-1]:.3f} ± {ydataerr[-1]:.3f}', fontsize=14)
    ax1.grid(True, alpha=0.3)
    ax1.legend()

    # Right plot: Individual reward components if available
    reward_components = {k: v for k, v in metrics.items()
                        if k.startswith('eval/episode_reward/') or
                           (k.startswith('eval/') and 'reward' in k.lower())}

    if reward_components:
        names = [k.replace('eval/episode_reward/', '').replace('eval/', '')
                for k in reward_components.keys()]
        values = [float(v) for v in reward_components.values()]

        colors = ['green' if v > 0 else 'red' for v in values]
        ax2.barh(range(len(names)), values, color=colors, alpha=0.6)
        ax2.set_yticks(range(len(names)))
        ax2.set_yticklabels(names, fontsize=9)
        ax2.axvline(x=0, color='black', linestyle='-', linewidth=0.5)
        ax2.set_xlabel('Reward contribution', fontsize=12)
        ax2.set_title('Reward Components', fontsize=12)
        ax2.grid(True, alpha=0.3, axis='x')
    else:
        ax2.text(0.5, 0.5, 'No component\nbreakdown available',
                ha='center', va='center', fontsize=12)
        ax2.set_xlim([0, 1])
        ax2.set_ylim([0, 1])

    plt.tight_layout()
    plt.show()

'''
Domain randomization for accurate sim-to-real transfer.
'''
def domain_randomize(sys, rng):
  """Randomizes the mjx.Model."""
  @jax.vmap
  def rand(rng):
    _, key = jax.random.split(rng, 2)
    # friction
    friction = jax.random.uniform(key, (1,), minval=0.6, maxval=1.4)
    friction = sys.geom_friction.at[:, 0].set(friction)
    # actuator
    _, key = jax.random.split(key, 2)
    gain_range = (-5, 5)
    param = jax.random.uniform(
        key, (1,), minval=gain_range[0], maxval=gain_range[1]
    ) + sys.actuator_gainprm[:, 0]
    gain = sys.actuator_gainprm.at[:, 0].set(param)
    bias = sys.actuator_biasprm.at[:, 1].set(-param)
    return friction, gain, bias

  friction, gain, bias = rand(rng)

  in_axes = jax.tree_util.tree_map(lambda x: None, sys)
  in_axes = in_axes.tree_replace({
      'geom_friction': 0,
      'actuator_gainprm': 0,
      'actuator_biasprm': 0,
  })

  sys = sys.tree_replace({
      'geom_friction': friction,
      'actuator_gainprm': gain,
      'actuator_biasprm': bias,
  })
  return sys, in_axes

# **Train Policy**

In [None]:
envs.register_environment('bittle', BittleEnv)

env_name = 'bittle'
xml_path = 'bittle_adapted_scene.xml'
env = envs.get_environment(env_name, xml_path = xml_path)

ckpt_path = epath.Path('/tmp/quadrupred_joystick/ckpts')
ckpt_path.mkdir(parents=True, exist_ok=True)

#Whether to run testing config for fast iteration or actual training config
TEST = False
train_fn = None

if TEST:

  # Minimal training config (est. 15 min on A100 GPU)
  train_fn = functools.partial(
      ppo.train,
      num_timesteps=10_000,
      num_evals=2,
      episode_length=100,
      num_envs=4,
      batch_size=4,
      unroll_length=5,
      num_minibatches=2,
      num_updates_per_batch=1,
  )

  print("Starting training with testing config")

else:
  # Heavier training config (est. 30 min on A100 GPU)
  train_fn = functools.partial(
    ppo.train,
    num_timesteps=10_000_000,
    num_evals=10,
    episode_length=1000,
    num_envs=4096,
    batch_size=512,
    unroll_length=20,
    num_minibatches=8,
    num_updates_per_batch=1,
  )

  print("Starting training with training config")


make_inference_fn, params, _ = train_fn(
      environment=env,
      progress_fn=progress_x,
)


print(f'time to jit: {times[1] - times[0]}')
print(f'time to train: {times[-1] - times[1]}')

# Save and reload params.
model_path = '/tmp/mjx_brax_quadruped_policy'
model.save_params(model_path, params)
params = model.load_params(model_path)

inference_fn = make_inference_fn(params)
jit_inference_fn = jax.jit(inference_fn)



Bittle has 9 actuators
Bittle has 16 position DOFs
Bittle has 15 velocity DOFs
Joint positions in q: indices [7:16]
Joint velocities in qd: indices [6:15]
Found lower leg body: servos_rf_1 (id=3)
Found lower leg body: servos_rr_1 (id=5)
Found lower leg body: servos_lf_1 (id=8)
Found lower leg body: servos_lr_1 (id=10)


  return literals.TypedNdArray(np.asarray(x, dtype), weak_type=False)


# **Visualize Policy**

In [None]:
eval_env = envs.get_environment(env_name, xml_path = xml_path)

jit_reset = jax.jit(eval_env.reset)
jit_step = jax.jit(eval_env.step)

In [None]:

# @markdown Commands **only used for Barkour Env**:
x_vel = 1.0  #@param {type: "number"}
y_vel = 0.0  #@param {type: "number"}
ang_vel = -0.5  #@param {type: "number"}

the_command = jp.array([x_vel, y_vel, ang_vel])

# initialize the state
rng = jax.random.PRNGKey(0)
state = jit_reset(rng)
state.info['command'] = the_command
rollout = [state.pipeline_state]

# grab a trajectory
n_steps = 500
render_every = 2

for i in range(n_steps):
  if i % 50 == 0:
    print(f"Step: {i}")
  act_rng, rng = jax.random.split(rng)
  ctrl, _ = jit_inference_fn(state.obs, act_rng)
  state = jit_step(state, ctrl)
  rollout.append(state.pipeline_state)

media.show_video(
    eval_env.render(rollout[::render_every], camera='track'),
    fps=1.0 / eval_env.dt / render_every)