# Installations and Imports

In [1]:
%%capture
!pip install mujoco mujoco_mjx brax playground
!command -v ffmpeg >/dev/null || (apt update && apt install -y ffmpeg)
!pip install -q mediapy
!apt-get update
!apt-get install -y libosmesa6-dev libgl1-mesa-glx libglfw3 patchelf

In [2]:
%pip install -q -U 'google-generativeai>=0.8.3'

In [1]:
import jax
import numpy as np
import mediapy
import os
from tqdm import tqdm

os.environ["MUJOCO_GL"] = "egl"

# More legible printing from numpy.
np.set_printoptions(precision=3, suppress=True, linewidth=100)

# Clone Repository

In [2]:
!git clone https://github.com/shaoanlu/llm_mjx_go1_playground.git

fatal: destination path 'llm_mjx_go1_playground' already exists and is not an empty directory.


In [3]:
%cd llm_mjx_go1_playground

/content/llm_mjx_go1_playground


# Prepare simulation

### !!!!! BEWARE !!!!!
The following `mv` command overwrites the xml file installed in the mujoco_playground. The modified xml has wall objects added into the scene.

In [4]:
!mv examples/scene_mjx_flat_terrain_maze.xml /usr/local/lib/python3.11/dist-packages/mujoco_playground/_src/locomotion/go1/xmls/scene_mjx_feetonly_flat_terrain.xml

mv: cannot stat 'examples/scene_mjx_flat_terrain_maze.xml': No such file or directory


In [5]:
# import mujoco after setting MUJOCO_GL to prevent errors in video rendering
import mujoco

In [6]:
from typing import List
from mujoco_playground._src import mjx_env

from src.environment.env_wrapper import Go1Env
from src.control.controller_factory import ControllerFactory
from src.control.position_controller import PositionController, PositionCommand
from src.control.algorithms.mlp import MLPPolicy, MLPPolicyParams
from src.mission_executer import EpisodeResult, MissionConfig, MissionExecuter
from src.control.position_controller import (
    create_position_controller,
    PositionControllerParams,
    SequentialControllerParams,
    PolarCoordinateControllerParams,
)

In [7]:
velocity_kick_range = [0.0, 0.0]  # Disable velocity kick.
kick_duration_range = [0.05, 0.2]


def sample_pert(rng, env, state):
    """
    Random purterbation to the robot
    """
    rng, key1, key2 = jax.random.split(rng, 3)
    pert_mag = jax.random.uniform(key1, minval=velocity_kick_range[0], maxval=velocity_kick_range[1])
    duration_seconds = jax.random.uniform(key2, minval=kick_duration_range[0], maxval=kick_duration_range[1])
    duration_steps = jax.numpy.round(duration_seconds / env.dt).astype(jax.numpy.int32)
    state.info["pert_mag"] = pert_mag
    state.info["pert_duration"] = duration_steps
    state.info["pert_duration_seconds"] = duration_seconds
    return rng

## Setup Gemini Client

In [8]:
from google import genai
from google.colab import userdata


client = genai.Client(api_key=userdata.get("GOOGLE_API_KEY"), http_options={"api_version": "v1alpha"})

## Create Traversability Map And LLM Navigator

In [9]:
from src.environment.traversability_map import TraversabilityMap, TraversabilityMapConfig

trav_map_config = TraversabilityMapConfig()
trav_map = TraversabilityMap(trav_map_config)
trav_map.load_from_image("assets/floor1.jpg")

array([[1., 0., 0., 0., 0.],
       [1., 1., 1., 1., 0.],
       [1., 1., 0., 1., 0.],
       [1., 0., 0., 1., 0.],
       [0., 0., 1., 1., 1.]])

## Position Control Helper Functions

In [10]:
def update_position_history(position_history, current_position):
    # Only update position_history is the robot enters a new cell (by comparing position in integer)
    if len(position_history) == 0:
        position_history.append(current_position.astype(np.int32))
        return position_history
    if np.array_equal(position_history[-1].astype(np.int32), current_position.astype(np.int32)):
        return position_history
    else:
        position_history.append(current_position.astype(np.int32))
    return position_history


