In [106]:
import mujoco
from mujoco import mjx
import jax
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random
from jax import device_put
import numpy as np 
import yaml
from typing import List, Dict, Text, Any, Sequence, Union, Optional
import time
import functools
import copy
from flax import struct
import logging

### Brax base and state classes

In [107]:
from jax.tree_util import tree_map

@struct.dataclass
class Base:
  """Base functionality extending all brax types.

  These methods allow for brax types to be operated like arrays/matrices.
  """

  def __add__(self, o: Any) -> Any:
    return tree_map(lambda x, y: x + y, self, o)

  def __sub__(self, o: Any) -> Any:
    return tree_map(lambda x, y: x - y, self, o)

  def __mul__(self, o: Any) -> Any:
    return tree_map(lambda x: x * o, self)

  def __neg__(self) -> Any:
    return tree_map(lambda x: -x, self)

  def __truediv__(self, o: Any) -> Any:
    return tree_map(lambda x: x / o, self)

  def reshape(self, shape: Sequence[int]) -> Any:
    return tree_map(lambda x: x.reshape(shape), self)

  def select(self, o: Any, cond: jax.Array) -> Any:
    return tree_map(lambda x, y: (x.T * cond + y.T * (1 - cond)).T, self, o)

  def slice(self, beg: int, end: int) -> Any:
    return tree_map(lambda x: x[beg:end], self)

  def take(self, i, axis=0) -> Any:
    return tree_map(lambda x: jnp.take(x, i, axis=axis, mode='wrap'), self)

  def concatenate(self, *others: Any, axis: int = 0) -> Any:
    return tree_map(lambda *x: jnp.concatenate(x, axis=axis), self, *others)

  def index_set(
      self, idx: Union[jax.Array, Sequence[jax.Array]], o: Any
  ) -> Any:
    return tree_map(lambda x, y: x.at[idx].set(y), self, o)

  def index_sum(
      self, idx: Union[jax.Array, Sequence[jax.Array]], o: Any
  ) -> Any:
    return tree_map(lambda x, y: x.at[idx].add(y), self, o)

  def vmap(self, in_axes=0, out_axes=0):
    """Returns an object that vmaps each follow-on instance method call."""

    # TODO: i think this is kinda handy, but maybe too clever?

    outer_self = self

    class VmapField:
      """Returns instance method calls as vmapped."""

      def __init__(self, in_axes, out_axes):
        self.in_axes = [in_axes]
        self.out_axes = [out_axes]

      def vmap(self, in_axes=0, out_axes=0):
        self.in_axes.append(in_axes)
        self.out_axes.append(out_axes)
        return self

      def __getattr__(self, attr):
        fun = getattr(outer_self.__class__, attr)
        # load the stack from the bottom up
        vmap_order = reversed(list(zip(self.in_axes, self.out_axes)))
        for in_axes, out_axes in vmap_order:
          fun = vmap(fun, in_axes=in_axes, out_axes=out_axes)
        fun = functools.partial(fun, outer_self)
        return fun

    return VmapField(in_axes, out_axes)

  def tree_replace(
      self, params: Dict[str, Optional[jax.typing.ArrayLike]]
  ) -> 'Base':
    """Creates a new object with parameters set.

    Args:
      params: a dictionary of key value pairs to replace

    Returns:
      data clas with new values

    Example:
      If a system has 3 links, the following code replaces the mass
      of each link in the System:
      >>> sys = sys.tree_replace(
      >>>     {'link.inertia.mass', jp.array([1.0, 1.2, 1.3])})
    """
    new = self
    for k, v in params.items():
      new = _tree_replace(new, k.split('.'), v)
    return new

  @property
  def T(self):  # pylint:disable=invalid-name
    return tree_map(lambda x: x.T, self)

def _tree_replace(
    base: Base,
    attr: Sequence[str],
    val: Optional[jax.typing.ArrayLike],
) -> Base:
  """Sets attributes in a struct.dataclass with values."""
  if not attr:
    return base

  # special case for List attribute
  if len(attr) > 1 and isinstance(getattr(base, attr[0]), list):
    lst = copy.deepcopy(getattr(base, attr[0]))

    for i, g in enumerate(lst):
      if not hasattr(g, attr[1]):
        continue
      v = val if not hasattr(val, '__iter__') else val[i]
      lst[i] = _tree_replace(g, attr[1:], v)

    return base.replace(**{attr[0]: lst})

  if len(attr) == 1:
    return base.replace(**{attr[0]: val})

  return base.replace(
      **{attr[0]: _tree_replace(getattr(base, attr[0]), attr[1:], val)}
  )

