In [4]:
import mujoco
from mujoco import mjx
import jax
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random
from jax import device_put
import numpy as np 
import yaml
from typing import List, Dict, Text
import mediapy as media

def load_params(param_path: Text) -> Dict:
    with open(param_path, "rb") as file:
        params = yaml.safe_load(file)
    return params

params = load_params("params/params.yaml")
model = mujoco.MjModel.from_xml_path(params["XML_PATH"])

# minimal example code--this is supposed to work
@jax.vmap
def batched_step(vel):
    mjx_data = mjx.make_data(mjx_model)
    qvel = mjx_data.qvel.at[0].set(vel)
    mjx_data = mjx_data.replace(qvel=qvel)
    pos = mjx.step(mjx_model, mjx_data).qpos[0]
    return pos

def serial_step(vel):
    data = mujoco.MjData(model)
    print(data.qpos)
    data.qvel[0] = 0
    # qvel[0] = vel
    # data = data.replace(qvel=qvel)
    mujoco.mj_step(model, data)
    
    return data.qpos

def serial_step_mjx(vel, mjx_data):
    qvel = mjx_data.qvel.at[0].set(vel)
    mjx_data = mjx_data.replace(qvel=qvel)
    return mjx.step(mjx_model, mjx_data)
    # mjx.forward(model, mjx_data)

In [7]:
vel = jax.numpy.linspace(0.0, 0.5, 5)
mjx_model = mjx.device_put(model)
pos = jax.jit(batched_step)(vel)

print(pos)

[0.03094968 0.03104968 0.03114968 0.03124968 0.03134968]


In [8]:
import time
start_time = time.time()

for i in np.arange(0,100, step=1):    
    vel = jax.numpy.linspace(0.0, 0.2, 1000)
    mjx_model = mjx.device_put(model)
    pos = jax.jit(batched_step, backend="gpu")(vel)


end_time = time.time()
print("Time to complete 100,000 steps: " + str(end_time - start_time))


In [None]:
renderer = mujoco.Renderer(model)
def get_image(mjx_model, mjx_data, camera: str) -> np.ndarray:
  """Renders the environment state."""
  d = mujoco.MjData(model)
  # write the mjx.Data into an mjData object
  mjx.device_get_into(d, mjx_data)
  mujoco.mj_forward(model, d)
  # use the mjData object to update the renderer
  renderer.update_scene(d, camera=camera)
  return renderer.render()

In [None]:
# initialize the state
mjx_model = mjx.device_put(model)
mjx_data = mjx.make_data(mjx_model)    

rollout = [state]
images = [get_image(mjx_model, mjx_data, camera='side')]

# grab a trajectory
for i in range(10):
  ctrl = -0.1 * jp.ones(mjx_model.nu)
  state = jit_step(state, ctrl)
  rollout.append(state)
  images.append(get_image(state, camera='side'))

media.show_video(images, fps=1.0 / env.dt)