# I-MPPI: Interactive Informative Model Predictive Path Integral Control

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/riccardo-enr/jax_mppi/blob/main/examples/i_mppi/i_mppi_interactive_simulation.ipynb)

This notebook provides an interactive simulation of the **two-layer I-MPPI architecture** for informative path planning with GPU acceleration.

## Architecture Overview

```
Layer 2 (FSMI Analyzer, ~5 Hz)         Layer 3 (I-MPPI Controller, ~50 Hz)
+---------------------------------+     +-----------------------------------+
| Full FSMI (O(n^2))             |     | Biased MPPI + Uniform-FSMI (O(n)) |
| - Occupancy grid analysis      | --> | - Tracks Layer 2 reference traj   |
| - Global information planning  |     | - Local informative reactivity    |
| - Reference trajectory output  |     | - Obstacle avoidance              |
+---------------------------------+     +-----------------------------------+
```

**Layer 2** generates reference trajectories maximizing global information gain using Fisher-Shannon Mutual Information (FSMI) on the occupancy grid.

**Layer 3** tracks the reference while maintaining local informative viewpoints via Uniform-FSMI, ensuring reactive exploration even between Layer 2 updates.

**Cost function**: `J = Tracking(ref) + Obstacles - lambda * Uniform_FSMI(local)`

In [None]:
import os
import shutil
import sys

try:
    import google.colab  # type: ignore
    IN_COLAB = True
except ImportError:
    IN_COLAB = False

if IN_COLAB:
    # Running in Google Colab
    target_dir = "/content/jax_mppi"

    # Change to /content first to avoid directory issues
    os.chdir("/content")

    # Remove existing directory if present
    if os.path.exists(target_dir):
        shutil.rmtree(target_dir)
        print(f"Removed existing directory: {target_dir}")

    # Clone the repository (shallow clone for speed)
    !git clone --depth 1 https://github.com/riccardo-enr/jax_mppi.git /content/jax_mppi

    if not os.path.exists("/content/jax_mppi/src/jax_mppi"):
        raise RuntimeError(
            "git clone failed -- check that the repository is public "
            "and the URL is correct."
        )

    # Install Python dependencies
    %pip install -q jaxtyping chex matplotlib ipywidgets Pillow

    # Add source and examples directories to sys.path
    sys.path.insert(0, "/content/jax_mppi/src")
    sys.path.insert(0, "/content/jax_mppi/examples/i_mppi")

    # Verify helper modules exist
    helpers_dir = "/content/jax_mppi/examples/i_mppi"
    required_modules = ["env_setup.py", "viz_utils.py", "sim_utils.py"]
    missing = [m for m in required_modules if not os.path.exists(os.path.join(helpers_dir, m))]

    if missing:
        raise RuntimeError(
            f"Helper modules not found in repository: {missing}\n"
            f"Make sure you are using the latest version of this notebook."
        )

    os.chdir("/content/jax_mppi")
    print("Setup complete!")
else:
    print("Not running in Colab -- assuming local installation.")

# Verify JAX and GPU
import jax

print(f"JAX version: {jax.__version__}")
print(f"Available devices: {jax.devices()}")
gpu_available = any(
    "gpu" in str(d).lower() or "cuda" in str(d).lower() for d in jax.devices()
)
if gpu_available:
    print("GPU detected.")
else:
    print("WARNING: No GPU detected. Simulation will run on CPU (slower).")
    print("  In Colab: Runtime > Change runtime type > GPU")

In [None]:
import importlib
import os
import sys
import time

import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np

# Ensure helper modules (env_setup, viz_utils, sim_utils) are importable.
# They live in the same directory as this notebook (examples/i_mppi/).
# Depending on how/where the notebook is launched the cwd varies, so we
# check a few candidate locations and add the first hit to sys.path.
_candidates = [
    os.getcwd(),                                          # notebook dir
    os.path.join(os.getcwd(), "examples", "i_mppi"),      # repo root
    "/content/jax_mppi/examples/i_mppi",                  # Colab
]
for _d in _candidates:
    if os.path.isfile(os.path.join(_d, "env_setup.py")):
        if _d not in sys.path:
            sys.path.insert(0, _d)
        break

# Reload library modules so source edits take effect without kernel restart
import jax_mppi.i_mppi.environment as _env_mod

