# MPC with MjWarp

In [9]:

# jupyter notebook --NotebookApp.allow_origin='https://colab.research.google.com' --port=8888 --NotebookApp.port_retries=0 --no-browser

from brax import base as brax_base
from brax.io import html
from brax.io import mjcf
import jax
from jax import numpy as jnp
from IPython.display import HTML
import mujoco
from mujoco import mjx
from mujoco.mjx._src import dataclasses
import mujoco_warp as mjwarp
import mediapy as media
import numpy as np
import os
import time
import tqdm
from typing import Callable
import warp as wp
from warp.jax_experimental.ffi import jax_callable

# this ensures JAX embeds Warp kernels into its own computation graph:
os.environ["XLA_FLAGS"] = "--xla_gpu_graph_min_graph_size=1"

# Humanoid Stand Up Task

In [10]:
NCONMAX = 81920
NJMAX = NCONMAX * 4
NWORLDS = 1024
HORIZON = 128
NSTEPS = 1000
PLAN_EVERY = 10

# Get MuJoCo's humanoid model.
print('Getting MuJoCo humanoid XML description from GitHub:')
!git clone https://github.com/google-deepmind/mujoco
humanoid_file = 'mujoco/model/humanoid/humanoid.xml'
spec = mujoco.MjSpec.from_file(humanoid_file)

spec.option.jacobian = mujoco.mjtJacobian.mjJAC_SPARSE
spec.option.integrator = mujoco.mjtIntegrator.mjINT_RK4

# Initialise to squat position
mjm = spec.compile()
mjm.opt.iterations = 1
mjm.opt.ls_iterations = 4
mjd = mujoco.MjData(mjm)
key = mjm.key('squat').id
mujoco.mj_resetDataKeyframe(mjm, mjd, key)

# make warp model/data
m = mjwarp.put_model(mjm)
d = mjwarp.put_data(mjm, mjd, nworld=NWORLDS, nconmax=131012, njmax=131012 * 4)

# make jax model/data
mjxm = mjx.put_model(mjm) 
mjxd = mjx.put_data(mjm, mjd)
mjxd = mjx.kinematics(mjxm, mjxd)

# Cost 
def cost_fn(qpos, qvel, ctrl):
  # body position
  body_pos = qpos[:3]
  body_quat = qpos[3:7]
  target_body_pos = jnp.array([0.0, 0.0, 1.2])
  target_body_weight = jnp.array([0.1, 0.1, 1.0])
  pos_cost = jnp.dot(
      (body_pos - target_body_pos) ** 2,
      target_body_weight
  )
  vel_cost = jnp.sum(qvel**2).reshape(-1)
  return 100 * pos_cost



Getting MuJoCo humanoid XML description from GitHub:
fatal: destination path 'mujoco' already exists and is not an empty directory.


# Define MPPI Planner

In [11]:


CostFn = Callable[[jax.Array, jax.Array, jax.Array], jax.Array]
StepFn = Callable[[jax.Array, jax.Array, jax.Array], tuple[jax.Array, jax.Array]]

class MPPIConfig(dataclasses.PyTreeNode):
  """Planning Config.

  Attributes:
    model: MJX model
    cost: function returning per-timestep cost
    step_fn: function for stepping simulation during rollout
    noise_scale: standard deviation of zero-mean Gaussian
    horizon: planning duration (steps)
    nspline: number of spline points to explore
    nsample: number of action sequence candidates sampled
    interp: type of action interpolation
    inverse_temperature: MPPI inverse temperature
  """
  model: mjx.Model
  cost: CostFn
  step_fn: StepFn
  noise_scale: float
  horizon: int
  nspline: int
  nsample: int
  interp: str
  inverse_temperature: float

