In [1]:
import os

xla_flags = os.environ.get("XLA_FLAGS", "")
xla_flags += " --xla_gpu_triton_gemm_any=True"
os.environ["XLA_FLAGS"] = xla_flags
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
os.environ["MUJOCO_GL"] = "egl"

In [2]:
import functools
import json
import pickle
from datetime import datetime

import jax
import mediapy as media
import mujoco
import numpy as np
from brax.training.agents.ppo import networks as ppo_networks
from brax.training.agents.ppo import train as ppo
from etils import epath
from flax.training import orbax_utils
from orbax import checkpoint as ocp

from mujoco_playground import registry, wrapper
from mujoco_playground.config import locomotion_params
from mujoco_playground.experimental.utils.plotting import TrainingPlotter

# Enable persistent compilation cache.
jax.config.update("jax_compilation_cache_dir", "/tmp/jax_cache")
jax.config.update("jax_persistent_cache_min_entry_size_bytes", -1)
jax.config.update("jax_persistent_cache_min_compile_time_secs", 0)

In [None]:
env_name = "ApolloJoystickFlatTerrain"
env_cfg = registry.get_default_config(env_name)
randomizer = registry.get_domain_randomizer(env_name)
ppo_params = locomotion_params.brax_ppo_config(env_name)

In [5]:
env_cfg.reward_config.scales.energy = -1e-5
env_cfg.reward_config.scales.action_rate = -1e-3
env_cfg.reward_config.scales.torques = 0.0

env_cfg.noise_config.level = 0.0  # 1.0
env_cfg.push_config.enable = True
env_cfg.push_config.magnitude_range = [0.1, 2.0]

ppo_params.num_timesteps = 150_000_000
ppo_params.num_evals = 15

In [None]:
SUFFIX = None
FINETUNE_PATH = None

# Generate unique experiment name.
now = datetime.now()
timestamp = now.strftime("%Y%m%d-%H%M%S")
exp_name = f"{env_name}-{timestamp}"
if SUFFIX is not None:
  exp_name += f"-{SUFFIX}"
print(f"{exp_name}")

# Possibly restore from the latest checkpoint.
if FINETUNE_PATH is not None:
  FINETUNE_PATH = epath.Path(FINETUNE_PATH)
  latest_ckpts = list(FINETUNE_PATH.glob("*"))
  latest_ckpts = [ckpt for ckpt in latest_ckpts if ckpt.is_dir()]
  latest_ckpts.sort(key=lambda x: int(x.name))
  latest_ckpt = latest_ckpts[-1]
  restore_checkpoint_path = latest_ckpt
  print(f"Restoring from: {restore_checkpoint_path}")
else:
  restore_checkpoint_path = None

In [None]:
ckpt_path = epath.Path("checkpoints").resolve() / exp_name
ckpt_path.mkdir(parents=True, exist_ok=True)
print(f"{ckpt_path}")

with open(ckpt_path / "config.json", "w") as fp:
  json.dump(env_cfg.to_json(), fp, indent=4)

In [None]:
plotter = TrainingPlotter(max_timesteps=ppo_params.num_timesteps, figsize=(15, 10))


def progress(num_steps, metrics):
  plotter.update(num_steps, metrics)


def policy_params_fn(current_step, make_policy, params):
  del make_policy  # Unused.
  orbax_checkpointer = ocp.PyTreeCheckpointer()
  save_args = orbax_utils.save_args_from_target(params)
  path = ckpt_path / f"{current_step}"
  orbax_checkpointer.save(path, params, force=True, save_args=save_args)


training_params = dict(ppo_params)
del training_params["network_factory"]

train_fn = functools.partial(
  ppo.train,
  **training_params,
  network_factory=functools.partial(
    ppo_networks.make_ppo_networks, **ppo_params.network_factory
  ),
  restore_checkpoint_path=restore_checkpoint_path,
  progress_fn=progress,
  wrap_env_fn=wrapper.wrap_for_brax_training,
  policy_params_fn=policy_params_fn,
  randomization_fn=randomizer,
)

env = registry.load(env_name, config=env_cfg)
eval_env = registry.load(env_name, config=env_cfg)
make_inference_fn, params, _ = train_fn(environment=env, eval_env=eval_env)
if len(plotter.times) > 1:
  print(f"time to jit: {plotter.times[1] - plotter.times[0]}")
  print(f"time to train: {plotter.times[-1] - plotter.times[1]}")

In [10]:
inference_fn = make_inference_fn(params, deterministic=True)
jit_inference_fn = jax.jit(inference_fn)

In [11]:
# Save normalizer and policy params to the checkpoint dir.
normalizer_params, policy_params, value_params = params
with open(ckpt_path / "params.pkl", "wb") as f:
  data = {
    "normalizer_params": normalizer_params,
    "policy_params": policy_params,
    "value_params": value_params,
  }
  pickle.dump(data, f)

In [12]:
from mujoco_playground._src.gait import draw_joystick_command

eval_env = registry.load(env_name, config=env_cfg)
jit_reset = jax.jit(eval_env.reset)
jit_step = jax.jit(eval_env.step)

In [None]:
rng = jax.random.PRNGKey(12345)
rollout = []
modify_scene_fns = []
state = jit_reset(rng)
for i in range(env_cfg.episode_length):
  act_rng, rng = jax.random.split(rng)
  ctrl, _ = jit_inference_fn(state.obs, act_rng)
  state = jit_step(state, ctrl)
  if state.done:
    print("something bad happened")
    break
  rollout.append(state)
  xyz = np.array(state.data.xpos[eval_env.mj_model.body("torso_link").id])
  xyz += np.array([0, 0.0, 0])
  x_axis = state.data.xmat[eval_env._torso_body_id, 0]
  yaw = -np.arctan2(x_axis[1], x_axis[0])
  modify_scene_fns.append(
    functools.partial(
      draw_joystick_command,
      cmd=state.info["command"],
      xyz=xyz,
      theta=yaw,
      scl=np.linalg.norm(state.info["command"]),
    )
  )

In [None]:
render_every = 2
fps = 1.0 / eval_env.dt / render_every
print(f"fps: {fps}")
traj = rollout[::render_every]
mod_fns = modify_scene_fns[::render_every]

scene_option = mujoco.MjvOption()
scene_option.geomgroup[2] = True
scene_option.geomgroup[3] = False
scene_option.flags[mujoco.mjtVisFlag.mjVIS_CONTACTPOINT] = True
scene_option.flags[mujoco.mjtVisFlag.mjVIS_CONTACTFORCE] = False
scene_option.flags[mujoco.mjtVisFlag.mjVIS_TRANSPARENT] = False
scene_option.flags[mujoco.mjtVisFlag.mjVIS_PERTFORCE] = False

frames = eval_env.render(
  traj,
  camera="track",
  scene_option=scene_option,
  width=640,
  height=480,
  modify_scene_fns=mod_fns,
)
media.show_video(frames, fps=fps, loop=False)