In [None]:
%load_ext autoreload

%autoreload 2

In [None]:
import os

xla_flags = os.environ.get("XLA_FLAGS", "")
xla_flags += " --xla_gpu_triton_gemm_any=True"
os.environ["XLA_FLAGS"] = xla_flags
os.environ["MUJOCO_GL"] = "egl"

In [None]:
import functools
import json
from datetime import datetime

import jax
import matplotlib.pyplot as plt
import mediapy as media
from brax.training.agents.ppo import networks as ppo_networks
from brax.training.agents.ppo import train as ppo
from etils import epath
from flax.training import orbax_utils
from IPython.display import clear_output, display
from orbax import checkpoint as ocp

from mujoco_playground import wrapper, manipulation

# Enable persistent compilation cache.
jax.config.update("jax_compilation_cache_dir", "/tmp/jax_cache")
jax.config.update("jax_persistent_cache_min_entry_size_bytes", -1)
jax.config.update("jax_persistent_cache_min_compile_time_secs", 0)

In [None]:
env_name = "PandaOpenCabinet"
env_cfg = manipulation.get_default_config(env_name)

In [None]:
SUFFIX = None
FINETUNE_PATH = None

# Generate unique experiment name.
now = datetime.now()
timestamp = now.strftime("%Y%m%d-%H%M%S")
exp_name = f"{env_name}/{timestamp}"
if SUFFIX is not None:
  exp_name += f"-{SUFFIX}"
print(f"Experiment name: {exp_name}")

In [None]:
make_networks_factory = functools.partial(
    ppo_networks.make_ppo_networks, policy_hidden_layer_sizes=(32, 32, 32, 32)
)

In [None]:
ckpt_path = epath.Path("checkpoints").resolve() / exp_name
ckpt_path.mkdir(parents=True, exist_ok=True)
print(f"Checkpoint path: {ckpt_path}")

with open(ckpt_path / "config.json", "w") as fp:
  json.dump(env_cfg.to_dict(), fp, indent=4)


def policy_params_fn(current_step, make_policy, params):
  orbax_checkpointer = ocp.PyTreeCheckpointer()
  save_args = orbax_utils.save_args_from_target(params)
  path = ckpt_path / f"{current_step}"
  orbax_checkpointer.save(path, params, force=True, save_args=save_args)

from mujoco_playground.config import manipulation_params

_train_params = manipulation_params.brax_ppo_config(env_name)
train_params = dict(_train_params)
train_params['seed'] = 1
del train_params["network_factory"]

train_fn = functools.partial(
  ppo.train,
  **dict(train_params),
  network_factory=functools.partial(
    ppo_networks.make_ppo_networks,
    policy_hidden_layer_sizes=_train_params.network_factory.policy_hidden_layer_sizes
  ))


In [None]:
x_data, y_data, y_dataerr = [], [], []
times = [datetime.now()]


def progress(num_steps, metrics):
  # Plot.
  clear_output(wait=True)
  times.append(datetime.now())
  x_data.append(num_steps)
  y_data.append(metrics["eval/episode_reward"])
  y_dataerr.append(metrics["eval/episode_reward_std"])

  # Performance.
  if len(x_data) >= 2:
    num = x_data[-1] - x_data[-2]
    denom = (times[-1] - times[-2]).total_seconds()
    fps = num / denom
    print(f"Training at {fps} FPS")

  plt.xlim([0, train_fn.keywords["num_timesteps"] * 1.25])
  # plt.ylim([0, YLIM[env_name]])
  plt.xlabel("# environment steps")
  plt.ylabel("reward per episode")
  plt.title(f"y={y_data[-1]:.3f}")
  plt.errorbar(x_data, y_data, yerr=y_dataerr, color="blue")

  display(plt.gcf())


env = manipulation.load(env_name, config=env_cfg)
make_inference_fn, params, _ = train_fn(
    wrap_env=False,
    environment=wrapper.wrap_for_brax_training(env,
                                               episode_length=env_cfg.episode_length,
                                               action_repeat=env_cfg.action_repeat),
                                               progress_fn=progress
)
print(f"time to jit: {times[1] - times[0]}")
print(f"time to train: {times[-1] - times[1]}")

In [None]:
env = manipulation.load(env_name, config=env_cfg)

jit_reset = jax.jit(env.reset)
jit_step = jax.jit(env.step)
jit_inference_fn = jax.jit(make_inference_fn(params, deterministic=True))

In [None]:
key = jax.random.PRNGKey(5)
key, key_reset = jax.random.split(key)
state = jit_reset(key_reset)
states = [state]

render_every = 2  # Policy is 50 FPS

for i in range(125):
  act_rng, key = jax.random.split(key)
  ctrl, _ = jit_inference_fn(state.obs, act_rng)
  state = jit_step(state, ctrl)
  if i % render_every == 0:
    states.append(state)



In [None]:
media.show_video(
    env.render(states, height=480, width=640),
    fps=1.0 / env.dt / render_every,
)