# I-MPPI: Interactive Informative Model Predictive Path Integral Control

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/riccardo-enr/jax_mppi/blob/main/docs/examples/i_mppi_interactive_simulation.ipynb)

This notebook provides an interactive simulation of the **two-layer I-MPPI architecture** for informative path planning with GPU acceleration.

## Architecture Overview

```
Layer 2 (FSMI Analyzer, ~5 Hz)         Layer 3 (I-MPPI Controller, ~50 Hz)
+---------------------------------+     +-----------------------------------+
| Full FSMI (O(n^2))             |     | Biased MPPI + Uniform-FSMI (O(n)) |
| - Occupancy grid analysis      | --> | - Tracks Layer 2 reference traj   |
| - Global information planning  |     | - Local informative reactivity    |
| - Reference trajectory output  |     | - Obstacle avoidance              |
+---------------------------------+     +-----------------------------------+
```

**Layer 2** generates reference trajectories maximizing global information gain using Fisher-Shannon Mutual Information (FSMI) on the occupancy grid.

**Layer 3** tracks the reference while maintaining local informative viewpoints via Uniform-FSMI, ensuring reactive exploration even between Layer 2 updates.

**Cost function**: `J = Tracking(ref) + Obstacles - lambda * Uniform_FSMI(local)`

In [None]:
# @title Setup: Install jax_mppi and dependencies (Colab only)
import os
import shutil
import sys

if "COLAB_GPU" in os.environ or "COLAB_RELEASE_TAG" in os.environ:
    # Running in Google Colab
    target_dir = "/content/jax_mppi"
    
    # Change to /content first to avoid directory issues
    os.chdir("/content")
    
    # Remove existing directory if present
    if os.path.exists(target_dir):
        shutil.rmtree(target_dir)
        print(f"Removed existing directory: {target_dir}")
    
    # Clone the repository
    !git clone https://github.com/riccardo-enr/jax_mppi.git /content/jax_mppi
    
    # Install Python dependencies
    %pip install -q jaxtyping chex matplotlib ipywidgets
    
    # Add the source directory to sys.path directly
    # (avoids scikit-build-core compilation issues on Colab)
    sys.path.insert(0, "/content/jax_mppi/src")
    
    os.chdir("/content/jax_mppi")
    print("Setup complete!")
else:
    print("Not running in Colab -- assuming local installation.")

# Verify JAX and GPU
import jax

print(f"JAX version: {jax.__version__}")
print(f"Available devices: {jax.devices()}")
gpu_available = any(
    "gpu" in str(d).lower() or "cuda" in str(d).lower() for d in jax.devices()
)
if gpu_available:
    print("GPU detected.")
else:
    print("WARNING: No GPU detected. Simulation will run on CPU (slower).")
    print("  In Colab: Runtime > Change runtime type > GPU")

In [None]:
# Imports
import time
from functools import partial

import ipywidgets as widgets
import jax
import jax.numpy as jnp
import matplotlib.patches as mpatches
import matplotlib.pyplot as plt
import numpy as np
from IPython.display import clear_output, display

from jax_mppi import mppi
from jax_mppi.i_mppi.environment import (
    GOAL_POS,
    INFO_ZONES,
    augmented_dynamics,
    informative_running_cost,
)
from jax_mppi.i_mppi.fsmi import (
    FSMIConfig,
    FSMITrajectoryGenerator,
    UniformFSMI,
    UniformFSMIConfig,
)
from jax_mppi.i_mppi.map import GridMap
from jax_mppi.i_mppi.planner import biased_mppi_command

print("All imports successful.")

## Environment Setup

The simulation takes place in a **14m x 12m office-like environment** with:
- **Walls** (gray): Physical obstacles the quadrotor must avoid
- **Info zones** (yellow): Unknown regions with high entropy -- the robot gains information by flying near them
- **Goal** (red star): Target position the robot must reach
- **Start** (green circle): Initial robot position

The occupancy grid encodes uncertainty: 0.5 = unknown, 0.2 = known free, 0.9 = known occupied.