importlib.reload(_env_mod)

import env_setup
import sim_utils
import viz_utils

from jax_mppi import mppi
from jax_mppi.i_mppi.environment import GOAL_POS, INFO_ZONES
from jax_mppi.i_mppi.fsmi import (
    FSMIConfig,
    FSMITrajectoryGenerator,
    UniformFSMI,
    UniformFSMIConfig,
)

importlib.reload(env_setup)
importlib.reload(viz_utils)
importlib.reload(sim_utils)

from env_setup import create_grid_map
from sim_utils import (
    CONTROL_HZ,
    DT,
    NOISE_SIGMA,
    NU,
    NX,
    U_INIT,
    U_MAX,
    U_MIN,
    build_sim_fn,
    compute_smoothness,
)
from viz_utils import (
    create_trajectory_gif,
    plot_environment,
    plot_info_levels,
)

print("All imports successful.")

## Environment Setup

The simulation takes place in a **14m x 12m office-like environment** with:
- **Walls** (gray): Physical obstacles the quadrotor must avoid
- **Info zones** (yellow): Unknown regions with high entropy -- the robot gains information by flying near them
- **Goal** (red star): Target position the robot must reach
- **Start** (green circle): Initial robot position

The occupancy grid encodes uncertainty: 0.5 = unknown, 0.2 = known free, 0.9 = known occupied.

In [None]:
grid_map_obj, grid_array, map_origin, map_resolution = create_grid_map()

fig, ax = plt.subplots(1, 1, figsize=(10, 8))
plot_environment(ax, grid_array, map_resolution)
ax.set_title("Office Environment: Occupancy Grid")
ax.legend(loc="upper left")
plt.colorbar(
    ax.images[0],
    ax=ax,
    label="Occupancy (0=free, 0.5=unknown, 1=occupied)",
)
plt.tight_layout()
plt.show()

## Interactive Parameter Configuration

Adjust the parameters below and click **Run Simulation** to execute the I-MPPI controller.

| Parameter | Description | Effect |
|-----------|-------------|--------|
| **Samples** | Number of MPPI rollout samples | More = better trajectories, slower |
| **Horizon** | Planning horizon (steps) | Longer = further look-ahead |
| **Lambda** | MPPI temperature | Lower = more greedy (exploits best) |
| **Info Weight** | Uniform-FSMI weight in Layer 3 | Higher = more exploration-driven |
| **FSMI Beams** | Number of sensor beams (Layer 2) | More = better information estimate |
| **FSMI Range** | Max sensor range (Layer 2) | Longer = wider planning scope |

In [None]:
import ipywidgets as widgets
from IPython.display import clear_output, display

# Global dict to hold results for the visualization cell
sim_results = {}

# System parameters
w_start_x = widgets.FloatSlider(
    value=1.0, min=0.5, max=12.0, step=0.5, description="Start X (m):"
)
w_start_y = widgets.FloatSlider(
    value=5.0, min=0.5, max=9.0, step=0.5, description="Start Y (m):"
)
w_duration = widgets.FloatSlider(
    value=30.0, min=5.0, max=60.0, step=5.0, description="Duration (s):"
)

# Controller parameters
w_samples = widgets.IntSlider(
    value=1000, min=100, max=5000, step=100, description="Samples:"
)
w_horizon = widgets.IntSlider(
    value=40, min=10, max=80, step=5, description="Horizon:"
)
w_lambda = widgets.FloatLogSlider(
    value=0.1, min=-2, max=1, step=0.1, description="Lambda:"
)
w_info_weight = widgets.FloatSlider(
    value=5.0, min=0.0, max=20.0, step=1.0, description="Info Weight:"
)

# FSMI parameters
w_fsmi_beams = widgets.IntSlider(
    value=12, min=4, max=24, step=2, description="FSMI Beams:"
)
w_fsmi_range = widgets.FloatSlider(
    value=5.0, min=2.0, max=10.0, step=0.5, description="FSMI Range:"
)

run_button = widgets.Button(
    description="Run Simulation",
    button_style="success",
    icon="play",
    layout=widgets.Layout(width="200px", height="40px"),
)
output_area = widgets.Output()


