# K-Scale Humanoid Benchmark

Welcome to the K-Scale Humanoid Benchmark! This notebook will walk you through training your own reinforcement learning policy, which you can then use to control a K-Scale robot.

*Note:* The Just-In-Time compilation may take a while and cause your Colab instance to appear to disconnect. However, your training cell may actually still be running. Make sure to check before restarting!

*Last updated: 2025/05/15*

## Dependencies and Config

The K-Scale Humanoid Benchmark uses K-Scale's open-source RL framework [K-Sim](https://github.com/kscalelabs/ksim) for training and the [K-Scale API](https://github.com/kscalelabs/kscale) for asset management.

In [None]:
# Install packages

!pip install ksim==0.1.2 xax==0.3.0 mujoco-scenes

In [1]:
# Set up environment variables
%env MUJOCO_GL=egl

env: MUJOCO_GL=egl


In [2]:
import asyncio
import functools
import math
from dataclasses import dataclass
from typing import Self

import attrs
import distrax
import equinox as eqx
import jax
import jax.numpy as jnp
import ksim
import mujoco
import mujoco_scenes
import mujoco_scenes.mjcf
import nest_asyncio
import optax
import xax
from jaxtyping import Array, PRNGKeyArray

nest_asyncio.apply()

2025-09-30 19:29:55.662084: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1759253395.682631   81197 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1759253395.687762   81197 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1759253395.699562   81197 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1759253395.699583   81197 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1759253395.699585   81197 computation_placer.cc:177] computation placer alr

Failed to import warp: No module named 'warp'
Failed to import mujoco.mjx.third_party.mujoco_warp as mujoco_warp: No module named 'warp'


  import pkg_resources


In [3]:
# These are in the order of the neural network outputs.
ZEROS: list[tuple[str, float]] = [
    ("dof_right_shoulder_pitch_03", 0.0),
    ("dof_right_shoulder_roll_03", math.radians(-10.0)),
    ("dof_right_shoulder_yaw_02", 0.0),
    ("dof_right_elbow_02", math.radians(90.0)),
    ("dof_right_wrist_00", 0.0),
    ("dof_left_shoulder_pitch_03", 0.0),
    ("dof_left_shoulder_roll_03", math.radians(10.0)),
    ("dof_left_shoulder_yaw_02", 0.0),
    ("dof_left_elbow_02", math.radians(-90.0)),
    ("dof_left_wrist_00", 0.0),
    ("dof_right_hip_pitch_04", math.radians(-20.0)),
    ("dof_right_hip_roll_03", math.radians(-0.0)),
    ("dof_right_hip_yaw_03", 0.0),
    ("dof_right_knee_04", math.radians(-50.0)),
    ("dof_right_ankle_02", math.radians(30.0)),
    ("dof_left_hip_pitch_04", math.radians(20.0)),
    ("dof_left_hip_roll_03", math.radians(0.0)),
    ("dof_left_hip_yaw_03", 0.0),
    ("dof_left_knee_04", math.radians(50.0)),
    ("dof_left_ankle_02", math.radians(-30.0)),
]

## Rewards

When training a reinforcement learning agent, the most important thing to define is what reward you want the agent to maximimze. `ksim` includes a number of useful default rewards for training walking agents, but it is often a good idea to define new rewards to encourage specific types of behavior. The cell below shows an example of how to define a custom reward. A similar pattern can be used to define custom objectives, events, observations, and more.

In [4]:
@attrs.define(frozen=True, kw_only=True)
class JointPositionPenalty(ksim.JointDeviationPenalty):
    @classmethod
    def create_from_names(
        cls,
        names: list[str],
        physics_model: ksim.PhysicsModel,
        scale: float = -1.0,
        scale_by_curriculum: bool = False,
    ) -> Self:
        zeros = {k: v for k, v in ZEROS}
        joint_targets = [zeros[name] for name in names]

        return cls.create(
            physics_model=physics_model,
            joint_names=tuple(names),
            joint_targets=tuple(joint_targets),
            scale=scale,
            scale_by_curriculum=scale_by_curriculum,
        )


@attrs.define(frozen=True, kw_only=True)
class BentArmPenalty(JointPositionPenalty):
    @classmethod
    def create_penalty(
        cls,
        physics_model: ksim.PhysicsModel,
        scale: float = -1.0,
        scale_by_curriculum: bool = False,
    ) -> Self:
        return cls.create_from_names(
            names=[
                "dof_right_shoulder_pitch_03",
                "dof_right_shoulder_roll_03",
                "dof_right_shoulder_yaw_02",
                "dof_right_elbow_02",
                "dof_right_wrist_00",
                "dof_left_shoulder_pitch_03",
                "dof_left_shoulder_roll_03",
                "dof_left_shoulder_yaw_02",
                "dof_left_elbow_02",
                "dof_left_wrist_00",
            ],
            physics_model=physics_model,
            scale=scale,
            scale_by_curriculum=scale_by_curriculum,
        )


