<a href="https://colab.research.google.com/github/ttktjmt/mjlab/blob/main/notebooks/create_new_task.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **ü§ñ CartPole Tutorial with MJLab**

This notebook demonstrates how to create a custom reinforcement learning task using MJLab. We'll build a CartPole environment from scratch, including:

1. **Robot Definition** - Define the CartPole model in MuJoCo XML
2. **Task Configuration** - Set up observations, actions, rewards, and terminations
3. **Training** - Train a policy using PPO
4. **Evaluation** - Visualize/Record the trained policy

> **Note**: This tutorial is created based on the official MJLab documentation.

## **üì¶ Setup and Installation**

In [1]:
# Clone the mjlab repository
!if [ ! -d "mjlab" ]; then git clone -q https://github.com/mujocolab/mjlab.git; fi
%cd /content/mjlab

# Install mjlab in editable mode
!uv pip install --system -e . -q

print("‚úì Installation complete!")

/content/mjlab
‚úì Installation complete!


### **üîë WandB Setup (Optional)**

Configure Weights & Biases for experiment tracking. Add your WandB API key to Colab Secrets:
- `WANDB_API_KEY`: from [wandb.ai/authorize](https://wandb.ai/authorize)
- `WANDB_ENTITY`: your organization name

In [2]:
import os
from google.colab import userdata

try:
    os.environ['WANDB_API_KEY'] = userdata.get('WANDB_API_KEY')
    os.environ['WANDB_ENTITY'] = userdata.get('WANDB_ENTITY')
    print("‚úì WandB configured successfully!")
except (AttributeError, KeyError):
    print("‚ö† WandB secrets not found. Training will proceed without WandB logging.")

‚úì WandB configured successfully!


---

## **ü§ñ Step 1: Define the Robot**

We'll create a simple CartPole robot with:
- A sliding cart (1 DOF)
- A hinged pole (1 DOF)
- A velocity actuator to control the cart

### **üìÅ Structure Directories**

In [3]:
# Create the cartpole robot directory structure
!mkdir -p /content/mjlab/src/mjlab/asset_zoo/robots/cartpole/
!mkdir -p /content/mjlab/src/mjlab/asset_zoo/robots/cartpole/xmls

print("‚úì Directory structure created")

‚úì Directory structure created


### **üìù Create MuJoCo XML Model**

This XML defines the CartPole physics:
- **Ground plane** for visualization
- **Cart body** with a sliding joint (¬±2m range)
- **Pole body** with a hinge joint (¬±90¬∞ range)
- **Velocity actuator** for cart control

In [4]:
%%writefile /content/mjlab/src/mjlab/asset_zoo/robots/cartpole/xmls/cartpole.xml
<mujoco model="cartpole">
  <compiler angle="degree" coordinate="local" inertiafromgeom="true"/>
  <worldbody>
    <geom name="ground" type="plane" pos="0 0 0" size="5 5 0.1" rgba="0.8 0.9 0.8 1"/>
    <body name="cart" pos="0 0 0.1">
      <geom type="box" size="0.2 0.1 0.1" rgba="0.2 0.2 0.8 1" mass="1.0"/>
      <joint name="slide" type="slide" axis="1 0 0" limited="true" range="-2 2"/>
      <body name="pole" pos="0 0 0.1">
        <geom type="capsule" size="0.05 0.5" fromto="0 0 0 0 0 1" rgba="0.8 0.2 0.2 1" mass="2.0"/>
        <joint name="hinge" type="hinge" axis="0 1 0" range="-90 90"/>
      </body>
    </body>
  </worldbody>
  <actuator>
    <velocity name="slide_velocity" joint="slide" ctrlrange="-20 20" kv="20"/>
  </actuator>
</mujoco>

Writing /content/mjlab/src/mjlab/asset_zoo/robots/cartpole/xmls/cartpole.xml


### **‚öôÔ∏è Create Robot Configuration**

In [5]:
%%writefile /content/mjlab/src/mjlab/asset_zoo/robots/cartpole/cartpole_constants.py
# from pathlib import Path
# import mujoco

# from mjlab import MJLAB_SRC_PATH
# from mjlab.entity import Entity, EntityCfg

# CARTPOLE_XML: Path = (
#   MJLAB_SRC_PATH / "asset_zoo" / "robots" / "cartpole" / "xmls" / "cartpole.xml"
# )
# assert CARTPOLE_XML.exists(), f"XML not found: {CARTPOLE_XML}"

# def get_spec() -> mujoco.MjSpec:
#   return mujoco.MjSpec.from_file(str(CARTPOLE_XML))

# def get_cartpole_robot_cfg() -> EntityCfg:
#   """Get a fresh CartPole robot configuration instance."""
#   return EntityCfg(spec_fn=get_spec)

# if __name__ == "__main__":
#   import mujoco.viewer as viewer
#   robot = Entity(get_cartpole_robot_cfg())
#   viewer.launch(robot.spec.compile())
from pathlib import Path
import mujoco

from mjlab import MJLAB_SRC_PATH
from mjlab.entity import Entity, EntityCfg, EntityArticulationInfoCfg
from mjlab.actuator import XmlVelocityActuatorCfg  # ‚Üê Use this import

CARTPOLE_XML: Path = (
  MJLAB_SRC_PATH / "asset_zoo" / "robots" / "cartpole" / "xmls" / "cartpole.xml"
)
assert CARTPOLE_XML.exists(), f"XML not found: {CARTPOLE_XML}"

def get_spec() -> mujoco.MjSpec:
  return mujoco.MjSpec.from_file(str(CARTPOLE_XML))

def get_cartpole_robot_cfg() -> EntityCfg:
  """Get a fresh CartPole robot configuration instance."""
  actuators = (
    XmlVelocityActuatorCfg(
      joint_names_expr=("slide",),  # Matches your XML actuator's target joint
    ),
  )
  articulation = EntityArticulationInfoCfg(actuators=actuators)  # ‚Üê Add this
  return EntityCfg(
    spec_fn=get_spec,
    articulation=articulation  # ‚Üê Add this
  )

if __name__ == "__main__":
  import mujoco.viewer as viewer
  robot = Entity(get_cartpole_robot_cfg())
  viewer.launch(robot.spec.compile())

Writing /content/mjlab/src/mjlab/asset_zoo/robots/cartpole/cartpole_constants.py


In [6]:
# Create __init__.py for the cartpole robot package
%%writefile /content/mjlab/src/mjlab/asset_zoo/robots/cartpole/__init__.py
# Empty __init__.py to mark the directory as a Python package

Writing /content/mjlab/src/mjlab/asset_zoo/robots/cartpole/__init__.py


In [7]:
import sys

# Append src dir to python path
mjlab_src = '/content/mjlab/src'
if mjlab_src not in sys.path:
    sys.path.insert(0, mjlab_src)
    print(f"‚úì Added {mjlab_src} to Python path")

‚úì Added /content/mjlab/src to Python path


### **‚úÖ Verify Robot Setup**

Let's test that the robot can be loaded correctly.

In [8]:
from mjlab.entity import Entity
from mjlab.asset_zoo.robots.cartpole.cartpole_constants import get_cartpole_robot_cfg

# Load the robot
robot = Entity(get_cartpole_robot_cfg())
model = robot.spec.compile()

# Display robot information
print("‚úì CartPole robot loaded successfully!")
print(f"  ‚Ä¢ Degrees of Freedom (DOF): {model.nv}")
print(f"  ‚Ä¢ Number of Actuators: {model.nu}")
print(f"  ‚Ä¢ Bodies: {model.nbody}")
print(f"  ‚Ä¢ Joints: {model.njnt}")

‚úì CartPole robot loaded successfully!
  ‚Ä¢ Degrees of Freedom (DOF): 2
  ‚Ä¢ Number of Actuators: 1
  ‚Ä¢ Bodies: 4
  ‚Ä¢ Joints: 2


### **üìã Register the Robot**

Add the CartPole robot to the asset zoo registry.

In [9]:
# Add CartPole import to robots __init__.py
with open('/content/mjlab/src/mjlab/asset_zoo/robots/__init__.py', 'a') as f:
    f.write('\n# CartPole robot\n')
    f.write('from mjlab.asset_zoo.robots.cartpole.cartpole_constants import ')
    f.write('get_cartpole_robot_cfg as get_cartpole_robot_cfg\n')

print("‚úì CartPole robot registered in asset zoo")

‚úì CartPole robot registered in asset zoo


---

## **üéØ Step 2: Define the Task (MDP)**

Now we'll define the Markov Decision Process:
- **Observations**: pole angle, angular velocity, cart position, cart velocity
- **Actions**: cart velocity commands
- **Rewards**: upright reward + effort penalty
- **Terminations**: pole tips over or timeout
- **Events**: random pushes for robustness

### **üìÅ Create Task Directory**

In [10]:
!mkdir -p /content/mjlab/src/mjlab/tasks/cartpole

print("‚úì Task directory created")

‚úì Task directory created


### **üìù Create Environment Configuration**

This file contains all MDP components:
1. **Scene Config**: 64 parallel environments
2. **Actions**: Joint position control with 20.0 scale
3. **Observations**: Normalized state variables
4. **Rewards**: Upright reward (5.0) + effort penalty (-0.01)
5. **Events**: Joint resets + random pushes
6. **Terminations**: Pole tipped (>30¬∞) or timeout (10s)

In [11]:
# Separate this into
# 1. /content/mjlab/src/mjlab/tasks/cartpole/env_cfg.py
# 2. /content/mjlab/src/mjlab/tasks/cartpole/rl_cfg.py

%%writefile /content/mjlab/src/mjlab/tasks/cartpole/cartpole_env_cfg.py
"""CartPole task environment configuration."""

import math
import torch

from mjlab.envs import ManagerBasedRlEnvCfg
from mjlab.envs.mdp.actions import JointVelocityActionCfg, JointPositionActionCfg
from mjlab.managers.manager_term_config import (
  ObservationGroupCfg,
  ObservationTermCfg,
  RewardTermCfg,
  TerminationTermCfg,
  EventTermCfg,
)
from mjlab.managers.scene_entity_config import SceneEntityCfg
from mjlab.scene import SceneCfg
from mjlab.sim import MujocoCfg, SimulationCfg
from mjlab.viewer import ViewerConfig
from mjlab.asset_zoo.robots.cartpole.cartpole_constants import get_cartpole_robot_cfg
from mjlab.rl import RslRlOnPolicyRunnerCfg
from mjlab.envs import mdp

# ==============================================================================
# Scene Configuration
# ==============================================================================

SCENE_CFG = SceneCfg(
  num_envs=64,  # Number of parallel environments
  extent=1.0,   # Spacing between environments
  entities={"robot": get_cartpole_robot_cfg()},
)

VIEWER_CONFIG = ViewerConfig(
  origin_type=ViewerConfig.OriginType.ASSET_BODY,
  asset_name="robot",
  body_name="pole",
  distance=3.0,
  elevation=10.0,
  azimuth=90.0,
)

SIM_CFG = SimulationCfg(
  mujoco=MujocoCfg(
    timestep=0.02,  # 50 Hz control
    iterations=1,
  ),
)

# ==============================================================================
# Actions
# ==============================================================================

def create_cartpole_actions() -> dict[str, JointVelocityActionCfg]:
  """Create CartPole actions."""
  return {
    "slider": JointVelocityActionCfg(
      asset_name="robot",
      actuator_names=(".*",),
      scale=20.0,
      use_default_offset=False,
    ),
  }

# ==============================================================================
# Observations
# ==============================================================================

def create_cartpole_observations() -> dict[str, ObservationGroupCfg]:
  """Create CartPole observations."""
  policy_terms = {
    "angle": ObservationTermCfg(
      func=lambda env: env.sim.data.qpos[:, 1:2] / math.pi
    ),
    "ang_vel": ObservationTermCfg(
      func=lambda env: env.sim.data.qvel[:, 1:2] / 5.0
    ),
    "cart_pos": ObservationTermCfg(
      func=lambda env: env.sim.data.qpos[:, 0:1] / 2.0
    ),
    "cart_vel": ObservationTermCfg(
      func=lambda env: env.sim.data.qvel[:, 0:1] / 20.0
    ),
  }

  return {
    "policy": ObservationGroupCfg(
      terms=policy_terms,
      concatenate_terms=True,
    ),
    "critic": ObservationGroupCfg(
      terms=policy_terms,  # Critic uses same observations
      concatenate_terms=True,
    ),
  }

# ==============================================================================
# Rewards
# ==============================================================================

def compute_upright_reward(env):
  """Reward for keeping pole upright (cosine of angle)."""
  return env.sim.data.qpos[:, 1].cos()

def compute_effort_penalty(env):
  """Penalty for control effort."""
  return -0.01 * (env.sim.data.ctrl[:, 0] ** 2)

def create_cartpole_rewards() -> dict[str, RewardTermCfg]:
  """Create CartPole rewards."""
  return {
    "upright": RewardTermCfg(func=compute_upright_reward, weight=5.0),
    "effort": RewardTermCfg(func=compute_effort_penalty, weight=1.0),
  }

# ==============================================================================
# Events
# ==============================================================================

def random_push_cart(env, env_ids, force_range=(-5, 5)):
  """Apply random force to cart for robustness training."""
  n = len(env_ids)
  random_forces = (
    torch.rand(n, device=env.device) *
    (force_range[1] - force_range[0]) +
    force_range[0]
  )
  env.sim.data.qfrc_applied[env_ids, 0] = random_forces

def create_cartpole_events() -> dict[str, EventTermCfg]:
  """Create CartPole events."""
  return {
    "reset_robot_joints": EventTermCfg(
      func=mdp.reset_joints_by_offset,
      mode="reset",
      params={
        "asset_cfg": SceneEntityCfg("robot"),
        "position_range": (-0.1, 0.1),
        "velocity_range": (-0.1, 0.1),
      },
    ),
    "random_push": EventTermCfg(
      func=random_push_cart,
      mode="interval",
      interval_range_s=(1.0, 2.0),
      params={"force_range": (-20.0, 20.0)},
    ),
  }

# ==============================================================================
# Terminations
# ==============================================================================

def check_pole_tipped(env):
  """Check if pole has tipped beyond 30 degrees."""
  return env.sim.data.qpos[:, 1].abs() > math.radians(30)

def create_cartpole_terminations() -> dict[str, TerminationTermCfg]:
  """Create CartPole terminations."""
  return {
    "timeout": TerminationTermCfg(func=mdp.time_out, time_out=True),
    "tipped": TerminationTermCfg(func=check_pole_tipped, time_out=False),
  }

# ==============================================================================
# Environment Configuration
# ==============================================================================

def create_cartpole_env_cfg() -> ManagerBasedRlEnvCfg:
  """Create CartPole environment configuration."""
  return ManagerBasedRlEnvCfg(
    scene=SCENE_CFG,
    observations=create_cartpole_observations(),
    actions=create_cartpole_actions(),
    rewards=create_cartpole_rewards(),
    events=create_cartpole_events(),
    terminations=create_cartpole_terminations(),
    sim=SIM_CFG,
    viewer=VIEWER_CONFIG,
    decimation=1,           # No action repeat
    episode_length_s=10.0,  # 10 second episodes
  )

# Module-level constant for gymnasium registration
CARTPOLE_ENV_CFG = create_cartpole_env_cfg()

Writing /content/mjlab/src/mjlab/tasks/cartpole/cartpole_env_cfg.py


### **üìã Register the Task Environment**

Register the CartPole task with mjlab registry.

In [23]:
%%writefile /content/mjlab/src/mjlab/tasks/cartpole/__init__.py
from mjlab.tasks.cartpole.cartpole_env_cfg import CARTPOLE_ENV_CFG

from mjlab.tasks.registry import register_mjlab_task, list_tasks
from mjlab.tasks.tracking.rl import MotionTrackingOnPolicyRunner

from .env_cfgs import unitree_g1_flat_tracking_env_cfg
from .rl_cfg import unitree_g1_tracking_ppo_runner_cfg

register_mjlab_task(
  task_id="Mjlab-Cartpole",
  env_cfg=CARTPOLE_ENV_CFG,
  play_env_cfg=unitree_g1_flat_tracking_env_cfg(play=True),
  rl_cfg=f"{__name__}.cartpole_env_cfg:RslRlOnPolicyRunnerCfg",
  runner_cls=MotionTrackingOnPolicyRunner,
)

# Report task registry
if "Mjlab-Cartpole" in list_tasks():
    print("‚úì Mjlab-Cartpole successfully registered in mjlab registry.")
else:
    print("‚úó Failed to register Mjlab-Cartpole.")

#####

from mjlab.tasks.registry import register_mjlab_task
from mjlab.tasks.tracking.rl import MotionTrackingOnPolicyRunner

from .env_cfgs import unitree_g1_flat_tracking_env_cfg
from .rl_cfg import unitree_g1_tracking_ppo_runner_cfg

register_mjlab_task(
  task_id="Mjlab-Tracking-Flat-Unitree-G1",
  env_cfg=unitree_g1_flat_tracking_env_cfg(),
  play_env_cfg=unitree_g1_flat_tracking_env_cfg(play=True),
  rl_cfg=unitree_g1_tracking_ppo_runner_cfg(),
  runner_cls=MotionTrackingOnPolicyRunner,
)


Overwriting /content/mjlab/src/mjlab/tasks/cartpole/__init__.py


In [39]:
from mjlab.tasks.registry import register_mjlab_task, list_tasks
from mjlab.tasks.tracking.rl import MotionTrackingOnPolicyRunner

print("Available Tasks:")
for task in list_tasks():
    print(f"  {task}")

Available Tasks:
  Mjlab-Lift-Cube-Yam
  Mjlab-Tracking-Flat-Unitree-G1
  Mjlab-Tracking-Flat-Unitree-G1-No-State-Estimation
  Mjlab-Velocity-Flat-Unitree-G1
  Mjlab-Velocity-Flat-Unitree-Go1
  Mjlab-Velocity-Flat-Unitree-Go1-ActuatorNet
  Mjlab-Velocity-Rough-Unitree-G1
  Mjlab-Velocity-Rough-Unitree-Go1


In [15]:
# Add CartPole task import to tasks __init__.py
with open('/content/mjlab/src/mjlab/tasks/__init__.py', 'a') as f:
    f.write('\n# CartPole task\n')
    f.write('from mjlab.tasks import cartpole\n')

print("‚úì CartPole task registered")

‚úì CartPole task registered


In [16]:
# Reload mjlab as an editable package
%cd /content/mjlab
!uv pip install --system -e .

import importlib
import mjlab.tasks.cartpole

importlib.reload(mjlab.tasks)
importlib.reload(mjlab.tasks.cartpole)
importlib.reload(mjlab.asset_zoo.robots.cartpole)

/content/mjlab
[2mUsing Python 3.12.12 environment at: /usr[0m
[2K[2mResolved [1m130 packages[0m [2min 255ms[0m[0m
[2K[2mPrepared [1m1 package[0m [2min 79ms[0m[0m
[2mUninstalled [1m1 package[0m [2min 0.50ms[0m[0m
[2K[2mInstalled [1m1 package[0m [2min 3ms[0m[0m
 [33m~[39m [1mmjlab[0m[2m==0.1.0 (from file:///content/mjlab)[0m
Using device: cpu

‚úì Mjlab-Cartpole successfully registered in Gym registry.


<module 'mjlab.asset_zoo.robots.cartpole' from '/content/mjlab/src/mjlab/asset_zoo/robots/cartpole/__init__.py'>

### **‚úÖ Verify Environment Registration**

Let's test that the environment is properly registered and can be created.

In [17]:
# import gymnasium as gym
# from mjlab.tasks import cartpole

# # Check if environment is registered
# env_specs = gym.envs.registry
# if "Mjlab-Cartpole" in env_specs:
#     print("‚úì Mjlab-Cartpole environment successfully registered!\n")

#     # Create a test environment, passing the determined device
#     # The 'device' variable is already defined in a previous cell (1BEPKiedIlRh)
#     env = gym.make("Mjlab-Cartpole", headless=True, device=device)

#     print("Environment Details:")
#     print(f"  ‚Ä¢ Observation space: {env.observation_space}")
#     print(f"  ‚Ä¢ Action space: {env.action_space}")
#     print(f"  ‚Ä¢ Number of environments: {env.unwrapped.num_envs}")

#     # Test a step
#     obs, info = env.reset()
#     print(f"\n  ‚Ä¢ Observation shape: {obs['policy'].shape}")
#     print(f"  ‚Ä¢ Sample observation: {obs['policy'][0]}")

#     env.close()
#     print("\n‚úì Environment test completed successfully!")
# else:
#     print("‚úó Environment not found in registry")
#     print("Available environments:\n")
#     print(sorted(list(env_specs.keys())))


---

## **üöÄ Step 3: Train the Agent**

Now let's train a PPO policy to balance the CartPole!

**Training Configuration:**
- Algorithm: PPO (Proximal Policy Optimization)
- Parallel Environments: 64
- Episode Length: 10 seconds (500 steps @ 50Hz)
- Total Steps: ~5-10 million (adjust as needed)

> **Note**: Training in Colab runs in headless mode (no visualization). Progress will be logged to the console and optionally to WandB.

In [18]:
# Train the CartPole task
# This will take several minutes depending on your training configuration
# !uv run train Mjlab-Cartpole --system --headless --total_steps 3000000
!python /content/mjlab/src/mjlab/scripts/train.py Mjlab-Cartpole --headless --total_steps 3000000

Using device: cpu

‚úì Mjlab-Cartpole successfully registered in Gym registry.
[31m‚ï≠[0m[31m‚îÄ[0m[m[0m[m [0m[31;1mInvalid[0m[31;1m choice[0m[m[0m[m [0m[31m‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚ïÆ[0m
[31m‚îÇ[0m [minvalid[0m[m choice[0m[m 'Mjlab-Cartpole'[0m[m for[0m[m argument[0m[m '{Mjlab-Lift-Cube-Yam,         [0m [31m‚îÇ[0m
[31m‚îÇ[0m [mMjlab-Tracking-Flat-Unitree-G1,                                             [0m [31m‚îÇ[0m
[31m‚îÇ[0m [mMjlab-Tracking-Flat-Unitree-G1-No-State-Estimation,                         [0m [31m‚îÇ[0m
[31m‚îÇ[0m [mMjlab-Velocity-Flat-Unitree-G1,[0m[mMjlab-Velocity-Flat-Unitree-Go1,             [0m [31m‚îÇ[0m
[31m‚îÇ[0m [mMjlab-Velocity-Flat-Unitree-Go1-ActuatorNet,[0m[mMjlab-Velocity-Rough-Unitree-G1,[0m [31m‚îÇ[0m
[31m‚îÇ[0m [mMjlab-Velocity-Rough

### **üìä Monitor Training Progress**

If WandB is configured, you can monitor training in real-time.

In [19]:
import wandb

if 'WANDB_API_KEY' in os.environ:
    entity = os.environ.get('WANDB_ENTITY', 'your-entity')
    print(f"üìä WandB Dashboard: https://wandb.ai/{entity}")
    print("\nTraining metrics to watch:")
    print("  ‚Ä¢ Episode Reward Mean - Should increase over time")
    print("  ‚Ä¢ Episode Length Mean - Should approach max episode length")
    print("  ‚Ä¢ Policy Loss - Should stabilize")
    print("  ‚Ä¢ Value Loss - Should decrease")
else:
    print("‚ö† WandB not configured")
    print("Training logs are saved locally in: logs/rsl_rl/cartpole/")

üìä WandB Dashboard: https://wandb.ai/ttktjmt-org-org

Training metrics to watch:
  ‚Ä¢ Episode Reward Mean - Should increase over time
  ‚Ä¢ Episode Length Mean - Should approach max episode length
  ‚Ä¢ Policy Loss - Should stabilize
  ‚Ä¢ Value Loss - Should decrease


### **üìÅ Locate Training Checkpoints**

After training, checkpoints are saved locally.

In [20]:
import os
from pathlib import Path

# Find the most recent training run
log_dir = Path("logs/rsl_rl/cartpole")
if log_dir.exists():
    runs = sorted(log_dir.glob("*"), key=os.path.getmtime, reverse=True)
    if runs:
        latest_run = runs[0]
        print(f"‚úì Latest training run: {latest_run.name}\n")

        # List checkpoints
        checkpoints = sorted(latest_run.glob("model_*.pt"))
        if checkpoints:
            print(f"Found {len(checkpoints)} checkpoints:")
            for ckpt in checkpoints[-5:]:  # Show last 5
                size_mb = ckpt.stat().st_size / (1024 * 1024)
                print(f"  ‚Ä¢ {ckpt.name} ({size_mb:.2f} MB)")

            # Store the best checkpoint path
            best_checkpoint = str(checkpoints[-1])
            print(f"\nüíæ Best checkpoint: {best_checkpoint}")
        else:
            print("‚ö† No checkpoints found yet")
    else:
        print("‚ö† No training runs found")
else:
    print("‚ö† Log directory not found. Have you run training yet?")

‚ö† Log directory not found. Have you run training yet?


---

## **üéÆ Step 4: Evaluate the Trained Policy**

Let's test the trained policy! Since we're in Colab (headless), we can:
1. Run the policy and print statistics
2. Generate a video recording (optional)

### **üéØ Run Policy Inference**

Replace `<checkpoint_path>` with your actual checkpoint path from above.

In [21]:
# IMPORTANT: Update this with your checkpoint path!
checkpoint_path = "logs/rsl_rl/cartpole/YYYY-MM-DD_HH-MM-SS/model_XXXX.pt"

# Uncomment and run after updating the path:
!uv run play Mjlab-Cartpole --system --checkpoint_file {checkpoint_path} --headless --num_envs 4

print("‚ÑπÔ∏è Update the checkpoint_path variable above with your actual checkpoint.")
print("   Then uncomment and run the cell.")

Using CPython [36m3.13.9[39m[36m[39m
Creating virtual environment at: [36m.venv[39m
[2K[2mInstalled [1m154 packages[0m [2min 1.32s[0m[0m
Traceback (most recent call last):
  File [35m"/content/mjlab/.venv/bin/play"[0m, line [35m4[0m, in [35m<module>[0m
    from mjlab.scripts.play import main
  File [35m"/content/mjlab/src/mjlab/scripts/play.py"[0m, line [35m15[0m, in [35m<module>[0m
    from mjlab.tasks.registry import list_tasks, load_env_cfg, load_rl_cfg, load_runner_cls
  File [35m"/content/mjlab/src/mjlab/tasks/__init__.py"[0m, line [35m5[0m, in [35m<module>[0m
    [31mimport_packages[0m[1;31m(__name__, _BLACKLIST_PKGS)[0m
    [31m~~~~~~~~~~~~~~~[0m[1;31m^^^^^^^^^^^^^^^^^^^^^^^^^^^[0m
  File [35m"/content/mjlab/src/mjlab/utils/lab_api/tasks/importer.py"[0m, line [35m40[0m, in [35mimport_packages[0m
    for _ in [31m_walk_packages[0m[1;31m([0m
             [31m~~~~~~~~~~~~~~[0m[1;31m^[0m
      [1;31mpackage.__path__, package.__na

In [22]:
from google.colab import output

output.serve_kernel_port_as_iframe(
    port=8081,
    height=700
)

<IPython.core.display.Javascript object>

### **üìπ Generate Video Recording**

Record a video of the trained policy for visualization as `.viser` format.