# Joystick Policy Training Notebook

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os

# os.environ["JAX_COMPILATION_CACHE_DIR"] = "/tmp/jax_cache"

# set default env variable if not set
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = os.environ.get(
    "XLA_PYTHON_CLIENT_MEM_FRACTION", "0.9"
)
os.environ["MUJOCO_GL"] = os.environ.get("MUJOCO_GL", "egl")
os.environ["PYOPENGL_PLATFORM"] = os.environ.get("PYOPENGL_PLATFORM", "egl")
os.environ["XLA_FLAGS"] = (
    "--xla_gpu_enable_triton_softmax_fusion=true --xla_gpu_triton_gemm_any=True "
)

# from jax.experimental.compilation_cache import compilation_cache as cc

# cc.set_cache_dir("/tmp/jax_cache")

In [None]:
import functools
import json
from datetime import datetime

import jax
import jax.numpy as jp
import matplotlib.pyplot as plt
import mediapy as media
import mujoco
import numpy as np
import wandb
from brax.training.agents.ppo import networks as ppo_networks
from brax.training.agents.ppo import train as ppo
from etils import epath
from flax.training import orbax_utils
from IPython.display import clear_output, display
from orbax import checkpoint as ocp

from mujoco_playground import locomotion, wrapper
from mujoco_playground.config import locomotion_params

# # Enable persistent compilation cache.
jax.config.update("jax_compilation_cache_dir", "/tmp/jax_cache")
jax.config.update("jax_persistent_cache_min_entry_size_bytes", -1)
jax.config.update("jax_persistent_cache_min_compile_time_secs", 0)
jax.config.update(
    "jax_persistent_cache_enable_xla_caches", "xla_gpu_per_fusion_autotune_cache_dir"
)

In [4]:
env_name = "Go1JoystickRoughTerrain"
env_cfg = locomotion.get_default_config(env_name)
randomizer = locomotion.get_domain_randomizer(env_name)
ppo_params = locomotion_params.brax_ppo_config(env_name)

In [5]:
from track_mjx.environment.task import joysticks

env = joysticks.RodentJoystick()

In [6]:
from pprint import pprint

ppo_params.num_evals = 500
ppo_params.num_timesteps = 1_000_000_000
ppo_params.num_envs = 2048
ppo_params.episode_length = 500
pprint(ppo_params)

action_repeat: 1
batch_size: 256
discounting: 0.97
entropy_cost: 0.01
episode_length: 500
learning_rate: 0.0003
max_grad_norm: 1.0
network_factory:
  policy_hidden_layer_sizes: &id001 !!python/tuple
  - 512
  - 256
  - 128
  policy_obs_key: state
  value_hidden_layer_sizes: *id001
  value_obs_key: privileged_state
normalize_observations: true
num_envs: 2048
num_evals: 500
num_minibatches: 32
num_resets_per_eval: 1
num_timesteps: 1000000000
num_updates_per_batch: 4
reward_scaling: 1.0
unroll_length: 20



In [7]:
# Setup wandb logging.
USE_WANDB = True

if USE_WANDB:
    wandb.init(project="rodent_joysticks", config=ppo_params.to_dict())
    wandb.config.update(
        {
            "env_name": "rodent_joysticks",
        }
    )
    wandb.run.name = (
        f"rodent_joysticks_{env_name}_{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}"
    )