def run_simulation(button):
    """Run I-MPPI simulation with current widget parameters."""
    with output_area:
        clear_output(wait=True)

        # Read parameters
        start_x = w_start_x.value
        start_y = w_start_y.value
        sim_duration = w_duration.value
        num_samples = w_samples.value
        horizon = w_horizon.value
        lambda_ = w_lambda.value
        info_weight = w_info_weight.value
        fsmi_beams = w_fsmi_beams.value
        fsmi_range = w_fsmi_range.value

        sim_steps = int(round(sim_duration * CONTROL_HZ))

        print("=" * 60)
        print("I-MPPI Simulation")
        print("=" * 60)
        print(f"  Start: ({start_x}, {start_y})")
        print(f"  Duration: {sim_duration}s ({sim_steps} steps)")
        print(
            f"  Samples: {num_samples}, Horizon: {horizon}, "
            f"Lambda: {lambda_}"
        )
        print(f"  Info Weight: {info_weight}")
        print(f"  FSMI Beams: {fsmi_beams}, Range: {fsmi_range}m")
        print()

        # Initial state
        start_pos = jnp.array([start_x, start_y, -2.0])
        info_init = jnp.array([100.0, 100.0, 100.0])
        x0 = jnp.zeros(13)
        x0 = x0.at[:3].set(start_pos)
        x0 = x0.at[6].set(1.0)  # qw=1
        state = jnp.concatenate([x0, info_init])

        # Layer 2: FSMI config
        fsmi_config = FSMIConfig(
            use_grid_fsmi=True,
            goal_pos=GOAL_POS,
            fov_rad=1.57,
            num_beams=fsmi_beams,
            max_range=fsmi_range,
            ray_step=0.15,
            sigma_range=0.15,
            gaussian_truncation_sigma=3.0,
            trajectory_subsample_rate=8,
            info_weight=25.0,
            motion_weight=0.5,
        )
        fsmi_planner = FSMITrajectoryGenerator(
            config=fsmi_config,
            info_zones=INFO_ZONES,
            grid_map=grid_map_obj,
        )

        # Layer 3: Uniform-FSMI config
        uniform_fsmi_config = UniformFSMIConfig(
            fov_rad=1.57,
            num_beams=6,
            max_range=2.5,
            ray_step=0.2,
            info_weight=info_weight,
        )
        uniform_fsmi = UniformFSMI(
            uniform_fsmi_config,
            map_origin,
            map_resolution,
        )

        # MPPI config
        config, ctrl_state = mppi.create(
            nx=NX,
            nu=NU,
            noise_sigma=NOISE_SIGMA,
            num_samples=num_samples,
            horizon=horizon,
            lambda_=lambda_,
            u_min=U_MIN,
            u_max=U_MAX,
            u_init=U_INIT,
            step_dependent_dynamics=True,
        )

        # Build simulation
        sim_fn = build_sim_fn(
            config,
            fsmi_planner,
            uniform_fsmi,
            uniform_fsmi_config,
            grid_map_obj,
            horizon,
            sim_steps,
        )

        # JIT compile + run
        print("JIT compiling + running...")
        t0 = time.perf_counter()
        final_state, history_x, history_info, targets, actions, done_step = sim_fn(
            state, ctrl_state
        )
        final_state.block_until_ready()
        runtime = time.perf_counter() - t0

        # Truncate history to actual simulation length
        done_step_int = int(done_step)
        if done_step_int > 0:
            n_active = done_step_int
            actual_duration = n_active * DT
            print(f"  Task completed at step {n_active} ({actual_duration:.1f}s)")
        else:
            n_active = sim_steps
            actual_duration = sim_duration
            print(f"  Timeout reached ({sim_duration}s)")

        history_x = history_x[:n_active]
        history_info = history_info[:n_active]
        actions = actions[:n_active]

        # Metrics
        action_jerk, traj_jerk = compute_smoothness(actions, history_x, DT)
        final_pos = final_state[:3]
        goal_dist = float(jnp.linalg.norm(final_pos - GOAL_POS))

        print(f"  Runtime: {runtime:.2f}s ({runtime / sim_duration:.2f}x realtime)")
        print(f"  Goal distance: {goal_dist:.2f}m")
        print(f"  Info levels: {final_state[13:]}")
        print()

        # Store results for the visualization cell
        sim_results.update({
            "history_x": history_x,
            "history_info": history_info,
            "actions": actions,
            "runtime_s": runtime,
            "action_jerk": float(action_jerk),
            "traj_jerk": float(traj_jerk),
            "goal_dist": goal_dist,
            "final_state": final_state,
            "done_step": done_step_int,
            "actual_duration": actual_duration,
        })
        print("Done. Run the next cell to visualize results and generate GIF.")


