In [None]:
%env MUJOCO_GL=egl
!command -v ffmpeg >/dev/null || (apt update && apt install -y ffmpeg)
!python3 -m pip install -Uq brax mediapy mujoco mujoco_mjx

In [None]:
from sys import stderr

print("Importing NumPy...")
try:
    import numpy as np

    np.set_printoptions(precision=3, suppress=True, linewidth=100)
except ModuleNotFoundError:
    print(
        "ERROR: NumPy is not installed. It's installable via `pip` as `numpy`.",
        file=stderr,
    )
    exit(1)

print("Importing JAX...")
try:
    import jax

    jax.numpy.set_printoptions(precision=3, suppress=True, linewidth=100)
except ModuleNotFoundError:
    print(
        "ERROR: JAX is not installed. It's installable at `https://docs.jax.dev/en/latest/quickstart.html`.",
        file=stderr,
    )
    exit(1)

print("Importing MuJoCo...")
try:
    import mujoco
except ModuleNotFoundError:
    print(
        "ERROR: MuJoCo is not installed. It's installable via `pip` as `mujoco`.",
        file=stderr,
    )
    exit(1)

print("Importing MJX...")
try:
    from mujoco import mjx
except ModuleNotFoundError:
    print(
        "ERROR: MJX (JAX rewrite of MuJoCo) is not installed. It's installable via `pip` as `mujoco_mjx`.",
        file=stderr,
    )
    exit(1)

print("Importing Brax...")
try:
    import brax
except ModuleNotFoundError:
    print(
        "ERROR: Brax is not installed. It's installable via `pip` as `brax`.",
        file=stderr,
    )
    exit(1)

print("Importing MediaPy...")
try:
    import mediapy as media
except ModuleNotFoundError:
    print(
        "ERROR: MediaPy is not installed. It's installable via `pip` as `mediapy`.",
        file=stderr,
    )

print("Importing MatPlotLib...")
try:
    import matplotlib.pyplot as plt
except ModuleNotFoundError:
    print(
        "ERROR: MatPlotLib is not installed. It's installable via `pip` as `matplotlib`.",
        file=stderr,
    )

In [None]:
print("Defining constants...")

# Library imports:
from etils import epath

# from mujoco_playground._src import mjx_env

N_LEGS = 3
COMPARABLE_MUJOCO_PLAYGROUND_ENV = "Go1JoystickRoughTerrain"
GENERATED_ROBOT_MJCF_XML_PATH = "generated-robot-mjcf.xml"
GENERATED_SCENE_MJCF_XML_PATH = "generated-scene-mjcf.xml"
GENERATED_MJCF_XML_PATH = "generated-mjcf.xml"
TRAINING_STEPS = 1_000_000

# ROOT_PATH = mjx_env.ROOT_PATH / "locomotion" / "go1"
# FEET_ONLY_FLAT_TERRAIN_XML = ROOT_PATH / "xmls" / "scene_mjx_feetonly_flat_terrain.xml"
# FEET_ONLY_ROUGH_TERRAIN_XML = (
#     ROOT_PATH / "xmls" / "scene_mjx_feetonly_rough_terrain.xml"
# )
# FULL_FLAT_TERRAIN_XML = ROOT_PATH / "xmls" / "scene_mjx_flat_terrain.xml"
# FULL_COLLISIONS_FLAT_TERRAIN_XML = (
#     ROOT_PATH / "xmls" / "scene_mjx_fullcollisions_flat_terrain.xml"
# )


def task_to_xml(task_name: str) -> epath.Path:
    return {
        "flat_terrain": FEET_ONLY_FLAT_TERRAIN_XML,
        "rough_terrain": FEET_ONLY_ROUGH_TERRAIN_XML,
    }[task_name]


FEET_SITES = [f"foot #{i + 1}" for i in range(0, N_LEGS)]

FEET_GEOMS = FEET_SITES

FEET_POS_SENSOR = [f"{site} pos" for site in FEET_SITES]


