# Installations and Imports

In [1]:
%%capture
!pip install mujoco mujoco_mjx brax playground
!command -v ffmpeg >/dev/null || (apt update && apt install -y ffmpeg)
!pip install -q mediapy
!apt-get update
!apt-get install -y libosmesa6-dev libgl1-mesa-glx libglfw3 patchelf

In [16]:
%pip install -q -U 'google-generativeai>=0.8.3'

In [2]:
import jax
import numpy as np
import matplotlib.pyplot as plt
import mediapy
import os
from tqdm import tqdm

os.environ["MUJOCO_GL"] = "egl"

# More legible printing from numpy.
np.set_printoptions(precision=3, suppress=True, linewidth=100)

# Clone Repository

In [4]:
!git clone https://github.com/shaoanlu/control_system_project_template.git

fatal: destination path 'control_system_project_template' already exists and is not an empty directory.


In [5]:
%cd control_system_project_template

/content/control_system_project_template


In [6]:
!pwd

/content/control_system_project_template


In [7]:
!git clone https://github.com/shaoanlu/gemini_maze_exploration.git

fatal: destination path 'gemini_maze_exploration' already exists and is not an empty directory.


# Prepare simulation

### !!!!! BEWARE !!!!!
The following `mv` command overwrites the xml file installed in the mujoco_playground.

In [None]:
!mv gemini_maze_exploration/assets/scene_mjx_flat_terrain_maze.xml /usr/local/lib/python3.11/dist-packages/mujoco_playground/_src/locomotion/go1/xmls/scene_mjx_feetonly_flat_terrain.xml

In [None]:
# import mujoco after setting MUJOCO_GL to prevent errors in video rendering
import mujoco

In [8]:
from examples.mujoco_Go1.env_wrapper import Go1Env
from examples.mujoco_Go1.ppo import PPO, PPOParams, PPOParamsBuilder
from src.control.controller_factory import ControllerFactory

In [9]:
velocity_kick_range = [0.0, 0.0]  # Disable velocity kick.
kick_duration_range = [0.05, 0.2]


def sample_pert(rng, env, state):
    """
    Random purterbation to the robot
    """
    rng, key1, key2 = jax.random.split(rng, 3)
    pert_mag = jax.random.uniform(key1, minval=velocity_kick_range[0], maxval=velocity_kick_range[1])
    duration_seconds = jax.random.uniform(key2, minval=kick_duration_range[0], maxval=kick_duration_range[1])
    duration_steps = jax.numpy.round(duration_seconds / env.dt).astype(jax.numpy.int32)
    state.info["pert_mag"] = pert_mag
    state.info["pert_duration"] = duration_steps
    state.info["pert_duration_seconds"] = duration_seconds
    return rng

## Setup Gemini Client

In [10]:
from google import genai
from IPython.display import Image, Markdown
from google.colab import userdata


client = genai.Client(api_key=userdata.get("GOOGLE_API_KEY"), http_options={"api_version": "v1alpha"})

## Import Gridmap And LLM Wrapper

In [11]:
from gemini_maze_exploration.src.gemini_chat import GeminiThinkChat

import numpy as np
from matplotlib import pyplot as plt
from dataclasses import dataclass
from typing import Tuple, List


@dataclass
class GridConfig:
    grid_size: Tuple[int, int] = (50, 50)
    image_size: int = 50  # 50x50 RGB image representing traversibility of a 5x5 maze
    threshold: int = 10  # White color