In [None]:
def create_occupancy_grid():
    """Create an office-like occupancy grid with rooms and corridors."""
    world_width = 14.0
    world_height = 12.0
    resolution = 0.1

    width = int(world_width / resolution)
    height = int(world_height / resolution)

    grid = 0.5 * jnp.ones((height, width))

    # Known free space (central corridor)
    grid = grid.at[35:85, 5:135].set(0.2)

    # Outer walls
    grid = grid.at[0:5, :].set(0.9)
    grid = grid.at[115:120, :].set(0.9)
    grid = grid.at[:, 0:5].set(0.9)
    grid = grid.at[:, 135:140].set(0.9)

    # Room 1: Bottom-left office
    grid = grid.at[85:115, 5:45].set(0.9)
    grid = grid.at[35:115, 5:10].set(0.9)
    grid = grid.at[35:85, 40:45].set(0.9)
    grid = grid.at[35:45, 40:45].set(0.2)  # doorway
    grid = grid.at[40:80, 10:40].set(0.5)  # unknown interior
    grid = grid.at[40:50, 30:40].set(0.35)

    # Room 2: Top-left office
    grid = grid.at[5:35, 5:45].set(0.9)
    grid = grid.at[5:35, 40:45].set(0.9)
    grid = grid.at[28:36, 40:45].set(0.2)  # doorway
    grid = grid.at[10:30, 10:40].set(0.5)

    # Room 3: Bottom-right office
    grid = grid.at[85:115, 95:135].set(0.9)
    grid = grid.at[85:115, 130:135].set(0.9)
    grid = grid.at[35:85, 95:100].set(0.9)
    grid = grid.at[40:50, 95:100].set(0.2)  # doorway
    grid = grid.at[40:80, 100:130].set(0.5)
    grid = grid.at[50:60, 105:115].set(0.8)
    grid = grid.at[65:75, 120:125].set(0.8)

    # Room 4: Top-right office
    grid = grid.at[5:35, 95:135].set(0.9)
    grid = grid.at[5:35, 95:100].set(0.9)
    grid = grid.at[28:36, 95:100].set(0.2)  # doorway
    grid = grid.at[10:30, 100:130].set(0.5)
    grid = grid.at[25:32, 100:110].set(0.35)

    # Central obstacles
    grid = grid.at[45:55, 50:60].set(0.85)
    grid = grid.at[65:75, 70:80].set(0.85)
    grid = grid.at[40:45, 85:90].set(0.8)
    grid = grid.at[75:80, 20:25].set(0.8)

    # Narrow passages
    grid = grid.at[35:85, 45:52].set(0.2)
    grid = grid.at[55:65, 60:70].set(0.2)

    # Info zones (high uncertainty)
    grid = grid.at[50:75, 12:35].set(0.5)
    grid = grid.at[55:70, 15:30].set(0.55)
    grid = grid.at[12:28, 102:128].set(0.5)
    grid = grid.at[15:25, 105:125].set(0.55)

    # Additional complexity
    grid = grid.at[70:82, 48:52].set(0.9)
    grid = grid.at[72:80, 48:52].set(0.2)
    grid = grid.at[72:80, 45:48].set(0.52)
    grid = grid.at[55:70, 90:95].set(0.9)
    grid = grid.at[35:36, 40:45].set(0.75)
    grid = grid.at[84:85, 95:100].set(0.75)

    map_origin = jnp.array([0.0, 0.0])
    return grid, map_origin, resolution, width, height


# Build environment
grid_array, map_origin, map_resolution, grid_width, grid_height = (
    create_occupancy_grid()
)
grid_map_obj = GridMap(
    grid=grid_array,
    origin=map_origin,
    resolution=map_resolution,
    width=grid_width,
    height=grid_height,
)