ROOT_BODY = "trunk"  # "torso"

UPVECTOR_SENSOR = "upvector"
GLOBAL_LINVEL_SENSOR = "global_linvel"
GLOBAL_ANGVEL_SENSOR = "global_angvel"
LOCAL_LINVEL_SENSOR = "local_linvel"
ACCELEROMETER_SENSOR = "accelerometer"
GYRO_SENSOR = "gyro"


def inches(murica):
    return murica * 0.0254


def grams(g):
    return g * 0.001


def kg_cm(torque):
    return torque * 0.0980665


SPHERE_RADIUS = inches(2.0)
LENGTH_HIP_TO_KNEE = inches(2.0)
LENGTH_KNEE_TO_FOOT = inches(4.0)

HIP_MIN_DEGREES = -90
HIP_MAX_DEGREES = 90
KNEE_MIN_DEGREES = -60
KNEE_MAX_DEGREES = 60
LEG_YAW_MIN_DEGREES = -30
LEG_YAW_MAX_DEGREES = 30

PUSH_ROD_SPACING = inches(0.5)
LEG_RADIUS = inches(0.1)
FOOT_RADIUS = inches(0.125)
KNEE_RADIUS = LEG_RADIUS * 1.5

LEG_DENSITY = grams(1.0) / inches(1.0)
SPHERE_MASS = grams(10.0)
SERVO_MASS = grams(19.0)
FOOT_CAP_MASS = grams(1.0)
EXTRA_SPHERE_MASS_PERCENTAGE_IM_FORGETTING = 0.1

PUPIL_SIZE_RELATIVE = None  # 0.75
PUPIL_SIZE_PROTRUSION = 0.05

SERVO_TORQUE_NM = 5  # kg_cm(2.7)
SERVO_KP = 21.1  # from <https://github.com/google-deepmind/mujoco/issues/1075>: see line <https://github.com/google-deepmind/mujoco_menagerie/blob/cfd91c5605e90f0b77860ae2278ff107366acc87/robotis_op3/op3.xml#L62>

JOINT_DAMPING = 1.084
JOINT_STIFFNESS = None
JOINT_ARMATURE = 0.045
JOINT_FRICTION_LOSS = 0.03

In [None]:
# Based on the ROBOTIS op3 spec: <https://github.com/google-deepmind/mujoco_menagerie/blob/main/robotis_op3/op3.xml>


# Library imports:
import xml.etree.ElementTree as XML


print("Generating MJCF XML files...")


# Top-level robot declaration:
robot = XML.Element("mujoco", model="eye-robot")


# Robot defaults:
default = XML.SubElement(robot, "default")

XML.SubElement(
    default,
    "geom",
    type="capsule",
    size=f"{LEG_RADIUS}",
    solref=".004 1",
    contype="0",
    conaffinity="0",
)

joint_kwargs = dict()
if JOINT_DAMPING is not None:
    joint_kwargs["damping"] = f"{JOINT_DAMPING}"
if JOINT_STIFFNESS is not None:
    joint_kwargs["stiffness"] = f"{JOINT_STIFFNESS}"
if JOINT_ARMATURE is not None:
    joint_kwargs["armature"] = f"{JOINT_ARMATURE}"
if JOINT_FRICTION_LOSS is not None:
    joint_kwargs["frictionloss"] = f"{JOINT_FRICTION_LOSS}"
XML.SubElement(default, "joint", **joint_kwargs)
del joint_kwargs

XML.SubElement(
    default,
    "position",
    kp=f"{SERVO_KP}",
    forcerange=f"{-SERVO_TORQUE_NM} {SERVO_TORQUE_NM}",
)


# Top-level declaration for physical objects:
worldbody = XML.SubElement(robot, "worldbody")

root = XML.SubElement(
    worldbody, "body", name=ROOT_BODY, pos=f"0 0 {LENGTH_KNEE_TO_FOOT}"
)

