# Rollout rgb_ppo ONNX in MuJoCo (AirbotPlayPickCube)

This notebook runs an exported **rgb_ppo** policy (ONNX) inside a MuJoCo XML (`mjx_single_cube.xml`) and records a video.

**Default run directory:** `/data/user/junzhe/code/mjpl/mujoco_playground/learning/runs/rgbppo_AirbotPlayPickCube__1__1766164466`

Notes:
- The notebook **auto-detects ONNX input signature**:
  - vector obs: `(B, D)`
  - image-only: `(B, H, W, C)` or `(B, C, H, W)`
  - image+state: two inputs (one 4D image tensor + one 2D state tensor)
- For **pure-vision**, you should export a **rgb-only** ONNX (no `state` input). If your ONNX expects `state`, this notebook can feed zeros by default, but that usually degrades performance unless your `state` matches training.


In [1]:
import os
# --- Paths (edit if needed) ---
RUN_DIR = r"/data/user/junzhe/code/mjpl/mujoco_playground/learning/runs/rgbppo_multi_view_AirbotPlayPickCube__1__1766319383"
XML_PATH = r"/data/user/junzhe/code/mjpl/mujoco_playground/mujoco_playground/_src/manipulation/airbot_play/xmls/mjx_single_cube.xml"

# Pick ONE of these (prefer rgb-only for "pure vision"):
ONNX_RGB_ONLY = os.path.join(RUN_DIR, "policy_rgb.onnx")
ONNX_RGB_STATE = os.path.join(RUN_DIR, "policy_rgb_state.onnx")

ONNX_PATH = ONNX_RGB_ONLY if os.path.exists(ONNX_RGB_ONLY) else ONNX_RGB_STATE
print("Using ONNX_PATH =", ONNX_PATH)


Using ONNX_PATH = /data/user/junzhe/code/mjpl/mujoco_playground/learning/runs/rgbppo_multi_view_AirbotPlayPickCube__1__1766319383/policy_rgb_state.onnx


In [2]:
# IMPORTANT: set MUJOCO_GL before importing mujoco
import os
os.environ.setdefault("MUJOCO_GL", "egl")

import numpy as np
np.set_printoptions(precision=3, suppress=True, linewidth=200)

import mujoco
import mediapy
import onnxruntime as ort
from PIL import Image


## Inspect ONNX inputs/outputs

In [3]:
assert os.path.exists(XML_PATH), f"XML not found: {XML_PATH}"
assert os.path.exists(ONNX_PATH), f"ONNX not found: {ONNX_PATH}"

sess = ort.InferenceSession(ONNX_PATH, providers=["CPUExecutionProvider"])
print("Outputs:")
for o in sess.get_outputs():
    print(" ", o.name, o.shape, o.type)
print("Inputs:")
for i in sess.get_inputs():
    print(" ", i.name, i.shape, i.type)

inputs = sess.get_inputs()
outputs = sess.get_outputs()
out_name = outputs[0].name


Outputs:
  action ['batch', 7] tensor(float)
Inputs:
  rgb ['batch', 128, 128, 6] tensor(uint8)
  state ['batch', 50] tensor(float)


## Helpers: resize/render + ONNX feed

In [4]:
def _shape_len(x):
    try:
        return len(x)
    except Exception:
        return None

def _is_image_shape(shape):
    # Accept common 4D forms: (B,H,W,C) or (B,C,H,W)
    return _shape_len(shape) == 4

def _is_vector_shape(shape):
    return _shape_len(shape) == 2

def _infer_io_layout(input_info):
    """Return dict describing expected inputs."""
    infos = [{"name": i.name, "shape": i.shape, "type": i.type} for i in input_info]
    image_inputs = [x for x in infos if _is_image_shape(x["shape"])]
    vec_inputs = [x for x in infos if _is_vector_shape(x["shape"])]

    layout = {"mode": None, "image": None, "state": None, "vector": None}
    if len(infos) == 1 and image_inputs:
        layout["mode"] = "image_only"
        layout["image"] = image_inputs[0]
    elif len(infos) == 2 and len(image_inputs) == 1 and len(vec_inputs) == 1:
        layout["mode"] = "image_state"
        layout["image"] = image_inputs[0]
        layout["state"] = vec_inputs[0]
    elif len(infos) == 1 and vec_inputs:
        layout["mode"] = "vector_only"
        layout["vector"] = vec_inputs[0]
    else:
        layout["mode"] = "unknown"
    return layout