@attrs.define(frozen=True, kw_only=True)
class StraightLegPenalty(JointPositionPenalty):
    @classmethod
    def create_penalty(
        cls,
        physics_model: ksim.PhysicsModel,
        scale: float = -1.0,
        scale_by_curriculum: bool = False,
    ) -> Self:
        return cls.create_from_names(
            names=[
                "dof_left_hip_roll_03",
                "dof_left_hip_yaw_03",
                "dof_right_hip_roll_03",
                "dof_right_hip_yaw_03",
            ],
            physics_model=physics_model,
            scale=scale,
            scale_by_curriculum=scale_by_curriculum,
        )

## Actor-Critic Model

We train our reinforcement learning agent using an RNN-based actor and critic, which we define below.

In [5]:
class Actor(eqx.Module):
    """Actor for the walking task."""

    input_proj: eqx.nn.Linear
    rnns: tuple[eqx.nn.GRUCell, ...]
    output_proj: eqx.nn.Linear
    num_inputs: int = eqx.static_field()
    num_outputs: int = eqx.static_field()
    num_mixtures: int = eqx.static_field()
    min_std: float = eqx.static_field()
    max_std: float = eqx.static_field()
    var_scale: float = eqx.static_field()

    def __init__(
        self,
        key: PRNGKeyArray,
        *,
        num_inputs: int,
        num_outputs: int,
        min_std: float,
        max_std: float,
        var_scale: float,
        hidden_size: int,
        num_mixtures: int,
        depth: int,
    ) -> None:
        # Project input to hidden size
        key, input_proj_key = jax.random.split(key)
        self.input_proj = eqx.nn.Linear(
            in_features=num_inputs,
            out_features=hidden_size,
            key=input_proj_key,
        )

        # Create RNN layer
        key, rnn_key = jax.random.split(key)
        rnn_keys = jax.random.split(rnn_key, depth)
        self.rnns = tuple(
            [
                eqx.nn.GRUCell(
                    input_size=hidden_size,
                    hidden_size=hidden_size,
                    key=rnn_key,
                )
                for rnn_key in rnn_keys
            ]
        )

        # Project to output
        self.output_proj = eqx.nn.Linear(
            in_features=hidden_size,
            out_features=num_outputs * 3 * num_mixtures,
            key=key,
        )

        self.num_inputs = num_inputs
        self.num_outputs = num_outputs
        self.num_mixtures = num_mixtures
        self.min_std = min_std
        self.max_std = max_std
        self.var_scale = var_scale

    def forward(self, obs_n: Array, carry: Array) -> tuple[distrax.Distribution, Array]:
        x_n = self.input_proj(obs_n)
        out_carries = []
        for i, rnn in enumerate(self.rnns):
            x_n = rnn(x_n, carry[i])
            out_carries.append(x_n)
        out_n = self.output_proj(x_n)

        # Reshape the output to be a mixture of gaussians.
        slice_len = self.num_outputs * self.num_mixtures
        mean_nm = out_n[..., :slice_len].reshape(self.num_outputs, self.num_mixtures)
        std_nm = out_n[..., slice_len : slice_len * 2].reshape(self.num_outputs, self.num_mixtures)
        logits_nm = out_n[..., slice_len * 2 :].reshape(self.num_outputs, self.num_mixtures)

        # Softplus and clip to ensure positive standard deviations.
        std_nm = jnp.clip((jax.nn.softplus(std_nm) + self.min_std) * self.var_scale, max=self.max_std)

        # Apply bias to the means.
        mean_nm = mean_nm + jnp.array([v for _, v in ZEROS])[:, None]

        dist_n = ksim.MixtureOfGaussians(means_nm=mean_nm, stds_nm=std_nm, logits_nm=logits_nm)

        return dist_n, jnp.stack(out_carries, axis=0)


