In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os

xla_flags = os.environ.get("XLA_FLAGS", "")
xla_flags += " --xla_gpu_triton_gemm_any=True"
os.environ["XLA_FLAGS"] = xla_flags
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
os.environ["MUJOCO_GL"] = "egl"

In [None]:
import functools
import json
from datetime import datetime

import jax
import matplotlib.pyplot as plt
import mediapy as media
import mujoco
import wandb
from brax.training.agents.ppo import networks as ppo_networks
from brax.training.agents.ppo import train as ppo
from etils import epath
from flax.training import orbax_utils
from IPython.display import clear_output, display
from orbax import checkpoint as ocp
from tqdm import tqdm
import pickle

from mujoco_playground import locomotion, wrapper
from mujoco_playground.config import locomotion_params


# Enable persistent compilation cache.
jax.config.update("jax_compilation_cache_dir", "/tmp/jax_cache")
jax.config.update("jax_persistent_cache_min_entry_size_bytes", -1)
jax.config.update("jax_persistent_cache_min_compile_time_secs", 0)

In [None]:
env_name = "Go1Footstand"
env_cfg = locomotion.get_default_config(env_name)
randomizer = locomotion.get_domain_randomizer(env_name)
ppo_params = locomotion_params.brax_ppo_config(env_name)

In [None]:
# env_cfg.reward_config.scales.dof_acc = 0.0
# env_cfg.reward_config.scales.energy = 0.0
# env_cfg.reward_config.scales.stay_still = 0.0

In [None]:
from pprint import pprint

pprint(ppo_params)

In [None]:
# Setup wandb logging.
USE_WANDB = False

if USE_WANDB:
  wandb.init(project="mjxrl", config=env_cfg)
  wandb.config.update({
      "env_name": env_name,
  })

In [None]:
SUFFIX = None
FINETUNE_PATH = None

# Generate unique experiment name.
now = datetime.now()
timestamp = now.strftime("%Y%m%d-%H%M%S")
exp_name = f"{env_name}-{timestamp}"
if SUFFIX is not None:
  exp_name += f"-{SUFFIX}"
print(f"Experiment name: {exp_name}")

# Possibly restore from the latest checkpoint.
if FINETUNE_PATH is not None:
  FINETUNE_PATH = epath.Path(FINETUNE_PATH)
  latest_ckpts = list(FINETUNE_PATH.glob("*"))
  latest_ckpts = [ckpt for ckpt in latest_ckpts if ckpt.is_dir()]
  latest_ckpts.sort(key=lambda x: int(x.name))
  latest_ckpt = latest_ckpts[-1]
  restore_checkpoint_path = latest_ckpt
  print(f"Restoring from: {restore_checkpoint_path}")
else:
  restore_checkpoint_path = None

In [None]:
ckpt_path = epath.Path("checkpoints").resolve() / exp_name
ckpt_path.mkdir(parents=True, exist_ok=True)
print(f"{ckpt_path}")

with open(ckpt_path / "config.json", "w") as fp:
  json.dump(env_cfg.to_dict(), fp, indent=4)

In [None]:
x_data, y_data, y_dataerr = [], [], []
times = [datetime.now()]


def progress(num_steps, metrics):
  # Log to wandb.
  if USE_WANDB:
    wandb.log(metrics, step=num_steps)

  # Plot.
  clear_output(wait=True)
  times.append(datetime.now())
  x_data.append(num_steps)
  y_data.append(metrics["eval/episode_reward"])
  y_dataerr.append(metrics["eval/episode_reward_std"])

  plt.xlim([0, ppo_params.num_timesteps * 1.25])
  plt.ylim([0, 40])
  plt.xlabel("# environment steps")
  plt.ylabel("reward per episode")
  plt.title(f"y={y_data[-1]:.3f}")
  plt.errorbar(x_data, y_data, yerr=y_dataerr, color="blue")

  display(plt.gcf())


def policy_params_fn(current_step, make_policy, params):
  del make_policy  # Unused.
  orbax_checkpointer = ocp.PyTreeCheckpointer()
  save_args = orbax_utils.save_args_from_target(params)
  path = ckpt_path / f"{current_step}"
  orbax_checkpointer.save(path, params, force=True, save_args=save_args)

training_params = dict(ppo_params)
del training_params["network_factory"]

train_fn = functools.partial(
  ppo.train,
  **training_params,
  network_factory=functools.partial(
      ppo_networks.make_ppo_networks,
      **ppo_params.network_factory
  ),
  restore_checkpoint_path=restore_checkpoint_path,
  progress_fn=progress,
  wrap_env_fn=wrapper.wrap_for_brax_training,
  policy_params_fn=policy_params_fn,
  randomization_fn=randomizer,
)

