In [1]:
import mujoco
from mujoco import mjx
import jax
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random
from jax import device_put
import numpy as np 
import yaml
from typing import List, Dict, Text, Any, Sequence, Union, Optional
import time
import functools
import copy
from flax import struct
import logging
import os

# os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = ".5"


### Brax base and state classes

In [2]:
from jax.tree_util import tree_map

@struct.dataclass
class Base:
  """Base functionality extending all brax types.

  These methods allow for brax types to be operated like arrays/matrices.
  """

  def __add__(self, o: Any) -> Any:
    return tree_map(lambda x, y: x + y, self, o)

  def __sub__(self, o: Any) -> Any:
    return tree_map(lambda x, y: x - y, self, o)

  def __mul__(self, o: Any) -> Any:
    return tree_map(lambda x: x * o, self)

  def __neg__(self) -> Any:
    return tree_map(lambda x: -x, self)

  def __truediv__(self, o: Any) -> Any:
    return tree_map(lambda x: x / o, self)

  def reshape(self, shape: Sequence[int]) -> Any:
    return tree_map(lambda x: x.reshape(shape), self)

  def select(self, o: Any, cond: jax.Array) -> Any:
    return tree_map(lambda x, y: (x.T * cond + y.T * (1 - cond)).T, self, o)

  def slice(self, beg: int, end: int) -> Any:
    return tree_map(lambda x: x[beg:end], self)

  def take(self, i, axis=0) -> Any:
    return tree_map(lambda x: jnp.take(x, i, axis=axis, mode='wrap'), self)

  def concatenate(self, *others: Any, axis: int = 0) -> Any:
    return tree_map(lambda *x: jnp.concatenate(x, axis=axis), self, *others)

  def index_set(
      self, idx: Union[jax.Array, Sequence[jax.Array]], o: Any
  ) -> Any:
    return tree_map(lambda x, y: x.at[idx].set(y), self, o)

  def index_sum(
      self, idx: Union[jax.Array, Sequence[jax.Array]], o: Any
  ) -> Any:
    return tree_map(lambda x, y: x.at[idx].add(y), self, o)

  def vmap(self, in_axes=0, out_axes=0):
    """Returns an object that vmaps each follow-on instance method call."""

    # TODO: i think this is kinda handy, but maybe too clever?

    outer_self = self

    class VmapField:
      """Returns instance method calls as vmapped."""

      def __init__(self, in_axes, out_axes):
        self.in_axes = [in_axes]
        self.out_axes = [out_axes]

      def vmap(self, in_axes=0, out_axes=0):
        self.in_axes.append(in_axes)
        self.out_axes.append(out_axes)
        return self

      def __getattr__(self, attr):
        fun = getattr(outer_self.__class__, attr)
        # load the stack from the bottom up
        vmap_order = reversed(list(zip(self.in_axes, self.out_axes)))
        for in_axes, out_axes in vmap_order:
          fun = vmap(fun, in_axes=in_axes, out_axes=out_axes)
        fun = functools.partial(fun, outer_self)
        return fun

    return VmapField(in_axes, out_axes)

  def tree_replace(
      self, params: Dict[str, Optional[jax.typing.ArrayLike]]
  ) -> 'Base':
    """Creates a new object with parameters set.

    Args:
      params: a dictionary of key value pairs to replace

    Returns:
      data clas with new values

    Example:
      If a system has 3 links, the following code replaces the mass
      of each link in the System:
      >>> sys = sys.tree_replace(
      >>>     {'link.inertia.mass', jp.array([1.0, 1.2, 1.3])})
    """
    new = self
    for k, v in params.items():
      new = _tree_replace(new, k.split('.'), v)
    return new

  @property
  def T(self):  # pylint:disable=invalid-name
    return tree_map(lambda x: x.T, self)

