In [1]:
import mujoco
from mujoco import mjx
import jax
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random
from jax import device_put
import numpy as np 
import yaml
from typing import List, Dict, Text
import time

def load_params(param_path: Text) -> Dict:
    with open(param_path, "rb") as file:
        params = yaml.safe_load(file)
    return params

params = load_params("params/params.yaml")
model = mujoco.MjModel.from_xml_path(params["XML_PATH"])
model.opt.solver = mujoco.mjtSolver.mjSOL_NEWTON
model.opt.iterations = 3
model.opt.ls_iterations = 3

mjx_model = mjx.device_put(model)

# minimal example code--this is supposed to work
@jax.vmap
def single_batch_step(ctrl):
    mjx_data = mjx.make_data(mjx_model)
    mjx_data = mjx_data.replace(ctrl=ctrl)
    qpos = mjx.step(mjx_model, mjx_data).qpos
    return qpos

def serial_step(vel):
    data = mujoco.MjData(model)
    print(data.qpos)
    data.qvel[0] = 0
    # qvel[0] = vel
    # data = data.replace(qvel=qvel)
    mujoco.mj_step(model, data)
    
    return data.qpos

def serial_step_mjx(vel):
    mjx_data = mjx.make_data(mjx_model)    
    print(mjx_data.qpos)
    qvel = mjx_data.qvel.at[0].set(vel)
    mjx_data = mjx_data.replace(qvel=qvel)
    mjx_data = mjx.step(mjx_model, mjx_data)
    # mjx.forward(model, mjx_data)
    
    return mjx_data.qpos

def take_steps(ctrl, steps, mjx_model):
    # ctrl = network(obs)
    mjx_data = mjx.make_data(mjx_model)
    mjx_data = mjx_data.replace(ctrl=ctrl)
    def f(data, _):
      return (
          mjx.step(mjx_model, data),
          None,
      )
      
    mjx_data, _ = jax.lax.scan(f, mjx_data, (), steps)
    return mjx_data.qpos



# time single steps multiple times

In [2]:
start_time = time.time()

n_envs_small = 1
n_envs_large = 1300
key = random.PRNGKey(0)
small_ctrl = random.uniform(key, shape=(n_envs_small, mjx_model.nu))
large_ctrl = random.uniform(key, shape=(n_envs_large, mjx_model.nu))



In [None]:
jit_single_batch_step = jit(single_batch_step)
jit_single_batch_step(small_ctrl)
prev = time.time()
print(f"initial execution time: {prev - start_time}")
for _ in range(5):
    jit_single_batch_step(large_ctrl)
    print(f"{time.time()-prev}")
    prev = time.time()

In [5]:
start_time = time.time()
n_envs_small = 1
n_envs_large = 128
steps = 100

batched_steps = vmap(lambda ctrl: take_steps(ctrl, steps, mjx_model), in_axes=0)

jit_batch_step = jit(batched_steps)

batch_end_data = jit_batch_step(small_ctrl)
prev = time.time()
print(f"initial execution time: {prev - start_time}")
for _ in range(5):
    batch_end_data = jit_batch_step(large_ctrl)
    print(f"{time.time()-prev}")
    prev = time.time()

initial execution time: 112.37875866889954


2023-11-28 16:50:11.024821: W external/tsl/tsl/framework/bfc_allocator.cc:296] Allocator (GPU_0_bfc) ran out of memory trying to allocate 6.88GiB with freed_by_count=0. The caller indicates that this is not a failure, but this may mean that there could be performance gains if more memory were available.


XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 7391266416 bytes.