@struct.dataclass
class State(Base):
  """A minimal state class (only containing mjx.Data).

  Args:
    pipeline_state: the physics state, mjx.Data
  """

  data: mjx.Data

In [108]:
def load_params(param_path: Text) -> Dict:
    with open(param_path, "rb") as file:
        params = yaml.safe_load(file)
    return params

params = load_params("params/params.yaml")
model = mujoco.MjModel.from_xml_path(params["XML_PATH"])
mjdata = mujoco.MjData(model)
model.opt.solver = mujoco.mjtSolver.mjSOL_NEWTON
model.opt.iterations = 1
model.opt.ls_iterations = 1

mjx_model = mjx.device_put(model)

"""takes state class and ctrl (action) vector, returns next step's state"""
@jax.vmap
def single_batch_step(state, ctrl):
    data = state.data
    data = data.replace(ctrl=ctrl)
    data = mjx.step(mjx_model, data)
    state.replace(data=data)
    return state

def serial_step(vel):
    data = mujoco.MjData(model)
    print(data.qpos)
    data.qvel[0] = 0
    # qvel[0] = vel
    # data = data.replace(qvel=qvel)
    mujoco.mj_step(model, data)
    
    return data.qpos

def serial_step_mjx(vel):
    mjx_data = mjx.make_data(mjx_model)    
    print(mjx_data.qpos)
    qvel = mjx_data.qvel.at[0].set(vel)
    mjx_data = mjx_data.replace(qvel=qvel)
    mjx_data = mjx.step(mjx_model, mjx_data)
    # mjx.forward(model, mjx_data)
    
    return mjx_data.qpos

In [109]:
process_count = jax.process_count()
process_id = jax.process_index()
local_device_count = jax.local_device_count()
local_devices_to_use = local_device_count
print(
    f"Device count: {jax.device_count()}, process count: "
    f"{process_count} (id {process_id}), local device count: "
    f"{local_device_count}, devices to be used count: {local_devices_to_use}")
device_count = local_devices_to_use * process_count

Device count: 1, process count: 1 (id 0), local device count: 1, devices to be used count: 1


In [110]:
n_envs_small = 1
n_envs_large = 2048
key = random.PRNGKey(0)
small_ctrl = random.uniform(key, shape=(n_envs_small, mjx_model.nu))
large_ctrl = random.uniform(key, shape=(n_envs_large, mjx_model.nu))
blah = random.uniform(key, shape=(n_envs_large, 1))

def reset(val: int) -> State:
    """Resets the environment to an initial state."""
    x = val/2
    data = mjx.make_data(mjx_model)
    data = data.replace(qvel=jnp.full(mjx_model.nv, x))
    data = mjx.forward(mjx_model, data)
    return State(data), x

In [113]:
reset_fn = jax.jit(jax.vmap(reset))
# blah = jnp.full((n_envs_large, 1), 0)
print(blah.shape, blah[0])

# blah = jnp.reshape(blah,
#                          (local_devices_to_use, -1) + blah.shape[1:])
# print(blah.shape)
env_states, x = reset_fn(blah)
print(type(x), type(env_states))


(2048, 1) [0.01946783]
<class 'jaxlib.xla_extension.ArrayImpl'> <class '__main__.State'>


### one step no scan inside step function

In [None]:
start_time = time.time()

jit_single_batch_step = jit(single_batch_step)
jit_single_batch_step(large_batch_state, large_ctrl)
prev = time.time()
print(f"initial execution time: {prev - start_time}")
iters = 25
for _ in range(iters):
    jit_single_batch_step(large_ctrl)
    print(f"{time.time()-prev}")
    prev = time.time()
    
print(f"Steps completed: {iters * n_envs_large}")

In [42]:
start_time = time.time()

