In [1]:
import mujoco
from mujoco import mjx
import jax
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random
from jax import device_put
import numpy as np 
import yaml
from typing import List, Dict, Text
import time

def load_params(param_path: Text) -> Dict:
    with open(param_path, "rb") as file:
        params = yaml.safe_load(file)
    return params

params = load_params("params/params.yaml")
model = mujoco.MjModel.from_xml_path(params["XML_PATH"])
model.opt.solver = mujoco.mjtSolver.mjSOL_NEWTON
model.opt.iterations = 1
model.opt.ls_iterations = 1

mjx_model = mjx.device_put(model)

# minimal example code--this is supposed to work
@jax.vmap
def single_batch_step(ctrl):
    mjx_data = mjx.make_data(mjx_model)
    mjx_data = mjx_data.replace(ctrl=ctrl)
    qpos = mjx.step(mjx_model, mjx_data).qpos
    return qpos

def serial_step(vel):
    data = mujoco.MjData(model)
    print(data.qpos)
    data.qvel[0] = 0
    # qvel[0] = vel
    # data = data.replace(qvel=qvel)
    mujoco.mj_step(model, data)
    
    return data.qpos

def serial_step_mjx(vel):
    mjx_data = mjx.make_data(mjx_model)    
    print(mjx_data.qpos)
    qvel = mjx_data.qvel.at[0].set(vel)
    mjx_data = mjx_data.replace(qvel=qvel)
    mjx_data = mjx.step(mjx_model, mjx_data)
    # mjx.forward(model, mjx_data)
    
    return mjx_data.qpos

def take_steps(ctrl, steps, mjx_model):
    # ctrl = network(obs)
    mjx_data = mjx.make_data(mjx_model)
    mjx_data = mjx_data.replace(ctrl=ctrl)
    def f(data, _):
      return (
          mjx.step(mjx_model, data),
          None,
      )
      
    mjx_data, _ = jax.lax.scan(f, mjx_data, (), steps)
    return mjx_data.qpos



# time single steps multiple times

In [5]:
start_time = time.time()

n_envs_small = 1
n_envs_large = 1400
key = random.PRNGKey(0)
small_ctrl = random.uniform(key, shape=(n_envs_small, mjx_model.nu))
large_ctrl = random.uniform(key, shape=(n_envs_large, mjx_model.nu))



In [6]:
jit_single_batch_step = jit(single_batch_step)
jit_single_batch_step(small_ctrl)
prev = time.time()
print(f"initial execution time: {prev - start_time}")
iters = 25
for _ in range(iters):
    jit_single_batch_step(large_ctrl)
    print(f"{time.time()-prev}")
    prev = time.time()
    
print(f"Steps completed: {iters * n_envs_large}")

initial execution time: 0.13068509101867676
75.88655710220337
0.002124786376953125
0.0013523101806640625
0.0010461807250976562
0.004746913909912109
0.1265885829925537
0.11447930335998535
0.11417794227600098
0.11552143096923828
0.11838054656982422
0.1202096939086914
0.11438846588134766
0.1157233715057373
0.11822628974914551
0.11577534675598145
0.11603307723999023
0.11487078666687012
0.11712050437927246
0.11706089973449707
0.11658263206481934
0.11541748046875
0.1166834831237793
0.1181631088256836
0.11561155319213867
0.11559343338012695
Steps completed: 32500


In [None]:
def loopfun():
    jax.lax.fori_loop(0, iters, single_batch_step, )

In [11]:
start_time = time.time()
n_envs_small = 1
n_envs_large = 512
steps = 10

batched_steps = vmap(lambda ctrl: take_steps(ctrl, steps, mjx_model), in_axes=0)

jit_batch_step = jit(batched_steps)

batch_end_data = jit_batch_step(small_ctrl)
prev = time.time()
print(f"initial execution time: {prev - start_time}")
def looper():
    prev = time.time()
    for _ in range(5):
        batch_end_data = jit_batch_step(large_ctrl)
        print(f"{time.time()-prev}")
        prev = time.time()
jit_looper = jit(looper)
jit_looper()

initial execution time: 94.86681985855103
34.524773836135864
0.0002276897430419922
0.00013828277587890625
0.00015544891357421875
0.00014090538024902344


In [12]:
5120/.00015

34133333.333333336