# Training Apptronik Apollo using MuJoCo Warp

In [None]:
# you may need some extra deps for this colab:
# pip install "jax[cuda12_local]"
# pip install playground
# pip install matplotlib
# if you want to run this on your local machine you can do like so:
# pip install jupyter
# jupyter notebook --NotebookApp.allow_origin='https://colab.research.google.com' --port=8888 --NotebookApp.port_retries=0 --no-browser

import dataclasses
import datetime
import functools
import os
import time
from typing import Any, Dict, Optional, Union

import jax
import mediapy as media
import mujoco
import numpy as np
import warp as wp
from etils import epath
from jax import numpy as jp
from ml_collections import config_dict
from mujoco import mjx
from mujoco_playground._src import mjx_env
from mujoco_playground._src import reward
from mujoco_playground._src.dm_control_suite import common
from warp.jax_experimental.ffi import jax_callable

import mujoco_warp as mjwarp

# this ensures JAX embeds Warp kernels into its own computation graph:
os.environ["XLA_FLAGS"] = "--xla_gpu_graph_min_graph_size=1"

In [None]:
# We'll grab the Apptronik model from MuJoCo Menagerie, then remove some
# MJX-specific changes that exist in the XML that MJWarp doesn't need
# (such as explicit contacts, really tight ls_iterations etc.)

mjx_env.ensure_menagerie_exists()

contrib_xml_dir = epath.resource_path("mujoco_warp").parent / "contrib/xml"
apptronik_dir = mjx_env.EXTERNAL_DEPS_PATH / "mujoco_menagerie/apptronik_apollo/"

! cp {contrib_xml_dir / 'apptronik_apollo.xml'} {apptronik_dir}
! cp {contrib_xml_dir / 'scene.xml'} {apptronik_dir}

In [None]:
# MJWarp is not yet fully conncected to JAX.  For now we will use the function
# `mjwarp_step` below, which we wrap in Warp's handy `jax_callable` function
# that converts Warp kernels to JAX operations.
#
# After we build a proper JAX wrapper for MJWarp, this code will disappear.

NWORLD = 8192
NCONMAX = 81920
NJMAX = NCONMAX * 4

xml_path = apptronik_dir / "scene.xml"
mjm = mujoco.MjModel.from_xml_path(xml_path.as_posix())
mjm.opt.iterations = 5
mjm.opt.ls_iterations = 10
mjd = mujoco.MjData(mjm)
mujoco.mj_resetDataKeyframe(mjm, mjd, 0)
mujoco.mj_forward(mjm, mjd)
m = mjwarp.put_model(mjm)
d = mjwarp.put_data(mjm, mjd, nworld=NWORLD, nconmax=NCONMAX, njmax=NJMAX)


def mjwarp_step(
  ctrl: wp.array(dtype=wp.float32, ndim=2),
  qpos_in: wp.array(dtype=wp.float32, ndim=2),
  qvel_in: wp.array(dtype=wp.float32, ndim=2),
  qacc_warmstart_in: wp.array(dtype=wp.float32, ndim=2),
  qpos_out: wp.array(dtype=wp.float32, ndim=2),
  qvel_out: wp.array(dtype=wp.float32, ndim=2),
  xpos_out: wp.array(dtype=wp.vec3, ndim=2),
  xmat_out: wp.array(dtype=wp.mat33, ndim=2),
  qacc_warmstart_out: wp.array(dtype=wp.float32, ndim=2),
  subtree_com_out: wp.array(dtype=wp.vec3, ndim=2),
  cvel_out: wp.array(dtype=wp.spatial_vector, ndim=2),
  site_xpos_out: wp.array(dtype=wp.vec3, ndim=2),
):
  wp.copy(d.ctrl, ctrl)
  wp.copy(d.qpos, qpos_in)
  wp.copy(d.qvel, qvel_in)
  wp.copy(d.qacc_warmstart, qacc_warmstart_in)

  # TODO(team): remove this hard coding substeps
  # ctrl_dt / sim_dt == 4
  for i in range(4):
    mjwarp.step(m, d)
  wp.copy(qpos_out, d.qpos)
  wp.copy(qvel_out, d.qvel)
  wp.copy(xpos_out, d.xpos)
  wp.copy(xmat_out, d.xmat)
  wp.copy(qacc_warmstart_out, d.qacc_warmstart)
  wp.copy(subtree_com_out, d.subtree_com)
  wp.copy(cvel_out, d.cvel)
  wp.copy(site_xpos_out, d.site_xpos)