layout = _infer_io_layout(inputs)
print("Detected ONNX layout:", layout["mode"])


Detected ONNX layout: image_state


In [5]:
def render_rgb(renderer: mujoco.Renderer, data: mujoco.MjData, camera: str) -> np.ndarray:
    """Return RGB uint8 image in HWC."""
    renderer.update_scene(data, camera=camera)
    img = renderer.render()
    if img.dtype != np.uint8:
        img = np.clip(img, 0, 255).astype(np.uint8)
    return img

def resize_hwc_uint8(img: np.ndarray, H: int, W: int) -> np.ndarray:
    """Resize HWC uint8 to (H,W,3)."""
    pil = Image.fromarray(img)
    pil = pil.resize((W, H), resample=Image.BILINEAR)
    out = np.asarray(pil, dtype=np.uint8)
    if out.ndim == 2:
        out = np.repeat(out[..., None], 3, axis=2)
    if out.shape[-1] == 4:
        out = out[..., :3]
    return out

def to_onnx_image(img_hwc_u8: np.ndarray, image_input) -> np.ndarray:
    """Pack image into expected ONNX image tensor (batch included)."""
    shape = image_input["shape"]  # e.g., [None, H, W, 3] or [None, 3, H, W]
    typ = image_input["type"]

    # Determine expected H/W/C positions
    # If second dim is 3, assume NCHW; else assume NHWC.
    is_nchw = (shape[1] == 3) if (len(shape) == 4 and isinstance(shape[1], int)) else False

    x = img_hwc_u8
    if is_nchw:
        x = np.transpose(x, (2, 0, 1))  # CHW
    x = x[None, ...]  # add batch

    # Cast based on input type
    if "uint8" in typ:
        return x.astype(np.uint8)
    else:
        # Most exported graphs cast to float internally anyway; float32 is safe.
        return x.astype(np.float32)

def to_onnx_state(state_vec: np.ndarray, state_input) -> np.ndarray:
    typ = state_input["type"]
    x = state_vec[None, :]
    if "float" in typ:
        return x.astype(np.float32)
    return x.astype(np.float32)


## (Optional) Vector obs builder
Only used if your ONNX was accidentally exported as a vector-only policy.

In [6]:
_ARM_JOINTS = ["joint1", "joint2", "joint3", "joint4", "joint5", "joint6"]
_FINGER_JOINTS = ["endleft", "endright"]

def build_vector_obs(model: mujoco.MjModel, data: mujoco.MjData) -> np.ndarray:
    """A best-effort vector obs (may NOT match training)."""
    gripper_site_id = model.site("endpoint").id
    obj_body_id = model.body("box").id

    # Joint indices
    all_joints = _ARM_JOINTS + _FINGER_JOINTS
    robot_qpos_adr = np.array([model.jnt_qposadr[model.joint(j).id] for j in all_joints])

    gripper_pos = data.site_xpos[gripper_site_id]
    gripper_mat = data.site_xmat[gripper_site_id].reshape(9)

    obj_xmat = data.xmat[obj_body_id].reshape(9)
    obj_xpos = data.xpos[obj_body_id]

    # WARNING: This likely differs from your MJX training obs. Use only if your ONNX is vector-only.
    obs = np.concatenate([
        data.qpos[robot_qpos_adr],
        data.qvel[robot_qpos_adr],
        gripper_pos,
        gripper_mat,
        obj_xmat,
        obj_xpos - gripper_pos,
    ]).astype(np.float32)
    return obs


## Rollout

In [9]:
# Load MuJoCo model
model = mujoco.MjModel.from_xml_path(XML_PATH)
data = mujoco.MjData(model)

# Render setup (this is only for visualization; policy input will be resized as needed)
cameras = ["side", "front"]
vis_h, vis_w = 480, 640
renderer = mujoco.Renderer(model, height=vis_h, width=vis_w)

# Control/sim timing
ctrl_dt = 0.02
sim_dt = 0.005
n_substeps = int(round(ctrl_dt / sim_dt))
model.opt.timestep = sim_dt

# Action scaling (delta-pos style). You may need to tune this.
action_scale = 0.04

# Determine policy image size if needed
if layout["mode"] in ("image_only", "image_state"):
    img_in = layout["image"]
    sh = img_in["shape"]
    # infer expected H,W from shape
    if sh[1] == 3:  # NCHW
        H = int(sh[2]); W = int(sh[3])
    else:           # NHWC
        H = int(sh[1]); W = int(sh[2])
    policy_H, policy_W = H, W
    print("Policy expects image:", (policy_H, policy_W))
