In [1]:
%load_ext autoreload

%autoreload 2

In [2]:
import os

xla_flags = os.environ.get("XLA_FLAGS", "")
xla_flags += " --xla_gpu_triton_gemm_any=True"
os.environ["XLA_FLAGS"] = xla_flags
os.environ["MUJOCO_GL"] = "egl"

In [3]:
import functools
import re
from datetime import datetime
from typing import Tuple

import jax
import jax.numpy as jp
import matplotlib.pyplot as plt
import mediapy as media
import numpy as np
from brax.training.agents.sac import train as sac
from IPython.display import clear_output, display

from mujoco_playground import dm_control_suite as suite
from mujoco_playground import wrapper

# Enable persistent compilation cache.
jax.config.update("jax_compilation_cache_dir", "/tmp/jax_cache")
jax.config.update("jax_persistent_cache_min_entry_size_bytes", -1)
jax.config.update("jax_persistent_cache_min_compile_time_secs", 0)

In [4]:
CAMERAS = {
    "AcrobotSwingup": "fixed",
    "BallInCup": "cam0",
    "CartpoleSwingup": "fixed",
    "CartpoleBalance": "fixed",
    "CheetahRun": "side",
    "HumanoidRun": "side",
    "HumanoidStand": "side",
    "HumanoidWalk": "side",
    "PointMass": "cam0",
    "WalkerRun": "side",
    "WalkerWalk": "side",
    "WalkerStand": "side",
    "HopperHop": "cam0",
    "HopperStand": "cam0",
    "FishSwim": "fixed_top",
    "ReacherEasy": "fixed",
    "ReacherHard": "fixed",
    "Swimmer6": "tracking1",
    "BarkourJoystick": "track",
    "PendulumSwingup": "fixed",
    "FingerSpin": "cam0",
    "FingerTurnEasy": "cam0",
    "FingerTurnHard": "cam0",
    "DogStand": "y-axis",
}

DISCOUNTS = {
    "HopperHop": 0.99,
    "CartpoleBalance": 0.99,
    "CartpoleSwingup": 0.99,
    "HumanoidRun": 0.95,
    "AcrobotSwingup": 0.99,
}

In [None]:
env_name = "WalkerWalk"
env_cfg = suite.get_default_config(env_name)
env = suite.load(env_name, config=env_cfg)
print(env_cfg)

## Unit Test Env

In [None]:
jit_reset = jax.jit(env.reset)
jit_step = jax.jit(env.step)

In [None]:
key = jax.random.PRNGKey(0)
state = jit_reset(key)

action = jp.zeros((env.action_size,))
state = jit_step(state, action)

In [None]:
key = jax.random.PRNGKey(12345)
key, key_reset = jax.random.split(key)
state = jit_reset(key_reset)
states = [state]
actions = []
f = 0.5
for i in range(env_cfg.episode_length):
  key, key_action = jax.random.split(key)
  action = []
  for j in range(env.action_size):
    action.append(
        jp.sin(
            state.data.time * 2 * jp.pi * f + j * 2 * jp.pi / env.action_size
        )
    )
  action = jp.array(action)
  actions.append(action)
  state = jit_step(state, action)
  states.append(state)
frames = env.render(states, camera=CAMERAS[env_name])
media.show_video(frames, fps=1.0 / env.dt)

In [None]:
from dm_control import suite as dmc_suite


def env_name_to_domain_and_task_name(env_name: str) -> Tuple[str, str]:
  domain_name, *task_name = re.split("(?<=.)(?=[A-Z])", env_name)
  if len(task_name) == 1:
    task_name = task_name[0]
  else:
    task_name = "_".join(task_name)
  return domain_name.lower(), task_name.lower()


domain_name, task_name = env_name_to_domain_and_task_name(env_name)
dmc_env = dmc_suite.load(domain_name, task_name, task_kwargs={"random": 0})

action_spec = dmc_env.action_spec()

frames = []
timestep = dmc_env.reset()
f = 0.5
while not timestep.last():
  action = []
  for i in range(action_spec.shape[0]):
    action.append(
        np.sin(
            dmc_env.physics.time() * 2 * np.pi * f
            + i * 2 * np.pi / action_spec.shape[0]
        )
    )
  action = np.array(action)
  timestep = dmc_env.step(action)
  frames.append(dmc_env.physics.render(camera_id=CAMERAS[env_name]))
media.show_video(frames, fps=1.0 / dmc_env.control_timestep())

In [None]:
key = jax.random.PRNGKey(0)
frames = []
for _ in range(5):
  key, key_reset = jax.random.split(key)
  state = jit_reset(key_reset)
  frames.append(env.render(state, camera=CAMERAS[env_name]))
media.show_image(np.hstack(frames))

In [None]:
frames = []
for _ in range(5):
  frames.append(dmc_env.physics.render(camera_id=CAMERAS[env_name]))
media.show_image(np.hstack(frames))

## Train

In [6]:
kwargs = {
    "num_timesteps": 4_000_000,
    "num_evals": 10,
    "reward_scaling": 1.0,
    "episode_length": env_cfg.episode_length,
    "normalize_observations": True,
    "action_repeat": 1,
    "discounting": DISCOUNTS.get(env_name, 0.95),
    "learning_rate": 1e-3,
    "num_envs": 128,
    "batch_size": 512,
    "grad_updates_per_step": 8,
    "max_replay_size": 1048576 * 4,
    "min_replay_size": 8192,
    "seed": 0,
    "max_devices_per_host": 1,
}

x_data, y_data, y_dataerr = [], [], []
times = [datetime.now()]


def progress(num_steps, metrics):
  clear_output(wait=True)

  times.append(datetime.now())
  x_data.append(num_steps)
  y_data.append(metrics["eval/episode_reward"])
  y_dataerr.append(metrics["eval/episode_reward_std"])

  plt.xlim([0, kwargs["num_timesteps"] * 1.25])
  plt.ylim([0, 1100])
  plt.xlabel("# environment steps")
  plt.ylabel("reward per episode")
  plt.title(f"y={y_data[-1]:.3f}")
  plt.errorbar(x_data, y_data, yerr=y_dataerr, color="blue")

  display(plt.gcf())


train_fn = functools.partial(sac.train, **kwargs, progress_fn=progress, wrap_env_fn=wrapper.wrap_for_brax_training)

In [None]:
make_inference_fn, params, metrics = train_fn(environment=env)
print(f"time to jit: {times[1] - times[0]}")
print(f"time to train: {times[-1] - times[1]}")

In [None]:
for k, v in metrics.items():
  print(f"{k}: {v}")

In [9]:
jit_reset = jax.jit(env.reset)
jit_step = jax.jit(env.step)
jit_inference_fn = jax.jit(make_inference_fn(params, deterministic=True))

In [10]:
rng = jax.random.PRNGKey(2)
rollout = []
n_episodes = 1

for _ in range(n_episodes):
  state = jit_reset(rng)
  rollout.append(state)
  for i in range(env_cfg.episode_length):
    act_rng, rng = jax.random.split(rng)
    ctrl, _ = jit_inference_fn(state.obs, act_rng)
    state = jit_step(state, ctrl)
    rollout.append(state)

In [None]:
render_every = 2
frames = env.render(rollout[::render_every], camera=CAMERAS[env_name])
rewards = [s.reward for s in rollout]
media.show_video(frames, fps=1.0 / env.dt / render_every)
media.write_video(f"{env_name}.mp4", frames, fps=1.0 / env.dt / render_every)

plt.plot(np.convolve(rewards, np.ones(100) / 100, mode="valid"))
plt.xlabel("time step")
plt.ylabel("reward")
plt.show()