class Critic(eqx.Module):
    """Critic for the walking task."""

    input_proj: eqx.nn.Linear
    rnns: tuple[eqx.nn.GRUCell, ...]
    output_proj: eqx.nn.Linear
    num_inputs: int = eqx.static_field()

    def __init__(
        self,
        key: PRNGKeyArray,
        *,
        num_inputs: int,
        hidden_size: int,
        depth: int,
    ) -> None:
        num_outputs = 1

        # Project input to hidden size
        key, input_proj_key = jax.random.split(key)
        self.input_proj = eqx.nn.Linear(
            in_features=num_inputs,
            out_features=hidden_size,
            key=input_proj_key,
        )

        # Create RNN layer
        key, rnn_key = jax.random.split(key)
        rnn_keys = jax.random.split(rnn_key, depth)
        self.rnns = tuple(
            [
                eqx.nn.GRUCell(
                    input_size=hidden_size,
                    hidden_size=hidden_size,
                    key=rnn_key,
                )
                for rnn_key in rnn_keys
            ]
        )

        # Project to output
        self.output_proj = eqx.nn.Linear(
            in_features=hidden_size,
            out_features=num_outputs,
            key=key,
        )

        self.num_inputs = num_inputs

    def forward(self, obs_n: Array, carry: Array) -> tuple[Array, Array]:
        x_n = self.input_proj(obs_n)
        out_carries = []
        for i, rnn in enumerate(self.rnns):
            x_n = rnn(x_n, carry[i])
            out_carries.append(x_n)
        out_n = self.output_proj(x_n)

        return out_n, jnp.stack(out_carries, axis=0)


class Model(eqx.Module):
    actor: Actor
    critic: Critic

    def __init__(
        self,
        key: PRNGKeyArray,
        *,
        num_actor_inputs: int,
        num_actor_outputs: int,
        num_critic_inputs: int,
        min_std: float,
        max_std: float,
        var_scale: float,
        hidden_size: int,
        num_mixtures: int,
        depth: int,
    ) -> None:
        actor_key, critic_key = jax.random.split(key)
        self.actor = Actor(
            actor_key,
            num_inputs=num_actor_inputs,
            num_outputs=num_actor_outputs,
            min_std=min_std,
            max_std=max_std,
            var_scale=var_scale,
            hidden_size=hidden_size,
            num_mixtures=num_mixtures,
            depth=depth,
        )
        self.critic = Critic(
            critic_key,
            hidden_size=hidden_size,
            depth=depth,
            num_inputs=num_critic_inputs,
        )

  num_inputs: int = eqx.static_field()
  num_outputs: int = eqx.static_field()
  num_mixtures: int = eqx.static_field()
  min_std: float = eqx.static_field()
  max_std: float = eqx.static_field()
  var_scale: float = eqx.static_field()
  num_inputs: int = eqx.static_field()


## Config

