# Import   

In [1]:
import os
import subprocess

from datetime import datetime
from etils import epath
import functools
from IPython.display import HTML
from typing import Any, Dict, Sequence, Tuple, Union
import os


import jax
from jax import numpy as jp
import numpy as np
from matplotlib import pyplot as plt
import mediapy as media

import mujoco
from mujoco import mjx

import time
import itertools
import numpy as np
from typing import Callable, NamedTuple, Optional, Union, List


import mediapy as media
import matplotlib.pyplot as plt

# More legible printing from numpy.
np.set_printoptions(precision=3, suppress=True, linewidth=100)

In [13]:
# Make model, data, and renderer
mj_model = mujoco.MjModel.from_xml_path("XML/basic.xml")
mj_model = mujoco.MjModel.from_xml_path("XML/robot.xml")

mj_data = mujoco.MjData(mj_model)
renderer = mujoco.Renderer(mj_model)

mjx_model = mjx.put_model(mj_model)
mjx_data = mjx.put_data(mj_model, mj_data)

Below, we print the qpos from MuJoCo and MJX. Notice that the qpos for the mjData is a numpy array living on the CPU, while the qpos for mjx.Data is a JAX Array living on the GPU/CPU device.

In [14]:
# qpos: generalized position of a robot / simulation model
print(mj_data.qpos, type(mj_data.qpos))
print(mjx_data.qpos, type(mjx_data.qpos), mjx_data.qpos.devices())

[0.    0.    1.    1.    0.    0.    0.    0.1   0.    1.5   0.707 0.    0.707 0.    0.1   0.
 2.    1.    0.    0.    0.   ] <class 'numpy.ndarray'>
[0.    0.    1.    1.    0.    0.    0.    0.1   0.    1.5   0.707 0.    0.707 0.    0.1   0.
 2.    1.    0.    0.    0.   ] <class 'jaxlib.xla_extension.ArrayImpl'> {CpuDevice(id=0)}


### Simulation run via normal MuJoCo

In [17]:
# enable joint visualization option:
scene_option = mujoco.MjvOption()
scene_option.flags[mujoco.mjtVisFlag.mjVIS_JOINT] = True

duration = 2  # (seconds)
framerate = 60  # (Hz)

frames = []
mujoco.mj_resetData(mj_model, mj_data)
while mj_data.time < duration:
    mujoco.mj_step(mj_model, mj_data)
    if len(frames) < mj_data.time * framerate:
        renderer.update_scene(mj_data, scene_option=scene_option)
        pixels = renderer.render()
        frames.append(pixels)


# Simulate and display video.
media.show_video(frames, fps=framerate)

0
This browser does not support the video tag.


## Simulation run via MuJoCo-mjx on cuda/cpu

In [19]:
jit_step = jax.jit(mjx.step)

duration = 2  # (seconds)
framerate = 60  # (Hz)

frames = []
mujoco.mj_resetData(mj_model, mj_data)
mjx_data = mjx.put_data(mj_model, mj_data)
while mjx_data.time < duration:
    mjx_data = jit_step(mjx_model, mjx_data)
    if len(frames) < mjx_data.time * framerate:
        mj_data = mjx.get_data(mj_model, mjx_data)
        renderer.update_scene(mj_data, scene_option=scene_option)
        pixels = renderer.render()
        frames.append(pixels)

media.show_video(frames, fps=framerate)

0
This browser does not support the video tag.