# Visualize
def plot_environment(ax, grid, resolution, show_labels=True):
    """Plot the occupancy grid with walls, info zones, start, and goal."""
    extent = [0, grid.shape[1] * resolution, 0, grid.shape[0] * resolution]
    ax.imshow(
        np.array(grid),
        origin="lower",
        extent=extent,
        cmap="gray_r",
        vmin=0,
        vmax=1,
        alpha=0.8,
    )

    # Info zones
    for i in range(len(INFO_ZONES)):
        cx, cy = float(INFO_ZONES[i, 0]), float(INFO_ZONES[i, 1])
        w, h = float(INFO_ZONES[i, 2]), float(INFO_ZONES[i, 3])
        rect = mpatches.FancyBboxPatch(
            (cx - w / 2, cy - h / 2),
            w,
            h,
            boxstyle="round,pad=0.05",
            facecolor="yellow",
            alpha=0.3,
            edgecolor="orange",
            linewidth=1.5,
        )
        ax.add_patch(rect)
        if show_labels:
            ax.text(
                cx, cy, f"Info {i+1}", ha="center", va="center", fontsize=8
            )

    # Start and goal
    ax.plot(1.0, 5.0, "go", markersize=10, label="Start", zorder=5)
    ax.plot(
        float(GOAL_POS[0]),
        float(GOAL_POS[1]),
        "r*",
        markersize=15,
        label="Goal",
        zorder=5,
    )

    ax.set_xlim(-0.5, 14.5)
    ax.set_ylim(-0.5, 12.5)
    ax.set_xlabel("X (m)")
    ax.set_ylabel("Y (m)")
    ax.set_aspect("equal")


fig, ax = plt.subplots(1, 1, figsize=(10, 8))
plot_environment(ax, grid_array, map_resolution)
ax.set_title("Office Environment: Occupancy Grid")
ax.legend(loc="upper left")
plt.colorbar(
    ax.images[0],
    ax=ax,
    label="Occupancy (0=free, 0.5=unknown, 1=occupied)",
)
plt.tight_layout()
plt.show()

## Interactive Parameter Configuration

Adjust the parameters below and click **Run Simulation** to execute the I-MPPI controller.

| Parameter | Description | Effect |
|-----------|-------------|--------|
| **Samples** | Number of MPPI rollout samples | More = better trajectories, slower |
| **Horizon** | Planning horizon (steps) | Longer = further look-ahead |
| **Lambda** | MPPI temperature | Lower = more greedy (exploits best) |
| **Info Weight** | Uniform-FSMI weight in Layer 3 | Higher = more exploration-driven |
| **FSMI Beams** | Number of sensor beams (Layer 2) | More = better information estimate |
| **FSMI Range** | Max sensor range (Layer 2) | Longer = wider planning scope |

In [None]:
# Visualization helper functions


def plot_trajectory_2d(
    ax, history_x, grid, resolution, title="I-MPPI Trajectory"
):
    """Plot 2D trajectory over the environment."""
    plot_environment(ax, grid, resolution, show_labels=False)

    # Trajectory colored by time
    positions = np.array(history_x[:, :2])
    n_steps = len(positions)
    colors = plt.cm.viridis(np.linspace(0, 1, n_steps))
    for i in range(n_steps - 1):
        ax.plot(
            positions[i : i + 2, 0],
            positions[i : i + 2, 1],
            color=colors[i],
            linewidth=2,
        )

    ax.plot(positions[0, 0], positions[0, 1], "go", markersize=10, zorder=5)
    ax.plot(positions[-1, 0], positions[-1, 1], "bs", markersize=8, zorder=5)
    ax.plot(
        float(GOAL_POS[0]), float(GOAL_POS[1]), "r*", markersize=15, zorder=5
    )
    ax.set_title(title)


def plot_info_levels(ax, history_info, dt):
    """Plot info zone depletion over time."""
    info = np.array(history_info)
    t = np.arange(len(info)) * dt
    for i in range(info.shape[1]):
        ax.plot(t, info[:, i], linewidth=2, label=f"Info Zone {i+1}")
    ax.set_xlabel("Time (s)")
    ax.set_ylabel("Information Level")
    ax.set_title("Information Zone Depletion")
    ax.legend()
    ax.grid(True, alpha=0.3)