class GridManager:
    def __init__(self, config: GridConfig):
        self.config = config
        self.grid = np.empty(config.grid_size)
        self.img2grid_scale = self.config.image_size // self.config.grid_size[0]

    def load_from_image(self, image_path: str) -> np.ndarray:
        """Load and process grid from image.

        Thr output is a binary 2D grid where 1 represents a valid position and 0 an invalid one.
        This method also involes a transformation of coordinates from image to grid space.
        As image coordinates are (0, 0) at the top-left corner and grid coordinates are (0, 0) at the bottom-left corner,
        we need to invert the y-axis and scale the coordinates accordingly.
        """
        im = plt.imread(image_path, format="jpeg")
        self.image = im
        for x in range(self.config.grid_size[0]):
            for y in range(self.config.grid_size[1]):
                color = im[
                    self.config.image_size - self.img2grid_scale * x - int(np.ceil(self.img2grid_scale / 2)),
                    self.img2grid_scale * y + int(np.floor(self.img2grid_scale / 2)),
                ][0]
                self.grid[y, x] = 1 if color >= self.config.threshold else 0
        return self.grid

    def is_valid_position(self, position: np.ndarray, scale: float = 10) -> bool:
        """Check if position is valid in the grid."""
        x, y = np.clip(scale * position, 0, np.inf).astype(int)
        return self.grid[x, y] == 1


grid_config = GridConfig()
grid_manager = GridManager(grid_config)
grid_manager.load_from_image("gemini_maze_exploration/assets/floor1.jpg")

array([[1., 1., 1., ..., 0., 0., 0.],
       [1., 1., 1., ..., 0., 0., 0.],
       [1., 1., 1., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 1., 1., 1.],
       [0., 0., 0., ..., 1., 1., 1.],
       [0., 0., 0., ..., 1., 1., 1.]])

## Position Control Helper Functions

In [None]:
import jax
import numpy as np
import scipy.linalg

In [None]:
def lqr_control(pos, tar_pos, yaw, tar_yaw, prev_command):
    Ts = 1
    v_k = prev_command[0]
    A = np.array([[0, 0, -Ts * v_k * np.sin(yaw)],
                    [0, 0, Ts * v_k * np.cos(yaw)],
                    [0, 0, 0]])
    A = A + np.eye(A.shape[0]) * 1e-8  # Add small regularization for numerical stability

    B = np.array([[Ts * np.cos(yaw), 0],
                    [Ts * np.sin(yaw), 0],
                    [0, Ts]])
    
    Q = np.diag([10, 10, 100])  # Penalize position and heading error
    R = np.diag([1, 1])  # Penalize control effort
    P = scipy.linalg.solve_continuous_are(A, B, Q, R)
    
    # Compute the optimal gain matrix K
    K = np.linalg.inv(R) @ B.T @ P
    X_error = np.array([tar_pos[0]-pos[0], tar_pos[1]-pos[1], tar_yaw-yaw])  # tracking error
    U = K @ X_error  # Compute control input
    command = jax.numpy.array([np.clip(U[0], -1.5, 1.5), 0, np.clip(U[1], -np.pi/2, np.pi/2)])
    return command

def p_control(pos, tar_pos, yaw, tar_yaw):
    # if with large yaw error, then do rotate
    # else move forward (toward target position)
    err_yaw = tar_yaw - yaw
    dist = jax.numpy.linalg.norm(tar_pos - pos[:2])
    if np.abs(err_yaw) > np.pi / 18:
        command = jax.numpy.array([0, 0, np.clip(7 * err_yaw, -np.pi / 2, np.pi / 2)])
    else:
        command = jax.numpy.array([np.clip(2 * dist, -1.5, 1.5), 0, 0])
    return command

In [12]:
def clean_instruction(instr: str) -> str:
    char_to_be_deleted = ["\n", "\t", "'", "`", "{", "}"]
    for c in char_to_be_deleted:
        instr = instr.replace(c, "")
    return instr


def get_torso_xyz_from_state(state):
    return np.array(state.data.site_xpos[0])  # + np.array([0.5, 0.5, 0.0])