def _tree_replace(
    base: Base,
    attr: Sequence[str],
    val: Optional[jax.typing.ArrayLike],
) -> Base:
  """Sets attributes in a struct.dataclass with values."""
  if not attr:
    return base

  # special case for List attribute
  if len(attr) > 1 and isinstance(getattr(base, attr[0]), list):
    lst = copy.deepcopy(getattr(base, attr[0]))

    for i, g in enumerate(lst):
      if not hasattr(g, attr[1]):
        continue
      v = val if not hasattr(val, '__iter__') else val[i]
      lst[i] = _tree_replace(g, attr[1:], v)

    return base.replace(**{attr[0]: lst})

  if len(attr) == 1:
    return base.replace(**{attr[0]: val})

  return base.replace(
      **{attr[0]: _tree_replace(getattr(base, attr[0]), attr[1:], val)}
  )

@struct.dataclass
class State(Base):
  """A minimal state class (only containing mjx.Data).

  Args:
    pipeline_state: the physics state, mjx.Data
  """

  data: mjx.Data

In [3]:
def load_params(param_path: Text) -> Dict:
    with open(param_path, "rb") as file:
        params = yaml.safe_load(file)
    return params

params = load_params("params/params.yaml")
model = mujoco.MjModel.from_xml_path(params["XML_PATH"])
mjdata = mujoco.MjData(model)
model.opt.solver = mujoco.mjtSolver.mjSOL_NEWTON
model.opt.iterations = 1
model.opt.ls_iterations = 1

mjx_model = mjx.device_put(model)

"""takes state class and ctrl (action) vector, returns next step's state"""
def single_step(state, ctrl):
    data0 = state.data
    data0 = data0.replace(ctrl=ctrl)
    data = mjx.step(mjx_model, data0)
    state = state.replace(data=data)
    return state

def serial_step(vel):
    data = mujoco.MjData(model)
    print(data.qpos)
    data.qvel[0] = 0
    # qvel[0] = vel
    # data = data.replace(qvel=qvel)
    mujoco.mj_step(model, data)
    
    return data.qpos

def serial_step_mjx(vel):
    mjx_data = mjx.make_data(mjx_model)    
    print(mjx_data.qpos)
    qvel = mjx_data.qvel.at[0].set(vel)
    mjx_data = mjx_data.replace(qvel=qvel)
    mjx_data = mjx.step(mjx_model, mjx_data)
    # mjx.forward(model, mjx_data)
    
    return mjx_data.qpos

In [4]:
process_count = jax.process_count()
process_id = jax.process_index()
local_device_count = jax.local_device_count()
local_devices_to_use = local_device_count
print(
    f"Device count: {jax.device_count()}, process count: "
    f"{process_count} (id {process_id}), local device count: "
    f"{local_device_count}, devices to be used count: {local_devices_to_use}")
device_count = local_devices_to_use * process_count

Device count: 1, process count: 1 (id 0), local device count: 1, devices to be used count: 1


In [13]:
n_envs_small = 1
n_envs_large = 256
key = random.PRNGKey(0)
small_ctrl = random.uniform(key, shape=(n_envs_small, mjx_model.nu))
large_ctrl = random.uniform(key, shape=(n_envs_large, mjx_model.nu))
blah = random.uniform(key, shape=(n_envs_large, 1))

def reset(val: int) -> State:
    """Resets the environment to an initial state."""
    data = mjx.make_data(mjx_model)
    # data = data.replace(qvel=jnp.zer(mjx_model.nv, x))
    data = mjx.forward(mjx_model, data)
    return State(data)

large_ctrl.shape

(256, 30)

In [14]:
reset_fn = jax.jit(jax.vmap(reset))
single_batch_step = jax.vmap(single_step)
# returns the state object with a batch axis for each attribute in data (batch_size=n_envs_large)
env_state = reset_fn(blah)
print(env_state.data.qpos.shape)

steps = 100

(256, 74)


### one step no scan inside step function

In [15]:
jit_single_batch_step = jax.jit(single_batch_step)
def f(state ,_):
    return (jit_single_batch_step(state, large_ctrl), None)

jit_f = jit(f)

In [19]:
start_time = time.time()

