In [1]:
%load_ext autoreload

%autoreload 2

In [2]:
import os

xla_flags = os.environ.get("XLA_FLAGS", "")
xla_flags += " --xla_gpu_triton_gemm_any=True"
os.environ["XLA_FLAGS"] = xla_flags
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
os.environ["MUJOCO_GL"] = "egl"

In [3]:
import functools
import json
from datetime import datetime

import jax
import jax.numpy as jp
import matplotlib.pyplot as plt
import mediapy as media
import mujoco
import numpy as np
import wandb
from brax.training.agents.ppo import networks as ppo_networks
from brax.training.agents.ppo import train as ppo
from etils import epath
from flax.training import orbax_utils
from IPython.display import clear_output, display
from orbax import checkpoint as ocp

from mujoco_playground import BraxEnvWrapper, locomotion
from mujoco_playground.utils.vis_utils import draw_joystick_command

# Enable persistent compilation cache.
jax.config.update("jax_compilation_cache_dir", "/tmp/jax_cache")
jax.config.update("jax_persistent_cache_min_entry_size_bytes", -1)
jax.config.update("jax_persistent_cache_min_compile_time_secs", 0)

In [None]:
env_name = "BerkeleyHumanoidJoystickGaitTracking"
env_cfg = locomotion.get_default_config(env_name)
randomizer = locomotion.get_domain_randomizer(env_name)

print(env_cfg)

In [5]:
# Setup wandb logging.
USE_WANDB = False

if USE_WANDB:
  wandb.init(project="mjxrl", config=env_cfg)
  wandb.config.update({
      "env_name": env_name,
  })

In [None]:
SUFFIX = None
FINETUNE_PATH = None

# Generate unique experiment name.
now = datetime.now()
timestamp = now.strftime("%Y%m%d-%H%M%S")
exp_name = f"{env_name}-{timestamp}"
if SUFFIX is not None:
  exp_name += f"-{SUFFIX}"
print(f"Experiment name: {exp_name}")

# Possibly restore from the latest checkpoint.
if FINETUNE_PATH is not None:
  FINETUNE_PATH = epath.Path(FINETUNE_PATH)
  latest_ckpts = list(FINETUNE_PATH.glob("*"))
  latest_ckpts.sort()
  latest_ckpt = latest_ckpts[0]
  restore_checkpoint_path = latest_ckpt
  print(f"Restoring from: {restore_checkpoint_path}")
else:
  restore_checkpoint_path = None

In [None]:
ckpt_path = epath.Path("checkpoints").resolve() / exp_name
ckpt_path.mkdir(parents=True, exist_ok=True)
print(f"Checkpoint path: {ckpt_path}")

with open(ckpt_path / "config.json", "w") as fp:
  json.dump(env_cfg.to_dict(), fp, indent=4)


def policy_params_fn(current_step, make_policy, params):
  orbax_checkpointer = ocp.PyTreeCheckpointer()
  save_args = orbax_utils.save_args_from_target(params)
  path = ckpt_path / f"{current_step}"
  orbax_checkpointer.save(path, params, force=True, save_args=save_args)


make_networks_factory = functools.partial(
    ppo_networks.make_ppo_networks,
    policy_hidden_layer_sizes=(128, 128, 128, 128),
)

train_fn = functools.partial(
    ppo.train,
    num_timesteps=400_000_000,
    num_evals=16,
    reward_scaling=0.1,
    episode_length=env_cfg.episode_length,
    normalize_observations=True,
    action_repeat=1,
    unroll_length=32,
    num_minibatches=32,
    num_updates_per_batch=5,
    discounting=0.98,
    learning_rate=1e-4,
    entropy_cost=0,
    num_envs=32768,
    batch_size=1024,
    network_factory=make_networks_factory,
    policy_params_fn=policy_params_fn,
    randomization_fn=None,  # randomizer,
    seed=1,
    restore_checkpoint_path=restore_checkpoint_path,
    # max_grad_norm=1.0,
    clipping_epsilon=0.2,
)

In [None]:
x_data, y_data, y_dataerr = [], [], []
times = [datetime.now()]


def progress(num_steps, metrics):
  # Log to wandb.
  if USE_WANDB:
    wandb.log(metrics, step=num_steps)

  # Plot.
  clear_output(wait=True)
  times.append(datetime.now())
  x_data.append(num_steps)
  y_data.append(metrics["eval/episode_reward"])
  y_dataerr.append(metrics["eval/episode_reward_std"])

  plt.xlim([0, train_fn.keywords["num_timesteps"] * 1.25])
  plt.ylim([0, 140])
  plt.xlabel("# environment steps")
  plt.ylabel("reward per episode")
  plt.title(f"y={y_data[-1]:.3f}")
  plt.errorbar(x_data, y_data, yerr=y_dataerr, color="blue")

  display(plt.gcf())


env = locomotion.load(env_name, config=env_cfg)
eval_env = locomotion.load(env_name, config=env_cfg)
make_inference_fn, params, _ = train_fn(
    environment=BraxEnvWrapper(env),
    progress_fn=progress,
    eval_env=BraxEnvWrapper(eval_env),
)
if len(times) > 1:
  print(f"time to jit: {times[1] - times[0]}")
  print(f"time to train: {times[-1] - times[1]}")