def run_one_episode(
    env: Go1Env,
    command_generator: PositionController,
    command_follower: MLPPolicy,
    traversability_map: TraversabilityMap,
    waypoints: List[np.ndarray],
    max_sim_steps: int,
    rng: jax.Array,
) -> EpisodeResult:
    rollout: List[mjx_env.State] = []
    position_history: List[np.ndarray] = []
    waypoint_idx: int = 0

    _, rng = jax.random.split(rng)
    state = env.reset(rng)

    # start closed-loop sim
    for i in tqdm(range(max_sim_steps)):
        _, rng = jax.random.split(rng)  # this will randomize perturbation
        rng = sample_pert(rng, env, state)

        # calculate controll command
        cell_center_offset = 0.5
        target_position = waypoints[waypoint_idx] + cell_center_offset
        pos_command: PositionCommand = command_generator.compute_command(state, target_position)
        state.info["command"] = pos_command.command

        ctrl = command_follower.control(state.obs["state"])  # controller step
        state = env.step(state, ctrl)  # simulator step

        # record
        rollout.append(state)
        position_history = update_position_history(position_history, pos_command.info.pos[:2])

        # check failure
        if not traversability_map.is_valid_position(pos_command.info.pos[:2]):
            return EpisodeResult(status="Stop", position_history=position_history, rollout=rollout)

        # check arrival at current waypoint
        if pos_command.info.is_arrived:
            waypoint_idx = min(waypoint_idx + 1, len(waypoints))  # proceed to next waypoint

            # Check if the arrived at the last waypoint
            if waypoint_idx == len(waypoints):
                return EpisodeResult(status="Success", position_history=position_history, rollout=rollout)

    return EpisodeResult(status="Timeout", position_history=position_history, rollout=rollout)

## Demo: Nagivating a go1 quadruped through a maze with Gemini-thinking Model