XML.SubElement(root, "freejoint")

sphere = XML.SubElement(root, "body", name="sphere")
XML.SubElement(
    sphere,
    "geom",
    name="sphere",
    type="sphere",
    size=f"{SPHERE_RADIUS}",
    rgba="1 1 1 1",
    mass=f"{(SPHERE_MASS + 9 * SERVO_MASS) * (1.0 + EXTRA_SPHERE_MASS_PERCENTAGE_IM_FORGETTING)}",
)
XML.SubElement(
    sphere,
    "site",
    name="imu",
)


contact = XML.SubElement(robot, "contact")
actuator = XML.SubElement(robot, "actuator")


sensor = XML.SubElement(robot, "sensor")
XML.SubElement(sensor, "gyro", site="imu", name="gyro")
XML.SubElement(sensor, "accelerometer", site="imu", name="accelerometer")


for i in range(0, N_LEGS):
    leg_mount_to_center = XML.SubElement(
        root,
        "body",
        name=f"leg_{i}_mount_to_center",
        euler=f"0 0 {360 * i / N_LEGS}",
    )
    leg_mount = XML.SubElement(
        leg_mount_to_center,
        "body",
        name=f"leg_{i}_mount",
        pos=f"{0.5 * SPHERE_RADIUS} 0 0",
    )
    leg = XML.SubElement(leg_mount, "body", name=f"leg_{i}")
    XML.SubElement(leg, "joint", axis="0 1 0", name=f"leg_{i}_hip_joint")
    XML.SubElement(leg, "joint", axis="0 0 1", name=f"leg_{i}_yaw_joint")
    hip_to_knee = XML.SubElement(
        leg,
        "geom",
        name=f"leg_{i}_hip_to_knee",
        fromto=f"0 0 0 {LENGTH_HIP_TO_KNEE} 0 0",
        mass=f"{LEG_DENSITY * LENGTH_HIP_TO_KNEE}",
        rgba="1 0 0 1",
    )
    lower_leg = XML.SubElement(
        leg,
        "body",
        name=f"leg_{i}_lower",
        pos=f"{LENGTH_HIP_TO_KNEE} 0 0",
        euler="0 90 0",
    )
    XML.SubElement(lower_leg, "joint", axis="0 1 0", name=f"leg_{i}_knee_joint")
    knee_to_foot = XML.SubElement(
        lower_leg,
        "geom",
        name=f"leg_{i}_knee_to_foot",
        fromto=f"0 0 0 {LENGTH_KNEE_TO_FOOT} 0 0",
        mass=f"{LEG_DENSITY * LENGTH_HIP_TO_KNEE}",
        rgba="1 0 0 1",
    )
    foot = XML.SubElement(
        lower_leg,
        "body",
        name=f"leg_{i}_foot",
        pos=f"{LENGTH_KNEE_TO_FOOT} 0 0",
    )
    XML.SubElement(
        foot,
        "geom",
        type="sphere",
        size=f"{FOOT_RADIUS}",
        name=f"leg_{i}_foot",
        mass=f"{FOOT_CAP_MASS}",
        rgba="1 0 0 1",
    )
    XML.SubElement(
        foot,
        "site",
        name=f"leg_{i}_foot",
    )

    XML.SubElement(contact, "pair", geom1=f"leg_{i}_foot", geom2="floor")

    XML.SubElement(
        actuator, "position", name=f"leg_{i}_hip_joint", joint=f"leg_{i}_hip_joint"
    )
    XML.SubElement(
        actuator, "position", name=f"leg_{i}_knee_joint", joint=f"leg_{i}_knee_joint"
    )
    XML.SubElement(
        actuator, "position", name=f"leg_{i}_yaw_joint", joint=f"leg_{i}_yaw_joint"
    )

    XML.SubElement(sensor, "force", site=f"leg_{i}_foot", name=f"leg_{i}_foot_fsr")