def plot_control_inputs(axes, actions, dt):
    """Plot control inputs over time (4 subplots)."""
    acts = np.array(actions)
    t = np.arange(len(acts)) * dt
    labels = [
        "Thrust (N)",
        "Omega X (rad/s)",
        "Omega Y (rad/s)",
        "Omega Z (rad/s)",
    ]
    colors = ["#1f77b4", "#ff7f0e", "#2ca02c", "#d62728"]
    for i, (ax, label, color) in enumerate(zip(axes, labels, colors)):
        ax.plot(t, acts[:, i], color=color, linewidth=1, alpha=0.8)
        ax.set_ylabel(label)
        ax.grid(True, alpha=0.3)
    axes[-1].set_xlabel("Time (s)")
    axes[0].set_title("Control Inputs")


def plot_position_3d(ax, history_x):
    """Plot 3D trajectory."""
    pos = np.array(history_x[:, :3])
    n = len(pos)
    colors = plt.cm.viridis(np.linspace(0, 1, n))
    for i in range(n - 1):
        ax.plot(
            pos[i : i + 2, 0],
            pos[i : i + 2, 1],
            pos[i : i + 2, 2],
            color=colors[i],
            linewidth=1.5,
        )
    ax.scatter(*pos[0], color="green", s=60, zorder=5)
    ax.scatter(
        float(GOAL_POS[0]),
        float(GOAL_POS[1]),
        float(GOAL_POS[2]),
        color="red",
        marker="*",
        s=100,
        zorder=5,
    )
    ax.set_xlabel("X (m)")
    ax.set_ylabel("Y (m)")
    ax.set_zlabel("Z (m)")
    ax.set_title("3D Trajectory")

In [None]:
# Simulation helper functions

# Constants
DT = 0.05
NX = 16  # 13 (quad) + 3 (info zones)
NU = 4
CONTROL_HZ = 50.0
FSMI_HZ = 5.0
FSMI_STEPS = int(round(CONTROL_HZ / FSMI_HZ))

U_MIN = jnp.array([0.0, -10.0, -10.0, -10.0])
U_MAX = jnp.array([4.0 * 9.81, 10.0, 10.0, 10.0])
U_INIT = jnp.array([9.81, 0.0, 0.0, 0.0])
NOISE_SIGMA = jnp.diag(jnp.array([2.0, 0.5, 0.5, 0.5]) ** 2)


def make_u_ref_from_traj(current_state, ref_traj):
    """Convert position reference trajectory to control reference."""
    pos = current_state[:3]
    err = ref_traj - pos[None, :]

    k_thrust = 3.0
    thrust = U_INIT[0] + k_thrust * err[:, 2]

    k_omega = 0.6
    omega_x = k_omega * err[:, 1]
    omega_y = -k_omega * err[:, 0]
    omega_z = jnp.zeros_like(omega_x)

    u_ref = jnp.stack([thrust, omega_x, omega_y, omega_z], axis=1)
    u_ref = jnp.clip(u_ref, U_MIN, U_MAX)
    return u_ref


def compute_smoothness(actions, positions, dt_val):
    """Compute action jerk and trajectory jerk metrics."""
    action_jerk = jnp.diff(actions, n=2, axis=0) / (dt_val**2)
    action_jerk_mean = jnp.mean(jnp.linalg.norm(action_jerk, axis=1))

    pos = positions[:, :3]
    vel = jnp.diff(pos, axis=0) / dt_val
    acc = jnp.diff(vel, axis=0) / dt_val
    jerk = jnp.diff(acc, axis=0) / dt_val
    traj_jerk_mean = jnp.mean(jnp.linalg.norm(jerk, axis=1))

    return action_jerk_mean, traj_jerk_mean


