In [2]:
import mujoco
import mujoco.mjx as mjx
import jax
import jax.numpy as jnp
import numpy as np

In [3]:
def alpha_order_grad(cost_fn, u, eps=1e-3, M=32, gamma=1.0, key=jax.random.PRNGKey(0)):
    """Adaptive alpha-order gradient for scalar J(u) with u.shape=(1,) or (m,)."""
    n = u.size
    k1, k0 = jax.random.split(key)
    Z1 = jax.random.normal(k1, (M, n))
    Z0 = jax.random.normal(k0, (M, n))

    # FoBG samples (pathwise)
    grad1_s = jax.vmap(lambda z: jax.grad(cost_fn)(u + eps*z))(Z1)
    grad1   = jnp.mean(grad1_s, axis=0)

    # ZoBG samples (REINFORCE/FD)
    f0 = cost_fn(u)
    grad0_s = jax.vmap(lambda z: (f0 - cost_fn(u + eps*z)) * z / eps)(Z0)
    grad0   = jnp.mean(grad0_s, axis=0)

    # Adaptive α (Suh et al. Eq. 4–5, robust rule)
    var1 = jnp.mean(jnp.sum((grad1_s - grad1)**2, axis=1))
    var0 = jnp.mean(jnp.sum((grad0_s - grad0)**2, axis=1))
    B = jnp.linalg.norm(grad1 - grad0)
    eps_conf = jnp.sqrt(var0 / M)
    alpha_inf = var0 / (var0 + var1 + 1e-12)
    alpha = jnp.where(alpha_inf * B <= (gamma - eps_conf),
                      alpha_inf,
                      jnp.clip((gamma - eps_conf) / (B + 1e-12), 0.0, 1.0))
    grad_mix = alpha * grad1 + (1.0 - alpha) * grad0
    return grad_mix, alpha, (grad1, grad0)


In [None]:
# ---- 1-DoF MuJoCo model (slide joint along z) ----
xml = """
<mujoco>
  <option timestep="0.005" gravity="0 0 -9.81"/>
  <worldbody>
    <body name="ball" pos="0 0 1.0">
      <joint name="slide_z" type="slide" axis="0 0 1" limited="true" range="0 2"/>
      <geom type="sphere" size="0.05" mass="0.2" condim="3"/>
    </body>
    <geom type="plane" size="1 1 0.05" pos="0 0 0"/>
  </worldbody>
  <actuator>
    <motor joint="slide_z" ctrlrange="-20 20"/>
  </actuator>
</mujoco>
"""

# Build model and data
m = mujoco.MjModel.from_xml_string(xml)
d0 = mujoco.MjData(m)
mjx_model = mjx.put_model(m)
mjx_data0 = mjx.put_data(m, d0)

dt = float(m.opt.timestep)
T = 200                 # 1 second
z_target = 1.2
w_u = 1e-2

def rollout_mjx(u_seq):
    """
    Simulate a control sequence u_seq over horizon T.
    Returns:
        q_hist: (T+1, nq)
        v_hist: (T+1, nv)
    """
    data = mjx_data0
    data = data.replace(qpos=jnp.array([1.0]), qvel=jnp.array([0.0]))

    def step_fn(d, u_t):
        d = d.replace(ctrl=u_t[None])  # ctrl expects shape (1,)
        d_next = mjx.step(mjx_model, d)
        # store current state (q,v)
        return d_next, (d.qpos, d.qvel)

    data_final, (q_hist, v_hist) = jax.lax.scan(step_fn, data, u_seq)
    # append final state
    q_hist = jnp.vstack([q_hist, data_final.qpos[None, :]])
    v_hist = jnp.vstack([v_hist, data_final.qvel[None, :]])
    return q_hist, v_hist


def cost_mjx(u):
    zT = rollout_final_height(u)
    return 0.5*(zT - z_target)**2 + 0.5*w_u*(u[0]**2)

In [12]:
u0 = jnp.array([0.0])
rollout_final_height(u0)

Array(-0.00036674, dtype=float32)

In [9]:
import mujoco
m = mujoco.MjModel.from_xml_string(xml)
print("Joints:", [m.joint(i).name for i in range(m.njnt)])


Joints: ['slide_z']
