In [1]:
#@title Import packages for plotting and creating graphics
import time
import itertools
import numpy as np
from typing import Callable, NamedTuple, Optional, Union, List

import mediapy as media
import matplotlib.pyplot as plt

# More legible printing from numpy.
np.set_printoptions(precision=3, suppress=True, linewidth=100)

In [2]:
#@title Import MuJoCo, MJX, and Brax
from datetime import datetime
from etils import epath
import functools
from IPython.display import HTML
from typing import Any, Dict, Sequence, Tuple, Union
import os
# from ml_collections import config_dict


import jax
from jax import numpy as jp
import numpy as np
# from flax.training import orbax_utils
# from flax import struct
from matplotlib import pyplot as plt
import mediapy as media
# from orbax import checkpoint as ocp

import mujoco
from mujoco import mjx

from brax import base
from brax import envs
from brax import math
from brax.base import Base, Motion, Transform
from brax.envs.base import Env, PipelineEnv, State
from brax.mjx.base import State as MjxState
from brax.training.agents.ppo import train as ppo
from brax.training.agents.ppo import networks as ppo_networks
from brax.io import html, mjcf, model
from brax.base import System

from copy import deepcopy as dc

# TwoArm URM Env


In [None]:
print(jax.devices[0])

## Visualize a Rollout

Let's instantiate the environment and visualize a short rollout.

NOTE: Since episodes terminates early if the torso is below the healthy z-range, the only relevant contacts for this task are between the feet and the plane. We turn off other contacts.

In [None]:
# instantiate the environment
env_name = 'urm2arm_prop'
env = envs.get_environment(env_name)




In [5]:
# define the jit reset/step functions
jit_reset = jax.jit(env.reset)
jit_step = jax.jit(env.step)
# initialize the state
state = jit_reset(jax.random.PRNGKey(0))
rollout = [state.pipeline_state]

In [None]:
# grab a trajectory
key = jax.random.key(110)
for i in range(20):
  key, _ = jax.random.split(key,2)
  print(env.action_size)
  idx = jax.random.randint(key, 1, 0, env.action_size)
  print(idx)
  ctrl = jp.zeros(16)
  ctrl = ctrl.at[idx].set(1.0)
  print(idx, ctrl)
  state = jit_step(state, ctrl)
  # print(state.ctrl)
  rollout.append(state.pipeline_state)
print("done")
print(env.e_list[idx])

In [None]:
media.show_video(env.render(rollout, camera='top'), fps=1.0 / env.dt)

In [None]:
print(jax.devices('gpu')[0])

## Train Two Arm URM Policy

In [3]:
# instantiate the environment
env_name = 'urm2arm_binary'
env = envs.get_environment(env_name)

def policy_params_fn(current_step, make_policy, params):
  model_path = '/home/tlee_theaiinstitute_com/mjx_brax_policy/test_new_{}'.format(current_step)
  model.save_params(model_path, params)


num_evals= 50
num_envs = 1000
train_fn = functools.partial(
    ppo.train, num_timesteps=40*num_evals*num_envs, 
    num_evals=num_evals, reward_scaling=0.1,
    episode_length=40, normalize_observations=False, action_repeat=1,
    unroll_length=40, num_minibatches=10, num_updates_per_batch=2,
    discounting=0.97, learning_rate=3e-4, entropy_cost=1e-3, num_envs=num_envs, clipping_epsilon=0.3,
    batch_size=100, seed=0, discrete_action = True, spectral_norm_actor = False, policy_params_fn=policy_params_fn)


x_data = []
y_data = []
ydataerr = []
times = [datetime.now()]

max_y, min_y = -400, 0
def progress(num_steps, metrics):
  times.append(datetime.now())
  x_data.append(num_steps)
  y_data.append(metrics['eval/episode_reward'])
  ydataerr.append(metrics['eval/episode_reward_std'])
  # print(datetime.now(), num_steps, metrics['eval/episode_reward'])

  plt.xlim([0, train_fn.keywords['num_timesteps'] * 2.0])
  plt.ylim([min_y, max_y])

  plt.xlabel('# environment steps')
  plt.ylabel('reward per episode')
  plt.title(f'y={y_data[-1]:.3f}')

  plt.errorbar(x_data, y_data, yerr=ydataerr)
  plt.show()

make_inference_fn, params, _= train_fn(environment=env, progress_fn=progress)

print(f'time to jit: {times[1] - times[0]}')
print(f'time to train: {times[-1] - times[1]}')

<!-- ## Save and Load Policy -->

We can save and load the policy using the brax model API.

In [14]:
#@title Save Modelz
model_path = '/home/tlee_theaiinstitute_com/mjx_brax_policy/test'
model.save_params(model_path, params)

In [15]:
#@title Load Model and Define Inference Function
params = model.load_params(model_path)

inference_fn = make_inference_fn(params)
jit_inference_fn = jax.jit(inference_fn)