def build_sim_fn(
    config,
    fsmi_planner,
    uniform_fsmi,
    uniform_fsmi_config,
    horizon,
    sim_steps,
):
    """Build a JIT-compiled simulation function."""

    def step_fn(carry, t):
        current_state, current_ctrl_state, ref_traj = carry

        # Layer 2: Full FSMI reference trajectory (slow, 5 Hz)
        current_info = current_state[13:]
        do_update = jnp.equal(jnp.mod(t, FSMI_STEPS), 0)
        info_data = (grid_map_obj.grid, current_info)

        ref_traj = jax.lax.cond(
            do_update,
            lambda _: fsmi_planner.get_reference_trajectory(
                current_state, info_data, horizon, DT
            )[0],
            lambda _: ref_traj,
            operand=None,
        )

        # Layer 3: Biased I-MPPI with Uniform-FSMI (fast, 50 Hz)
        cost_fn = partial(
            informative_running_cost,
            target=ref_traj,
            grid_map=grid_map_obj.grid,
            uniform_fsmi_fn=uniform_fsmi.compute,
            info_weight=uniform_fsmi_config.info_weight,
        )

        U_ref_local = make_u_ref_from_traj(current_state, ref_traj)

        action, next_ctrl_state = biased_mppi_command(
            config,
            current_ctrl_state,
            current_state,
            augmented_dynamics,
            cost_fn,
            U_ref_local,
            bias_alpha=0.2,
        )

        next_state = augmented_dynamics(current_state, action, dt=DT)

        return (next_state, next_ctrl_state, ref_traj), (
            next_state,
            current_info,
            ref_traj[0],
            action,
        )

    def sim_fn(initial_state, initial_ctrl_state):
        info_data = (grid_map_obj.grid, initial_state[13:])
        init_ref_traj, _ = fsmi_planner.get_reference_trajectory(
            initial_state, info_data, horizon, DT
        )
        (
            (final_state, final_ctrl_state, _),
            (history_x, history_info, targets, actions),
        ) = jax.lax.scan(
            step_fn,
            (initial_state, initial_ctrl_state, init_ref_traj),
            jnp.arange(sim_steps),
        )
        return final_state, history_x, history_info, targets, actions

    return jax.jit(sim_fn)


print("Simulation engine ready.")

In [None]:
# Interactive widgets

# System parameters
w_start_x = widgets.FloatSlider(
    value=1.0, min=0.5, max=12.0, step=0.5, description="Start X (m):"
)
w_start_y = widgets.FloatSlider(
    value=5.0, min=0.5, max=9.0, step=0.5, description="Start Y (m):"
)
w_duration = widgets.FloatSlider(
    value=15.0, min=5.0, max=60.0, step=5.0, description="Duration (s):"
)

# Controller parameters
w_samples = widgets.IntSlider(
    value=1000, min=100, max=5000, step=100, description="Samples:"
)
w_horizon = widgets.IntSlider(
    value=40, min=10, max=80, step=5, description="Horizon:"
)
w_lambda = widgets.FloatLogSlider(
    value=0.1, min=-2, max=1, step=0.1, description="Lambda:"
)
w_info_weight = widgets.FloatSlider(
    value=5.0, min=0.0, max=20.0, step=1.0, description="Info Weight:"
)

# FSMI parameters
w_fsmi_beams = widgets.IntSlider(
    value=12, min=4, max=24, step=2, description="FSMI Beams:"
)
w_fsmi_range = widgets.FloatSlider(
    value=5.0, min=2.0, max=10.0, step=0.5, description="FSMI Range:"
)

run_button = widgets.Button(
    description="Run Simulation",
    button_style="success",
    icon="play",
    layout=widgets.Layout(width="200px", height="40px"),
)
output_area = widgets.Output()