# Save the MJCF MXL for the robot:
XML.indent(robot)
with open(GENERATED_ROBOT_MJCF_XML_PATH, "wb") as file:
    XML.ElementTree(robot).write(file, encoding="utf-8")
del robot


# Top-level scene declaration:
scene = XML.Element("mujoco", model="eye-robot")


# Focus on the center of the body. TODO: what's `extent`?
XML.SubElement(
    scene, "statistic", center=f"0 0 {LENGTH_KNEE_TO_FOOT}", extent="0.6"
)


# Rendering settings:
visual = XML.SubElement(scene, "visual")

# Headlight (light from the active camera):
XML.SubElement(
    visual, "headlight", diffuse="0.6 0.6 0.6", ambient="0.3 0.3 0.3", specular="0 0 0"
)

# Haze at the render limit:
XML.SubElement(visual, "rgba", haze="0.15 0.25 0.35 1")

# Global camera orientation:
XML.SubElement(visual, "global", azimuth="160", elevation="-20")


# Textures & materials:
asset = XML.SubElement(scene, "asset")

# Sky texture:
XML.SubElement(
    asset,
    "texture",
    type="skybox",
    builtin="gradient",
    rgb1="0.3 0.5 0.7",
    rgb2="0 0 0",
    width="512",
    height="3072",
)

# Ground plane/floor grid texture:
XML.SubElement(
    asset,
    "texture",
    type="2d",
    name="groundplane",
    builtin="checker",
    mark="edge",
    rgb1="0.2 0.3 0.4",
    rgb2="0.1 0.2 0.3",
    markrgb="0.8 0.8 0.8",
    width="300",
    height="300",
)

# Ground plane/floor material:
XML.SubElement(
    asset,
    "material",
    name="groundplane",
    texture="groundplane",
    texuniform="true",
    texrepeat="5 5",
    reflectance="0.2",
)


# Then use the textures & materials we just defined on physical things:
worldbody = XML.SubElement(scene, "worldbody")

# Sunlight:
XML.SubElement(worldbody, "light", pos="0 0 1.5", dir="0 0 -1", directional="true")

# Sunlight:
XML.SubElement(
    worldbody,
    "geom",
    name="floor",
    pos="0 0 -0.05",
    size="0 0 0.05",
    type="plane",
    material="groundplane",
)


# Save the MJCF MXL for the scene:
XML.indent(scene)
with open(GENERATED_SCENE_MJCF_XML_PATH, "wb") as file:
    XML.ElementTree(scene).write(file, encoding="utf-8")
del scene


# Top-level declaration for the combined model:
combined = XML.Element("mujoco", model="eye-robot")
XML.SubElement(combined, "include", file=GENERATED_SCENE_MJCF_XML_PATH)
XML.SubElement(combined, "include", file=GENERATED_ROBOT_MJCF_XML_PATH)


# Save the MJCF MXL for the combined model:
XML.indent(combined)
with open(GENERATED_MJCF_XML_PATH, "wb") as file:
    XML.ElementTree(combined).write(file, encoding="utf-8")
del combined

In [None]:
print("Library imports...")
from brax import envs
from brax.envs.base import PipelineEnv, State
from brax.io import mjcf, model
from brax.training.agents.ppo import train as ppo
import cv2
from datetime import datetime
import functools
import jax
from jax import numpy as jp
import mujoco
from mujoco import mjx

In [None]:
print("Building a MuJoCo model...")
mj_model = mujoco.MjModel.from_xml_path(GENERATED_MJCF_XML_PATH)
mj_data = mujoco.MjData(mj_model)
renderer = mujoco.Renderer(mj_model)

In [None]:
print("Porting model to MJX...")
mjx_model = mjx.put_model(mj_model)
mjx_data = mjx.put_data(mj_model, mj_data)

In [None]:
# enable joint visualization option:
print("Visualing with MuJoCo...")
scene_option = mujoco.MjvOption()
scene_option.flags[mujoco.mjtVisFlag.mjVIS_JOINT] = True

