<a href="https://colab.research.google.com/github/ttktjmt/mjlab/blob/main/notebooks/create_new_task.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **ü§ñ CartPole Tutorial with MJLab**

This notebook demonstrates how to create a custom reinforcement learning task using MJLab. We'll build a CartPole environment from scratch, including:

1. **Robot Definition** - Define the CartPole model in MuJoCo XML
2. **Task Configuration** - Set up observations, actions, rewards, and terminations
3. **Training** - Train a policy using PPO
4. **Evaluation** - Visualize/Record the trained policy

> **Note**: This tutorial is created based on the official MJLab documentation.

## **üì¶ Setup and Installation**

In [1]:
# Clone the mjlab repository
!if [ ! -d "mjlab" ]; then git clone -q https://github.com/mujocolab/mjlab.git; fi
%cd /content/mjlab

# Install mjlab in editable mode
!uv pip install --system -e . -q

print("‚úì Installation complete!")

/content/mjlab
‚úì Installation complete!


### **üîë WandB Setup (Optional)**

Configure Weights & Biases for experiment tracking. Add your WandB API key to Colab Secrets:
- `WANDB_API_KEY`: from [wandb.ai/authorize](https://wandb.ai/authorize)
- `WANDB_ENTITY`: your wandb entity name

In [24]:
import os
from google.colab import userdata

try:
    # Set this to disable wandb logger
    # os.environ['WANDB_MODE'] = 'disabled'

    # Set this to use wandb logger
    os.environ['WANDB_API_KEY'] = userdata.get('WANDB_API_KEY')
    os.environ['WANDB_ENTITY'] = userdata.get('WANDB_ENTITY')

    print("‚úì WandB configured successfully!")
except (AttributeError, KeyError):
    print("‚ö† WandB secrets not found. Training will proceed without WandB logging.")

‚úì WandB configured successfully!


---

## **ü§ñ Step 1: Define the Robot**

We'll create a simple CartPole robot with:
- A sliding cart (1 DOF)
- A hinged pole (1 DOF)
- A velocity actuator to control the cart

### **üìÅ Structure Directories**

In [3]:
# Create the cartpole robot directory structure
!mkdir -p /content/mjlab/src/mjlab/asset_zoo/robots/cartpole/
!mkdir -p /content/mjlab/src/mjlab/asset_zoo/robots/cartpole/xmls

print("‚úì Directory structure created")

‚úì Directory structure created


### **üìù Create MuJoCo XML Model**

This XML defines the CartPole physics:
- **Ground plane** for visualization
- **Cart body** with a sliding joint (¬±2m range)
- **Pole body** with a hinge joint (¬±90¬∞ range)
- **Velocity actuator** for cart control

In [4]:
%%writefile /content/mjlab/src/mjlab/asset_zoo/robots/cartpole/xmls/cartpole.xml
<mujoco model="cartpole">
  <compiler angle="degree" coordinate="local" inertiafromgeom="true"/>
  <worldbody>
    <geom name="ground" type="plane" pos="0 0 0" size="5 5 0.1" rgba="0.8 0.9 0.8 1"/>
    <body name="cart" pos="0 0 0.1">
      <geom type="box" size="0.2 0.1 0.1" rgba="0.2 0.2 0.8 1" mass="1.0"/>
      <joint name="slide" type="slide" axis="1 0 0" limited="true" range="-2 2"/>
      <body name="pole" pos="0 0 0.1">
        <geom type="capsule" size="0.05 0.5" fromto="0 0 0 0 0 1" rgba="0.8 0.2 0.2 1" mass="2.0"/>
        <joint name="hinge" type="hinge" axis="0 1 0" range="-90 90"/>
      </body>
    </body>
  </worldbody>
  <actuator>
    <velocity name="slide_velocity" joint="slide" ctrlrange="-20 20" kv="20"/>
  </actuator>
</mujoco>

Writing /content/mjlab/src/mjlab/asset_zoo/robots/cartpole/xmls/cartpole.xml


### **‚öôÔ∏è Create Robot Configuration**

In [5]:
%%writefile /content/mjlab/src/mjlab/asset_zoo/robots/cartpole/cartpole_constants.py
from pathlib import Path
import mujoco

from mjlab import MJLAB_SRC_PATH
from mjlab.entity import Entity, EntityCfg, EntityArticulationInfoCfg
from mjlab.actuator import XmlVelocityActuatorCfg

CARTPOLE_XML: Path = (
  MJLAB_SRC_PATH / "asset_zoo" / "robots" / "cartpole" / "xmls" / "cartpole.xml"
)
assert CARTPOLE_XML.exists(), f"XML not found: {CARTPOLE_XML}"

def get_spec() -> mujoco.MjSpec:
  return mujoco.MjSpec.from_file(str(CARTPOLE_XML))

def get_cartpole_robot_cfg() -> EntityCfg:
  """Get a fresh CartPole robot configuration instance."""
  actuators = (
    XmlVelocityActuatorCfg(
      joint_names_expr=("slide",),
    ),
  )
  articulation = EntityArticulationInfoCfg(actuators=actuators)
  return EntityCfg(
    spec_fn=get_spec,
    articulation=articulation
  )

# if __name__ == "__main__":
#   import mujoco.viewer as viewer
#   robot = Entity(get_cartpole_robot_cfg())
#   viewer.launch(robot.spec.compile())

Writing /content/mjlab/src/mjlab/asset_zoo/robots/cartpole/cartpole_constants.py


In [6]:
# Create __init__.py for the cartpole robot package
%%writefile /content/mjlab/src/mjlab/asset_zoo/robots/cartpole/__init__.py
# Empty __init__.py to mark the directory as a Python package

Writing /content/mjlab/src/mjlab/asset_zoo/robots/cartpole/__init__.py


In [7]:
import sys

# Append src dir to python path
mjlab_src = '/content/mjlab/src'
if mjlab_src not in sys.path:
    sys.path.insert(0, mjlab_src)
    print(f"‚úì Added {mjlab_src} to Python path")

‚úì Added /content/mjlab/src to Python path


### **‚úÖ Verify Robot Setup**

Let's test that the robot can be loaded correctly.

In [8]:
from mjlab.entity import Entity
from mjlab.asset_zoo.robots.cartpole.cartpole_constants import get_cartpole_robot_cfg

# Load the robot
robot = Entity(get_cartpole_robot_cfg())
model = robot.spec.compile()

# Display robot information
print("‚úì CartPole robot loaded successfully!")
print(f"  ‚Ä¢ Degrees of Freedom (DOF): {model.nv}")
print(f"  ‚Ä¢ Number of Actuators: {model.nu}")
print(f"  ‚Ä¢ Bodies: {model.nbody}")
print(f"  ‚Ä¢ Joints: {model.njnt}")

‚úì CartPole robot loaded successfully!
  ‚Ä¢ Degrees of Freedom (DOF): 2
  ‚Ä¢ Number of Actuators: 1
  ‚Ä¢ Bodies: 4
  ‚Ä¢ Joints: 2


### **üìã Register the Robot**

Add the CartPole robot to the asset zoo registry.

In [9]:
# Add CartPole import to robots __init__.py
with open('/content/mjlab/src/mjlab/asset_zoo/robots/__init__.py', 'a') as f:
    f.write('\n# CartPole robot\n')
    f.write('from mjlab.asset_zoo.robots.cartpole.cartpole_constants import ')
    f.write('get_cartpole_robot_cfg as get_cartpole_robot_cfg\n')

print("‚úì CartPole robot registered in asset zoo")

‚úì CartPole robot registered in asset zoo


---

## **üéØ Step 2: Define the Task (MDP)**

Now we'll define the Markov Decision Process:
- **Observations**: pole angle, angular velocity, cart position, cart velocity
- **Actions**: cart velocity commands
- **Rewards**: upright reward + effort penalty
- **Terminations**: pole tips over or timeout
- **Events**: random pushes for robustness

### **üìÅ Create Task Directory**

In [10]:
!mkdir -p /content/mjlab/src/mjlab/tasks/cartpole

print("‚úì Task directory created")

‚úì Task directory created


### **üìù Create Environment Configuration**

This file contains the MDP (Markov Decision Process) components:
1. **Scene Config**: 64 parallel environments
2. **Actions**: Joint velocity control with 20.0 scale
3. **Observations**: Normalized state variables
4. **Rewards**: Upright reward (5.0) + effort penalty (-0.01)
5. **Events**: Joint resets + random pushes
6. **Terminations**: Pole tipped (>30¬∞) or timeout (10s)

In [11]:
%%writefile /content/mjlab/src/mjlab/tasks/cartpole/env_cfg.py
"""CartPole task environment configuration."""

import math
import torch

from mjlab.envs import ManagerBasedRlEnvCfg
from mjlab.envs.mdp.actions import JointVelocityActionCfg
from mjlab.managers.manager_term_config import (
  ObservationGroupCfg,
  ObservationTermCfg,
  RewardTermCfg,
  TerminationTermCfg,
  EventTermCfg,
)
from mjlab.managers.scene_entity_config import SceneEntityCfg
from mjlab.scene import SceneCfg
from mjlab.sim import MujocoCfg, SimulationCfg
from mjlab.viewer import ViewerConfig
from mjlab.asset_zoo.robots.cartpole.cartpole_constants import get_cartpole_robot_cfg
from mjlab.envs import mdp


def cartpole_env_cfg(play: bool = False) -> ManagerBasedRlEnvCfg:
  """Create CartPole environment configuration.

  Args:
    play: If True, disables corruption and extends episode length for evaluation.
  """

  # ==============================================================================
  # Scene Configuration
  # ==============================================================================

  scene_cfg = SceneCfg(
    num_envs=64 if not play else 16,  # Fewer envs for play mode
    extent=1.0,   # Spacing between environments
    entities={"robot": get_cartpole_robot_cfg()},
  )

  viewer_cfg = ViewerConfig(
    origin_type=ViewerConfig.OriginType.ASSET_BODY,
    asset_name="robot",
    body_name="pole",
    distance=3.0,
    elevation=10.0,
    azimuth=90.0,
  )

  sim_cfg = SimulationCfg(
    mujoco=MujocoCfg(
      timestep=0.02,  # 50 Hz control
      iterations=1,
    ),
  )

  # ==============================================================================
  # Actions
  # ==============================================================================

  actions = {
    "joint_pos": JointVelocityActionCfg(
      asset_name="robot",
      actuator_names=(".*",),
      scale=20.0,
      use_default_offset=False,
    ),
  }

  # ==============================================================================
  # Observations
  # ==============================================================================

  policy_terms = {
    "angle": ObservationTermCfg(
      func=lambda env: env.sim.data.qpos[:, 1:2] / math.pi
    ),
    "ang_vel": ObservationTermCfg(
      func=lambda env: env.sim.data.qvel[:, 1:2] / 5.0
    ),
    "cart_pos": ObservationTermCfg(
      func=lambda env: env.sim.data.qpos[:, 0:1] / 2.0
    ),
    "cart_vel": ObservationTermCfg(
      func=lambda env: env.sim.data.qvel[:, 0:1] / 20.0
    ),
  }

  observations = {
    "policy": ObservationGroupCfg(
      terms=policy_terms,
      concatenate_terms=True,
      enable_corruption=not play,  # Disable corruption in play mode
    ),
    "critic": ObservationGroupCfg(
      terms=policy_terms,  # Critic uses same observations
      concatenate_terms=True,
      enable_corruption=False,
    ),
  }

  # ==============================================================================
  # Rewards
  # ==============================================================================

  def compute_upright_reward(env):
    """Reward for keeping pole upright (cosine of angle)."""
    return env.sim.data.qpos[:, 1].cos()

  def compute_effort_penalty(env):
    """Penalty for control effort."""
    return -0.01 * (env.sim.data.ctrl[:, 0] ** 2)

  rewards = {
    "upright": RewardTermCfg(func=compute_upright_reward, weight=5.0),
    "effort": RewardTermCfg(func=compute_effort_penalty, weight=1.0),
  }

  # ==============================================================================
  # Events
  # ==============================================================================

  def random_push_cart(env, env_ids, force_range=(-5, 5)):
    """Apply random force to cart for robustness training."""
    n = len(env_ids)
    random_forces = (
      torch.rand(n, device=env.device) *
      (force_range[1] - force_range[0]) +
      force_range[0]
    )
    env.sim.data.qfrc_applied[env_ids, 0] = random_forces

  events = {
    "reset_robot_joints": EventTermCfg(
      func=mdp.reset_joints_by_offset,
      mode="reset",
      params={
        "asset_cfg": SceneEntityCfg("robot"),
        "position_range": (-0.1, 0.1),
        "velocity_range": (-0.1, 0.1),
      },
    ),
  }

  # Add random pushes only in training mode
  if not play:
    events["random_push"] = EventTermCfg(
      func=random_push_cart,
      mode="interval",
      interval_range_s=(1.0, 2.0),
      params={"force_range": (-20.0, 20.0)},
    )

  # ==============================================================================
  # Terminations
  # ==============================================================================

  def check_pole_tipped(env):
    """Check if pole has tipped beyond 30 degrees."""
    return env.sim.data.qpos[:, 1].abs() > math.radians(30)

  terminations = {
    "timeout": TerminationTermCfg(func=mdp.time_out, time_out=True),
    "tipped": TerminationTermCfg(func=check_pole_tipped, time_out=False),
  }

  # ==============================================================================
  # Environment Configuration
  # ==============================================================================

  return ManagerBasedRlEnvCfg(
    scene=scene_cfg,
    observations=observations,
    actions=actions,
    rewards=rewards,
    events=events,
    terminations=terminations,
    sim=sim_cfg,
    viewer=viewer_cfg,
    decimation=1,           # No action repeat
    episode_length_s=int(1e9) if play else 10.0,  # Infinite for play, 10s for training
  )

Writing /content/mjlab/src/mjlab/tasks/cartpole/env_cfg.py


### **‚öôÔ∏è Create RL Configuration**

This file defines the PPO (Proximal Policy Optimization) training parameters.

In [12]:
%%writefile /content/mjlab/src/mjlab/tasks/cartpole/rl_cfg.py
"""RL configuration for CartPole task."""

from mjlab.rl.config import (
  RslRlOnPolicyRunnerCfg,
  RslRlPpoActorCriticCfg,
  RslRlPpoAlgorithmCfg,
)


def cartpole_ppo_runner_cfg() -> RslRlOnPolicyRunnerCfg:
  """Create RL runner configuration for CartPole task."""
  return RslRlOnPolicyRunnerCfg(
    policy=RslRlPpoActorCriticCfg(
      init_noise_std=1.0,
      actor_obs_normalization=True,
      critic_obs_normalization=True,
      actor_hidden_dims=(256, 128, 64),  # Smaller network for simpler task
      critic_hidden_dims=(256, 128, 64),
      activation="elu",
    ),
    algorithm=RslRlPpoAlgorithmCfg(
      value_loss_coef=1.0,
      use_clipped_value_loss=True,
      clip_param=0.2,
      entropy_coef=0.01,
      num_learning_epochs=5,
      num_mini_batches=4,
      learning_rate=1.0e-3,
      schedule="adaptive",
      gamma=0.99,
      lam=0.95,
      desired_kl=0.01,
      max_grad_norm=1.0,
    ),
    experiment_name="cartpole",
    save_interval=50,
    num_steps_per_env=24,
    max_iterations=5_000,  # Fewer iterations for simpler task
  )

Writing /content/mjlab/src/mjlab/tasks/cartpole/rl_cfg.py


### **üìã Register the Task Environment**

Register the CartPole task with mjlab registry.

In [13]:
%%writefile /content/mjlab/src/mjlab/tasks/cartpole/__init__.py
"""CartPole task registration."""

from mjlab.tasks.registry import register_mjlab_task
from mjlab.tasks.velocity.rl import VelocityOnPolicyRunner

from .env_cfg import cartpole_env_cfg
from .rl_cfg import cartpole_ppo_runner_cfg

register_mjlab_task(
  task_id="Mjlab-Cartpole",
  env_cfg=cartpole_env_cfg(),
  play_env_cfg=cartpole_env_cfg(play=True),
  rl_cfg=cartpole_ppo_runner_cfg(),
  runner_cls=VelocityOnPolicyRunner,
)

Writing /content/mjlab/src/mjlab/tasks/cartpole/__init__.py


---

## **üöÄ Step 3: Train the Agent**

Now let's train a PPO policy to balance the CartPole!

**Training Configuration:**
- Algorithm: PPO (Proximal Policy Optimization)
- Parallel Environments: 64
- Episode Length: 10 seconds (500 steps @ 50Hz)
- Total Steps: ~5-10 million (adjust as needed)

**‚ö†Ô∏è You may need to create a project named "mjlab" on wandb UI manually, since google colab doesn't have permission to create a new project.**

In [25]:
# This will take several minutes depending on your training configuration
# !uv run train Mjlab-Cartpole --agent.max-iterations 1000 --agent.save-interval 300
# !python ./src/mjlab/scripts/train.py Mjlab-Cartpole --help
!python /content/mjlab/src/mjlab/scripts/train.py Mjlab-Cartpole --agent.max-iterations 1000 --agent.save-interval 300

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
                           Time elapsed: 00:02:57
                                    ETA: 00:00:52

################################################################################
[1m                          Learning iteration 773/1000                           [0m 

                            Total steps: 1188864 
                       Steps per second: 6822 
                        Collection time: 0.115s 
                          Learning time: 0.110s 
                        Mean value loss: 0.0000
                    Mean surrogate loss: 0.0076
                      Mean entropy loss: -4.5735
                            Mean reward: 49.96
                    Mean episode length: 500.00
                  Mean action noise std: 0.00
                 Episode_Reward/upright: 4.9998
                  Episode_Reward/effort: -0.0042
            Episode_Termination/timeout: 1.0000
             Episode_Termination/tip

### **üìÅ Locate Training Checkpoints**

After training, checkpoints are saved locally.

In [26]:
import os
from pathlib import Path

# Find the most recent training run
log_dir = Path("logs/rsl_rl/cartpole")
if log_dir.exists():
    runs = sorted(log_dir.glob("*"), key=os.path.getmtime, reverse=True)
    if runs:
        latest_run = runs[0]
        print(f"‚úì Latest training run: {latest_run.name}\n")

        # List checkpoints
        checkpoints = sorted(latest_run.glob("model_*.pt"))
        if checkpoints:
            print(f"Found {len(checkpoints)} checkpoints:")
            for ckpt in checkpoints[-5:]:  # Show last 5
                size_mb = ckpt.stat().st_size / (1024 * 1024)
                print(f"  ‚Ä¢ {ckpt.name} ({size_mb:.2f} MB)")

            # Store the last checkpoint path
            last_checkpoint = str(checkpoints[-1])
            print(f"\nüíæ Last checkpoint: {last_checkpoint}")
        else:
            print("‚ö† No checkpoints found yet")
    else:
        print("‚ö† No training runs found")
else:
    print("‚ö† Log directory not found. Have you run training yet?")

‚úì Latest training run: 2025-12-05_17-47-56

Found 5 checkpoints:
  ‚Ä¢ model_0.pt (0.99 MB)
  ‚Ä¢ model_300.pt (1.00 MB)
  ‚Ä¢ model_600.pt (1.00 MB)
  ‚Ä¢ model_900.pt (1.00 MB)
  ‚Ä¢ model_999.pt (1.00 MB)

üíæ Last checkpoint: logs/rsl_rl/cartpole/2025-12-05_17-47-56/model_999.pt


---

## **üéÆ Step 4: Visualize the Trained Policy**

Let's see the trained policy in action!

### **üåê Launch the Viser Viewer API**

In [27]:
import subprocess
import sys

process = subprocess.Popen(
  [
    "python",
    "/content/mjlab/src/mjlab/scripts/play.py",
    "Mjlab-Cartpole",
    "--checkpoint_file",
    last_checkpoint,
    "--num_envs",
    "4",
  ],
  stdout=subprocess.PIPE,
  stderr=subprocess.STDOUT,
  universal_newlines=True,
  bufsize=1,
)

for line in process.stdout:
  print(line, end="")
  sys.stdout.flush()

  if "serving" in line.lower() or "running on" in line.lower() or "8081" in line:
    print("\n" + "=" * 52)
    print("‚úÖ Server is running! Execute the next cell to view.")
    print("=" * 52)
    break

[INFO]: Loading checkpoint: model_999.pt
Warp 1.10.1 initialized:
   CUDA Toolkit 12.8, Driver 12.4
   Devices:
     "cpu"      : "x86_64"
     "cuda:0"   : "Tesla T4" (15 GiB, sm_75, mempool enabled)
   Kernel cache:
     /root/.cache/warp/1.10.1
Module mujoco_warp._src.smooth 9ca7ec0 load on device 'cuda:0' took 4.05 ms  (cached)
Module mujoco_warp._src.collision_driver e72006d load on device 'cuda:0' took 0.33 ms  (cached)
Module _nxn_broadphase__locals__kernel_1799b5b8 1799b5b load on device 'cuda:0' took 0.32 ms  (cached)
Module mujoco_warp._src.collision_primitive._create_narrowphase_kernel f53bec7 load on device 'cuda:0' took 2.62 ms  (cached)
Module mujoco_warp._src.constraint fa42ba8 load on device 'cuda:0' took 1.37 ms  (cached)
Module _actuator_velocity__locals__actuator_velocity_7933d235 876a329 load on device 'cuda:0' took 0.53 ms  (cached)
Module mujoco_warp._src.passive fc4f8e1 load on device 'cuda:0' took 0.77 ms  (cached)
Module mujoco_warp._src.forward a88f545 load on

In [28]:
from google.colab import output

output.serve_kernel_port_as_iframe(
    port=8081,
    height=700
)

<IPython.core.display.Javascript object>

### **üìπ Generate Video Recording**

Record a video of the trained policy for visualization as `.viser` format.