def run_simulation(button):
    """Run I-MPPI simulation with current widget parameters."""
    with output_area:
        clear_output(wait=True)

        # Read parameters
        start_x = w_start_x.value
        start_y = w_start_y.value
        sim_duration = w_duration.value
        num_samples = w_samples.value
        horizon = w_horizon.value
        lambda_ = w_lambda.value
        info_weight = w_info_weight.value
        fsmi_beams = w_fsmi_beams.value
        fsmi_range = w_fsmi_range.value

        sim_steps = int(round(sim_duration * CONTROL_HZ))

        print("=" * 60)
        print("I-MPPI Interactive Simulation")
        print("=" * 60)
        print(f"  Start: ({start_x}, {start_y})")
        print(f"  Duration: {sim_duration}s ({sim_steps} steps)")
        print(
            f"  Samples: {num_samples}, Horizon: {horizon}, "
            f"Lambda: {lambda_}"
        )
        print(f"  Info Weight: {info_weight}")
        print(f"  FSMI Beams: {fsmi_beams}, Range: {fsmi_range}m")
        print()

        # Initial state
        start_pos = jnp.array([start_x, start_y, -2.0])
        info_init = jnp.array([100.0, 100.0, 100.0])
        x0 = jnp.zeros(13)
        x0 = x0.at[:3].set(start_pos)
        x0 = x0.at[6].set(1.0)  # qw=1
        state = jnp.concatenate([x0, info_init])

        # Layer 2: FSMI config
        fsmi_config = FSMIConfig(
            use_grid_fsmi=True,
            goal_pos=GOAL_POS,
            fov_rad=1.57,
            num_beams=fsmi_beams,
            max_range=fsmi_range,
            ray_step=0.15,
            sigma_range=0.15,
            gaussian_truncation_sigma=3.0,
            trajectory_subsample_rate=8,
            info_weight=25.0,
            motion_weight=0.5,
        )
        fsmi_planner = FSMITrajectoryGenerator(
            config=fsmi_config,
            info_zones=INFO_ZONES,
            grid_map=grid_map_obj,
        )

        # Layer 3: Uniform-FSMI config
        uniform_fsmi_config = UniformFSMIConfig(
            fov_rad=1.57,
            num_beams=6,
            max_range=2.5,
            ray_step=0.2,
            info_weight=info_weight,
        )
        uniform_fsmi = UniformFSMI(
            uniform_fsmi_config,
            map_origin,
            map_resolution,
        )

        # MPPI config
        config, ctrl_state = mppi.create(
            nx=NX,
            nu=NU,
            noise_sigma=NOISE_SIGMA,
            num_samples=num_samples,
            horizon=horizon,
            lambda_=lambda_,
            u_min=U_MIN,
            u_max=U_MAX,
            u_init=U_INIT,
            step_dependent_dynamics=True,
        )

        # Build simulation
        sim_fn = build_sim_fn(
            config,
            fsmi_planner,
            uniform_fsmi,
            uniform_fsmi_config,
            horizon,
            sim_steps,
        )

        # JIT compile (warm-up)
        print("JIT compiling... (this may take 1-2 min on first run)")
        t0 = time.perf_counter()
        warm = sim_fn(state, ctrl_state)
        for leaf in jax.tree_util.tree_leaves(warm):
            leaf.block_until_ready()
        compile_time = time.perf_counter() - t0
        print(f"  Compilation: {compile_time:.1f}s")

        # Timed run
        print("Running simulation...")
        t0 = time.perf_counter()
        final_state, history_x, history_info, targets, actions = sim_fn(
            state, ctrl_state
        )
        final_state.block_until_ready()
        runtime = time.perf_counter() - t0

        # Metrics
        action_jerk, traj_jerk = compute_smoothness(actions, history_x, DT)
        final_pos = final_state[:3]
        goal_dist = float(jnp.linalg.norm(final_pos - GOAL_POS))

        print(
            f"  Runtime: {runtime:.2f}s "
            f"({runtime / sim_duration:.2f}x realtime)"
        )
        print(
            f"  Final position: [{final_pos[0]:.2f}, "
            f"{final_pos[1]:.2f}, {final_pos[2]:.2f}]"
        )
        print(f"  Distance to goal: {goal_dist:.2f}m")
        print(f"  Info levels: {final_state[13:]}")
        print(f"  Action jerk: {float(action_jerk):.4f}")
        print(f"  Trajectory jerk: {float(traj_jerk):.4f}")
        print()

        # Visualization
        fig = plt.figure(figsize=(16, 12))

        # 2D trajectory
        ax1 = fig.add_subplot(2, 2, 1)
        plot_trajectory_2d(
            ax1,
            history_x,
            grid_array,
            map_resolution,
            "I-MPPI Trajectory (2D)",
        )

        # Info level depletion
        ax2 = fig.add_subplot(2, 2, 2)
        plot_info_levels(ax2, history_info, DT)

        # 3D trajectory
        ax3 = fig.add_subplot(2, 2, 3, projection="3d")
        plot_position_3d(ax3, history_x)

        # Control inputs
        ax4a = fig.add_subplot(4, 2, 6)
        ax4b = fig.add_subplot(4, 2, 8)
        acts = np.array(actions)
        t_arr = np.arange(len(acts)) * DT
        ax4a.plot(t_arr, acts[:, 0], color="#1f77b4", linewidth=1)
        ax4a.set_ylabel("Thrust (N)")
        ax4a.set_title("Control Inputs")
        ax4a.grid(True, alpha=0.3)
        ax4b.plot(t_arr, acts[:, 1], color="#ff7f0e", linewidth=1, label="wx")
        ax4b.plot(t_arr, acts[:, 2], color="#2ca02c", linewidth=1, label="wy")
        ax4b.plot(t_arr, acts[:, 3], color="#d62728", linewidth=1, label="wz")
        ax4b.set_ylabel("Angular rates")
        ax4b.set_xlabel("Time (s)")
        ax4b.legend(fontsize=8)
        ax4b.grid(True, alpha=0.3)

        plt.tight_layout()
        plt.show()