jax_mjwarp_step = jax_callable(
  mjwarp_step,
  num_outputs=8,
  output_dims={
    "qpos_out": (NWORLD, mjm.nq),
    "qvel_out": (NWORLD, mjm.nv),
    "xpos_out": (NWORLD, mjm.nbody, 3),
    "xmat_out": (NWORLD, mjm.nbody, 3, 3),
    "qacc_warmstart_out": (NWORLD, mjm.nv),
    "subtree_com_out": (NWORLD, mjm.nbody, 3),
    "cvel_out": (NWORLD, mjm.nbody, 6),
    "site_xpos_out": (NWORLD, mjm.nsite, 3),
  },
)

# the functions below allow us to call MJWarp step inside jax vmap:


@jax.custom_batching.custom_vmap
def step(d: mjx.Data):
  return d


@step.def_vmap
def step_vmap_rule(axis_size, in_batched, d: mjx.Data):
  if in_batched[0].ctrl:
    assert d.ctrl.shape[0] == axis_size
  else:
    d = d.replace(ctrl=jp.tile(d.ctrl, (axis_size, 1)))
  params = {f.name: None for f in dataclasses.fields(mjx.Data)}
  params["ctrl"] = True
  params["qpos"] = True
  params["qvel"] = True
  params["xpos"] = True
  params["xmat"] = True
  params["qacc_warmstart"] = True
  params["subtree_com"] = True
  params["cvel"] = True
  params["site_xpos"] = True
  out_batched = mjx.Data(**params)

  qpos, qvel, xpos, xmat, qacc_warmstart, subtree_com, cvel, site_xpos = jax_mjwarp_step(
    d.ctrl, d.qpos, d.qvel, d.qacc_warmstart
  )
  d = d.replace(
    qpos=qpos,
    qvel=qvel,
    xpos=xpos,
    xmat=xmat,
    qacc_warmstart=qacc_warmstart,
    subtree_com=subtree_com,
    cvel=cvel,
    site_xpos=site_xpos,
  )
  return d, out_batched


def init(qpos, ctrl) -> mjx.Data:
  init_params = {f.name: None for f in dataclasses.fields(mjx.Data)}
  init_params["qpos"] = qpos
  init_params["ctrl"] = ctrl
  init_params["qvel"] = jp.zeros(m.nv)
  init_params["xpos"] = jp.zeros((m.nbody, 3))
  init_params["xmat"] = jp.tile(jp.zeros((3, 3)), (m.nbody, 1, 1))
  init_params["qacc_warmstart"] = jp.array(mjd.qacc_warmstart)
  init_params["subtree_com"] = jp.zeros((m.nbody, 3))
  init_params["cvel"] = jp.zeros((m.nbody, 6))
  init_params["site_xpos"] = jp.zeros((m.nsite, 3))
  return mjx.Data(**init_params)

In [None]:
# An environment for training an Apptronik Apollo to walk.
# This is the same format as environments in MuJoCo Playground.


def default_config() -> config_dict.ConfigDict:
  return config_dict.create(
    ctrl_dt=0.02,
    sim_dt=0.005,
    episode_length=1000,
    action_repeat=1,
    action_scale=0.5,
    soft_joint_pos_limit_factor=0.95,
    reward_config=config_dict.create(
      scales=config_dict.create(
        # Tracking related rewards.
        tracking_lin_vel=1.0,
        tracking_ang_vel=0.75,
        # Base related rewards.
        ang_vel_xy=-0.15,
        orientation=-2.0,
        # Energy related rewards.
        action_rate=0.0,
        # Feet related rewards.
        feet_air_time=2.0,
        feet_slip=-0.25,
        feet_phase=1.0,
        # Other rewards.
        termination=-5.0,
        # Pose related rewards.
        joint_deviation_knee=-0.1,
        joint_deviation_hip=-0.25,
        dof_pos_limits=-1.0,
        pose=-0.1,
      ),
      tracking_sigma=0.25,
      max_foot_height=0.15,
      base_height_target=0.5,
    ),
    lin_vel_x=[1.0, 1.0],
    lin_vel_y=[0.0, 0.0],
    ang_vel_yaw=[0.0, 0.0],
  )