def generate_position_control_command(state, tar_pos, prev_command=jax.numpy.zeros((3)), use_lqr=False, debug=False):
    pos = get_torso_xyz_from_state(state)
    tar_yaw = np.arctan2(tar_pos[1] - pos[1], tar_pos[0] - pos[0])
    forward_vec = state.data.site_xmat[1] @ np.array([1.0, 0.0, 0.0])
    yaw = np.arctan2(forward_vec[1], forward_vec[0])
    tar_yaw, yaw = np.unwrap([tar_yaw, yaw])
    err_yaw = tar_yaw - yaw
    dist = jax.numpy.linalg.norm(tar_pos - pos[:2])

    # Control strategy
    # if close to target position, then no control
    # else do P control or LQR  control
    is_arrived = False
    if np.abs(dist) <= 0.1:
        command = jax.numpy.zeros((3))
        is_arrived = True
    else:
        if use_lqr:
            try:
                # ARE might have numerical issue, fall back to simple P control in that case
                command = lqr_control(pos, tar_pos, yaw, tar_yaw, prev_command)
            except:
                command = p_control(pos, tar_pos, yaw, tar_yaw)
        else:
            command = p_control(pos, tar_pos, yaw, tar_yaw)

    if debug:
        print(f"\n{pos=}\n{tar_pos=}\t{yaw=}, {tar_yaw=}\n{dist=}\n{command=}")
    info = {
        "pos": pos,
        "tar": np.array([tar_pos[0], tar_pos[1], tar_yaw]),
        "dist": dist,
        "command": command,
        "is_arrived": is_arrived,
    }
    return command, info

In [13]:
def format_failure_message(pos, position_history, failure_type: str) -> str:
    """Format failure message with position history."""
    x, y = pos.astype(int)
    return f"Failed: {failure_type} at ({x}, {y}), traversed cells: {[(int(xy_[0]), int(xy_[1])) for xy_ in position_history]}"


def update_position_history(position_history, current_position):
    # Only update position_history is the robot enters a new cell (by comparing position in integer)
    if len(position_history) == 0:
        position_history.append(current_position.astype(np.int32))
        return position_history
    if np.array_equal(position_history[-1].astype(np.int32), current_position.astype(np.int32)):
        return position_history
    else:
        position_history.append(current_position.astype(np.int32))
    return position_history


def run_one_episode(state, env, rng, controller, grid_manager, waypoints, goal_pos=np.array([4, 4])):
    rollout = []
    ref_commands = []
    position_history = []
    tar_wp_idx = 0

    # start closed-loop sim
    sim_time_sec = 30.0
    sim_steps = int(sim_time_sec / env.dt)
    for i in tqdm(range(sim_steps)):
        if state.info["steps_since_last_pert"] < state.info["steps_until_next_pert"]:
            rng = sample_pert(rng, env, state)
        _, rng = jax.random.split(rng)  # this will randomize perturbation

        # calculate controll command
        tar_pos = waypoints[tar_wp_idx]
        cell_center_offset = 0.5
        command, control_info = generate_position_control_command(
            state, tar_pos + cell_center_offset, prev_command=command, use_lqr=False, debug=False
        )
        state.info["command"] = command

        ctrl = controller.control(state.obs["state"])  # controller step
        state = env.step(state, ctrl)  # simulator step

        # record
        rollout.append(state)
        ref_commands.append(command)
        position_history = update_position_history(position_history, control_info["pos"][:2])

        # check failure
        if not grid_manager.is_valid_position(control_info["pos"][:2]):
            return format_failure_message(control_info["pos"][:2], position_history, "Stop"), rollout
        # check arrival at current waypoint
        if control_info["is_arrived"]:
            tar_wp_idx += 1  # proceed to next waypoint

        # Check if the mission is completed
        is_at_last_waypoint = control_info["is_arrived"] and tar_wp_idx == len(waypoints)
        if is_at_last_waypoint:
            return "Success", rollout

    return format_failure_message(control_info["pos"][:2], position_history, "Timeout"), rollout

## Demo: Nagivating a go1 quadruped through a maze with Gemini-thinking Model

In [14]:
import time

In [15]:
# instantiate mujoco Env
env_name = "Go1JoystickFlatTerrain"
rng = jax.random.PRNGKey(0)
env = Go1Env(env_name=env_name)

# Instantiate controller based on env_name
factory = ControllerFactory()
factory.register_controller(PPOParams, PPO)
controller_config = {"npy_path": f"examples/mujoco_Go1/nn_params/{env_name}"}
ppo_params = PPOParamsBuilder().build(config=controller_config)
controller = factory.build(params=ppo_params)