run_button.on_click(run_simulation)

# Layout
system_box = widgets.VBox(
    [
        widgets.HTML("<h3>System Parameters</h3>"),
        w_start_x,
        w_start_y,
        w_duration,
    ]
)
controller_box = widgets.VBox(
    [
        widgets.HTML("<h3>Controller Parameters</h3>"),
        w_samples,
        w_horizon,
        w_lambda,
        w_info_weight,
    ]
)
fsmi_box = widgets.VBox(
    [
        widgets.HTML("<h3>FSMI Parameters (Layer 2)</h3>"),
        w_fsmi_beams,
        w_fsmi_range,
    ]
)

param_panel = widgets.HBox([system_box, controller_box, fsmi_box])
display(widgets.VBox([param_panel, run_button, output_area]))

## MPPI Simulation with Fixed Parameters

Run the I-MPPI controller with fixed parameters for a reproducible comparison baseline.

The two-layer architecture uses:
- **Layer 2**: Full FSMI reference trajectory generation
- **Layer 3**: Biased MPPI tracking + local Uniform-FSMI information gain

In [None]:
# Fixed-parameter MPPI simulation


def run_fixed_simulation():
    """Run MPPI controller with fixed parameters."""
    horizon = 40
    sim_duration = 30.0
    sim_steps = int(round(sim_duration * CONTROL_HZ))

    # Initial state
    start_pos = jnp.array([1.0, 5.0, -2.0])
    info_init = jnp.array([100.0, 100.0, 100.0])
    x0 = jnp.zeros(13)
    x0 = x0.at[:3].set(start_pos)
    x0 = x0.at[6].set(1.0)
    state = jnp.concatenate([x0, info_init])

    # Shared FSMI setup
    fsmi_config = FSMIConfig(
        use_grid_fsmi=True,
        goal_pos=GOAL_POS,
        fov_rad=1.57,
        num_beams=12,
        max_range=5.0,
        ray_step=0.15,
        sigma_range=0.15,
        gaussian_truncation_sigma=3.0,
        trajectory_subsample_rate=8,
        info_weight=25.0,
        motion_weight=0.5,
    )
    fsmi_planner = FSMITrajectoryGenerator(
        config=fsmi_config,
        info_zones=INFO_ZONES,
        grid_map=grid_map_obj,
    )

    uniform_fsmi_config = UniformFSMIConfig(
        fov_rad=1.57,
        num_beams=6,
        max_range=2.5,
        ray_step=0.2,
        info_weight=5.0,
    )
    uniform_fsmi = UniformFSMI(
        uniform_fsmi_config,
        map_origin,
        map_resolution,
    )

    def block_until_ready(tree):
        for leaf in jax.tree_util.tree_leaves(tree):
            leaf.block_until_ready()

    # MPPI
    mppi_config, mppi_state = mppi.create(
        nx=NX, nu=NU, noise_sigma=NOISE_SIGMA, num_samples=1000,
        horizon=horizon, lambda_=0.1, u_min=U_MIN, u_max=U_MAX,
        u_init=U_INIT, step_dependent_dynamics=True,
    )

    sim_fn = build_sim_fn(
        mppi_config,
        fsmi_planner,
        uniform_fsmi,
        uniform_fsmi_config,
        horizon,
        sim_steps,
    )

    print("Running MPPI...")

    # Warm-up
    print("  JIT compiling...")
    warm = sim_fn(state, mppi_state)
    block_until_ready(warm)

    # Timed run
    t0 = time.perf_counter()
    final_state, history_x, history_info, targets, actions = sim_fn(
        state, mppi_state
    )
    final_state.block_until_ready()
    elapsed = time.perf_counter() - t0

    action_jerk, traj_jerk = compute_smoothness(actions, history_x, DT)
    goal_dist = float(jnp.linalg.norm(final_state[:3] - GOAL_POS))

    results = {
        "history_x": history_x,
        "history_info": history_info,
        "actions": actions,
        "runtime_s": elapsed,
        "action_jerk": float(action_jerk),
        "traj_jerk": float(traj_jerk),
        "goal_dist": goal_dist,
        "final_state": final_state,
    }
    print(
        f"  MPPI: {elapsed:.2f}s "
        f"({elapsed / sim_duration:.2f}x realtime), "
        f"goal dist: {goal_dist:.2f}m"
    )

    return results