env = locomotion.load(env_name, config=env_cfg)
eval_env = locomotion.load(env_name, config=env_cfg)
make_inference_fn, params, _ = train_fn(environment=env, eval_env=eval_env)
if len(times) > 1:
  print(f"time to jit: {times[1] - times[0]}")
  print(f"time to train: {times[-1] - times[1]}")

# Make a final plot of reward and success vs WALLCLOCK time.
plt.figure()
plt.ylim([0, 40])
plt.xlabel("wallclock time (s)")
plt.ylabel("reward per episode")
plt.title(f"y={y_data[-1]:.3f}")
plt.errorbar(
    [(t - times[0]).total_seconds() for t in times[:-1] ],
    y_data,
    yerr=y_dataerr,
    color="blue",
)
plt.show()

In [None]:
inference_fn = make_inference_fn(params, deterministic=True)
jit_inference_fn = jax.jit(inference_fn)
# env_cfg.init_from_crouch = 0.0
eval_env = locomotion.load(env_name, config=env_cfg)
jit_reset = jax.jit(eval_env.reset)
jit_step = jax.jit(eval_env.step)

In [None]:
# Save normalizer and policy params to the checkpoint dir.
normalizer_params, policy_params, value_params = params
with open(ckpt_path / "params.pkl", "wb") as f:
  data = {
    "normalizer_params": normalizer_params,
    "policy_params": policy_params,
    "value_params": value_params,
  }
  pickle.dump(data, f)

In [None]:
rng = jax.random.PRNGKey(123)
import jax.numpy as jp
from scipy.spatial.transform import Rotation
rollout = []
rewards = []
height = []
linvel = []
angvel = []
accel = []
z_rot = []
power1 = []
power2 = []
gravity = []

for _ in tqdm(range(1)):
  rng, reset_rng = jax.random.split(rng)
  state = jit_reset(reset_rng)
  for i in range(env_cfg.episode_length):
    act_rng, rng = jax.random.split(rng)
    ctrl, _ = jit_inference_fn(state.obs, act_rng)
    state = jit_step(state, ctrl)
    rollout.append(state)
    height.append(state.data.site_xpos[env._imu_site_id][2])
    linvel.append(env.get_local_linvel(state.data))
    accel.append(state.data.qacc[:3])
    angvel.append(env.get_global_angvel(state.data))
    mat = state.data.site_xmat[env._imu_site_id]
    r = Rotation.from_matrix(mat)
    yaw = r.as_euler('xyz')[2]
    z_rot.append(yaw)

    frc = state.data.actuator_force
    qvel = state.data.qvel[6:]
    power1.append(jp.sum(frc * qvel))
    power2.append(jp.sum(jp.abs(frc * qvel)))

    gravity.append(env.get_gravity(state.data))

    rewards.append(
        {k[7:]: v for k, v in state.metrics.items() if k.startswith("reward/")}
    )
    if state.done:
      print("oops")
      break

In [None]:
power = jp.array(power1)
# Plot smoothed power.
power = jp.convolve(power, jp.ones(10) / 10, mode='valid')
plt.plot(power)
# Plot mean as a horizontal line.
plt.axhline(y=jp.mean(jp.array(power)), color='r', linestyle='-', label='Mean')
# Plot horizontal line at 0.0.
plt.axhline(y=0, color='k', linestyle='--')
plt.legend()
plt.show()
# Print min and max.
print(f"Min power: {jp.min(power)}, Max power: {jp.max(power)}")

In [None]:
plt.plot([r['height'] for r in rewards], label='r_height')
plt.plot([r['orientation'] for r in rewards], label='r_orientation')
plt.legend()

In [None]:
render_every = 1
fps = 1.0 / eval_env.dt / render_every
traj = rollout[::render_every]

scene_option = mujoco.MjvOption()
scene_option.geomgroup[2] = True
scene_option.geomgroup[3] = False
scene_option.flags[mujoco.mjtVisFlag.mjVIS_CONTACTPOINT] = True
scene_option.flags[mujoco.mjtVisFlag.mjVIS_CONTACTFORCE] = False
scene_option.flags[mujoco.mjtVisFlag.mjVIS_TRANSPARENT] = False

frames = eval_env.render(
    traj, camera="track", scene_option=scene_option, height=480, width=640,
)
media.show_video(frames, fps=fps, loop=False)
# media.write_video(f"{env_name}.mp4", frames, fps=fps, qp=18)