In [1]:
%load_ext autoreload

%autoreload 2

In [2]:
import os

xla_flags = os.environ.get("XLA_FLAGS", "")
xla_flags += " --xla_gpu_triton_gemm_any=True"
os.environ["XLA_FLAGS"] = xla_flags
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
os.environ["MUJOCO_GL"] = "egl"

In [3]:
import functools
import json
from datetime import datetime

import jax
import jax.numpy as jp
import pickle
import matplotlib.pyplot as plt
import mediapy as media
import mujoco
import numpy as np
import wandb
from brax.training.agents.ppo import networks as ppo_networks
from brax.training.agents.ppo import train as ppo
from etils import epath
from flax.training import orbax_utils
from IPython.display import clear_output, display
from orbax import checkpoint as ocp

from mujoco_playground import locomotion, wrapper
from mujoco_playground.config import locomotion_params

# Enable persistent compilation cache.
jax.config.update("jax_compilation_cache_dir", "/tmp/jax_cache")
jax.config.update("jax_persistent_cache_min_entry_size_bytes", -1)
jax.config.update("jax_persistent_cache_min_compile_time_secs", 0)

In [4]:
# env_name = "BerkeleyHumanoidJoystickRoughTerrain"
# env_name = "BerkeleyHumanoidJoystickFlatTerrain"
env_name = "G1JoystickFlatTerrain"
env_cfg = locomotion.get_default_config(env_name)
randomizer = locomotion.get_domain_randomizer(env_name)
ppo_params = locomotion_params.brax_ppo_config(env_name)

In [5]:
env_cfg.reward_config.scales.energy = 0.0
env_cfg.reward_config.scales.dof_acc = 0.0

# ppo_params.num_timesteps = 50_000_000
# ppo_params.num_evals = 5
# ppo_params.learning_rate = 1e-3
# env_cfg.action_scale = 1.0
# env_cfg.command_config.a = [1.5, 0.8, jp.pi]

In [None]:
from pprint import pprint

pprint(ppo_params)

In [7]:
# Setup wandb logging.
USE_WANDB = False

if USE_WANDB:
  wandb.init(project="mjxrl", config=env_cfg)
  wandb.config.update({
      "env_name": env_name,
  })

In [None]:
SUFFIX = None

FINETUNE_PATH = None

# Generate unique experiment name.
now = datetime.now()
timestamp = now.strftime("%Y%m%d-%H%M%S")
exp_name = f"{env_name}-{timestamp}"
if SUFFIX is not None:
  exp_name += f"-{SUFFIX}"
print(f"{exp_name}")

# Possibly restore from the latest checkpoint.
if FINETUNE_PATH is not None:
  FINETUNE_PATH = epath.Path(FINETUNE_PATH)
  latest_ckpts = list(FINETUNE_PATH.glob("*"))
  latest_ckpts = [ckpt for ckpt in latest_ckpts if ckpt.is_dir()]
  latest_ckpts.sort(key=lambda x: int(x.name))
  latest_ckpt = latest_ckpts[-1]
  restore_checkpoint_path = latest_ckpt
  print(f"Restoring from: {restore_checkpoint_path}")
else:
  restore_checkpoint_path = None

In [None]:
ckpt_path = epath.Path("checkpoints").resolve() / exp_name
ckpt_path.mkdir(parents=True, exist_ok=True)
print(f"{ckpt_path}")

with open(ckpt_path / "config.json", "w") as fp:
  json.dump(env_cfg.to_dict(), fp, indent=4)

In [None]:
x_data, y_data, y_dataerr = [], [], []
ep_length_mean, ep_length_std = [], []
times = [datetime.now()]


def progress(num_steps, metrics):
  # Log to wandb.
  if USE_WANDB:
    wandb.log(metrics, step=num_steps)

  # Plot.
  clear_output(wait=True)
  times.append(datetime.now())
  x_data.append(num_steps)
  y_data.append(metrics["eval/episode_reward"])
  y_dataerr.append(metrics["eval/episode_reward_std"])

  plt.xlim([0, ppo_params["num_timesteps"] * 1.25])
  plt.xlabel("# environment steps")
  plt.ylabel("reward per episode")
  plt.title(f"y={y_data[-1]:.3f}")
  plt.errorbar(x_data, y_data, yerr=y_dataerr, color="blue")

  display(plt.gcf())


def policy_params_fn(current_step, make_policy, params):
  del make_policy  # Unused.
  orbax_checkpointer = ocp.PyTreeCheckpointer()
  save_args = orbax_utils.save_args_from_target(params)
  path = ckpt_path / f"{current_step}"
  orbax_checkpointer.save(path, params, force=True, save_args=save_args)

training_params = dict(ppo_params)
del training_params["network_factory"]

