In [1]:
import os
os.environ["JAX_PLATFORM_NAME"] = "cpu"
import jax
import jax.numpy as jnp
import numpy as np
import mujoco
import mujoco.mjx as mjx
import matplotlib.pyplot as plt


# ---------- Load a simple model with contact ----------
xml = """
<mujoco>
  <option gravity="0 0 -9.81" timestep="0.01"/>
  <worldbody>
    <body name="ball" pos="0 0 1">
      <geom type="sphere" size="0.05" mass="0.1" condim="3"/>
      <joint name="slide" type="free"/>
    </body>
    <geom type="plane" size="1 1 0.05" pos="0 0 0"/>
  </worldbody>
  <actuator>
    <motor joint="slide" ctrllimited="true" ctrlrange="-50 50"/>
  </actuator>
</mujoco>
"""

# Build model and MJX objects
m = mujoco.MjModel.from_xml_string(xml)
d0 = mujoco.MjData(m)
mjx_model = mjx.put_model(m)
mjx_data0 = mjx.put_data(m, d0)

# ---------- Rollout function ----------
def rollout(ctrl_seq):
    """Simulate control sequence and return final z-height."""
    def step_fn(data, u):
        # replace control and advance simulation
        data = data.replace(ctrl=u)
        data = mjx.step(mjx_model, data)
        return data, None

    data, _ = jax.lax.scan(step_fn, mjx_data0, ctrl_seq)
    return data.qpos[2]  # final z position

# ---------- Objective ----------
def objective(ctrl_seq):
    # maximize height ⇒ minimize negative height
    return -rollout(ctrl_seq)


# ---------- Zeroth- and first-order gradients ----------
def grad_fo(ctrl_seq):
    return jax.grad(objective)(ctrl_seq)

def grad_zo(ctrl_seq, eps=0.05, M=32, key=jax.random.PRNGKey(0)):
    n = ctrl_seq.shape[0]
    Z = jax.random.normal(key, (M, n))
    def single(z):
        return (objective(ctrl_seq + eps*z) - objective(ctrl_seq)) * z / eps
    grads = jax.vmap(single)(Z)
    return jnp.mean(grads, axis=0)

# ---------- Adaptive α-order estimator ----------
def interp_mc_grad(ctrl, eps=0.05, M=32, gamma=1.0, key=jax.random.PRNGKey(0)):
    n = ctrl.shape[0]
    subkeys = jax.random.split(key, 2)

    def fo_sample(z):
        return jax.grad(objective)(ctrl + eps*z)
    def zo_sample(z):
        return (objective(ctrl + eps*z) - objective(ctrl)) * (-z) / eps

    Z1 = jax.random.normal(subkeys[0], (M, n))
    Z0 = jax.random.normal(subkeys[1], (M, n))
    grad1_samples = jax.vmap(fo_sample)(Z1)
    grad0_samples = jax.vmap(zo_sample)(Z0)

    grad1 = jnp.mean(grad1_samples, axis=0)
    grad0 = jnp.mean(grad0_samples, axis=0)
    var1 = jnp.mean(jnp.sum((grad1_samples - grad1)**2, axis=1))
    var0 = jnp.mean(jnp.sum((grad0_samples - grad0)**2, axis=1))
    B = jnp.linalg.norm(grad1 - grad0)
    eps_conf = jnp.sqrt(var0 / M)
    alpha_inf = var0 / (var0 + var1 + 1e-12)
    alpha = jnp.where(alpha_inf * B <= gamma - eps_conf,
                      alpha_inf,
                      jnp.clip((gamma - eps_conf) / (B + 1e-12), 0.0, 1.0))
    grad_mix = alpha * grad1 + (1 - alpha) * grad0
    return grad_mix, alpha

# ---------- Run ----------
T = 50
ctrl = jnp.zeros(T)

grad_mix, alpha = interp_mc_grad(ctrl, eps=0.05, M=8)
print(f"α = {float(alpha):.3f}")
print(f"‖grad‖ = {float(jnp.linalg.norm(grad_mix)):.3f}")

plt.plot(np.arange(T), np.asarray(grad_mix))
plt.title(f"Adaptive α = {float(alpha):.2f}")
plt.xlabel("time step")
plt.ylabel("∂loss/∂u_t")
plt.tight_layout()
plt.show()




TypeError: scan body function carry input and carry output must have equal types, but they differ:

The input carry component data.ctrl has type float32[1] but the corresponding output carry component has type float32[], so the shapes do not match.

Revise the function so that all output types match the corresponding input types.

In [2]:
import mujoco.viewer


def record_trajectory(ctrl_seq):
    data = mjx_data0
    traj = []
    for u in np.array(ctrl_seq):
        data = data.replace(ctrl=u)
        data = mjx.step(mjx_model, data)
        traj.append(np.array(data.qpos))
    return np.array(traj)

import mujoco as mj
import imageio
import os

def replay_traj_to_gif(traj, mj_model, gif_path="ball_traj.gif", fps=60, height=450, width=450):
    os.makedirs(os.path.dirname(gif_path) or ".", exist_ok=True)
    d = mj.MjData(mj_model)
    renderer = mj.Renderer(mj_model, height=height, width=width)
    frames = []

    for qpos in traj:
        d.qpos[:] = qpos
        mj.mj_forward(mj_model, d)
        renderer.update_scene(d)
        img = renderer.render()
        frames.append(img)

    renderer.close()
    imageio.mimsave(gif_path, frames, fps=fps)
    print(f"Saved animation to: {gif_path}")
    return gif_path

In [3]:
# Simulate trajectory using MJX
traj = record_trajectory(np.zeros(T))
replay_traj_to_gif(traj, m, gif_path="ball_bounce.gif", fps=int(1/m.opt.timestep))


Saved animation to: ball_bounce.gif


'ball_bounce.gif'