run_button.on_click(run_simulation)

# Layout
system_box = widgets.VBox(
    [
        widgets.HTML("<h3>System Parameters</h3>"),
        w_start_x,
        w_start_y,
        w_duration,
    ]
)
controller_box = widgets.VBox(
    [
        widgets.HTML("<h3>Controller Parameters</h3>"),
        w_samples,
        w_horizon,
        w_lambda,
        w_info_weight,
    ]
)
fsmi_box = widgets.VBox(
    [
        widgets.HTML("<h3>FSMI Parameters (Layer 2)</h3>"),
        w_fsmi_beams,
        w_fsmi_range,
    ]
)

param_panel = widgets.HBox([system_box, controller_box, fsmi_box])
display(widgets.VBox([param_panel, run_button, output_area]))

In [None]:
assert sim_results, "No results yet â€” click 'Run Simulation' first."

r = sim_results
history_x = r["history_x"]
history_info = r["history_info"]
actions = r["actions"]

fig, axes = plt.subplots(1, 3, figsize=(18, 6))

# 2D trajectory
ax = axes[0]
plot_environment(ax, grid_array, map_resolution, show_labels=False)
positions = np.array(history_x[:, :2])
n_steps = len(positions)
colors_traj = plt.cm.viridis(np.linspace(0, 1, n_steps))
for i in range(n_steps - 1):
    ax.plot(
        positions[i : i + 2, 0],
        positions[i : i + 2, 1],
        color=colors_traj[i],
        linewidth=2,
    )
ax.set_title("I-MPPI Trajectory")

# Info level depletion
plot_info_levels(axes[1], history_info, DT)
axes[1].set_title("Information Zone Depletion")

# Control inputs
acts = np.array(actions)
t_arr = np.arange(len(acts)) * DT
ax3 = axes[2]
ax3.plot(t_arr, acts[:, 0], color="#1f77b4", linewidth=1, label="Thrust")
ax3.plot(t_arr, acts[:, 1], color="#ff7f0e", linewidth=1, label="wx")
ax3.plot(t_arr, acts[:, 2], color="#2ca02c", linewidth=1, label="wy")
ax3.plot(t_arr, acts[:, 3], color="#d62728", linewidth=1, label="wz")
ax3.set_ylabel("Control Input")
ax3.set_xlabel("Time (s)")
ax3.set_title("Control Inputs")
ax3.legend(fontsize=8)
ax3.grid(True, alpha=0.3)

status = "Completed" if r["done_step"] > 0 else "Timeout"
plt.suptitle(
    f"I-MPPI: Full FSMI (Layer 2) + "
    f"Biased MPPI with Uniform-FSMI (Layer 3) [{status}]",
    fontsize=14,
)
plt.tight_layout()
plt.show()

# Performance summary
print()
print("=" * 60)
print(f"{'Metric':<25} {'Value':>15}")
print("-" * 60)
print(f"{'Status':<25} {status:>15}")
print(f"{'Sim Duration (s)':<25} {r['actual_duration']:>15.1f}")
print(f"{'Runtime (ms)':<25} {r['runtime_s']*1000:>15.1f}")
print(f"{'Goal Distance (m)':<25} {r['goal_dist']:>15.2f}")
print(f"{'Action Jerk':<25} {r['action_jerk']:>15.4f}")
print(f"{'Trajectory Jerk':<25} {r['traj_jerk']:>15.4f}")
print("=" * 60)

# Generate trajectory animation GIF
print()
print("Generating trajectory GIF...")
gif_path = create_trajectory_gif(
    history_x,
    history_info,
    grid_array,
    map_resolution,
    DT,
    save_path="i_mppi_trajectory.gif",
)
print(f"Saved to: {gif_path}")

# Display GIF inline (works in both Colab and Jupyter)
from IPython.display import Image, display

display(Image(filename=gif_path))