# MJX Humanoid Training (Minimal)

- 목적: `models/humanoid.xml`로 MJX FoPG(APG) 최소 예제 학습
- 원칙: mjx만 사용, 핵심만 간결히 구현 (Brax 미사용)


In [2]:
# Setup: imports, JAX precision
import os
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.8"

import jax
import jax.numpy as jnp
from jax import random, jit, value_and_grad
from jax import config as jax_config
jax_config.update("jax_enable_x64", True)
jax_config.update("jax_default_matmul_precision", "high")

import numpy as np
import mujoco
import mujoco.mjx as mjx

XML_PATH = "models/humanoid.xml"


Failed to import warp: No module named 'warp'
Failed to import mujoco.mjx.third_party.mujoco_warp as mujoco_warp: No module named 'warp'


In [3]:
# Load model and build MJX system
m = mujoco.MjModel.from_xml_path(XML_PATH)
sys = mjx.put_model(m)

# State sizes
q_size = m.nq
qd_size = m.nv
act_size = m.nu

# Initial state from keyframe or default
q0 = np.copy(m.qpos0)
qd0 = np.zeros(qd_size)

@jit
def pipeline_init(q, qd):
    data = mjx.make_data(sys)
    data = data.replace(qpos=q, qvel=qd)
    data = mjx.forward(sys, data)
    return data

@jit
def pipeline_step(data, ctrl):
    data = data.replace(ctrl=ctrl)
    data = mjx.step(sys, data)
    return data


NotImplementedError: [<mjtSensor.mjSENS_JOINTLIMITFRC: 22>] not supported

In [None]:
# Minimal MLP policy (stateless)

def mlp_init(rng, in_dim, hidden, out_dim):
    k1, k2, k3 = random.split(rng, 3)
    w1 = random.normal(k1, (in_dim, hidden)) * (1.0 / jnp.sqrt(in_dim))
    b1 = jnp.zeros((hidden,))
    w2 = random.normal(k2, (hidden, hidden)) * (1.0 / jnp.sqrt(hidden))
    b2 = jnp.zeros((hidden,))
    w3 = random.normal(k3, (hidden, out_dim)) * (1.0 / jnp.sqrt(hidden))
    b3 = jnp.zeros((out_dim,))
    return (w1, b1, w2, b2, w3, b3)

@jit
def mlp_apply(params, x):
    w1, b1, w2, b2, w3, b3 = params
    x = jnp.tanh(x @ w1 + b1)
    x = jnp.tanh(x @ w2 + b2)
    x = jnp.tanh(x @ w3 + b3)
    return x

obs_dim = q_size + qd_size
act_dim = act_size
rng = random.PRNGKey(0)
params = mlp_init(rng, obs_dim, 128, act_dim)


In [None]:
# Rollout and loss (FoPG)

@jit
def get_obs(d):
    return jnp.concatenate([d.qpos, d.qvel], axis=0)

@jit
def reward_fn(d):
    # Simple upright and velocity penalties (dense, differentiable)
    height = d.qpos[2]
    up_reward = jnp.exp(-jnp.square(1.4 - height))
    ang_vel_pen = jnp.sum(jnp.square(d.qvel[:3]))
    joint_vel_pen = jnp.sum(jnp.square(d.qvel[3:]))
    return up_reward - 1e-3 * ang_vel_pen - 1e-4 * joint_vel_pen

@jit
def rollout_return(params, key, horizon=256):
    d = pipeline_init(q0, qd0)
    ret = 0.0
    gamma = 0.99
    disc = 1.0
    for t in range(horizon):
        obs = get_obs(d)
        act = mlp_apply(params, obs)
        d = pipeline_step(d, act)
        r = reward_fn(d)
        ret = ret + disc * r
        disc = disc * gamma
    return ret

loss = jit(lambda p, k: -rollout_return(p, k))


In [None]:
# SGD training (very short demo)

@jit
def sgd_update(params, grads, lr):
    w1, b1, w2, b2, w3, b3 = params
    g1, gb1, g2, gb2, g3, gb3 = grads
    return (
        w1 - lr * g1,
        b1 - lr * gb1,
        w2 - lr * g2,
        b2 - lr * gb2,
        w3 - lr * g3,
        b3 - lr * gb3,
    )

key = random.PRNGKey(42)
lr = 3e-4
num_updates = 200

val_and_grad = jit(value_and_grad(lambda p, k: loss(p, k)))

for i in range(num_updates):
    key, sub = random.split(key)
    v, g = val_and_grad(params, sub)
    params = sgd_update(params, g, lr)
    if (i+1) % 20 == 0:
        print(i+1, "loss:", float(v))