In [16]:
instruction = open("gemini_maze_exploration/prompt.txt", "r").read()
chat = GeminiThinkChat(client, model_name="gemini-2.0-flash-thinking-exp")


retry_delay_sec = 5
goal_pos = np.array([4, 4])
max_attempts = 15
rollout = []
for attempt in range(max_attempts):
    # Initialize for a new round of simulation
    state = env.reset(rng)
    current_position = get_torso_xyz_from_state(state)[:2]
    if attempt == 0:
        prompt = instruction + f"\nStart. you are at ({current_position[0]}, {current_position[1]})"

    # Prompt the LLM to get waypoints suggestion
    waypoints = chat.get_waypoints(prompt)
    if not np.array_equal(waypoints[-1], goal_pos):
        waypoints.append(goal_pos)
    # waypoints = [np.array([0, 0]), np.array([1, 0]), np.array([1, 3]), np.array([4, 3]), np.array([4, 4])]

    # run the mission (simulation)
    result, epi_rollout = run_one_episode(state, env, rng, controller, grid_manager, waypoints, goal_pos=goal_pos)
    rollout.extend(epi_rollout)

    print(f"[Trial {attempt + 1}]\n{prompt=}\n{waypoints=}\n{result=}\n")
    prompt = result

    if result == "Success":
        break

    # add a delay before retrying to avoid API rate limiting
    time.sleep(retry_delay_sec)

result

 16%|█▌        | 238/1500 [01:11<06:19,  3.33it/s]


[Trial 1]
prompt='You are a path planner for a quadruped robot navigating in a partially known 5x5 grid environment. Your task is to guide the robot from the start position (0,0) to the goal position (4,4).\n\nEnvironment specifications:\n- Grid size: 5x5\n- Start position: (0,0)\n- Goal position: (4,4)\n- Some cells are not traversable, but their status is only revealed when the robot attempts to traverse them\n- The traversability of a cell is probabilistic - failed attempts don\'t guarantee the cell is completely untraversable\n\nPath planning constraints:\n1. Maximum 10 waypoints allowed per path\n2. Consecutive waypoints moving in the same direction must be simplified\n   Example: [(0,0), (1,0), (2,0), (2,1)] should be simplified to [(0,0), (2,0), (2,1)]\n\nExpected output format:\nProvide only a list of coordinates in the format:\n[(x0,y0), (x1,y1), ..., (xn,yn)]\n\nPossible execution outcomes:\n1. Success: Robot reaches the goal\n2. Failed (Obstacle): "Failed: stop at (X,Y), tra

 24%|██▎       | 356/1500 [00:13<00:43, 26.06it/s]


[Trial 2]
prompt='Failed: Stop at (2, 2), traversed cells: [(0, 0), (1, 0), (1, 1), (2, 1), (2, 2)]'
waypoints=[array([0, 0]), array([3, 0]), array([3, 3]), array([4, 3]), array([4, 4])]
result='Failed: Stop at (3, 1), traversed cells: [(0, 0), (1, 0), (2, 0), (3, 0), (3, 1)]'



 19%|█▊        | 279/1500 [00:05<00:23, 52.94it/s]


[Trial 3]
prompt='Failed: Stop at (3, 1), traversed cells: [(0, 0), (1, 0), (2, 0), (3, 0), (3, 1)]'
waypoints=[array([0, 0]), array([4, 0]), array([4, 4])]
result='Failed: Stop at (4, 0), traversed cells: [(0, 0), (1, 0), (2, 0), (3, 0), (4, 0)]'



 11%|█         | 167/1500 [00:04<00:36, 36.37it/s]


[Trial 4]
prompt='Failed: Stop at (4, 0), traversed cells: [(0, 0), (1, 0), (2, 0), (3, 0), (4, 0)]'
waypoints=[array([0, 0]), array([0, 1]), array([1, 1]), array([2, 1]), array([4, 1]), array([4, 2]), array([4, 3]), array([4, 4])]
result='Failed: Stop at (0, 1), traversed cells: [(0, 0), (0, 1)]'



 24%|██▎       | 354/1500 [00:07<00:23, 48.25it/s]