duration = 3.8  # (seconds)
framerate = 60  # (Hz)

video_filename = "mj.mp4"
frames = []
video_writer = cv2.VideoWriter(
    video_filename,
    cv2.VideoWriter_fourcc(*"mp4v"),
    framerate,
    renderer.render().shape[:2][::-1],
)
mujoco.mj_resetData(mj_model, mj_data)
while mj_data.time < duration:
    mujoco.mj_step(mj_model, mj_data)
    if len(frames) < mj_data.time * framerate:
        renderer.update_scene(mj_data, scene_option=scene_option)
        pixels = renderer.render()
        frames.append(pixels)
        video_writer.write(pixels[..., ::-1])
video_writer.release()

# Simulate and display video.
try:
    media.show_video(frames, fps=framerate)
except:
    print("`media` failed. If you're not in a Jupyter notebook, this is expected.")
del frames

In [None]:
print("Visualing with MJX...")
jit_step = jax.jit(mjx.step)

video_filename = "mjx.mp4"
frames = []
video_writer = cv2.VideoWriter(
    video_filename,
    cv2.VideoWriter_fourcc(*"mp4v"),
    framerate,
    renderer.render().shape[:2][::-1],
)
mujoco.mj_resetData(mj_model, mj_data)
mjx_data = mjx.put_data(mj_model, mj_data)
while mjx_data.time < duration:
    mjx_data = jit_step(mjx_model, mjx_data)
    if len(frames) < mjx_data.time * framerate:
        mj_data = mjx.get_data(mj_model, mjx_data)
        renderer.update_scene(mj_data, scene_option=scene_option)
        pixels = renderer.render()
        frames.append(pixels)
        video_writer.write(pixels[..., ::-1])
video_writer.release()

try:
    media.show_video(frames, fps=framerate)
except:
    print("`media` failed. If you're not in a Jupyter notebook, this is expected.")
del frames

In [None]:
print("Randomizing across batches...")
rng = jax.random.PRNGKey(0)
rng = jax.random.split(rng, 4096)
batch = jax.vmap(
    lambda rng: mjx_data.replace(qpos=jax.random.uniform(rng, (mjx_model.nq,)))
)(rng)

jit_step = jax.jit(jax.vmap(mjx.step, in_axes=(None, 0)))
batch = jit_step(mjx_model, batch)

In [None]:
print("Porting randomized batches to MJX...")
batched_mj_data = mjx.get_data(mj_model, batch)

In [None]:
print("Defining a MuJoCo/Brax pipeline environment...")