class Joystick(mjx_env.MjxEnv):
  """Track a joystick command."""

  def __init__(
    self,
    config: config_dict.ConfigDict = default_config(),
    config_overrides: Optional[Dict[str, Union[str, int, list[Any]]]] = None,
  ):
    super().__init__(config, config_overrides)
    self._post_init()

  def _post_init(self) -> None:
    self._init_q = jp.array(mjm.keyframe("stand").qpos)
    self._default_pose = self._init_q[7:]

    # Note: First joint is freejoint.
    self._lowers, self._uppers = self.mj_model.jnt_range[1:].T
    c = (self._lowers + self._uppers) / 2
    r = self._uppers - self._lowers
    self._soft_lowers = c - 0.5 * r * self._config.soft_joint_pos_limit_factor
    self._soft_uppers = c + 0.5 * r * self._config.soft_joint_pos_limit_factor

    hip_joints = ["l_hip_ie", "l_hip_aa", "r_hip_ie", "r_hip_aa"]
    hip_indices = [mjm.joint(j).qposadr - 7 for j in hip_joints]
    self._hip_indices = jp.array(hip_indices)

    knee_joints = ["l_knee_fe", "r_knee_fe"]
    knee_indices = [mjm.joint(j).qposadr - 7 for j in knee_joints]
    self._knee_indices = jp.array(knee_indices)

    self._head_body_id = mjm.body("neck_pitch_link").id
    self._torso_id = mjm.body("torso_link").id

    feet_sites = ["l_foot_fr", "l_foot_br", "l_foot_fl", "l_foot_bl"]
    feet_sites += ["r_foot_fr", "r_foot_br", "r_foot_fl", "r_foot_bl"]
    feet_site_ids = [mjm.site(s).id for s in feet_sites]
    self._feet_site_id = jp.array(feet_site_ids)
    self._feet_contact_z = 0.003

    self._floor_geom_id = mjm.geom("floor").id

  def reset(self, rng: jax.Array) -> mjx_env.State:
    qpos = self._init_q

    data = init(qpos=qpos, ctrl=qpos[7:])

    # Phase, freq=U(1.0, 1.5)
    rng, key = jax.random.split(rng)
    gait_freq = jax.random.uniform(key, (1,), minval=1.25, maxval=1.5)
    phase_dt = 2 * jp.pi * self.dt * gait_freq
    phase = jp.array([0, jp.pi])

    rng, cmd_rng = jax.random.split(rng)
    cmd = self.sample_command(cmd_rng)

    info = {
      "rng": rng,
      "step": 0,
      "command": cmd,
      "last_act": jp.zeros(mjm.nu),
      "last_last_act": jp.zeros(mjm.nu),
      "motor_targets": jp.zeros(mjm.nu),
      "feet_air_time": jp.zeros(2),
      "last_contact": jp.zeros(2, dtype=bool),
      # Phase related.
      "phase_dt": phase_dt,
      "phase": phase,
    }

    metrics = {}
    for k in self._config.reward_config.scales.keys():
      metrics[f"reward/{k}"] = jp.zeros(())

    contact = data.site_xpos[self._feet_site_id] < self._feet_contact_z
    contact = jp.array([contact[0:4].any(), contact[4:8].any()])
    obs = self._get_obs(data, info, contact)
    reward, done = jp.zeros(2)
    return mjx_env.State(data, obs, reward, done, metrics, info)

  def step(self, state: mjx_env.State, action: jax.Array) -> mjx_env.State:
    state.info["rng"], _ = jax.random.split(state.info["rng"], 2)

    ctrl = self._default_pose + action * self._config.action_scale
    data = state.data
    data = data.replace(ctrl=ctrl)
    data = step(data)
    state.info["motor_targets"] = ctrl

    contact = data.site_xpos[self._feet_site_id, 2] < self._feet_contact_z
    contact = jp.array([contact[0:4].any(), contact[4:8].any()])
    contact_filt = contact | state.info["last_contact"]
    first_contact = (state.info["feet_air_time"] > 0.0) * contact_filt
    state.info["feet_air_time"] += self.dt

    obs = self._get_obs(data, state.info, contact)
    done = self._get_termination(data)

    rewards = self._get_reward(data, action, state.info, state.metrics, done, first_contact, contact)
    rewards = {k: v * self._config.reward_config.scales[k] for k, v in rewards.items()}
    reward = sum(rewards.values()) * self.dt

    state.info["step"] += 1
    phase_tp1 = state.info["phase"] + state.info["phase_dt"]
    state.info["phase"] = jp.fmod(phase_tp1 + jp.pi, 2 * jp.pi) - jp.pi
    state.info["last_last_act"] = state.info["last_act"]
    state.info["last_act"] = action
    state.info["rng"], cmd_rng = jax.random.split(state.info["rng"])
    state.info["command"] = jp.where(
      state.info["step"] > 500,
      self.sample_command(cmd_rng),
      state.info["command"],
    )
    state.info["step"] = jp.where(
      done | (state.info["step"] > 500),
      0,
      state.info["step"],
    )
    state.info["feet_air_time"] *= ~contact
    state.info["last_contact"] = contact
    for k, v in rewards.items():
      state.metrics[f"reward/{k}"] = v

    done = done.astype(reward.dtype)
    state = state.replace(data=data, obs=obs, reward=reward, done=done)
    return state

  def _get_termination(self, data: mjx.Data) -> jax.Array:
    fall_termination = data.xpos[self._head_body_id, 2] < 1.0
    return fall_termination

  def _get_obs(self, data: mjx.Data, info: dict[str, Any], contact: jax.Array) -> mjx_env.Observation:
    cos = jp.cos(info["phase"])
    sin = jp.sin(info["phase"])
    phase = jp.concatenate([cos, sin])

    return jp.hstack(
      [
        data.qpos,
        data.qvel,
        data.cvel.ravel(),
        data.xpos.ravel(),
        data.xmat.ravel(),
        phase,
        info["command"],
        info["last_act"],
        info["feet_air_time"],
      ]
    )

  def _get_reward(
    self,
    data: mjx.Data,
    action: jax.Array,
    info: dict[str, Any],
    metrics: dict[str, Any],
    done: jax.Array,
    first_contact: jax.Array,
    contact: jax.Array,
  ) -> dict[str, jax.Array]:
    del metrics  # Unused.
    return {
      # Tracking rewards.
      "tracking_lin_vel": self._reward_tracking_lin_vel(info["command"], self._get_global_linvel(data, self._torso_id)),
      "tracking_ang_vel": self._reward_tracking_ang_vel(info["command"], self._get_global_angvel(data, self._torso_id)),
      # Base-related rewards.
      "ang_vel_xy": self._cost_ang_vel_xy(self._get_global_angvel(data, self._torso_id)),
      "orientation": self._cost_orientation(self._get_z_frame(data, self._torso_id)),
      # Energy related rewards.
      "action_rate": self._cost_action_rate(action, info["last_act"], info["last_last_act"]),
      # Feet related rewards.
      "feet_slip": self._cost_feet_slip(data, contact, info),
      "feet_air_time": self._reward_feet_air_time(info["feet_air_time"], first_contact, info["command"]),
      "feet_phase": self._reward_feet_phase(
        data,
        info["phase"],
        self._config.reward_config.max_foot_height,
        info["command"],
      ),
      # Pose related rewards.
      "joint_deviation_hip": self._cost_joint_deviation_hip(data.qpos[7:], info["command"]),
      "joint_deviation_knee": self._cost_joint_deviation_knee(data.qpos[7:]),
      "dof_pos_limits": self._cost_joint_pos_limits(data.qpos[7:]),
      "pose": self._cost_pose(data.qpos[7:]),
      # Other rewards.
      "termination": self._cost_termination(done),
    }

  def _get_global_angvel(self, data: mjx.Data, bodyid: int):
    return data.cvel[bodyid, :3]

  def _get_global_linvel(self, data: mjx.Data, bodyid: int):
    offset = data.xpos[bodyid] - data.subtree_com[mjm.body_rootid[bodyid]]
    xang = data.cvel[bodyid, :3]
    xvel = data.cvel[bodyid, 3:] + jp.cross(offset, xang)
    return xvel

  def _get_z_frame(self, data: mjx.Data, bodyid: int):
    return data.xmat[bodyid, :, 2]

  # Tracking rewards.

  def _reward_tracking_lin_vel(
    self,
    commands: jax.Array,
    local_vel: jax.Array,
  ) -> jax.Array:
    lin_vel_error = jp.sum(jp.square(commands[:2] - local_vel[:2]))
    return jp.exp(-lin_vel_error / self._config.reward_config.tracking_sigma)

  def _reward_tracking_ang_vel(
    self,
    commands: jax.Array,
    ang_vel: jax.Array,
  ) -> jax.Array:
    ang_vel_error = jp.square(commands[2] - ang_vel[2])
    return jp.exp(-ang_vel_error / self._config.reward_config.tracking_sigma)

  # Base-related rewards.

  def _cost_ang_vel_xy(self, global_angvel_torso: jax.Array) -> jax.Array:
    return jp.sum(jp.square(global_angvel_torso[:2]))

  def _cost_orientation(self, torso_zaxis: jax.Array) -> jax.Array:
    return jp.sum(jp.square(torso_zaxis - jp.array([0.0, 0.0, 1.0])))

  def _cost_base_height(self, base_height: jax.Array) -> jax.Array:
    return jp.square(base_height - self._config.reward_config.base_height_target)

  # Energy related rewards.

  def _cost_action_rate(self, act: jax.Array, last_act: jax.Array, last_last_act: jax.Array) -> jax.Array:
    del last_last_act  # Unused.
    return jp.sum(jp.square(act - last_act))

  # Feet related rewards.

  def _cost_feet_slip(self, data: mjx.Data, contact: jax.Array, info: dict[str, Any]) -> jax.Array:
    del info  # Unused.
    body_vel = self._get_global_linvel(data, self._torso_id)[:2]
    reward = jp.sum(jp.linalg.norm(body_vel, axis=-1) * contact)
    return reward

  def _reward_feet_air_time(
    self,
    air_time: jax.Array,
    first_contact: jax.Array,
    commands: jax.Array,
    threshold_min: float = 0.2,
    threshold_max: float = 0.5,
  ) -> jax.Array:
    del commands  # Unused.
    air_time = (air_time - threshold_min) * first_contact
    air_time = jp.clip(air_time, max=threshold_max - threshold_min)
    reward = jp.sum(air_time)
    return reward

  def get_rz(phi: Union[jax.Array, float], swing_height: Union[jax.Array, float] = 0.08) -> jax.Array:
    def cubic_bezier_interpolation(y_start, y_end, x):
      y_diff = y_end - y_start
      bezier = x**3 + 3 * (x**2 * (1 - x))
      return y_start + y_diff * bezier

    x = (phi + jp.pi) / (2 * jp.pi)
    stance = cubic_bezier_interpolation(0, swing_height, 2 * x)
    swing = cubic_bezier_interpolation(swing_height, 0, 2 * x - 1)
    return jp.where(x <= 0.5, stance, swing)

  def _reward_feet_phase(
    self,
    data: mjx.Data,
    phase: jax.Array,
    foot_height: jax.Array,
    command: jax.Array,
  ) -> jax.Array:
    # Reward for tracking the desired foot height.
    foot_pos = data.site_xpos[self._feet_site_id]
    foot_pos = jp.array([jp.mean(foot_pos[0:4], axis=0), jp.mean(foot_pos[4:8], axis=0)])
    foot_z = foot_pos[..., -1]
    rz = Joystick.get_rz(phase, swing_height=foot_height)
    error = jp.sum(jp.square(foot_z - rz))
    reward = jp.exp(-error / 0.01)
    body_linvel = self._get_global_linvel(data, self._torso_id)[:2]
    body_angvel = self._get_global_angvel(data, self._torso_id)[2]
    linvel_mask = jp.logical_or(
      jp.linalg.norm(body_linvel) > 0.1,
      jp.abs(body_angvel) > 0.1,
    )
    mask = jp.logical_or(linvel_mask, jp.linalg.norm(command) > 0.01)
    reward *= mask
    return reward

  # Pose-related rewards.

  def _cost_joint_deviation_hip(self, qpos: jax.Array, cmd: jax.Array) -> jax.Array:
    error = qpos[self._hip_indices] - self._default_pose[self._hip_indices]
    # Allow roll deviation when lateral velocity is high.
    weight = jp.where(
      cmd[1] > 0.1,
      jp.array([0.0, 1.0, 0.0, 1.0]),
      jp.array([1.0, 1.0, 1.0, 1.0]),
    )
    cost = jp.sum(jp.abs(error) * weight)
    return cost

  def _cost_joint_deviation_knee(self, qpos: jax.Array) -> jax.Array:
    error = qpos[self._knee_indices] - self._default_pose[self._knee_indices]
    return jp.sum(jp.abs(error))

  def _cost_joint_pos_limits(self, qpos: jax.Array) -> jax.Array:
    out_of_limits = -jp.clip(qpos - self._soft_lowers, None, 0.0)
    out_of_limits += jp.clip(qpos - self._soft_uppers, 0.0, None)
    return jp.sum(out_of_limits)

  def _cost_pose(self, qpos: jax.Array) -> jax.Array:
    return jp.sum(jp.square(qpos - self._default_pose))

  # Other rewards.

  def _cost_termination(self, done: jax.Array) -> jax.Array:
    return done

  def sample_command(self, rng: jax.Array) -> jax.Array:
    rng1, rng2, rng3, rng4 = jax.random.split(rng, 4)

    lin_vel_x = jax.random.uniform(rng1, minval=self._config.lin_vel_x[0], maxval=self._config.lin_vel_x[1])
    lin_vel_y = jax.random.uniform(rng2, minval=self._config.lin_vel_y[0], maxval=self._config.lin_vel_y[1])
    ang_vel_yaw = jax.random.uniform(
      rng3,
      minval=self._config.ang_vel_yaw[0],
      maxval=self._config.ang_vel_yaw[1],
    )

    return jp.hstack([lin_vel_x, lin_vel_y, ang_vel_yaw])

  @property
  def xml_path(self) -> str:
    return xml_path.as_posix()

  @property
  def action_size(self) -> int:
    return mjm.nu

  @property
  def mj_model(self) -> mujoco.MjModel:
    return mjm

  @property
  def mjx_model(self) -> mjx.Model:
    return None  # unused