In [9]:
inference_fn = make_inference_fn(params, deterministic=True)
jit_inference_fn = jax.jit(inference_fn)

In [10]:
eval_env = locomotion.load(env_name, config=env_cfg)
jit_reset = jax.jit(eval_env.reset)
jit_step = jax.jit(eval_env.step)

In [None]:
import numpy as np

x_min = 0.0
x_max = 1.0


def get_scl(x: float) -> float:
  return (x - x_min) / (x_max - x_min)


yaw = 0.0
x = 0.5
cmd = jp.array([x, 0, yaw])

rng = jax.random.PRNGKey(12345)
rng, reset_rng = jax.random.split(rng)
state = jit_reset(reset_rng)
state.info["command"] = cmd

# Change gait type.
# PHASES = jp.array(
#     [
#         [0, jp.pi, jp.pi, 0],  # trot
#         [0, 0.5 * jp.pi, jp.pi, 1.5 * jp.pi],  # walk
#         [0, jp.pi, 0, jp.pi],  # pace
#         [0, 0, jp.pi, jp.pi],  # bound
#         [0, 0, 0, 0],  # pronk
#     ]
# )
PHASES = jp.array([
    [0, jp.pi],  # walk
    [0, 0],  # jump
])
gait = jp.array(0)
state.info["gait"] = gait
state.info["phase"] = PHASES[gait]

modify_scene_fns = []
rollout = []
swing_peak = []
air_time = []
lin_vel_x = []
lin_vel_y = []
ang_vel_z = []
foot_height = 0.1  # [0.08, 0.4]
gait_freq = 1.0  # [0.5, 4.0]
phase_dt = 2 * jp.pi * env.dt * jp.array(gait_freq)
state.info["phase_dt"] = phase_dt
state.info["gait_freq"] = jp.array(gait_freq)
for i in range(800):
  if i % 200 == 0:
    # foot_height = 0.2  # 0.15 # np.random.uniform(0.08, 0.3)
    state.info["foot_height"] = jp.array(foot_height)
    # x += 0.2
    # yaw = np.random.uniform(-0.7, 0.7)
    cmd = jp.array([x, 0, yaw])
    state.info["command"] = cmd
    # Scale frequency proprtionally to speed.
    # gait_freq = 1.5 + 1.5 * x
    # phase_dt = 2 * jp.pi * env.dt * jp.array(gait_freq)
    # state.info["phase_dt"] = phase_dt
    # state.info["gait_freq"] = jp.array(gait_freq)

  act_rng, rng = jax.random.split(rng)
  ctrl, _ = jit_inference_fn(state.obs, act_rng)
  state = jit_step(state, ctrl)
  state.info["command"] = cmd
  state.info["phase_dt"] = phase_dt
  state.info["gait_freq"] = jp.array(gait_freq)
  rollout.append(state)
  swing_peak.append(state.info["swing_peak"])
  air_time.append(state.info["feet_air_time"])
  lin_vel_x.append(state.info["lin_vel"][0])
  lin_vel_y.append(state.info["lin_vel"][1])
  ang_vel_z.append(state.info["ang_vel"][2])

  torso_id = env.mj_model.body("torso").id
  xyz = np.array(state.data.xpos[torso_id])
  xyz += np.array([0.2, 0, 0])
  x_axis = state.data.xmat[torso_id, 0]
  modify_scene_fns.append(
      functools.partial(
          draw_joystick_command,
          cmd=cmd,
          xyz=xyz,
          # foot_height=foot_height,
          # foot_pos=state.data.site_xpos[env._feet_site_id],
          theta=-np.arctan2(x_axis[1], x_axis[0]),
          scl=get_scl(x),
      )
  )

In [None]:
render_every = 2
fps = 1.0 / eval_env.dt / render_every
print(f"fps: {fps}")

traj = rollout[::render_every]
modify_scene_fns = modify_scene_fns[::render_every]

scene_option = mujoco.MjvOption()
scene_option.geomgroup[2] = True
scene_option.geomgroup[3] = False
scene_option.flags[mujoco.mjtVisFlag.mjVIS_CONTACTPOINT] = False
scene_option.flags[mujoco.mjtVisFlag.mjVIS_CONTACTFORCE] = False

frames = eval_env.render(
    traj,
    camera="track",
    scene_option=scene_option,
    modify_scene_fns=modify_scene_fns,
    height=480,
    width=640,
)
media.show_video(frames, fps=fps)

In [None]:
# Plot each foot in a 2x2 grid.
swing_peak = jp.array(swing_peak)
# names = ["FR", "FL", "HR", "HL"]
names = ["left_foot", "right_foot"]
colors = ["r", "g", "b", "y"]
fig, axs = plt.subplots(1, 2)
for i, ax in enumerate(axs.flat):
  ax.plot(swing_peak[:, i], color=colors[i])
  ax.axhline(foot_height, color="k", linestyle="--")
  ax.set_title(names[i])
  ax.set_xlabel("time")
  ax.set_ylabel("height")
plt.tight_layout()
plt.show()