jit_single_batch_step = jit(single_batch_step)
jit_single_batch_step(large_ctrl)
prev = time.time()
print(f"initial execution time: {prev - start_time}")
iters = 25
for _ in range(iters):
    jit_single_batch_step(large_ctrl)
    print(f"{time.time()-prev}")
    prev = time.time()
    
print(f"Steps completed: {iters * n_envs_large}")

initial execution time: 11.462520122528076
0.008361577987670898
0.006643056869506836
0.007374763488769531
0.00690460205078125
0.006705284118652344
0.007718801498413086
0.006632089614868164
0.008097410202026367
0.006530046463012695
0.007310390472412109
0.006793498992919922
0.006508588790893555
0.0077953338623046875
0.006516695022583008
0.007936954498291016
0.0061798095703125
0.005064249038696289
0.005731105804443359
0.004906415939331055
0.004903316497802734
0.0060842037200927734
0.0050351619720458984
0.004934787750244141
0.006151437759399414
0.00493621826171875
Steps completed: 51200


In [31]:
# this is not working properly..
@jax.vmap
def single_batch_step_for(ctrl):
    mjx_data = mjx.make_data(mjx_model)
    mjx_data = mjx_data.replace(ctrl=ctrl)
    mjx.step(mjx_model, mjx_data)
    return ctrl

def loopfun(iters):
    return jax.lax.fori_loop(0, iters, lambda i, x: single_batch_step_for(x), large_ctrl)
    
jit_loopfun = jit(loopfun)

start = time.time()
print(jit_loopfun(50))
one = time.time()
print(f"one: {one-start}")
print(jit_loopfun(50))
print(f"two: {time.time()-one}")


[[0.6169447  0.7502353  0.3459301  ... 0.51863337 0.58932376 0.55449045]
 [0.47548127 0.7241138  0.3235253  ... 0.24412751 0.3680569  0.7035271 ]
 [0.6749189  0.26419532 0.22431946 ... 0.65006864 0.8245549  0.13341737]
 ...
 [0.2089827  0.32239807 0.6633911  ... 0.35540724 0.5210892  0.6393193 ]
 [0.3341124  0.7145281  0.17795193 ... 0.37987053 0.28285062 0.5811237 ]
 [0.7939482  0.67065954 0.08375716 ... 0.11895394 0.99804735 0.4388287 ]]
one: 46.03626227378845
[[0.6169447  0.7502353  0.3459301  ... 0.51863337 0.58932376 0.55449045]
 [0.47548127 0.7241138  0.3235253  ... 0.24412751 0.3680569  0.7035271 ]
 [0.6749189  0.26419532 0.22431946 ... 0.65006864 0.8245549  0.13341737]
 ...
 [0.2089827  0.32239807 0.6633911  ... 0.35540724 0.5210892  0.6393193 ]
 [0.3341124  0.7145281  0.17795193 ... 0.37987053 0.28285062 0.5811237 ]
 [0.7939482  0.67065954 0.08375716 ... 0.11895394 0.99804735 0.4388287 ]]
two: 0.006846427917480469


# a different approach

The vectorized function takes multiple sequential steps instead of just one

In [25]:
def take_steps(ctrl, steps, mjx_model):
    mjx_data = mjx.make_data(mjx_model)
    mjx_data = mjx_data.replace(ctrl=ctrl)
    def f(data, _):
      return (
          mjx.step(mjx_model, data),
          None,
      )
      
    mjx_data, _ = jax.lax.scan(f, mjx_data, (), steps)
    return mjx_data.qpos

In [26]:
start_time = time.time()
n_envs_small = 1
n_envs_large = 512
steps = 10

batched_steps = vmap(lambda ctrl: take_steps(ctrl, steps, mjx_model), in_axes=0)

jit_batch_step = jit(batched_steps)

batch_end_data = jit_batch_step(small_ctrl)
prev = time.time()
print(f"initial execution time: {prev - start_time}")
def looper():
    prev = time.time()
    for _ in range(5):
        batch_end_data = jit_batch_step(large_ctrl)
        print(f"{time.time()-prev}")
        prev = time.time()
jit_looper = jit(looper)
jit_looper()

initial execution time: 97.75418162345886
35.59422421455383
0.0002770423889160156
0.00016260147094726562
0.000225067138671875
0.00015163421630859375


In [32]:
5120/.00015

34133333.333333336