In [1]:
%load_ext autoreload

%autoreload 2

In [2]:
import os

xla_flags = os.environ.get("XLA_FLAGS", "")
xla_flags += " --xla_gpu_triton_gemm_any=True"
os.environ["XLA_FLAGS"] = xla_flags
os.environ["MUJOCO_GL"] = "egl"

In [12]:
import functools
import json
from datetime import datetime

import jax
import matplotlib.pyplot as plt
import mediapy as media
from brax.training.agents.ppo import networks as ppo_networks
from brax.training.agents.ppo import train as ppo
from IPython.display import clear_output, display
import numpy as np

from mujoco_playground import BraxEnvWrapper, manipulation

# Enable persistent compilation cache.
jax.config.update("jax_compilation_cache_dir", "/tmp/jax_cache")
jax.config.update("jax_persistent_cache_min_entry_size_bytes", -1)
jax.config.update("jax_persistent_cache_min_compile_time_secs", 0)

In [9]:
env_name = "AlohaSinglePegInsertion"
env_cfg = manipulation.get_default_config(env_name)
env = manipulation.load(env_name, config=env_cfg)

In [5]:
from mujoco_playground.learning import manipulation_params

ppo_params = manipulation_params.brax_ppo_config(env_name)

In [7]:
x_data, y_data, y_dataerr = [], [], []
times = [datetime.now()]


def progress(num_steps, metrics):
  clear_output(wait=True)

  times.append(datetime.now())
  x_data.append(num_steps)
  y_data.append(metrics["eval/episode_reward"])
  y_dataerr.append(metrics["eval/episode_reward_std"])

  plt.xlim([0, ppo_params["num_timesteps"] * 1.25])
  plt.ylim([0, 15_000])
  plt.xlabel("# environment steps")
  plt.ylabel("reward per episode")
  plt.title(f"y={y_data[-1]:.3f}")
  plt.errorbar(x_data, y_data, yerr=y_dataerr, color="blue")

  display(plt.gcf())


training_params = dict(ppo_params)
del training_params["network_factory"]
train_fn = functools.partial(ppo.train, **training_params, progress_fn=progress)

In [None]:
network_factory = functools.partial(
    ppo_networks.make_ppo_networks,
    policy_hidden_layer_sizes=ppo_params.network_factory.policy_hidden_layer_sizes,
)
make_inference_fn, params, metrics = train_fn(environment=BraxEnvWrapper(env))
print(f"time to jit: {times[1] - times[0]}")
print(f"time to train: {times[-1] - times[1]}")

In [11]:
jit_reset = jax.jit(env.reset)
jit_step = jax.jit(env.step)
jit_inference_fn = jax.jit(make_inference_fn(params, deterministic=True))

In [None]:
# rng = jax.random.PRNGKey(42)
# rollout = []
# n_episodes = 1

# for _ in range(n_episodes):
#   state = jit_reset(rng)
#   rollout.append(state)
#   for i in range(env_cfg.episode_length):
#     act_rng, rng = jax.random.split(rng)
#     ctrl, _ = jit_inference_fn(state.obs, act_rng)
#     state = jit_step(state, ctrl)
#     rollout.append(state)

render_every = 1
frames = env.render(
    rollout[::render_every],
    camera="teleoperator_pov",
    height=480 * 2,
    width=640 * 2,
)
rewards = [s.reward for s in rollout]
media.show_video(frames, fps=1.0 / env.dt / render_every)

plt.plot(np.convolve(rewards, np.ones(100) / 100, mode="valid"))
plt.xlabel("time step")
plt.ylabel("reward")
plt.show()