[Trial 5]
prompt='Failed: Stop at (0, 1), traversed cells: [(0, 0), (0, 1)]'
waypoints=[array([0, 0]), array([3, 0]), array([3, 2]), array([4, 4])]
result='Failed: Stop at (3, 1), traversed cells: [(0, 0), (1, 0), (2, 0), (3, 0), (3, 1)]'



 25%|██▍       | 372/1500 [00:08<00:25, 44.34it/s]


[Trial 6]
prompt='Failed: Stop at (3, 1), traversed cells: [(0, 0), (1, 0), (2, 0), (3, 0), (3, 1)]'
waypoints=[array([0, 0]), array([1, 0]), array([2, 0]), array([2, 3]), array([3, 3]), array([4, 4])]
result='Failed: Stop at (2, 2), traversed cells: [(0, 0), (1, 0), (2, 0), (2, 1), (2, 2)]'



 25%|██▍       | 373/1500 [00:07<00:22, 49.86it/s]


[Trial 7]
prompt='Failed: Stop at (2, 2), traversed cells: [(0, 0), (1, 0), (2, 0), (2, 1), (2, 2)]'
waypoints=[array([0, 0]), array([1, 0]), array([1, 4]), array([4, 4])]
result='Failed: Stop at (1, 4), traversed cells: [(0, 0), (1, 0), (1, 1), (1, 2), (1, 3), (1, 4)]'



 22%|██▏       | 331/1500 [00:06<00:24, 48.27it/s]


[Trial 8]
prompt='Failed: Stop at (1, 4), traversed cells: [(0, 0), (1, 0), (1, 1), (1, 2), (1, 3), (1, 4)]'
waypoints=[array([0, 0]), array([1, 0]), array([2, 1]), array([3, 2]), array([4, 4])]
result='Failed: Stop at (3, 1), traversed cells: [(0, 0), (1, 0), (1, 1), (2, 1), (3, 1)]'



 42%|████▏     | 637/1500 [00:20<00:27, 31.54it/s]

[Trial 9]
prompt='Failed: Stop at (3, 1), traversed cells: [(0, 0), (1, 0), (1, 1), (2, 1), (3, 1)]'
waypoints=[array([0, 0]), array([1, 0]), array([1, 3]), array([4, 3]), array([4, 4])]
result='Success'






'Success'

In [17]:
# Set up the camera
camera = mujoco.MjvCamera()
mujoco.mjv_defaultCamera(camera)

# Object is centered at (0,0,1.5) with size 15m x 15m x 3m
camera.lookat[:] = np.array([2.0, 1.5, 1.5])  # Center of object

# Calculate distance using a safe margin
object_size = 6  # Maximum size in X or Y
fov = 45  # Default MuJoCo camera FOV
safe_margin = 1  # Factor to ensure entire object is in view

# Compute required distance using field of view
camera.distance = (object_size / 2) / np.tan(np.radians(fov / 2)) * safe_margin

# Look directly down
camera.azimuth = 90  # No horizontal rotation
camera.elevation = -70  # Directly downward

In [None]:
render_every = 5
fps = 3.0 / env.dt / render_every  # 3x realtime
traj = rollout[::render_every]

scene_option = mujoco.MjvOption()
scene_option.geomgroup[2] = True
scene_option.geomgroup[3] = False
scene_option.flags[mujoco.mjtVisFlag.mjVIS_CONTACTPOINT] = True
scene_option.flags[mujoco.mjtVisFlag.mjVIS_CONTACTFORCE] = False
scene_option.flags[mujoco.mjtVisFlag.mjVIS_TRANSPARENT] = False

frames = env.render(
    traj,
    camera=camera,
    scene_option=scene_option,
    height=480,
    width=640,
)
mediapy.show_video(frames, fps=fps, loop=False)