print("Running fixed-parameter simulation...")
print()
mppi_results = run_fixed_simulation()

In [None]:
# MPPI results visualization

fig, axes = plt.subplots(1, 3, figsize=(18, 6))

# 2D trajectory
ax = axes[0]
plot_environment(ax, grid_array, map_resolution, show_labels=False)
history_x = mppi_results["history_x"]
positions = np.array(history_x[:, :2])
n_steps = len(positions)
colors_traj = plt.cm.viridis(np.linspace(0, 1, n_steps))
for i in range(n_steps - 1):
    ax.plot(
        positions[i : i + 2, 0],
        positions[i : i + 2, 1],
        color=colors_traj[i],
        linewidth=2,
    )
ax.set_title("I-MPPI Trajectory")

# Info level depletion
plot_info_levels(axes[1], mppi_results["history_info"], DT)
axes[1].set_title("Information Zone Depletion")

# Control inputs
acts = np.array(mppi_results["actions"])
t_arr = np.arange(len(acts)) * DT
ax3 = axes[2]
ax3.plot(t_arr, acts[:, 0], color="#1f77b4", linewidth=1, label="Thrust")
ax3.plot(t_arr, acts[:, 1], color="#ff7f0e", linewidth=1, label="wx")
ax3.plot(t_arr, acts[:, 2], color="#2ca02c", linewidth=1, label="wy")
ax3.plot(t_arr, acts[:, 3], color="#d62728", linewidth=1, label="wz")
ax3.set_ylabel("Control Input")
ax3.set_xlabel("Time (s)")
ax3.set_title("Control Inputs")
ax3.legend(fontsize=8)
ax3.grid(True, alpha=0.3)

plt.suptitle(
    "I-MPPI: Full FSMI (Layer 2) + "
    "Biased MPPI with Uniform-FSMI (Layer 3)",
    fontsize=14,
)
plt.tight_layout()
plt.show()

# Performance summary
r = mppi_results
print()
print("=" * 60)
print(f"{'Metric':<25} {'Value':>15}")
print("-" * 60)
print(f"{'Runtime (ms)':<25} {r['runtime_s']*1000:>15.1f}")
print(f"{'Goal Distance (m)':<25} {r['goal_dist']:>15.2f}")
print(f"{'Action Jerk':<25} {r['action_jerk']:>15.4f}")
print(f"{'Trajectory Jerk':<25} {r['traj_jerk']:>15.4f}")
print("=" * 60)