env_state, times = jax.lax.scan(jit_f, env_state, (), length=steps)
    
print(f"{steps * n_envs_large} steps completed in {time.time()-start_time} seconds")

25600 steps completed in 8.15187406539917 seconds


In [17]:
jit_single_batch_step = jax.jit(single_batch_step)

In [20]:
start_time = time.time()

jit_single_batch_step(env_state, large_ctrl)
prev = time.time()
print(f"initial execution time: {prev - start_time}")
for _ in range(steps):
    env_state = jit_single_batch_step(env_state, large_ctrl)
    print(f"{time.time()-prev}")
    prev = time.time()
    
print(f"{steps * n_envs_large + steps} steps completed in {time.time()-start_time} seconds")

initial execution time: 0.08024907112121582
0.08127737045288086
0.08141756057739258
0.08145451545715332
0.08017158508300781
0.08051609992980957
0.08052468299865723
0.0788581371307373
0.0792696475982666
0.08059453964233398
0.08382916450500488
0.0826718807220459
0.08217501640319824
0.08235883712768555
0.08407402038574219
0.08199644088745117
0.08224773406982422
0.0833585262298584
0.08191037178039551
0.08121228218078613
0.08164596557617188
0.08332228660583496
0.08403611183166504
0.08244514465332031
0.08173274993896484
0.08148002624511719
0.08144259452819824
0.08344054222106934
0.0826103687286377
0.08182072639465332
0.08192324638366699
0.08421158790588379
0.08213257789611816
0.0837864875793457
0.08339691162109375
0.08277511596679688
0.08310174942016602
0.08156156539916992
0.08157515525817871
0.08267092704772949
0.08363771438598633
0.09081363677978516
0.08409976959228516
0.08292675018310547
0.0819239616394043
0.08315372467041016
0.08321428298950195
0.08463025093078613
0.08644795417785645
0.0

### In conclusion:

1. Use the first method, scanning a jitted function that takes one batch step `length` times. 
2. mjxData is not first-order vectorizable (as in you cant have an array on mjxData objects), but it still is vectorizable in that it can be taken as an argument in a vmapped function, and its attributes (qpos, qvel, ctrl, etc.) will gain a leading "batch" dimension. 

### Testing mjData vmapping

In [19]:
@vmap
def twoargs(ctrl):
    data = mjx.make_data(mjx_model)
    return data.replace(ctrl=ctrl)

@vmap
def zeroctrl(mjx_data, ctrl):
    return mjx_data.replace(ctrl=ctrl)

In [22]:
jit_twoargs = jit(twoargs)
jit_zeroctrl = jit(zeroctrl)
zeros = jnp.zeros((n_envs_large, mjx_model.nu))


In [23]:
data_1 = jit_twoargs(large_ctrl)
data_1.ctrl

Array([[9.09454703e-01, 3.68365049e-01, 7.46057868e-01, ...,
        7.32579350e-01, 3.80138040e-01, 5.04539847e-01],
       [8.52779508e-01, 2.25239992e-02, 1.83701515e-04, ...,
        5.84446311e-01, 4.56560373e-01, 3.11894655e-01],
       [1.00612044e-01, 5.34569263e-01, 5.25110483e-01, ...,
        8.11771512e-01, 7.01126814e-01, 1.95618153e-01],
       ...,
       [2.59530187e-01, 8.63910198e-01, 8.89778972e-01, ...,
        8.04281712e-01, 2.51758814e-01, 9.61673498e-01],
       [9.13880110e-01, 4.04026389e-01, 1.58157945e-01, ...,
        5.03633142e-01, 5.00825763e-01, 2.66343951e-01],
       [3.64324570e-01, 8.98343801e-01, 6.76937461e-01, ...,
        3.46303225e-01, 3.63741875e-01, 8.23707700e-01]], dtype=float32)

In [28]:
data_2 = jit_zeroctrl(data_1, zeros)
print(data_2.ctrl.shape)
data_2.ctrl

(128, 30)


Array([[0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.]], dtype=float32)