In [None]:
# Train the environment using Brax PPO

from brax.training.agents.ppo import networks as ppo_networks
from brax.training.agents.ppo import train as ppo
from IPython.display import HTML
from IPython.display import clear_output
from matplotlib import pyplot as plt
from mujoco_playground import wrapper
from mujoco_playground.config import locomotion_params

ppo_params = config_dict.create(
  num_timesteps=50_000_000,
  num_evals=5,
  reward_scaling=1.0,
  clipping_epsilon=0.2,
  episode_length=1000,
  normalize_observations=True,
  action_repeat=1,
  unroll_length=20,
  num_minibatches=32,
  num_updates_per_batch=4,
  discounting=0.97,
  learning_rate=3e-4,
  entropy_cost=0.005,
  num_envs=NWORLD,
  num_eval_envs=NWORLD,
  batch_size=256,
  max_grad_norm=1.0,
  # network_factory=config_dict.create(
  #   policy_hidden_layer_sizes=(512, 256, 128),
  #   value_hidden_layer_sizes=(512, 256, 128),
  # ),
)

x_data, y_data, y_dataerr = [], [], []
times = [datetime.datetime.now()]


def progress(num_steps, metrics):
  times.append(datetime.datetime.now())
  x_data.append(num_steps)
  y_data.append(metrics["eval/episode_reward"])
  y_dataerr.append(metrics["eval/episode_reward_std"])

  plt.xlim([0, ppo_params["num_timesteps"] * 1.25])
  plt.ylim([0, 30])
  plt.xlabel("# environment steps")
  plt.ylabel("reward per episode")
  plt.title(f"y={y_data[-1]:.3f}")
  plt.errorbar(x_data, y_data, yerr=y_dataerr, color="blue")

  display(plt.gcf())
  clear_output(wait=True)