[34m[1mwandb[0m: Currently logged in as: [33myuy004[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


In [8]:
SUFFIX = None
FINETUNE_PATH = None
env_name = "rodent_joysticks"
# Generate unique experiment name.
now = datetime.now()
timestamp = now.strftime("%Y%m%d-%H%M%S")
exp_name = f"{env_name}-{timestamp}"
if SUFFIX is not None:
    exp_name += f"-{SUFFIX}"
print(f"Experiment name: {exp_name}")

# Possibly restore from the latest checkpoint.
if FINETUNE_PATH is not None:
    FINETUNE_PATH = epath.Path(FINETUNE_PATH)
    latest_ckpts = list(FINETUNE_PATH.glob("*"))
    latest_ckpts = [ckpt for ckpt in latest_ckpts if ckpt.is_dir()]
    latest_ckpts.sort(key=lambda x: int(x.name))
    latest_ckpt = latest_ckpts[-1]
    restore_checkpoint_path = latest_ckpt
    print(f"Restoring from: {restore_checkpoint_path}")
else:
    restore_checkpoint_path = None

Experiment name: rodent_joysticks-20250315-003516


In [9]:
ckpt_path = epath.Path("checkpoints").resolve() / exp_name
ckpt_path.mkdir(parents=True, exist_ok=True)
print(f"Checkpoint path: {ckpt_path}")

with open(ckpt_path / "config.json", "w") as fp:
    json.dump(env_cfg.to_dict(), fp, indent=4)

Checkpoint path: /root/vast/scott-yang/track-mjx/notebooks/checkpoints/rodent_joysticks-20250315-003516


In [None]:
x_data, y_data, y_dataerr = [], [], []
times = [datetime.now()]


def progress(num_steps, metrics):
    # Log to wandb.
    if USE_WANDB:
        wandb.log(metrics, step=num_steps)

    # Plot.
    clear_output(wait=True)
    times.append(datetime.now())
    x_data.append(num_steps)
    y_data.append(metrics["eval/episode_reward"])
    y_dataerr.append(metrics["eval/episode_reward_std"])

    plt.xlim([0, ppo_params.num_timesteps * 1.25])
    plt.ylim([0, 75])
    plt.xlabel("# environment steps")
    plt.ylabel("reward per episode")
    plt.title(f"y={y_data[-1]:.3f}")
    plt.errorbar(x_data, y_data, yerr=y_dataerr, color="blue")

    display(plt.gcf())


def policy_params_fn(current_step, make_policy, params):
    del make_policy  # Unused.
    orbax_checkpointer = ocp.PyTreeCheckpointer()
    save_args = orbax_utils.save_args_from_target(params)
    path = ckpt_path / f"{current_step}"
    orbax_checkpointer.save(path, params, force=True, save_args=save_args)


training_params = dict(ppo_params)
del training_params["network_factory"]

train_fn = functools.partial(
    ppo.train,
    **training_params,
    network_factory=functools.partial(
        ppo_networks.make_ppo_networks, **ppo_params.network_factory
    ),
    restore_checkpoint_path=restore_checkpoint_path,
    progress_fn=progress,
    wrap_env_fn=wrapper.wrap_for_brax_training,
    policy_params_fn=policy_params_fn,
)

env = joysticks.RodentJoystick()
eval_env = joysticks.RodentJoystick()

make_inference_fn, params, _ = train_fn(environment=env, eval_env=eval_env)
if len(times) > 1:
    print(f"time to jit: {times[1] - times[0]}")
    print(f"time to train: {times[-1] - times[1]}")

# Make a final plot of reward and success vs WALLCLOCK time.
plt.figure()
plt.ylim([0, 75])
plt.xlabel("wallclock time (s)")
plt.ylabel("reward per episode")
plt.title(f"y={y_data[-1]:.3f}")
plt.errorbar(
    [(t - times[0]).total_seconds() for t in times[:-1]],
    y_data,
    yerr=y_dataerr,
    color="blue",
)
plt.show()

2025-03-15 00:39:08.377130: E external/xla/xla/service/slow_operation_alarm.cc:73] 
********************************
[Compiling module jit_reset] Very slow compile? If you want to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.
********************************
2025-03-15 00:41:06.616463: E external/xla/xla/service/slow_operation_alarm.cc:140] The operation took 3m58.239574255s

********************************
[Compiling module jit_reset] Very slow compile? If you want to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.
********************************


In [10]:
inference_fn = make_inference_fn(params, deterministic=True)
jit_inference_fn = jax.jit(inference_fn)

In [None]:
import pickle

# Save normalizer and policy params to the checkpoint dir.
normalizer_params, policy_params, value_params = params
with open(ckpt_path / "params.pkl", "wb") as f:
    data = {
        "normalizer_params": normalizer_params,
        "policy_params": policy_params,
        "value_params": value_params,
    }
    pickle.dump(data, f)

In [None]:
from mujoco_playground._src.gait import draw_joystick_command

# Enable perturbation in the eval env.
env_cfg = locomotion.get_default_config(env_name)
# env_cfg.episode_length = 500  # Shorten episode length so we don't go out of boudns of the hfield.
env_cfg.pert_config.enable = True
env_cfg.pert_config.velocity_kick = [3.0, 6.0]
env_cfg.pert_config.kick_wait_times = [5.0, 15.0]
env_cfg.command_config.a = [1.5, 0.8, 2 * jp.pi]
eval_env = locomotion.load(env_name, config=env_cfg)

jit_reset = jax.jit(eval_env.reset)
jit_step = jax.jit(eval_env.step)

In [None]:
velocity_kick_range = [0.0, 0.0]  # Disable velocity kick.
kick_duration_range = [0.05, 0.2]


def sample_pert(rng):
    rng, key1, key2 = jax.random.split(rng, 3)
    pert_mag = jax.random.uniform(
        key1, minval=velocity_kick_range[0], maxval=velocity_kick_range[1]
    )
    duration_seconds = jax.random.uniform(
        key2, minval=kick_duration_range[0], maxval=kick_duration_range[1]
    )
    duration_steps = jp.round(duration_seconds / eval_env.dt).astype(jp.int32)
    state.info["pert_mag"] = pert_mag
    state.info["pert_duration"] = duration_steps
    state.info["pert_duration_seconds"] = duration_seconds
    return rng


rng = jax.random.PRNGKey(0)
rollout = []
modify_scene_fns = []

swing_peak = []
rewards = []
linvel = []
angvel = []
track = []
foot_vel = []
rews = []
contact = []

x = 0.0
y = 0.0
yaw = jp.pi
command = jp.array([x, y, yaw])

state = jit_reset(rng)
if state.info["steps_since_last_pert"] < state.info["steps_until_next_pert"]:
    rng = sample_pert(rng)
state.info["command"] = command
for i in range(env_cfg.episode_length):
    if state.info["steps_since_last_pert"] < state.info["steps_until_next_pert"]:
        rng = sample_pert(rng)
    act_rng, rng = jax.random.split(rng)
    ctrl, _ = jit_inference_fn(state.obs, act_rng)
    state = jit_step(state, ctrl)
    state.info["command"] = command
    rews.append({k: v for k, v in state.metrics.items() if k.startswith("reward/")})
    rollout.append(state)
    swing_peak.append(state.info["swing_peak"])
    rewards.append(
        {k[7:]: v for k, v in state.metrics.items() if k.startswith("reward/")}
    )
    linvel.append(env.get_global_linvel(state.data))
    angvel.append(env.get_gyro(state.data))
    track.append(
        env._reward_tracking_lin_vel(
            state.info["command"], env.get_local_linvel(state.data)
        )
    )

    feet_vel = state.data.sensordata[env._foot_linvel_sensor_adr]
    vel_xy = feet_vel[..., :2]
    vel_norm = jp.sqrt(jp.linalg.norm(vel_xy, axis=-1))
    foot_vel.append(vel_norm)

    contact.append(state.info["last_contact"])

    xyz = np.array(state.data.xpos[env._torso_body_id])
    xyz += np.array([0, 0, 0.2])
    x_axis = state.data.xmat[env._torso_body_id, 0]
    yaw = -np.arctan2(x_axis[1], x_axis[0])
    modify_scene_fns.append(
        functools.partial(
            draw_joystick_command,
            cmd=state.info["command"],
            xyz=xyz,
            theta=yaw,
            scl=abs(state.info["command"][0]) / env_cfg.command_config.a[0],
        )
    )

In [None]:
render_every = 2
fps = 1.0 / eval_env.dt / render_every
traj = rollout[::render_every]
mod_fns = modify_scene_fns[::render_every]

scene_option = mujoco.MjvOption()
scene_option.geomgroup[2] = True
scene_option.geomgroup[3] = False
scene_option.flags[mujoco.mjtVisFlag.mjVIS_CONTACTPOINT] = True
scene_option.flags[mujoco.mjtVisFlag.mjVIS_TRANSPARENT] = False
scene_option.flags[mujoco.mjtVisFlag.mjVIS_PERTFORCE] = True

frames = eval_env.render(
    traj,
    camera="track",
    scene_option=scene_option,
    width=640,
    height=480,
    modify_scene_fns=mod_fns,
)
media.show_video(frames, fps=fps, loop=False)

In [None]:
# Plot each foot in a 2x2 grid.
swing_peak = jp.array(swing_peak)
names = ["FR", "FL", "RR", "RL"]
colors = ["r", "g", "b", "y"]
fig, axs = plt.subplots(2, 2)
for i, ax in enumerate(axs.flat):
    ax.plot(swing_peak[:, i], color=colors[i])
    ax.set_ylim([0, env_cfg.reward_config.max_foot_height * 1.25])
    ax.axhline(env_cfg.reward_config.max_foot_height, color="k", linestyle="--")
    ax.set_title(names[i])
    ax.set_xlabel("time")
    ax.set_ylabel("height")
plt.tight_layout()
plt.show()

linvel_x = jp.array(linvel)[:, 0]
linvel_y = jp.array(linvel)[:, 1]
angvel_yaw = jp.array(angvel)[:, 2]

# Plot whether velocity is within the command range.
linvel_x = jp.convolve(linvel_x, jp.ones(10) / 10, mode="same")
linvel_y = jp.convolve(linvel_y, jp.ones(10) / 10, mode="same")
angvel_yaw = jp.convolve(angvel_yaw, jp.ones(10) / 10, mode="same")

fig, axes = plt.subplots(3, 1, figsize=(10, 10))
axes[0].plot(linvel_x)
axes[1].plot(linvel_y)
axes[2].plot(angvel_yaw)

axes[0].set_ylim(-env_cfg.command_config.a[0], env_cfg.command_config.a[0])
axes[1].set_ylim(-env_cfg.command_config.a[1], env_cfg.command_config.a[1])
axes[2].set_ylim(-env_cfg.command_config.a[2], env_cfg.command_config.a[2])

for i, ax in enumerate(axes):
    ax.axhline(state.info["command"][i], color="red", linestyle="--")

labels = ["dx", "dy", "dyaw"]
for i, ax in enumerate(axes):
    ax.set_ylabel(labels[i])

### Run at increasingly faster speeds

In [None]:
rng = jax.random.PRNGKey(0)
rollout = []
modify_scene_fns = []
swing_peak = []
linvel = []
angvel = []

x = -0.25
command = jp.array([x, 0, 0])
state = jit_reset(rng)
for i in range(1_400):
    # Increase the forward velocity by 0.2 m/s every 200 steps.
    if i % 200 == 0:
        x += 0.25
        print(f"Setting x to {x}")
        command = jp.array([x, 0, 0])
    state.info["command"] = command
    if state.info["steps_since_last_pert"] < state.info["steps_until_next_pert"]:
        rng = sample_pert(rng)
    act_rng, rng = jax.random.split(rng)
    ctrl, _ = jit_inference_fn(state.obs, act_rng)
    state = jit_step(state, ctrl)
    rollout.append(state)
    swing_peak.append(state.info["swing_peak"])
    linvel.append(env.get_global_linvel(state.data))
    angvel.append(env.get_gyro(state.data))
    xyz = np.array(state.data.xpos[env._torso_body_id])
    xyz += np.array([0, 0, 0.2])
    x_axis = state.data.xmat[env._torso_body_id, 0]
    yaw = -np.arctan2(x_axis[1], x_axis[0])
    modify_scene_fns.append(
        functools.partial(
            draw_joystick_command,
            cmd=command,
            xyz=xyz,
            theta=yaw,
            scl=abs(command[0]) / env_cfg.command_config.a[0],
        )
    )

In [None]:
# Plot each foot in a 2x2 grid.
swing_peak = jp.array(swing_peak)
names = ["FR", "FL", "RR", "RL"]
colors = ["r", "g", "b", "y"]
fig, axs = plt.subplots(2, 2)
for i, ax in enumerate(axs.flat):
    ax.plot(swing_peak[:, i], color=colors[i])
    ax.set_ylim([0, env_cfg.reward_config.max_foot_height * 1.25])
    ax.axhline(env_cfg.reward_config.max_foot_height, color="k", linestyle="--")
    ax.set_title(names[i])
    ax.set_xlabel("time")
    ax.set_ylabel("height")
plt.tight_layout()
plt.show()

linvel_x = jp.array(linvel)[:, 0]
linvel_y = jp.array(linvel)[:, 1]
angvel_yaw = jp.array(angvel)[:, 2]

# Plot whether velocity is within the command range.
linvel_x = jp.convolve(linvel_x, jp.ones(10) / 10, mode="same")
linvel_y = jp.convolve(linvel_y, jp.ones(10) / 10, mode="same")
angvel_yaw = jp.convolve(angvel_yaw, jp.ones(10) / 10, mode="same")

fig, axes = plt.subplots(3, 1, figsize=(10, 10))
axes[0].plot(linvel_x)
axes[1].plot(linvel_y)
axes[2].plot(angvel_yaw)

axes[0].set_ylim(-env_cfg.command_config.a[0], env_cfg.command_config.a[0])
axes[1].set_ylim(-env_cfg.command_config.a[1], env_cfg.command_config.a[1])
axes[2].set_ylim(-env_cfg.command_config.a[2], env_cfg.command_config.a[2])

for i, ax in enumerate(axes):
    ax.axhline(state.info["command"][i], color="red", linestyle="--")

labels = ["dx", "dy", "dyaw"]
for i, ax in enumerate(axes):
    ax.set_ylabel(labels[i])

In [None]:
render_every = 2
fps = 1.0 / eval_env.dt / render_every
print(f"fps: {fps}")

traj = rollout[::render_every]
mod_fns = modify_scene_fns[::render_every]
assert len(traj) == len(mod_fns)

scene_option = mujoco.MjvOption()
scene_option.geomgroup[2] = True
scene_option.geomgroup[3] = False
scene_option.flags[mujoco.mjtVisFlag.mjVIS_CONTACTPOINT] = True
scene_option.flags[mujoco.mjtVisFlag.mjVIS_PERTFORCE] = True

frames = eval_env.render(
    traj,
    camera="track",
    height=480,
    width=640,
    modify_scene_fns=mod_fns,
    scene_option=scene_option,
)
media.show_video(frames, fps=fps, loop=False)
# media.write_video(f"{env_name}_faster_commands.mp4", frames, fps=fps, qp=18)

### Resist perturbation

In [None]:
rng = jax.random.PRNGKey(12345)
rollout = []
modify_scene_fns = []
n_episodes = 1

x_vel = 0.8
y_vel = 0.0
ang_vel = 0.25
command = jp.array([x_vel, y_vel, ang_vel])

for _ in range(n_episodes):
    state = jit_reset(rng)
    for i in range(1_600):
        if i % 200 == 0:
            x_vel = -x_vel
            y_vel = -y_vel
            ang_vel = -ang_vel
            command = jp.array([x_vel, y_vel, ang_vel])
        state.info["command"] = command
        act_rng, rng = jax.random.split(rng)
        ctrl, _ = jit_inference_fn(state.obs, act_rng)
        state = jit_step(state, ctrl)
        rollout.append(state)
        xyz = np.array(state.data.xpos[env._torso_body_id])
        xyz += np.array([0, 0, 0.2])
        x_axis = state.data.xmat[env._torso_body_id, 0]
        yaw = -np.arctan2(x_axis[1], x_axis[0])
        modify_scene_fns.append(
            functools.partial(
                draw_joystick_command,
                cmd=command,
                xyz=xyz,
                theta=yaw,
                scl=abs(command[0]) / env_cfg.command_config.a[0],
            )
        )

In [None]:
render_every = 1
fps = 1.0 / eval_env.dt / render_every
print(f"fps: {fps}")

traj = rollout[::render_every]
mod_fns = modify_scene_fns[::render_every]
assert len(traj) == len(mod_fns)

scene_option = mujoco.MjvOption()
scene_option.geomgroup[2] = True
scene_option.geomgroup[3] = False
scene_option.flags[mujoco.mjtVisFlag.mjVIS_CONTACTPOINT] = True
scene_option.flags[mujoco.mjtVisFlag.mjVIS_PERTFORCE] = True
## Set to True if you can't see the perturbation force.
scene_option.flags[mujoco.mjtVisFlag.mjVIS_TRANSPARENT] = False

frames = eval_env.render(
    traj,
    camera="track",
    height=480,
    width=640,
    modify_scene_fns=mod_fns,
    scene_option=scene_option,
)
media.show_video(frames, fps=fps, loop=False)
# media.write_video(f"{env_name}_resist_perturbation.mp4", frames, fps=fps, qp=18)