class Humanoid(PipelineEnv):

    def __init__(
        self,
        forward_reward_weight=1.25,
        ctrl_cost_weight=0.1,
        healthy_reward=5.0,
        terminate_when_unhealthy=True,
        reset_noise_scale=1e-2,
        exclude_current_positions_from_observation=True,
        **kwargs,
    ):
        #
        mj_model = mujoco.MjModel.from_xml_path(GENERATED_MJCF_XML_PATH)
        mj_model.opt.solver = mujoco.mjtSolver.mjSOL_CG
        mj_model.opt.iterations = 6
        mj_model.opt.ls_iterations = 6

        sys = mjcf.load_model(mj_model)

        physics_steps_per_control_step = 5
        kwargs["n_frames"] = kwargs.get("n_frames", physics_steps_per_control_step)
        kwargs["backend"] = "mjx"

        super().__init__(sys, **kwargs)

        self._forward_reward_weight = forward_reward_weight
        self._ctrl_cost_weight = ctrl_cost_weight
        self._healthy_reward = healthy_reward
        self._terminate_when_unhealthy = terminate_when_unhealthy
        self._reset_noise_scale = reset_noise_scale
        self._exclude_current_positions_from_observation = (
            exclude_current_positions_from_observation
        )

    def reset(self, rng: jp.ndarray) -> State:
        """Resets the environment to an initial state."""
        rng, rng1, rng2 = jax.random.split(rng, 3)

        low, hi = -self._reset_noise_scale, self._reset_noise_scale
        qpos = self.sys.qpos0 + jax.random.uniform(
            rng1, (self.sys.nq,), minval=low, maxval=hi
        )
        qvel = jax.random.uniform(rng2, (self.sys.nv,), minval=low, maxval=hi)

        data = self.pipeline_init(qpos, qvel)

        obs = self._get_obs(data, jp.zeros(self.sys.nu))
        reward, done, zero = jp.zeros(3)
        metrics = {
            "forward_reward": zero,
            "reward_linvel": zero,
            "reward_quadctrl": zero,
            "reward_alive": zero,
            "x_position": zero,
            "y_position": zero,
            "distance_from_origin": zero,
            "x_velocity": zero,
            "y_velocity": zero,
        }
        return State(data, obs, reward, done, metrics)

    def step(self, state: State, action: jp.ndarray) -> State:
        """Runs one timestep of the environment's dynamics."""
        data0 = state.pipeline_state
        data = self.pipeline_step(data0, action)

        com_before = data0.subtree_com[1]
        com_after = data.subtree_com[1]
        velocity = (com_after - com_before) / self.dt
        forward_reward = self._forward_reward_weight * velocity[0]

        is_healthy = jp.where(data.q[2] < SPHERE_RADIUS, 0.0, 1.0)
        if self._terminate_when_unhealthy:
            healthy_reward = self._healthy_reward
        else:
            healthy_reward = self._healthy_reward * is_healthy

        ctrl_cost = self._ctrl_cost_weight * jp.sum(jp.square(action))

        obs = self._get_obs(data, action)
        reward = forward_reward + healthy_reward - ctrl_cost
        done = 1.0 - is_healthy if self._terminate_when_unhealthy else 0.0
        state.metrics.update(
            forward_reward=forward_reward,
            reward_linvel=forward_reward,
            reward_quadctrl=-ctrl_cost,
            reward_alive=healthy_reward,
            x_position=com_after[0],
            y_position=com_after[1],
            distance_from_origin=jp.linalg.norm(com_after),
            x_velocity=velocity[0],
            y_velocity=velocity[1],
        )

        return state.replace(pipeline_state=data, obs=obs, reward=reward, done=done)

    def _get_obs(self, data: mjx.Data, action: jp.ndarray) -> jp.ndarray:
        """Observes humanoid body position, velocities, and angles."""
        position = data.qpos
        if self._exclude_current_positions_from_observation:
            position = position[2:]

        # external_contact_forces are excluded
        return jp.concatenate(
            [
                position,
                data.qvel,
                data.cinert[1:].ravel(),
                data.cvel[1:].ravel(),
                data.qfrc_actuator,
            ]
        )


print("Registering that environment...")
env_name = "humanoid"
envs.register_environment(env_name, Humanoid)

In [None]:
# instantiate the environment
print("Instantiating that environment...")
env = envs.get_environment(env_name)

# define the jit reset/step functions
print("Defining (but not yet compiling) JIT-compiled `reset` and `swap`...")
jit_reset = jax.jit(env.reset)
jit_step = jax.jit(env.step)

In [None]:
# initialize the state
print("Initializing state...")
state = jit_reset(jax.random.PRNGKey(0))
rollout = [state.pipeline_state]

# grab a trajectory
print("Computing a trajectory...")
for i in range(10):
    ctrl = -0.1 * jp.ones(env.sys.nu)
    state = jit_step(state, ctrl)
    rollout.append(state.pipeline_state)

print("Visualizing that trajectory...")
trajectory_visualized = env.render(rollout)
framerate = int(1.0 / env.dt)
try:
    media.show_video(trajectory_visualized, fps=framerate)
except:
    print("`media` failed. If you're not in a Jupyter notebook, this is expected.")