ppo_training_params = dict(ppo_params)
network_factory = ppo_networks.make_ppo_networks
if "network_factory" in ppo_params:
  del ppo_training_params["network_factory"]
  network_factory = functools.partial(ppo_networks.make_ppo_networks, **ppo_params.network_factory)

train_fn = functools.partial(
  ppo.train,
  **dict(ppo_training_params),
  network_factory=network_factory,
  progress_fn=progress,
)

env = Joystick()

make_inference_fn, params, metrics = train_fn(
  environment=env,
  wrap_env_fn=wrapper.wrap_for_brax_training,
)

print(f"time to jit: {times[1] - times[0]}")
print(f"time to train: {times[-1] - times[1]}")

In [None]:
rng = jax.random.PRNGKey(0)

jit_reset = jax.jit(env.reset)


def unroll(state):
  inference_fn = make_inference_fn(params, deterministic=True)
  rng = jax.random.PRNGKey(0)

  def single_step(state, _):
    action, _ = inference_fn(state.obs, rng)
    action = jp.tile(action, (NWORLD, 1))
    state = jax.tree.map(lambda x: jp.tile(x, (NWORLD,) + (1,) * len(x.shape)), state)
    state = jax.vmap(env.step)(state, action)
    state = jax.tree.map(lambda x: x[0], state)

    return state, state

  _, states = jax.lax.scan(single_step, state, length=1000)

  return states


rollout = jax.jit(unroll)(jit_reset(rng))

In [None]:
rollout_arr = [jax.tree.map(lambda x, i=i: x[i], rollout) for i in range(400)]
frames = env.render(rollout_arr, camera="track", width=640, height=480)
media.show_video(frames, fps=1.0 / env.dt)