else:
    policy_H, policy_W = None, None

# Determine state dim if needed
if layout["mode"] == "image_state":
    state_dim = int(layout["state"]["shape"][1])
    print("Policy expects state_dim:", state_dim)
else:
    state_dim = None

def render_multi_view_input():
    images = []
    for cam in cameras:
        img = render_rgb(renderer, data, camera=cam)
        img_rs = resize_hwc_uint8(img, policy_H, policy_W)
        images.append(img_rs)
    if not images:
        raise RuntimeError("No cameras configured for rendering.")
    return np.concatenate(images, axis=2)

def policy_step():
    """Compute action from current sim state."""
    feed = {}
    if layout["mode"] == "image_only":
        img_rs = render_multi_view_input()
        feed[layout["image"]["name"]] = to_onnx_image(img_rs, layout["image"])
    elif layout["mode"] == "image_state":
        img_rs = render_multi_view_input()
        feed[layout["image"]["name"]] = to_onnx_image(img_rs, layout["image"])

        # Default: feed zeros. Replace this with a training-consistent state vector if you trained with state.
        state = np.zeros((state_dim,), dtype=np.float32)
        feed[layout["state"]["name"]] = to_onnx_state(state, layout["state"])
    elif layout["mode"] == "vector_only":
        obs = build_vector_obs(model, data)
        feed[layout["vector"]["name"]] = obs[None, :].astype(np.float32)
    else:
        raise RuntimeError("Unknown ONNX input layout; cannot run. Check sess.get_inputs().")

    action = sess.run([out_name], feed)[0][0].astype(np.float32)
    return action

# Reset to home keyframe if available
mujoco.mj_resetData(model, data)
key_id = mujoco.mj_name2id(model, mujoco.mjtObj.mjOBJ_KEY, "home")
if key_id >= 0:
    mujoco.mj_resetDataKeyframe(model, data, key_id)
    try:
        data.ctrl[:] = model.keyframe("home").ctrl
    except Exception:
        pass
mujoco.mj_forward(model, data)

print("Starting rollout...")


frames = []
T = 10.0  # seconds
fps = 30

while data.time < T:
    # Visualization frame (large)
    img_vis = render_rgb(renderer, data, camera=cameras[0])

    # Policy action
    action = policy_step()

    # Apply action as delta on ctrl (best-effort; adjust to your control mode)
    if data.ctrl.size == action.size:
        data.ctrl[:] = data.ctrl + action_scale * action
        # clip to actuator range
        ctrl_range = model.actuator_ctrlrange
        data.ctrl[:] = np.clip(data.ctrl, ctrl_range[:, 0], ctrl_range[:, 1])
    else:
        # If dims mismatch, print once and stop
        raise RuntimeError(f"action dim {action.size} != ctrl dim {data.ctrl.size}. Your ONNX likely targets a different action space/control mode.")

    # step physics
    for _ in range(n_substeps):
        mujoco.mj_step(model, data)

    # video sampling
    if len(frames) < data.time * fps:
        frames.append(img_vis)

mediapy.show_video(frames, fps=fps)


Policy expects image: (128, 128)
Policy expects state_dim: 50
Starting rollout...


0
This browser does not support the video tag.


## Save MP4 (optional)

In [8]:
import imageio.v2 as imageio
out_mp4 = os.path.join(r"/data/user/junzhe/code/mjpl/mujoco_playground/learning/runs/rgbppo_AirbotPlayPickCube__1__1766164466", "onnx_rollout.mp4")
imageio.mimwrite(out_mp4, frames, fps=30, codec="libx264", quality=8)
print("Wrote:", out_mp4)


Wrote: /data/user/junzhe/code/mjpl/mujoco_playground/learning/runs/rgbppo_AirbotPlayPickCube__1__1766164466/onnx_rollout.mp4


    ## If your ONNX expects `state`

    For **rgb+state** ONNX, you must feed a `state` vector that matches training.
    If you exported `policy_rgb_state.onnx` from your MJX wrapper training, the correct `state` is whatever your MJX env provided as `obs["state"]`.
    Running the same ONNX inside a standalone MuJoCo XML without reproducing that state definition will usually reduce performance.

    If you want **pure-vision** inference, re-export **rgb-only** ONNX and use it here.
    