class MPPIPlanner:

  def __init__(self,
              model: mjx.Model,
              cost: CostFn,
              step_fn: StepFn,
              noise_scale: float,
              horizon: int,
              nspline: int,
              nsample: int,
              interp: str,
              inverse_temperature: float,
              ) -> None:

    self._config = MPPIConfig (
      cost = cost,
      model = model,
      step_fn = step_fn,
      noise_scale = noise_scale,
      horizon = horizon,
      nsample = nsample,
      nspline = nspline,
      interp = interp,
      inverse_temperature = inverse_temperature,
    )
      
  @property
  def config(self) -> MPPIConfig:
    return self._config

  @staticmethod
  def get_actions(p: MPPIConfig, policy: jax.Array) -> jax.Array:
    """Gets actions over a planning duration from a policy."""
    if p.interp == 'zero':
      indices = [i * p.nspline // p.horizon for i in range(p.horizon)]
      actions = policy[jnp.array(indices)]
    elif p.interp == 'linear':
      locs = jnp.array([i * p.nspline / p.horizon for i in range(p.horizon)])
      idx = locs.astype(int)
      actions = jax.vmap(jnp.multiply)(policy[idx], 1 - locs + idx)
      actions += jax.vmap(jnp.multiply)(policy[idx + 1], locs - idx)
    else:
      raise ValueError(f'unimplemented interpolation method: {p.interp}')
    return actions

  @staticmethod
  def rollout(p: MPPIConfig, qpos: jax.Array, 
              qvel:jax.Array, policy: jax.Array) -> jax.Array:
    """Expand the policy into actions and roll out dynamics and cost."""

    B, H, _ = policy.shape
    qpos = jnp.tile(qpos, (B, 1))
    qvel = jnp.tile(qvel, (B, 1))
    actions = jax.vmap(MPPIPlanner.get_actions, in_axes=(None, 0))(p, policy)

    def step(carry, u):
      qpos, qvel = carry
      qpos, qvel = p.step_fn(qpos, qvel, u)
      cost = jax.vmap(p.cost)(qpos, qvel, u)
      return (qpos, qvel), (qpos, qvel, cost)

    _, (qpos_traj, qvel_traj, costs) = jax.lax.scan(step, (qpos, qvel), actions.transpose(1, 0, 2))


    return jnp.sum(costs, axis=0)

  @staticmethod
  def resample(p: MPPIConfig, policy: jax.Array, steps_per_plan: int) -> jax.Array:
    """Resample policy to new advanced time."""
    if p.interp == 'zero':
      return policy  # assuming steps_per_plan < splinesteps
    elif p.interp == 'linear':
      actions = MPPIPlanner.get_actions(p, policy)
      roll = steps_per_plan
      actions = jnp.roll(actions, -roll, axis=-2)
      actions = actions.at[..., -roll:, :].set(actions[..., [-1], :])
      idx = jnp.floor(jnp.linspace(0, p.horizon, p.nspline)).astype(int)
      return actions[..., idx, :]
    return policy

  @staticmethod
  def improve_policy(
      p: MPPIConfig,
      qpos: jax.Array,
      qvel: jax.Array,
      policy: jax.Array,
      rng: jax.Array,
  ) -> tuple[jax.Array, jax.Array]:
    """Improves policy."""
    # create noisy policies
    noise = (
        jax.random.normal(rng, (p.nsample, p.nspline, p.model.nu)) * p.noise_scale
    )
    policies = policy + noise

    # clamp actions to ctrlrange
    limit = p.model.actuator_ctrlrange
    policies = jnp.clip(policies, limit[:, 0], limit[:, 1])

    # perform parallel rollouts
    costs = MPPIPlanner.rollout(p, qpos, qvel, policies)
    costs = jnp.nan_to_num(costs, nan=jnp.inf)

    # add perturbation cost
    perturb_cost = jnp.sum(p.inverse_temperature * noise *
                           policies / p.noise_scale**2, axis=(1, 2))
    costs += perturb_cost

    # normalize cost to [0, 1]
    costs /= jnp.max(costs) - jnp.min(costs)
    costs -= jnp.min(costs)

    # compute update weights
    omega = jax.nn.softmax(-costs / p.inverse_temperature)

    # get final nominal
    policy = jnp.einsum('i,ijk->jk', omega, policies)
    return policy


def mpc_rollout(
    nsteps,
    steps_per_plan,
    p: MPPIPlanner,
    init_policy,
    rng,
    model,
    data,
):
  """Receding horizon optimization starting from sim_data's state."""
  qpos = np.zeros((nsteps, model.nq))
  qvel = np.zeros((nsteps, model.nv))
  costs = np.zeros(nsteps)
    
  policy = init_policy.clone()
  for step in tqdm.tqdm(range(nsteps // steps_per_plan)):

    # resample
    policy = jax.jit(p.resample, static_argnums=(2))(p.config, policy, steps_per_plan)
    # planning
    policy = jax.jit(p.improve_policy)(
        p.config,
        data.qpos,
        data.qvel,
        policy,
        rng,
    )
    # get actions from spline
    actions = jax.jit(p.get_actions)(p.config, policy)

    # rollout
    for i in range(steps_per_plan):
      action = actions[i]
      data = data.replace(ctrl=action)
      costs[step * steps_per_plan + i] = jax.jit(p.config.cost)(data.qpos, data.qvel, action)
      data = jax.jit(mjx.step)(model, data)
      qpos[step * steps_per_plan + i] = data.qpos
      qvel[step * steps_per_plan + i] = data.qvel
        
  return qpos, qvel, costs


# Using mjwarp.step

In [12]:
# Step fn
def warp_step(
  qpos_in: wp.array(dtype=wp.float32, ndim=2),
  qvel_in: wp.array(dtype=wp.float32, ndim=2),
  ctrl_in: wp.array(dtype=wp.float32, ndim=2),
  qpos_out: wp.array(dtype=wp.float32, ndim=2),
  qvel_out: wp.array(dtype=wp.float32, ndim=2),
):
  wp.copy(d.qpos, qpos_in)
  wp.copy(d.qvel, qvel_in)
  wp.copy(d.ctrl, ctrl_in)
  mjwarp.step(m, d)
  wp.copy(qpos_out, d.qpos)
  wp.copy(qvel_out, d.qvel)

warp_step_fn = jax_callable(
  warp_step,
  num_outputs=2,
  output_dims={"qpos_out": (NWORLDS, mjm.nq), "qvel_out": (NWORLDS, mjm.nv)},
)

# Instantiate the planner
warp_planner = MPPIPlanner(
    model=mjxm,
    step_fn = warp_step_fn,
    cost=cost_fn,
    noise_scale=2.0,
    horizon=HORIZON,
    nspline=16,
    nsample=NWORLDS,
    interp='zero',
    inverse_temperature=0.01
)

rng = jax.random.PRNGKey(0)
planner_params = warp_planner.config
policy = jnp.zeros((planner_params.nspline, planner_params.model.nu))

# Run once to compile
beg = time.perf_counter()
jax.block_until_ready(mpc_rollout(1, 1, warp_planner, policy,
                                rng, mjx.put_model(mjm),
                                mjx.put_data(mjm, mjd)
                                 )
                     )
end = time.perf_counter()
print(f"Jit time: {end-beg}.4f")

qpos, qvel, costs = mpc_rollout(NSTEPS, PLAN_EVERY, warp_planner, policy,
                                rng, mjx.put_model(mjm),
                                mjx.put_data(mjm, mjd))

# # # now let's render the model into a video
d = mujoco.MjData(mjm)
sys = mjcf.load_model(mjm)
xstates = []
for qp in qpos.reshape(-1, mjm.nq):
  d.qpos = qp
  mujoco.mj_kinematics(mjm, d)
  x = brax_base.Transform(pos=d.xpos[1:].copy(), rot=d.xquat[1:].copy())
  xstates.append(brax_base.State(q=None, qd=None, x=x, xd=None, contact=None))

HTML(html.render(sys, xstates))

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [01:30<00:00, 90.16s/it]


Jit time: 90.19343079399914.4f


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [01:45<00:00,  1.05s/it]


# Using mjx.step

In [13]:
def mjx_step_fn(
    qpos_in: jax.Array,
    qvel_in: jax.Array,
    ctrl_in: jax.Array,
) -> tuple[jax.Array, jax.Array]:
  global mjxd
  mjxd = mjxd.replace(qpos=qpos_in, qvel=qvel_in, ctrl=ctrl_in)
  mjxd = mjx.step(mjxm, mjxd)
  return mjxd.qpos, mjxd.qvel

mjx_step_fn = jax.vmap(mjx_step_fn)

# Instantiate the planner
mjx_planner = MPPIPlanner(
    model=mjxm,
    step_fn = mjx_step_fn,
    cost=cost_fn,
    noise_scale=2.0,
    horizon=HORIZON,
    nspline=16,
    nsample=NWORLDS,
    interp='zero',
    inverse_temperature=0.01
)
rng = jax.random.PRNGKey(0)
planner_params = warp_planner.config
policy = jnp.zeros((planner_params.nspline, planner_params.model.nu))

# Run once to compile
beg = time.perf_counter()
jax.block_until_ready(mpc_rollout(1, 1, mjx_planner, policy,
                                rng, mjx.put_model(mjm),
                                mjx.put_data(mjm, mjd)
                                 )
                     )
end = time.perf_counter()
print(f"Jit time: {end-beg}.4f")

# do loop
qpos, qvel, costs = mpc_rollout(NSTEPS, PLAN_EVERY, mjx_planner, policy,
                                rng, mjx.put_model(mjm), mjx.put_data(mjm, mjd))

# # # now let's render the model into a video
d = mujoco.MjData(mjm)
sys = mjcf.load_model(mjm)
xstates = []
for qp in qpos.reshape(-1, mjm.nq):
  d.qpos = qp
  mujoco.mj_kinematics(mjm, d)
  x = brax_base.Transform(pos=d.xpos[1:].copy(), rot=d.xquat[1:].copy())
  xstates.append(brax_base.State(q=None, qd=None, x=x, xd=None, contact=None))

HTML(html.render(sys, xstates))

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [01:44<00:00, 104.15s/it]


Jit time: 104.18404249299783.4f


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [03:24<00:00,  2.04s/it]