## Visualize Policy

Finally we can visualize the policy.

In [16]:
eval_env = envs.get_environment(env_name)

jit_reset = jax.jit(eval_env.reset)
jit_step = jax.jit(eval_env.step)

In [None]:
# initialize the state
rng = jax.random.PRNGKey(10)
state = jit_reset(rng)
rollout = [state.pipeline_state]

# grab a trajectory
n_steps = 40
render_every = 2

for i in range(n_steps):
  act_rng, rng = jax.random.split(rng)
  ctrl, _ = jit_inference_fn(state.obs, act_rng)
  state = jit_step(state, ctrl)
  rollout.append(state.pipeline_state)
  # print(state.pipeline_state.qpos, state.reward)
  print(ctrl, state.reward)
  # print(state.done)
  # if state.done:
  #   break
print("done")
# media.write_video('/home/tlee_theaiinstitute_com/mjx_brax_policy/rotate_90deg.mp4', env.render(rollout[::render_every], camera='top'), fps=1.0 / env.dt / render_every)
media.show_video(env.render(rollout[::render_every], camera='top'), fps=1.0 / env.dt / render_every)

In [None]:
print(len(rollout))

# MJX Policy in MuJoCo

We can also perform the physics step using the original MuJoCo python bindings to show that the policy trained in MJX works in MuJoCo.

In [None]:
mj_model = eval_env.sys.mj_model
mj_data = mujoco.MjData(mj_model)

renderer = mujoco.Renderer(mj_model)
ctrl = jp.zeros(mj_model.nu)
images = []
for i in range(n_steps):
  act_rng, rng = jax.random.split(rng)

  obs = eval_env._get_obs(mjx.put_data(mj_model, mj_data), ctrl)
  ctrl, _ = jit_inference_fn(obs, act_rng)
  print(ctrl)
  
  mj_data.ctrl = ctrl
  for _ in range(eval_env._n_frames):
    mujoco.mj_step(mj_model, mj_data)  # Physics step using MuJoCo mj_step.
  # print(mj_data.qpos)
  if i % render_every == 0:
    renderer.update_scene(mj_data, camera='side')
    images.append(renderer.render())
# plt.show(images[100])
media.show_video(images) #, fps=1.0 / eval_env.dt / render_every)
print("done")
# media.write_video('/home/tlee_theaiinstitute_com/mjx_brax_policy/test.mp4', images, fps=1.0 / env.dt / render_every)

In [None]:
mj_model = eval_env.sys.mj_model
mj_data = mujoco.MjData(mj_model)
mujoco.mj_step(mj_model, mj_data)
renderer.update_scene(mj_data, camera='side')
im = renderer.render()
plt.imshow(im)

# Training a Policy with Domain Randomization

We might also want to include randomization over certain `mjModel` parameters while training a policy. In MJX, we can easily create a batch of environments with randomized values populated in `mjx.Model`. Below, we show a function that randomizes friction and actuator gain/bias.

In [None]:
def domain_randomize(sys, rng):
  """Randomizes the mjx.Model."""
  @jax.vmap
  def rand(rng):
    _, key = jax.random.split(rng, 2)
    # friction
    friction = jax.random.uniform(key, (1,), minval=0.6, maxval=1.4)
    friction = sys.geom_friction.at[:, 0].set(friction)
    # actuator
    _, key = jax.random.split(key, 2)
    gain_range = (-5, 5)
    param = jax.random.uniform(
        key, (1,), minval=gain_range[0], maxval=gain_range[1]
    ) + sys.actuator_gainprm[:, 0]
    gain = sys.actuator_gainprm.at[:, 0].set(param)
    bias = sys.actuator_biasprm.at[:, 1].set(-param)
    return friction, gain, bias

  friction, gain, bias = rand(rng)

  in_axes = jax.tree_util.tree_map(lambda x: None, sys)
  in_axes = in_axes.tree_replace({
      'geom_friction': 0,
      'actuator_gainprm': 0,
      'actuator_biasprm': 0,
  })

  sys = sys.tree_replace({
      'geom_friction': friction,
      'actuator_gainprm': gain,
      'actuator_biasprm': bias,
  })

  return sys, in_axes

If we wanted 10 environments with randomized friction and actuator params, we can call `domain_randomize`, which returns a batched `mjx.Model` along with a dictionary specifying the axes that are batched.

In [None]:
rng = jax.random.PRNGKey(0)
rng = jax.random.split(rng, 10)
batched_sys, _ = domain_randomize(env.sys, rng)

print('Single env friction shape: ', env.sys.geom_friction.shape)
print('Batched env friction shape: ', batched_sys.geom_friction.shape)

print('Friction on geom 0: ', env.sys.geom_friction[0, 0])
print('Random frictions on geom 0: ', batched_sys.geom_friction[:, 0, 0])