train_fn = functools.partial(
  ppo.train,
  **training_params,
  network_factory=functools.partial(
      ppo_networks.make_ppo_networks,
      **ppo_params.network_factory
  ),
  restore_checkpoint_path=restore_checkpoint_path,
  progress_fn=progress,
  wrap_env_fn=wrapper.wrap_for_brax_training,
  policy_params_fn=policy_params_fn,
  randomization_fn=randomizer,
)

env = locomotion.load(env_name, config=env_cfg)
eval_env = locomotion.load(env_name, config=env_cfg)
make_inference_fn, params, _ = train_fn(environment=env, eval_env=eval_env)
if len(times) > 1:
  print(f"time to jit: {times[1] - times[0]}")
  print(f"time to train: {times[-1] - times[1]}")

# Make a final plot of reward and success vs WALLCLOCK time.
plt.figure()
plt.xlabel("wallclock time (s)")
plt.ylabel("reward per episode")
plt.title(f"y={y_data[-1]:.3f}")
plt.errorbar(
    [(t - times[0]).total_seconds() for t in times[:-1]],
    y_data,
    yerr=y_dataerr,
    color="blue",
)
plt.show()

In [11]:
inference_fn = make_inference_fn(params, deterministic=True)
jit_inference_fn = jax.jit(inference_fn)

In [12]:
# Save normalizer and policy params to the checkpoint dir.
normalizer_params, policy_params, value_params = params
with open(ckpt_path / "params.pkl", "wb") as f:
  data = {
    "normalizer_params": normalizer_params,
    "policy_params": policy_params,
    "value_params": value_params,
  }
  pickle.dump(data, f)

In [13]:
from mujoco_playground._src.gait import draw_joystick_command

# Enable perturbation in the eval env.
env_cfg = locomotion.get_default_config(env_name)
eval_env = locomotion.load(env_name, config=env_cfg)

jit_reset = jax.jit(eval_env.reset)
jit_step = jax.jit(eval_env.step)

In [None]:
rng = jax.random.PRNGKey(1)
rollout = []
modify_scene_fns = []

swing_peak = []
rewards = []
linvel = []
angvel = []
track = []
hand_positions = []
pelvis_position = []
foot_vel = []
rews = []
qvels = []

x_cmd = 0.0
y_cmd = 0.0
yaw_cmd = 0
command = jp.array([x_cmd, y_cmd, yaw_cmd])

power1 = []
power2 = []
qpos_cost = []

rews_ep = []

phase_dt = 2 * jp.pi * eval_env.dt * 1.5
phase = jp.array([0, jp.pi])

for j in range(1):
  print(f"episode {j}")
  state = jit_reset(rng)
  # state.info["command"] = command
  state.info["phase_dt"] = phase_dt
  state.info["phase"] = phase
  ep_rews = []
  for i in range(env_cfg.episode_length):
    # if i % 200 == 0:
    #   x_cmd = -x_cmd
    #   y_cmd = -y_cmd
    #   yaw_cmd = -yaw_cmd
    #   command = jp.array([x_cmd, y_cmd, yaw_cmd])
    #   print(command)
    act_rng, rng = jax.random.split(rng)
    ctrl, _ = jit_inference_fn(state.obs, act_rng)
    state = jit_step(state, ctrl)
    ep_rews.append(state.reward)
    if state.done:
      print("something bad happened")
      break
    state.info["command"] = command
    rews.append(
        {k: v for k, v in state.metrics.items() if k.startswith("reward/")}
    )
    rollout.append(state)
    swing_peak.append(state.info["swing_peak"])
    rewards.append(
        {k[7:]: v for k, v in state.metrics.items() if k.startswith("reward/")}
    )
    linvel.append(eval_env.get_local_linvel(state.data, "pelvis"))
    angvel.append(eval_env.get_gyro(state.data))
    track.append(
        eval_env._reward_tracking_lin_vel(
            state.info["command"], eval_env.get_local_linvel(state.data, "pelvis")
        )
    )

    # hand_positions.append(state.data.site_xpos[eval_env._hands_site_id])
    # pelvis_position.append(state.data.site_xpos[eval_env._pelvis_imu_site_id])

    feet_vel = state.data.sensordata[eval_env._foot_linvel_sensor_adr]
    vel_xy = feet_vel[..., :2]
    vel_norm = jp.sqrt(jp.linalg.norm(vel_xy, axis=-1))
    foot_vel.append(vel_norm)

    frc = state.data.actuator_force
    qvel = state.data.qvel[6:]
    power1.append(jp.sum(frc * qvel))
    power2.append(jp.sum(jp.abs(frc * qvel)))

    qvels.append(qvel)
    qpos_cost.append(jp.sum(jp.square(state.data.qpos[7:] - eval_env._default_pose)))

    xyz = np.array(state.data.xpos[eval_env.mj_model.body("torso_link").id])
    xyz += np.array([0, 0.0, 0])
    x_axis = state.data.xmat[eval_env._torso_body_id, 0]
    yaw = -np.arctan2(x_axis[1], x_axis[0])
    modify_scene_fns.append(
        functools.partial(
            draw_joystick_command,
            cmd=state.info["command"],
            xyz=xyz,
            theta=yaw,
            scl=np.linalg.norm(state.info["command"]),
        )
    )
  rews_ep.append(ep_rews)