The [ksim framework](https://github.com/kscalelabs/ksim) is based on [xax](https://github.com/kscalelabs/xax), a JAX training library built by K-Scale. To provide configuration options, xax uses a Config dataclass to parse command-line options. We define the config here.

In [6]:
@dataclass
class HumanoidWalkingTaskConfig(ksim.PPOConfig):
    """Config for the humanoid walking task."""

    # Model parameters.
    hidden_size: int = xax.field(
        value=128,
        help="The hidden size for the MLPs.",
    )
    depth: int = xax.field(
        value=5,
        help="The depth for the MLPs.",
    )
    num_mixtures: int = xax.field(
        value=5,
        help="The number of mixtures for the actor.",
    )
    var_scale: float = xax.field(
        value=0.5,
        help="The scale for the standard deviations of the actor.",
    )
    use_acc_gyro: bool = xax.field(
        value=True,
        help="Whether to use the IMU acceleration and gyroscope observations.",
    )

    # Curriculum parameters.
    num_curriculum_levels: int = xax.field(
        value=100,
        help="The number of curriculum levels to use.",
    )
    increase_threshold: float = xax.field(
        value=5.0,
        help="Increase the curriculum level when the mean trajectory length is above this threshold.",
    )
    decrease_threshold: float = xax.field(
        value=1.0,
        help="Decrease the curriculum level when the mean trajectory length is below this threshold.",
    )
    min_level_steps: int = xax.field(
        value=1,
        help="The minimum number of steps to wait before changing the curriculum level.",
    )

    # Optimizer parameters.
    learning_rate: float = xax.field(
        value=3e-4,
        help="Learning rate for PPO.",
    )
    adam_weight_decay: float = xax.field(
        value=1e-5,
        help="Weight decay for the Adam optimizer.",
    )

## Task

The meat-and-potatoes of our training code is the task. This defines the observations, rewards, model calling logic, and everything else needed by `ksim` to train our reinforcement learning agent.

In [7]:
class HumanoidWalkingTask(ksim.PPOTask[HumanoidWalkingTaskConfig]):
    def get_optimizer(self) -> optax.GradientTransformation:
        return (
            optax.adam(self.config.learning_rate)
            if self.config.adam_weight_decay == 0.0
            else optax.adamw(self.config.learning_rate, weight_decay=self.config.adam_weight_decay)
        )

    def get_mujoco_model(self) -> mujoco.MjModel:
        mjcf_path = asyncio.run(ksim.get_mujoco_model_path("kbot", name="robot"))
        return mujoco_scenes.mjcf.load_mjmodel(mjcf_path, scene="smooth")

    def get_mujoco_model_metadata(self, mj_model: mujoco.MjModel) -> ksim.Metadata:
        metadata = asyncio.run(ksim.get_mujoco_model_metadata("kbot"))
        if metadata.joint_name_to_metadata is None:
            raise ValueError("Joint metadata is not available")
        if metadata.actuator_type_to_metadata is None:
            raise ValueError("Actuator metadata is not available")
        return metadata

    def get_actuators(
        self,
        physics_model: ksim.PhysicsModel,
        metadata: ksim.Metadata | None = None,
    ) -> ksim.Actuators:
        assert metadata is not None, "Metadata is required"
        return ksim.PositionActuators(
            physics_model=physics_model,
            metadata=metadata,
        )

    def get_physics_randomizers(self, physics_model: ksim.PhysicsModel) -> list[ksim.PhysicsRandomizer]:
        return [
            ksim.StaticFrictionRandomizer(),
            ksim.ArmatureRandomizer(),
            ksim.AllBodiesMassMultiplicationRandomizer(scale_lower=0.95, scale_upper=1.05),
            ksim.JointDampingRandomizer(),
            ksim.JointZeroPositionRandomizer(scale_lower=math.radians(-2), scale_upper=math.radians(2)),
        ]

    def get_events(self, physics_model: ksim.PhysicsModel) -> list[ksim.Event]:
        return [
            ksim.PushEvent(
                x_force=1.0,
                y_force=1.0,
                z_force=0.3,
                force_range=(0.5, 1.0),
                x_angular_force=0.0,
                y_angular_force=0.0,
                z_angular_force=0.0,
                interval_range=(0.5, 4.0),
            ),
        ]

    def get_resets(self, physics_model: ksim.PhysicsModel) -> list[ksim.Reset]:
        return [
            ksim.RandomJointPositionReset.create(physics_model, {k: v for k, v in ZEROS}, scale=0.1),
            ksim.RandomJointVelocityReset(),
        ]

    def get_observations(self, physics_model: ksim.PhysicsModel) -> list[ksim.Observation]:
        return [
            ksim.TimestepObservation(),
            ksim.JointPositionObservation(noise=math.radians(2)),
            ksim.JointVelocityObservation(noise=math.radians(10)),
            ksim.ActuatorForceObservation(),
            ksim.CenterOfMassInertiaObservation(),
            ksim.CenterOfMassVelocityObservation(),
            ksim.BasePositionObservation(),
            ksim.BaseOrientationObservation(),
            ksim.BaseLinearVelocityObservation(),
            ksim.BaseAngularVelocityObservation(),
            ksim.BaseLinearAccelerationObservation(),
            ksim.BaseAngularAccelerationObservation(),
            ksim.ActuatorAccelerationObservation(),
            ksim.ProjectedGravityObservation.create(
                physics_model=physics_model,
                framequat_name="imu_site_quat",
                lag_range=(0.0, 0.1),
                noise=math.radians(1),
            ),
            ksim.SensorObservation.create(
                physics_model=physics_model,
                sensor_name="imu_acc",
                noise=1.0,
            ),
            ksim.SensorObservation.create(
                physics_model=physics_model,
                sensor_name="imu_gyro",
                noise=math.radians(10),
            ),
        ]

    def get_commands(self, physics_model: ksim.PhysicsModel) -> list[ksim.Command]:
        return []

    def get_rewards(self, physics_model: ksim.PhysicsModel) -> list[ksim.Reward]:
        return [
            # Standard rewards.
            ksim.NaiveForwardReward(clip_max=1.25, in_robot_frame=False, scale=3.0),
            ksim.NaiveForwardOrientationReward(scale=1.0),
            ksim.StayAliveReward(scale=1.0),
            ksim.UprightReward(scale=0.5),
            # Avoid movement penalties.
            ksim.AngularVelocityPenalty(index=("x", "y"), scale=-0.1),
            ksim.LinearVelocityPenalty(index=("z"), scale=-0.1),
            # Normalization penalties.
            ksim.AvoidLimitsPenalty.create(physics_model, scale=-0.01),
            ksim.JointAccelerationPenalty(scale=-0.01, scale_by_curriculum=True),
            ksim.JointJerkPenalty(scale=-0.01, scale_by_curriculum=True),
            ksim.LinkAccelerationPenalty(scale=-0.01, scale_by_curriculum=True),
            ksim.LinkJerkPenalty(scale=-0.01, scale_by_curriculum=True),
            ksim.ActionAccelerationPenalty(scale=-0.01, scale_by_curriculum=True),
            # Bespoke rewards.
            BentArmPenalty.create_penalty(physics_model, scale=-0.1),
            StraightLegPenalty.create_penalty(physics_model, scale=-0.1),
        ]

    def get_terminations(self, physics_model: ksim.PhysicsModel) -> list[ksim.Termination]:
        return [
            ksim.BadZTermination(unhealthy_z_lower=0.6, unhealthy_z_upper=1.2),
            ksim.FarFromOriginTermination(max_dist=10.0),
        ]

    def get_curriculum(self, physics_model: ksim.PhysicsModel) -> ksim.Curriculum:
        return ksim.DistanceFromOriginCurriculum(
            min_level_steps=5,
        )

    def get_model(self, key: PRNGKeyArray) -> Model:
        return Model(
            key,
            num_actor_inputs=51 if self.config.use_acc_gyro else 45,
            num_actor_outputs=len(ZEROS),
            num_critic_inputs=446,
            min_std=0.001,
            max_std=1.0,
            var_scale=self.config.var_scale,
            hidden_size=self.config.hidden_size,
            num_mixtures=self.config.num_mixtures,
            depth=self.config.depth,
        )

    def run_actor(
        self,
        model: Actor,
        observations: xax.FrozenDict[str, Array],
        commands: xax.FrozenDict[str, Array],
        carry: Array,
    ) -> tuple[distrax.Distribution, Array]:
        time_1 = observations["timestep_observation"]
        joint_pos_n = observations["joint_position_observation"]
        joint_vel_n = observations["joint_velocity_observation"]
        proj_grav_3 = observations["projected_gravity_observation"]
        imu_acc_3 = observations["sensor_observation_imu_acc"]
        imu_gyro_3 = observations["sensor_observation_imu_gyro"]

        obs = [
            jnp.sin(time_1),
            jnp.cos(time_1),
            joint_pos_n,  # NUM_JOINTS
            joint_vel_n,  # NUM_JOINTS
            proj_grav_3,  # 3
        ]
        if self.config.use_acc_gyro:
            obs += [
                imu_acc_3,  # 3
                imu_gyro_3,  # 3
            ]

        obs_n = jnp.concatenate(obs, axis=-1)
        action, carry = model.forward(obs_n, carry)

        return action, carry

    def run_critic(
        self,
        model: Critic,
        observations: xax.FrozenDict[str, Array],
        commands: xax.FrozenDict[str, Array],
        carry: Array,
    ) -> tuple[Array, Array]:
        time_1 = observations["timestep_observation"]
        dh_joint_pos_j = observations["joint_position_observation"]
        dh_joint_vel_j = observations["joint_velocity_observation"]
        com_inertia_n = observations["center_of_mass_inertia_observation"]
        com_vel_n = observations["center_of_mass_velocity_observation"]
        imu_acc_3 = observations["sensor_observation_imu_acc"]
        imu_gyro_3 = observations["sensor_observation_imu_gyro"]
        proj_grav_3 = observations["projected_gravity_observation"]
        act_frc_obs_n = observations["actuator_force_observation"]
        base_pos_3 = observations["base_position_observation"]
        base_quat_4 = observations["base_orientation_observation"]

        obs_n = jnp.concatenate(
            [
                jnp.sin(time_1),
                jnp.cos(time_1),
                dh_joint_pos_j,  # NUM_JOINTS
                dh_joint_vel_j / 10.0,  # NUM_JOINTS
                com_inertia_n,  # 160
                com_vel_n,  # 96
                imu_acc_3,  # 3
                imu_gyro_3,  # 3
                proj_grav_3,  # 3
                act_frc_obs_n / 100.0,  # NUM_JOINTS
                base_pos_3,  # 3
                base_quat_4,  # 4
            ],
            axis=-1,
        )

        return model.forward(obs_n, carry)

    def _model_scan_fn(
        self,
        actor_critic_carry: tuple[Array, Array],
        xs: tuple[ksim.Trajectory, PRNGKeyArray],
        model: Model,
    ) -> tuple[tuple[Array, Array], ksim.PPOVariables]:
        transition, rng = xs

        actor_carry, critic_carry = actor_critic_carry
        actor_dist, next_actor_carry = self.run_actor(
            model=model.actor,
            observations=transition.obs,
            commands=transition.command,
            carry=actor_carry,
        )

        # Gets the log probabilities of the action.
        log_probs = actor_dist.log_prob(transition.action)
        assert isinstance(log_probs, Array)

        value, next_critic_carry = self.run_critic(
            model=model.critic,
            observations=transition.obs,
            commands=transition.command,
            carry=critic_carry,
        )

        transition_ppo_variables = ksim.PPOVariables(
            log_probs=log_probs,
            values=value.squeeze(-1),
        )

        next_carry = jax.tree.map(
            lambda x, y: jnp.where(transition.done, x, y),
            self.get_initial_model_carry(rng),
            (next_actor_carry, next_critic_carry),
        )

        return next_carry, transition_ppo_variables

    def get_ppo_variables(
        self,
        model: Model,
        trajectory: ksim.Trajectory,
        model_carry: tuple[Array, Array],
        rng: PRNGKeyArray,
    ) -> tuple[ksim.PPOVariables, tuple[Array, Array]]:
        scan_fn = functools.partial(self._model_scan_fn, model=model)
        next_model_carry, ppo_variables = xax.scan(
            scan_fn,
            model_carry,
            (trajectory, jax.random.split(rng, len(trajectory.done))),
            jit_level=4,
        )
        return ppo_variables, next_model_carry

    def get_initial_model_carry(self, rng: PRNGKeyArray) -> tuple[Array, Array]:
        return (
            jnp.zeros(shape=(self.config.depth, self.config.hidden_size)),
            jnp.zeros(shape=(self.config.depth, self.config.hidden_size)),
        )

    def sample_action(
        self,
        model: Model,
        model_carry: tuple[Array, Array],
        physics_model: ksim.PhysicsModel,
        physics_state: ksim.PhysicsState,
        observations: xax.FrozenDict[str, Array],
        commands: xax.FrozenDict[str, Array],
        rng: PRNGKeyArray,
        argmax: bool,
    ) -> ksim.Action:
        actor_carry_in, critic_carry_in = model_carry
        action_dist_j, actor_carry = self.run_actor(
            model=model.actor,
            observations=observations,
            commands=commands,
            carry=actor_carry_in,
        )
        action_j = action_dist_j.mode() if argmax else action_dist_j.sample(seed=rng)
        return ksim.Action(action=action_j, carry=(actor_carry, critic_carry_in))

# Launch TensorBoard

The below cell launches TensorBoard to visualize the training progress.

After launching an experiment, please wait for ~5 minutes for the task to start running and then click the reload button in the top right corner of the TensorBoard page. You can also open the settings and check the "Reload data" option to automatically reload the TensorBoard. 

In [8]:
# Launch TensorBoard
%load_ext tensorboard
%tensorboard --logdir humanoid_walking_task

## Launching an Experiment

To launch an experiment with `xax`, you can use `Task.launch(config)`. Note that this is usually intended to be called from the command-line, so it will by default attempt to parse additional command-line arguments unless `use_cli=False` is set.

By default, runs will be logged to a directory called `run_[x]` in the task directory (/content/humanoid_walking_task/ in Colab). From there, you can download the ckpt `.bin` files and the TensorBoard logs.

Also note that since this is a Jupyter notebook, the task will be unable to find the task training code and emit a warning like "Could not resolve task path for <TASK_NAME>, returning current working directory". You can safely ignore this warning.

In [9]:
if __name__ == "__main__":
    HumanoidWalkingTask.launch(
        HumanoidWalkingTaskConfig(
            # Training parameters.
            num_envs=2048,
            batch_size=256,
            num_passes=4,
            epochs_per_log_step=1,
            rollout_length_seconds=8.0,
            global_grad_clip=2.0,
            # Simulation parameters.
            dt=0.002,
            ctrl_dt=0.02,
            iterations=8,
            ls_iterations=8,
            action_latency_range=(0.003, 0.01),  # Simulate 3-10ms of latency.
            drop_action_prob=0.05,  # Drop 5% of commands.
            # Visualization parameters
            render_track_body_id=0,
            # Checkpointing parameters.
            save_every_n_seconds=60,
        ),
        use_cli=False,
    )



  import pkg_resources


  [1;36mINFO[0m  [90m2025-09-30 19:30:40[0m [[1;34mxax.task.mixins.compile[0m] Setting JAX logging level to INFO
  [1;36mINFO[0m  [90m2025-09-30 19:30:40[0m [[1;34mxax.task.mixins.compile[0m] Setting JAX compilation cache directory to /home/pathofseb/.cache/jax/jaxcache
  [1;36mINFO[0m  [90m2025-09-30 19:30:40[0m [[1;34mxax.task.mixins.compile[0m] Configuring JAX compilation cache parameters


INFO:2025-09-30 19:30:40,803:jax._src.xla_bridge:822: Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory


  [1;36mINFO[0m  [90m2025-09-30 19:30:40[0m [[1;34mjax._src.xla_bridge[0m] Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory
 [1;32mSTATUS[0m [90m2025-09-30 19:30:40[0m [[1;34mxax.task.mixins.artifacts[0m] /home/pathofseb/coding/ksim-gym-example/humanoid_walking_task/run_2
 [1;32mSTATUS[0m [90m2025-09-30 19:30:40[0m [[1;34mxax.task.mixins.train[0m] /home/pathofseb/coding/ksim-gym-example
 [1;32mSTATUS[0m [90m2025-09-30 19:30:40[0m [[1;34mxax.task.mixins.train[0m] humanoid_walking_task
 [1;32mSTATUS[0m [90m2025-09-30 19:30:40[0m [[1;34mxax.task.mixins.train[0m] JAX devices: [CpuDevice(id=0)]
  [1;36mINFO[0m  [90m2025-09-30 19:30:41[0m [[1;34mxax.task.mixins.train[0m] Starting a new training run
  [1;35mPING[0m  [90m2025-09-30 19:30:42[0m [[1;34mksim.task.rl[0m] Model size: 1,090,861 parameters
  [1;35mPING[0m  [90m2025-09-30 19:30:42[0m [[1;34mksim.ta

ValueError: Non-hashable static arguments are not supported. An error occurred while trying to hash an object of type <class 'ksim.engine.MjxEngine'>, MjxEngine(
  actuators=<ksim.actuators.PositionActuators object at 0x7e71144e9580>,
  resets=[
    RandomJointPositionReset(scale=0.1, zeros=(0.0, -0.17453292519943295, 0.0, 1.5707963267948966, 0.0, 0.0, 0.17453292519943295, 0.0, -1.5707963267948966, 0.0, -0.3490658503988659, -0.0, 0.0, -0.8726646259971648, 0.5235987755982988, 0.3490658503988659, 0.0, 0.0, 0.8726646259971648, -0.5235987755982988)),
    RandomJointVelocityReset(scale=0.01)
  ],
  events=[
    PushEvent(x_force=1.0, y_force=1.0, z_force=0.3, force_range=(0.5, 1.0), x_angular_force=0.0, y_angular_force=0.0, z_angular_force=0.0, interval_range=(0.5, 4.0))
  ],
  phys_steps_per_ctrl_steps=10,
  min_action_latency_step=1.4999999287538264,
  max_action_latency_step=4.999999762512755,
  drop_action_prob=0.05
). The error was:
Traceback (most recent call last):
  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/home/pathofseb/coding/ksim-gym-example/.venv/lib/python3.12/site-packages/ipykernel_launcher.py", line 18, in <module>
  File "/home/pathofseb/coding/ksim-gym-example/.venv/lib/python3.12/site-packages/traitlets/config/application.py", line 1075, in launch_instance
  File "/home/pathofseb/coding/ksim-gym-example/.venv/lib/python3.12/site-packages/ipykernel/kernelapp.py", line 739, in start
  File "/home/pathofseb/coding/ksim-gym-example/.venv/lib/python3.12/site-packages/tornado/platform/asyncio.py", line 211, in start
  File "/usr/lib/python3.12/asyncio/base_events.py", line 641, in run_forever
  File "/home/pathofseb/coding/ksim-gym-example/.venv/lib/python3.12/site-packages/nest_asyncio.py", line 133, in _run_once
  File "/usr/lib/python3.12/asyncio/events.py", line 88, in _run
  File "/home/pathofseb/coding/ksim-gym-example/.venv/lib/python3.12/site-packages/ipykernel/kernelbase.py", line 519, in dispatch_queue
  File "/home/pathofseb/coding/ksim-gym-example/.venv/lib/python3.12/site-packages/ipykernel/kernelbase.py", line 508, in process_one
  File "/home/pathofseb/coding/ksim-gym-example/.venv/lib/python3.12/site-packages/ipykernel/kernelbase.py", line 400, in dispatch_shell
  File "/home/pathofseb/coding/ksim-gym-example/.venv/lib/python3.12/site-packages/ipykernel/ipkernel.py", line 368, in execute_request
  File "/home/pathofseb/coding/ksim-gym-example/.venv/lib/python3.12/site-packages/ipykernel/kernelbase.py", line 767, in execute_request
  File "/home/pathofseb/coding/ksim-gym-example/.venv/lib/python3.12/site-packages/ipykernel/ipkernel.py", line 455, in do_execute
  File "/home/pathofseb/coding/ksim-gym-example/.venv/lib/python3.12/site-packages/ipykernel/zmqshell.py", line 577, in run_cell
  File "/home/pathofseb/coding/ksim-gym-example/.venv/lib/python3.12/site-packages/IPython/core/interactiveshell.py", line 3116, in run_cell
  File "/home/pathofseb/coding/ksim-gym-example/.venv/lib/python3.12/site-packages/IPython/core/interactiveshell.py", line 3171, in _run_cell
  File "/home/pathofseb/coding/ksim-gym-example/.venv/lib/python3.12/site-packages/IPython/core/async_helpers.py", line 128, in _pseudo_sync_runner
  File "/home/pathofseb/coding/ksim-gym-example/.venv/lib/python3.12/site-packages/IPython/core/interactiveshell.py", line 3394, in run_cell_async
  File "/home/pathofseb/coding/ksim-gym-example/.venv/lib/python3.12/site-packages/IPython/core/interactiveshell.py", line 3639, in run_ast_nodes
  File "/home/pathofseb/coding/ksim-gym-example/.venv/lib/python3.12/site-packages/IPython/core/interactiveshell.py", line 3699, in run_code
  File "/tmp/ipykernel_81197/682409069.py", line 2, in <module>
  File "/home/pathofseb/coding/ksim-gym-example/.venv/lib/python3.12/site-packages/xax/task/mixins/runnable.py", line 51, in launch
  File "/home/pathofseb/coding/ksim-gym-example/.venv/lib/python3.12/site-packages/xax/task/launchers/cli.py", line 40, in launch
  File "/home/pathofseb/coding/ksim-gym-example/.venv/lib/python3.12/site-packages/xax/task/launchers/single_process.py", line 30, in launch
  File "/home/pathofseb/coding/ksim-gym-example/.venv/lib/python3.12/site-packages/xax/task/launchers/single_process.py", line 20, in run_single_process_training
  File "/home/pathofseb/coding/ksim-gym-example/.venv/lib/python3.12/site-packages/ksim/task/rl.py", line 1009, in run
  File "/home/pathofseb/coding/ksim-gym-example/.venv/lib/python3.12/site-packages/ksim/task/rl.py", line 2042, in run_training
  File "/home/pathofseb/coding/ksim-gym-example/.venv/lib/python3.12/site-packages/ksim/task/rl.py", line 1990, in initialize_rl_training
  File "/home/pathofseb/coding/ksim-gym-example/.venv/lib/python3.12/site-packages/ksim/task/rl.py", line 1801, in _get_env_state
  File "/home/pathofseb/coding/ksim-gym-example/.venv/lib/python3.12/site-packages/jax/_src/traceback_util.py", line 180, in reraise_with_filtered_traceback
  File "/home/pathofseb/coding/ksim-gym-example/.venv/lib/python3.12/site-packages/jax/_src/api.py", line 1112, in vmap_f
  File "/home/pathofseb/coding/ksim-gym-example/.venv/lib/python3.12/site-packages/jax/_src/linear_util.py", line 212, in call_wrapped
  File "/home/pathofseb/coding/ksim-gym-example/.venv/lib/python3.12/site-packages/jax/_src/interpreters/batching.py", line 644, in _batch_outer
  File "/home/pathofseb/coding/ksim-gym-example/.venv/lib/python3.12/site-packages/jax/_src/interpreters/batching.py", line 660, in _batch_inner
  File "/home/pathofseb/coding/ksim-gym-example/.venv/lib/python3.12/site-packages/jax/_src/interpreters/batching.py", line 342, in flatten_fun_for_vmap
  File "/home/pathofseb/coding/ksim-gym-example/.venv/lib/python3.12/site-packages/jax/_src/linear_util.py", line 397, in _get_result_paths_thunk
  File "/home/pathofseb/coding/ksim-gym-example/.venv/lib/python3.12/site-packages/ksim/task/rl.py", line 322, in apply_randomizations
  File "/home/pathofseb/coding/ksim-gym-example/.venv/lib/python3.12/site-packages/equinox/_module/_prebuilt.py", line 33, in __call__
  File "/home/pathofseb/coding/ksim-gym-example/.venv/lib/python3.12/site-packages/xax/utils/jax.py", line 139, in wrapped
  File "/home/pathofseb/coding/ksim-gym-example/.venv/lib/python3.12/site-packages/jax/_src/traceback_util.py", line 180, in reraise_with_filtered_traceback
  File "/home/pathofseb/coding/ksim-gym-example/.venv/lib/python3.12/site-packages/jax/_src/pjit.py", line 270, in cache_miss
  File "/home/pathofseb/coding/ksim-gym-example/.venv/lib/python3.12/site-packages/jax/_src/pjit.py", line 139, in _python_pjit_helper
  File "/home/pathofseb/coding/ksim-gym-example/.venv/lib/python3.12/site-packages/jax/_src/pjit.py", line 643, in _infer_params
  File "/home/pathofseb/coding/ksim-gym-example/.venv/lib/python3.12/site-packages/jax/_src/pjit.py", line 664, in _infer_params_internal
  File "/home/pathofseb/coding/ksim-gym-example/.venv/lib/python3.12/site-packages/equinox/_module/_module.py", line 516, in __hash__
TypeError: unhashable type: 'list'