The following picture show the flowchart of the navigation task.
![](https://github.com/shaoanlu/llm_mjx_go1_playground/raw/main/assets/llm_go1_navigation.png)

In [11]:
class DummyPlanner:
    """A planner that only returns a fixed list of waypoints, for debugging purpose"""

    def __init__(self):
        pass

    def plan(self, **kwargs) -> List[np.ndarray]:
        return [np.array([0, 0]), np.array([1, 0]), np.array([1, 3]), np.array([4, 3]), np.array([4, 4])]

    def reset_chat(self):
        pass


dummy_navigator = DummyPlanner()

### LLM navigator
`MissionExecuter` calls the `.plan(prompt=...)` method of the LLM navigator to get the suggest list of waypoints

In [12]:
from src.planning.llm_nagivator import GeminiThinkingNavigator

llm_navigator = GeminiThinkingNavigator(client, model_name="gemini-2.0-flash-thinking-exp")

In [13]:
# instantiate mujoco Env
env_name = "Go1JoystickFlatTerrain"
rng = jax.random.PRNGKey(0)
env = Go1Env(env_name=env_name)

### Controller instantiations

In [20]:
# Instantiate controller based on env_name
factory = ControllerFactory()
controller_config = {"algorithm_type": "mlp", "npy_path": f"src/control/nn_params/{env_name}"}
mlp_params = MLPPolicyParams.from_dict(controller_config)
command_follower = factory.build(params=mlp_params)

pc_config = PositionControllerParams(
    # PolarCoordinateControllerParams() is faster but can deviate from the straight path
    primary_controller=SequentialControllerParams(),
    fallback_controller=SequentialControllerParams(),
)
command_generator = create_position_controller(controller_factory=factory, config=pc_config)

# Instantiate the ochestrator for mission execution
instruction = open("examples/prompt.txt", "r").read()
mission_config = MissionConfig()
mission_config.max_attempts = 20
mission_config.max_sim_steps = 1000
mission_executer = MissionExecuter(config=mission_config, instruciton_prompt=instruction)
mission_config

MissionConfig(goal=(4, 4), max_sim_steps=1000, retry_delay_sec=5, max_attempts=20)

In [15]:
from functools import partial

# Create a partially applied function with fixed arguments for closed-loop simulation
run_one_episode_func = partial(
    run_one_episode,
    env=env,
    command_generator=command_generator,
    command_follower=command_follower,
    traversability_map=trav_map,
    rng=rng,
)

## Start Navigation

I takes some time for the first simulation step. During simulation, the intermediate reuslt of attemps will be printed out, including the prompt, the waypoints suggested by LLM, and the result.

In [16]:
# dummy_navigator for debugging
# result = mission_executer.execute_mission(planner=dummy_navigator, execute_single_attempt=run_one_episode_func)

result = mission_executer.execute_mission(planner=llm_navigator, execute_single_attempt=run_one_episode_func)

  8%|▊         | 81/1000 [01:02<11:53,  1.29it/s]


[Trial 1]
prompt='You are a path planner for a quadruped robot navigating in a partially known 5x5 grid environment. Your task is to guide the robot from the start position (0,0) to the goal position (4,4).\n\nEnvironment specifications:\n- Grid size: 5x5\n- Start position: (0,0)\n- Goal position: (4,4)\n- Some cells are not traversable, but their status is only revealed when the robot attempts to traverse them\n- The traversability of a cell is probabilistic - failed attempts don\'t guarantee the cell is completely untraversable\n\nPath planning constraints:\n1. Maximum 10 waypoints allowed per path\n2. Consecutive waypoints moving in the same direction must be simplified\n   Example: [(0,0), (1,0), (2,0), (2,1)] should be simplified to [(0,0), (2,0), (2,1)]\n\nExpected output format:\nProvide only a list of coordinates in the format:\n[(x0,y0), (x1,y1), ..., (xn,yn)]\n\nPossible execution outcomes:\n1. Success: Robot reaches the goal\n2. Failed (Obstacle): "Failed: stop at (X,Y), tra

 20%|██        | 202/1000 [00:04<00:18, 43.20it/s]


[Trial 2]
prompt='Failed: Stop at (0, 1), traversed cells: [(0, 0), (0, 1)]'
waypoints=[array([0, 0]), array([4, 0]), array([4, 1]), array([4, 2]), array([4, 3]), array([4, 4]), array([4, 4])]
result.status='Stop'(4, 0) 	{result.position_history=}



 24%|██▎       | 236/1000 [00:07<00:24, 31.07it/s]


[Trial 3]
prompt='Failed: Stop at (4, 0), traversed cells: [(0, 0), (1, 0), (2, 0), (3, 0), (4, 0)]'
waypoints=[array([0, 0]), array([3, 0]), array([3, 1]), array([4, 1]), array([4, 4]), array([4, 4])]
result.status='Stop'(3, 1) 	{result.position_history=}



 23%|██▎       | 230/1000 [00:04<00:15, 48.19it/s]


[Trial 4]
prompt='Failed: Stop at (3, 1), traversed cells: [(0, 0), (1, 0), (2, 0), (3, 0), (3, 1)]'
waypoints=[array([0, 0]), array([2, 0]), array([2, 2]), array([2, 4]), array([3, 4]), array([4, 4]), array([4, 4])]
result.status='Stop'(2, 2) 	{result.position_history=}



 41%|████      | 409/1000 [00:07<00:11, 51.65it/s]


[Trial 5]
prompt='Failed: Stop at (2, 2), traversed cells: [(0, 0), (1, 0), (2, 0), (2, 1), (2, 2)]'
waypoints=[array([0, 0]), array([2, 0]), array([2, 1]), array([1, 1]), array([1, 4]), array([4, 4]), array([4, 4])]
result.status='Stop'(1, 4) 	{result.position_history=}



 60%|██████    | 603/1000 [00:12<00:08, 47.41it/s]

[Trial 6]
prompt='Failed: Stop at (1, 4), traversed cells: [(0, 0), (1, 0), (2, 0), (2, 1), (1, 1), (1, 2), (1, 3), (1, 4)]'
waypoints=[array([0, 0]), array([2, 0]), array([2, 1]), array([1, 1]), array([1, 3]), array([4, 3]), array([4, 4]), array([4, 4])]
result.status='Success'(4, 4) 	{result.position_history=}






In [17]:
# Set up the camera
camera = mujoco.MjvCamera()
mujoco.mjv_defaultCamera(camera)

# Object is centered at (0,0,1.5) with size 15m x 15m x 3m
camera.lookat[:] = np.array([2.0, 1.5, 1.5])  # Center of object

# Calculate distance using a safe margin
object_size = 6  # Maximum size in X or Y
fov = 45  # Default MuJoCo camera FOV
safe_margin = 1  # Factor to ensure entire object is in view

# Compute required distance using field of view
camera.distance = (object_size / 2) / np.tan(np.radians(fov / 2)) * safe_margin

# Look directly down
camera.azimuth = 90  # No horizontal rotation
camera.elevation = -70  # Directly downward

In [None]:
render_every = 5
fps = 3.0 / env.dt / render_every  # 3x realtime
traj = result.rollouts[::render_every]

scene_option = mujoco.MjvOption()
scene_option.geomgroup[2] = True
scene_option.geomgroup[3] = False
scene_option.flags[mujoco.mjtVisFlag.mjVIS_CONTACTPOINT] = True
scene_option.flags[mujoco.mjtVisFlag.mjVIS_CONTACTFORCE] = False
scene_option.flags[mujoco.mjtVisFlag.mjVIS_TRANSPARENT] = False

frames = env.render(
    traj,
    camera=camera,
    scene_option=scene_option,
    height=480,
    width=640,
)
mediapy.show_video(frames, fps=fps, loop=False)