In [None]:
power = jp.array(power1)
# Plot smoothed power.
power = jp.convolve(power, jp.ones(10) / 10, mode='valid')
plt.plot(power)
# Plot mean as a horizontal line.
plt.axhline(y=jp.mean(jp.array(power)), color='r', linestyle='-', label='Mean')
# Plot horizontal line at 0.0.
plt.axhline(y=0, color='k', linestyle='--')
plt.legend()
plt.show()
# Print min and max.
print(f"Min power: {jp.min(power)}, Max power: {jp.max(power)}", "Mean power: ", jp.mean(power))

In [None]:
render_every = 1
fps = 1.0 / eval_env.dt / render_every
print(f"fps: {fps}")
traj = rollout[::render_every]
mod_fns = modify_scene_fns[::render_every]

scene_option = mujoco.MjvOption()
scene_option.geomgroup[2] = True
scene_option.geomgroup[3] = False
scene_option.flags[mujoco.mjtVisFlag.mjVIS_CONTACTPOINT] = True
scene_option.flags[mujoco.mjtVisFlag.mjVIS_TRANSPARENT] = False
scene_option.flags[mujoco.mjtVisFlag.mjVIS_PERTFORCE] = False

frames = eval_env.render(
    traj,
    camera="track",
    scene_option=scene_option,
    width=640*2,
    height=480,
    modify_scene_fns=mod_fns,
)
media.show_video(frames, fps=fps, loop=False)

In [None]:
angvel_reward = [r["tracking_ang_vel"] for r in rewards]
angvel_reward = jp.array(angvel_reward)
angvel_reward = jp.convolve(angvel_reward, jp.ones(10) / 10, mode="same")

linvel_reward = [r["tracking_lin_vel"] for r in rewards]
linvel_reward = jp.array(linvel_reward)
linvel_reward = jp.convolve(linvel_reward, jp.ones(10) / 10, mode="same")

plt.plot(angvel_reward, label="angvel")
plt.plot(linvel_reward, label="linvel")
# Plot the max as a horizontal line.
plt.axhline(env_cfg.reward_config.scales.tracking_ang_vel, color="k", linestyle="--")
plt.axhline(env_cfg.reward_config.scales.tracking_lin_vel, color="k", linestyle="--")
plt.legend()

In [None]:
# Plot each foot in a 2x2 grid.
swing_peak = jp.array(swing_peak)
names = ["r_foot", "l_foot",]
colors = ["r", "g"]
fig, axs = plt.subplots(1, 2)
for i, ax in enumerate(axs.flat):
  ax.plot(swing_peak[:, i], color=colors[i])
  ax.set_ylim([0, env_cfg.reward_config.max_foot_height * 1.25])
  ax.axhline(env_cfg.reward_config.max_foot_height, color="k", linestyle="--")
  ax.set_title(names[i])
  ax.set_xlabel("time")
  ax.set_ylabel("height")
plt.tight_layout()
plt.show()

In [None]:
linvel_x = jp.array(linvel)[:, 0]
linvel_y = jp.array(linvel)[:, 1]
angvel_yaw = jp.array(angvel)[:, 2]

# Plot whether velocity is within the command range.
linvel_x = jp.convolve(linvel_x, jp.ones(10) / 10, mode="same")
linvel_y = jp.convolve(linvel_y, jp.ones(10) / 10, mode="same")
angvel_yaw = jp.convolve(angvel_yaw, jp.ones(10) / 10, mode="same")

fig, axes = plt.subplots(3, 1, figsize=(10, 10))
axes[0].plot(linvel_x)
axes[1].plot(linvel_y)
axes[2].plot(angvel_yaw)

# Draw limits as horizontal lines in black.
for i, ax in enumerate(axes):
    ax.axhline(-env_cfg.command_config.a[i], color="k", linestyle="--")
    ax.axhline(env_cfg.command_config.a[i], color="k", linestyle="--")

# Draw the command as a red line.
for i, ax in enumerate(axes):
    ax.axhline(state.info["command"][i], color="r", linestyle="--")

labels = ["dx", "dy", "dyaw"]
for i, ax in enumerate(axes):
  ax.set_ylabel(labels[i])