video_filename = "trajectory.mp4"
video_writer = cv2.VideoWriter(
    video_filename,
    cv2.VideoWriter_fourcc(*"mp4v"),
    framerate,
    trajectory_visualized[0].shape[:2][::-1],
)
for pixels in trajectory_visualized:
    video_writer.write(pixels[..., ::-1])
video_writer.release()
del trajectory_visualized

In [None]:
print("Setting up a training loop...")
train_fn = functools.partial(
    ppo.train,
    num_timesteps=20_000_000,
    num_evals=5,
    reward_scaling=0.1,
    episode_length=1000,
    normalize_observations=True,
    action_repeat=1,
    unroll_length=10,
    num_minibatches=24,
    num_updates_per_batch=8,
    discounting=0.97,
    learning_rate=3e-4,
    entropy_cost=1e-3,
    num_envs=3072,
    batch_size=512,
    seed=0,
)


x_data = []
y_data = []
ydataerr = []
start_time = datetime.now()
jit_time = None

max_y, min_y = 13000, 0


def progress(num_steps, metrics):
    global jit_time
    global x_data
    global y_data
    global y_dataerr

    if jit_time is None:
        jit_time = datetime.now()
        print(f"Time to JIT: {jit_time - start_time}")

    x_data.append(num_steps)
    y_data.append(metrics["eval/episode_reward"])
    ydataerr.append(metrics["eval/episode_reward_std"])

    try:
        plt.xlim([0, train_fn.keywords["num_timesteps"] * 1.25])
        plt.ylim([min_y, max_y])

        plt.xlabel("# environment steps")
        plt.ylabel("reward per episode")
        plt.title(f"y={y_data[-1]:.3f}")

        plt.errorbar(x_data, y_data, yerr=ydataerr)
        plt.show()
    except:
        print("`plt` failed. If you're not in a Jupyter notebook, this is expected.")


print("Training...")
make_inference_fn, params, _ = train_fn(environment=env, progress_fn=progress)

print(f"Time to train (after JIT-compiling): {datetime.now() - jit_time}")

In [None]:
model_path = "./mjx_brax_policy"
model.save_params(model_path, params)

In [None]:
params = model.load_params(model_path)

inference_fn = make_inference_fn(params)
jit_inference_fn = jax.jit(inference_fn)

In [None]:
eval_env = envs.get_environment(env_name)

jit_reset = jax.jit(eval_env.reset)
jit_step = jax.jit(eval_env.step)

In [None]:
# initialize the state
rng = jax.random.PRNGKey(0)
state = jit_reset(rng)
rollout = [state.pipeline_state]

# grab a trajectory
n_steps = 500
render_every = 2

for i in range(n_steps):
    act_rng, rng = jax.random.split(rng)
    ctrl, _ = jit_inference_fn(state.obs, act_rng)
    state = jit_step(state, ctrl)
    rollout.append(state.pipeline_state)

    if state.done:
        break

try:
    media.show_video(
        env.render(rollout[::render_every]),
        fps=1.0 / env.dt / render_every,
    )
except:
    print("`media` failed. If you're not in a Jupyter notebook, this is expected.")

In [None]:
mj_model = eval_env.sys.mj_model
mj_data = mujoco.MjData(mj_model)

renderer = mujoco.Renderer(mj_model)
ctrl = jp.zeros(mj_model.nu)

images = []
for i in range(n_steps):
    act_rng, rng = jax.random.split(rng)

    obs = eval_env._get_obs(mjx.put_data(mj_model, mj_data), ctrl)
    ctrl, _ = jit_inference_fn(obs, act_rng)

    mj_data.ctrl = ctrl
    for _ in range(eval_env._n_frames):
        mujoco.mj_step(mj_model, mj_data)  # Physics step using MuJoCo mj_step.

    if i % render_every == 0:
        renderer.update_scene(mj_data)
        images.append(renderer.render())

try:
    media.show_video(images, fps=1.0 / eval_env.dt / render_every)
except:
    print("`media` failed. If you're not in a Jupyter notebook, this is expected.")