# AxiomCUDA Usage Example

This notebook demonstrates how to use **AxiomCUDA**, a GPU-accelerated implementation of the AXIOM (Active Inference agent) framework by VERSES Research.

> **Credit**: Based on the original AXIOM work by VERSES Research (https://github.com/VersesTech/axiom)

AxiomCUDA provides high-performance CUDA-accelerated tensor operations while maintaining the same API as the original AXIOM framework. This notebook covers:

1. Setup and imports
2. Environment initialization
3. Model initialization (SMM, RMM, TMM, IMM)
4. Running inference and planning
5. Visualizing model states
6. Running full episodes
7. Analyzing results

---

## 1. Setup and Imports

First, we import the necessary modules from axiomcuda and check GPU availability.

In [None]:
# Core axiomcuda imports
import axiomcuda
from axiomcuda import visualize as vis
from axiomcuda import infer as ax
from axiomcuda.models import rmm as rmm_tools
from axiomcuda.models import imm as imm_tools
from axiomcuda.models import smm as smm_tools
from axiomcuda.models import tmm as tmm_tools
from axiomcuda import config as ax_config
from axiomcuda import planner
from axiomcuda.defaults import ExperimentConfig, create_smm_configs

# JAX imports
import jax
import jax.numpy as jnp
import jax.random as jr

# Other utilities
import numpy as np
import matplotlib.pyplot as plt
from dataclasses import dataclass
from typing import Sequence

# Environment
import gymnasium

# Check GPU availability
print(f"CUDA Available: {axiomcuda.cuda_available()}")
print(f"JAX Devices: {jax.devices()}")

# Set random seeds for reproducibility
SEED = 42
key = jr.PRNGKey(SEED)
np.random.seed(SEED)

---

## 2. Initialize Environment

Create a Gameworld environment. AxiomCUDA works with Atari-like environments from the Gameworld suite.

In [None]:
# Import gameworld environments (triggers registration)
try:
    import gameworld.envs
    GAMEWORLD_AVAILABLE = True
except ImportError:
    print("Warning: gameworld package not available. Using mock environment.")
    GAMEWORLD_AVAILABLE = False

if GAMEWORLD_AVAILABLE:
    # Create environment
    GAME_NAME = "Explode"
    env = gymnasium.make(f'Gameworld-{GAME_NAME}-v0')
    
    # Reset environment to get initial observation
    obs, info = env.reset(seed=SEED)
    obs = obs.astype(np.uint8)
    
    print(f"Environment: {GAME_NAME}")
    print(f"Observation shape: {obs.shape}")
    print(f"Action space: {env.action_space}")
    action_dim = env.action_space.n
    print(f"Number of actions: {action_dim}")
else:
    # Mock environment for demonstration
    print("Creating mock environment...")
    obs = np.random.randint(0, 255, (210, 160, 3), dtype=np.uint8)
    action_dim = 4
    env = None

In [None]:
# Display the initial observation
fig, ax = plt.subplots(figsize=(8, 6))
ax.imshow(obs)
ax.set_title("Initial Observation")
ax.axis('off')
plt.show()

---

## 3. Initialize Models

AxiomCUDA uses a modular architecture with four main models:

- **SMM (Slot Mixture Model)**: Object discovery and segmentation
- **TMM (Transition Mixture Model)**: Object dynamics and motion patterns
- **RMM (Reward Mixture Model)**: Reward function and interaction modeling
- **IMM (Identity Mixture Model)**: Object type identification

Let's create configurations for each model.

In [None]:
# Create SMM configuration (Slot Mixture Model)
smm_config = ax_config.SMMConfig(
    width=160,
    height=210,
    input_dim=5,  # x, y, r, g, b
    slot_dim=2,
    num_slots=32,
    use_bias=True,
    ns_a=1.0,
    ns_b=1.0,
    dof_offset=10.0,
    mask_prob=(0.0, 0.0, 0.0, 0.0, 1.0),
    scale=(0.075, 0.075, 0.75, 0.75, 0.75),
    transform_inv_v_scale=100.0,
    bias_inv_v_scale=0.001,
    num_e_steps=2,
    learning_rate=1.0,
    beta=0.0,
    eloglike_threshold=5.0,
    max_grow_steps=20,
)
print("SMM Configuration created")
print(f"  - Number of slots: {smm_config.num_slots}")
print(f"  - Input dimension: {smm_config.input_dim}")

In [None]:
# Create TMM configuration (Transition Mixture Model)
tmm_config = ax_config.TMMConfig(
    n_total_components=200,
    state_dim=2,
    dt=1.0,
    vu=0.05,
    use_bias=True,
    sigma_sqr=2.0,
    logp_threshold=-0.00001,
    position_threshold=0.15,
    use_unused_counter=True,
    use_velocity=True,
    clip_value=5e-4,
)
print("TMM Configuration created")
print(f"  - Total components: {tmm_config.n_total_components}")
print(f"  - State dimension: {tmm_config.state_dim}")

In [None]:
# Create IMM configuration (Identity Mixture Model)
imm_config = ax_config.IMMConfig(
    num_object_types=32,
    num_features=5,
    i_ell_threshold=-500.0,
    cont_scale_identity=0.5,
    color_precision_scale=1.0,
    color_only_identity=False,
)
print("IMM Configuration created")
print(f"  - Number of object types: {imm_config.num_object_types}")
print(f"  - Number of features: {imm_config.num_features}")

In [None]:
# Create RMM configuration (Reward Mixture Model)
rmm_config = ax_config.RMMConfig(
    num_components_per_switch=25,
    num_switches=100,
    num_object_types=32,
    num_features=5,
    num_continuous_dims=7,
    interact_with_static=False,
    r_ell_threshold=-100.0,
    i_ell_threshold=-500.0,
    cont_scale_identity=0.5,
    cont_scale_switch=25.0,
    discrete_alphas=(1e-4, 1e-4, 1e-4, 1e-4, 1.0, 1e-4),
    r_interacting=0.6,
    r_interacting_predict=0.6,
    forward_predict=False,
    stable_r=False,
    relative_distance=True,
    absolute_distance_scale=False,
    reward_prob_threshold=0.45,
    color_precision_scale=1.0,
    color_only_identity=False,
    exclude_background=True,
    use_ellipses_for_interaction=True,
    velocity_scale=10.0,
)
print("RMM Configuration created")
print(f"  - Components per switch: {rmm_config.num_components_per_switch}")
print(f"  - Number of switches: {rmm_config.num_switches}")

In [None]:
# Create Planner configuration (MPPI)
planner_config = ax_config.PlannerConfig(
    num_steps=24,        # Planning horizon
    num_policies=512,    # Number of action sequences to sample
    num_samples_per_policy=1,
    topk_ratio=0.1,      # Fraction of top samples for refitting
    random_ratio=0.5,    # Fraction of random samples
    alpha=1.0,           # Learning rate
    temperature=10.0,    # Temperature for softmax
    normalize=True,
    iters=1,
    gamma=0.99,          # Discount factor
    repeat_prob=0.0,
    info_gain=1.0,       # Weight for information gain
    lazy_reward=False,
    sample_action=False,
)
print("Planner Configuration created")
print(f"  - Planning horizon: {planner_config.num_steps}")
print(f"  - Number of policies: {planner_config.num_policies}")

In [None]:
# Create the full experiment configuration
config = ExperimentConfig(
    name="axiomcuda_demo",
    id="demo_001",
    group="demo",
    seed=SEED,
    game="Explode" if env is not None else "Mock",
    num_steps=1000,
    smm=(smm_config,),  # Tuple of SMM configs (can be multiple for hierarchical)
    imm=imm_config,
    tmm=tmm_config,
    rmm=rmm_config,
    planner=planner_config,
    moving_threshold=1e-2,
    used_threshold=0.2,
    min_track_steps=(1, 1),
    max_steps_tracked_unused=10,
    prune_every=500,
    use_unused_counter=True,
    project="axiomcuda",
    precision_type="float32",
    layer_for_dynamics=0,
    warmup_smm=False,
    num_warmup_steps=50,
    velocity_clip_value=7.5e-4,
    bmr_samples=2000,
    bmr_pairs=2000,
)

print("Experiment Configuration created successfully!")
print(f"\nConfiguration Summary:")
print(f"  - Experiment name: {config.name}")
print(f"  - Game: {config.game}")
print(f"  - Seed: {config.seed}")
print(f"  - Number of steps: {config.num_steps}")
print(f"  - Number of SMM layers: {len(config.smm) if isinstance(config.smm, Sequence) else 1}")

In [None]:
# Initialize all models using axiomcuda.init
# This creates the initial carry dictionary containing all model states

key, subkey = jr.split(key)
carry = ax.init(subkey, config, obs, action_dim)

print("Models initialized successfully!")
print(f"\nCarry dictionary keys:")
for key_name in carry.keys():
    print(f"  - {key_name}")

In [None]:
# Examine the structure of each model
print("SMM Model Structure:")
print(f"  - Type: {type(carry['smm_model'])}")
if hasattr(carry['smm_model'], 'stats'):
    print(f"  - Stats keys: {list(carry['smm_model'].stats.keys())}")

print("\nTMM Model Structure:")
print(f"  - Type: {type(carry['tmm_model'])}")

print("\nRMM Model Structure:")
print(f"  - Type: {type(carry['rmm_model'])}")
if hasattr(carry['rmm_model'], 'used_mask'):
    print(f"  - Used components: {carry['rmm_model'].used_mask.sum()}")

print("\nIMM Model Structure:")
print(f"  - Type: {type(carry['imm_model'])}")
if hasattr(carry['imm_model'], 'used_mask'):
    print(f"  - Used identities: {carry['imm_model'].used_mask.sum()}")

---

## 4. Run Inference

Now let's run a single step through the environment to see how the models work together:

1. **Plan**: Use MPPI to select an action
2. **Step**: Execute the action in the environment
3. **Update**: Update models with the new observation

In [None]:
# Step 1: Plan action using MPPI (Model Predictive Path Integral)
key, subkey = jr.split(key)
action, carry, plan_info = ax.plan_fn(subkey, carry, config, action_dim)

print(f"Planned action: {action}")
print(f"\nPlanning info keys: {list(plan_info.keys())}")

# Inspect plan info
if 'rewards' in plan_info:
    print(f"\nRewards shape: {plan_info['rewards'].shape}")
    print(f"Expected utility shape: {plan_info['expected_utility'].shape}")
    print(f"Expected info gain shape: {plan_info['expected_info_gain'].shape}")

In [None]:
# Step 2: Execute action in environment
if env is not None:
    obs, reward, done, truncated, info = env.step(action)
    obs = obs.astype(np.uint8)
else:
    # Mock step
    obs = np.random.randint(0, 255, (210, 160, 3), dtype=np.uint8)
    reward = np.random.randn()
    done = False

print(f"Step result:")
print(f"  - Action taken: {action}")
print(f"  - Reward received: {reward:.4f}")
print(f"  - Episode done: {done}")
print(f"  - New observation shape: {obs.shape}")

In [None]:
# Step 3: Update models with the new observation
carry, records = ax.step_fn(
    carry,
    config,
    obs,
    jnp.array(reward),
    action,
    num_tracked=0,
    update=True,
    remap_color=False,
)

print("Models updated successfully!")
print(f"\nRecords keys: {list(records.keys())}")

# Inspect the records
if 'decoded_mu' in records:
    print(f"\nDecoded means shape: {records['decoded_mu'][0].shape}")
if 'switches' in records:
    print(f"TMM switches: {records['switches']}")
if 'rmm_switches' in records:
    print(f"RMM switches: {records['rmm_switches']}")

---

## 5. Visualize Models

Let's visualize the internal state of each model after the first update.

In [None]:
# Visualize RMM (Reward Mixture Model)
# The RMM components represent learned reward functions and interactions

fig, axes = plt.subplots(2, 2, figsize=(16, 16))

# Color by different attributes
colorize_options = ['switch', 'reward', 'cluster', 'infogain']

for idx, colorize in enumerate(colorize_options):
    ax = axes[idx // 2, idx % 2]
    try:
        # Get the plot
        rmm_img = vis.plot_rmm(
            carry['rmm_model'], 
            carry['imm_model'],
            colorize=colorize,
            return_ax=False
        )
        ax.imshow(rmm_img)
        ax.set_title(f'RMM - Colored by {colorize}')
        ax.axis('off')
    except Exception as e:
        ax.text(0.5, 0.5, f'Error: {str(e)}', ha='center', va='center')
        ax.set_title(f'RMM - {colorize} (Error)')

plt.tight_layout()
plt.suptitle('RMM State Visualization', fontsize=16, y=1.02)
plt.show()

print("Colorization options:")
print("  - switch: Components colored by TMM switch assignment")
print("  - reward: Components colored by reward value (red=negative, green=positive)")
print("  - cluster: Components colored by identity")
print("  - infogain: Components colored by information gain")

In [None]:
# Visualize SMM slots
# SMM performs object discovery and segmentation

if 'decoded_mu' in records and 'decoded_sigma' in records:
    smm_img = vis.plot_smm(
        records['decoded_mu'][0],
        records['decoded_sigma'][0],
        carry['smm_model'].stats['offset'],
        carry['smm_model'].stats['stdevs'],
        width=160,
        height=210,
        qz=records['qz'][0] if 'qz' in records else None
    )
    
    fig, ax = plt.subplots(figsize=(8, 6))
    ax.imshow(smm_img)
    ax.set_title('SMM Slot Visualization')
    ax.axis('off')
    plt.show()
    
    print(f"Number of slots: {len(records['decoded_mu'][0])}")
    print(f"Each ellipse represents a discovered object/slot")

In [None]:
# Visualize the planning results
# Shows the top-k trajectories considered by MPPI

if env is not None:
    try:
        plan_img = vis.plot_plan(
            obs,
            plan_info,
            carry['tracked_obj_ids'][config.layer_for_dynamics],
            carry['smm_model'].stats,
            decoded_mu=records['decoded_mu'][0] if 'decoded_mu' in records else None,
            topk=5
        )
        
        fig, ax = plt.subplots(figsize=(8, 6))
        ax.imshow(plan_img)
        ax.set_title('MPPI Planning Visualization')
        ax.axis('off')
        plt.show()
        
        print("The visualization shows:")
        print("  - Current observation as background")
        print("  - Top-5 planned trajectories overlaid")
        print("  - Trajectories colored by expected reward")
    except Exception as e:
        print(f"Error plotting plan: {e}")

In [None]:
# Visualize Identity Model (IMM)
# IMM learns to identify object types based on their visual features

try:
    imm_img = vis.plot_identity_model(
        carry['imm_model'],
        color_only_identity=imm_config.color_only_identity
    )
    
    fig, ax = plt.subplots(figsize=(12, 8))
    ax.imshow(imm_img)
    ax.set_title('IMM - Learned Object Identities')
    ax.axis('off')
    plt.show()
    
    print("Each subplot shows a learned object identity:")
    print("  - Ellipse shows shape and size")
    print("  - Color shows learned RGB values")
    print("  - Title indicates if this identity is active (True/False)")
except Exception as e:
    print(f"Error plotting IMM: {e}")

In [None]:
# Visualize TMM (Transition Mixture Model) components
# TMM models different motion patterns (static, constant velocity, etc.)

if hasattr(carry['tmm_model'], 'transitions'):
    try:
        tmm_img = vis.plot_tmm(
            carry['tmm_model'].transitions,
            carry['tmm_model'].used_mask,
            width=160,
            height=210
        )
        
        fig, ax = plt.subplots(figsize=(12, 8))
        ax.imshow(tmm_img)
        ax.set_title('TMM - Transition Components')
        ax.axis('off')
        plt.show()
        
        used_components = int(carry['tmm_model'].used_mask.sum())
        print(f"Total TMM components: {len(carry['tmm_model'].transitions)}")
        print(f"Used components: {used_components}")
    except Exception as e:
        print(f"Error plotting TMM: {e}")

---

## 6. Run Episode

Now let's run a full episode (multiple steps) to see how the agent learns over time.
We'll track:
- Rewards collected
- Model component growth
- Generate a video of the gameplay

In [None]:
# Reset environment for a fresh episode
if env is not None:
    obs, _ = env.reset(seed=SEED)
    obs = obs.astype(np.uint8)
else:
    obs = np.random.randint(0, 255, (210, 160, 3), dtype=np.uint8)

# Re-initialize models
key, subkey = jr.split(key)
carry = ax.init(subkey, config, obs, action_dim)

print("Environment reset and models re-initialized")
print(f"Starting new episode...")

In [None]:
# Run episode for N steps
N_STEPS = 100  # Number of steps to run

# Data collection arrays
observations = [obs]
rewards_list = []
actions_list = []
component_counts = []
expected_utilities = []
expected_info_gains = []

print(f"Running {N_STEPS} steps...")

for step in range(N_STEPS):
    # Plan action
    key, subkey = jr.split(key)
    action, carry, plan_info = ax.plan_fn(subkey, carry, config, action_dim)
    
    # Track planning metrics
    if 'rewards' in plan_info:
        best_idx = jnp.argsort(plan_info['rewards'][:, :, 0].sum(0))[-1]
        expected_utilities.append(
            float(plan_info['expected_utility'][:, best_idx, :].mean(-1).sum(0))
        )
        expected_info_gains.append(
            float(plan_info['expected_info_gain'][:, best_idx, :].mean(-1).sum(0))
        )
    
    # Track RMM component count
    if hasattr(carry['rmm_model'], 'used_mask'):
        component_counts.append(int(carry['rmm_model'].used_mask.sum()))
    
    # Execute action
    if env is not None:
        obs, reward, done, truncated, info = env.step(action)
        obs = obs.astype(np.uint8)
    else:
        obs = np.random.randint(0, 255, (210, 160, 3), dtype=np.uint8)
        reward = np.random.randn()
        done = False
    
    # Store data
    observations.append(obs)
    rewards_list.append(float(reward))
    actions_list.append(int(action))
    
    # Update models
    carry, records = ax.step_fn(
        carry,
        config,
        obs,
        jnp.array(reward),
        action,
        num_tracked=0,
        update=True,
        remap_color=False,
    )
    
    # Progress update every 25 steps
    if (step + 1) % 25 == 0:
        recent_reward = np.mean(rewards_list[-25:])
        recent_components = np.mean(component_counts[-25:]) if component_counts else 0
        print(f"Step {step+1}/{N_STEPS}: "
              f"Avg reward (last 25): {recent_reward:.3f}, "
              f"Avg components: {recent_components:.1f}")
    
    if done:
        print(f"Episode done at step {step+1}")
        break

print(f"\nEpisode complete! Total steps: {len(rewards_list)}")
print(f"Total reward: {sum(rewards_list):.2f}")
print(f"Average reward per step: {np.mean(rewards_list):.3f}")

In [None]:
# Generate video from observations
try:
    import mediapy
    
    # Show video
    print("Generating gameplay video...")
    mediapy.show_video(observations, fps=30, title="AxiomCUDA Gameplay")
except ImportError:
    print("mediapy not available. Showing first few frames instead:")
    
    # Show first 4 frames
    fig, axes = plt.subplots(2, 2, figsize=(12, 10))
    for idx, ax in enumerate(axes.flat):
        if idx < len(observations):
            ax.imshow(observations[idx])
            ax.set_title(f'Frame {idx}')
            ax.axis('off')
    plt.tight_layout()
    plt.show()

---

## 7. Analyze Results

Let's analyze what the agent learned during the episode.

In [None]:
# Plot reward curve
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# 1. Reward over time
ax1 = axes[0, 0]
ax1.plot(rewards_list, alpha=0.5, color='blue', label='Step reward')
# Add moving average
window = min(25, len(rewards_list) // 4 + 1)
if window > 1:
    moving_avg = np.convolve(rewards_list, np.ones(window)/window, mode='valid')
    ax1.plot(range(window-1, len(rewards_list)), moving_avg, 
             color='red', linewidth=2, label=f'{window}-step MA')
ax1.set_xlabel('Step')
ax1.set_ylabel('Reward')
ax1.set_title('Reward Curve')
ax1.legend()
ax1.grid(True, alpha=0.3)

# 2. Cumulative reward
ax2 = axes[0, 1]
cumulative_rewards = np.cumsum(rewards_list)
ax2.plot(cumulative_rewards, color='green')
ax2.set_xlabel('Step')
ax2.set_ylabel('Cumulative Reward')
ax2.set_title('Cumulative Reward Over Time')
ax2.grid(True, alpha=0.3)

# 3. Component count over time
if component_counts:
    ax3 = axes[1, 0]
    ax3.plot(component_counts, color='purple')
    ax3.set_xlabel('Step')
    ax3.set_ylabel('RMM Components')
    ax3.set_title('Model Growth: RMM Components')
    ax3.grid(True, alpha=0.3)
    
    # 4. Component growth rate
    ax4 = axes[1, 1]
    if len(component_counts) > 1:
        growth = np.diff(component_counts)
        ax4.bar(range(len(growth)), growth, alpha=0.7, color='orange')
        ax4.axhline(y=0, color='black', linestyle='-', linewidth=0.5)
        ax4.set_xlabel('Step')
        ax4.set_ylabel('Component Change')
        ax4.set_title('Component Growth Rate')
        ax4.grid(True, alpha=0.3)

plt.tight_layout()
plt.suptitle('Episode Analysis', fontsize=16, y=1.02)
plt.show()

print(f"\nEpisode Statistics:")
print(f"  - Total steps: {len(rewards_list)}")
print(f"  - Total reward: {sum(rewards_list):.2f}")
print(f"  - Mean reward: {np.mean(rewards_list):.3f} Â± {np.std(rewards_list):.3f}")
print(f"  - Max reward: {max(rewards_list):.2f}")
print(f"  - Min reward: {min(rewards_list):.2f}")
if component_counts:
    print(f"  - Final components: {component_counts[-1]}")
    print(f"  - Component growth: {component_counts[-1] - component_counts[0]}")

In [None]:
# Plot planning metrics
if expected_utilities and expected_info_gains:
    fig, axes = plt.subplots(2, 1, figsize=(12, 8))
    
    # Expected utility
    ax1 = axes[0]
    ax1.plot(expected_utilities, color='blue', alpha=0.7)
    ax1.set_xlabel('Step')
    ax1.set_ylabel('Expected Utility')
    ax1.set_title('MPPI Expected Utility Over Time')
    ax1.grid(True, alpha=0.3)
    
    # Expected information gain
    ax2 = axes[1]
    ax2.plot(expected_info_gains, color='red', alpha=0.7)
    ax2.set_xlabel('Step')
    ax2.set_ylabel('Expected Info Gain')
    ax2.set_title('MPPI Expected Information Gain Over Time')
    ax2.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    print(f"\nPlanning Statistics:")
    print(f"  - Mean expected utility: {np.mean(expected_utilities):.3f}")
    print(f"  - Mean expected info gain: {np.mean(expected_info_gains):.3f}")

In [None]:
# Show final model states
print("Final Model States:\n")

# RMM state
print("RMM (Reward Mixture Model):")
if hasattr(carry['rmm_model'], 'used_mask'):
    used = carry['rmm_model'].used_mask.sum()
    total = len(carry['rmm_model'].used_mask)
    print(f"  - Used components: {used}/{total} ({100*used/total:.1f}%)")

# IMM state
print("\nIMM (Identity Mixture Model):")
if hasattr(carry['imm_model'], 'used_mask'):
    used = carry['imm_model'].used_mask.sum()
    total = len(carry['imm_model'].used_mask)
    print(f"  - Used identities: {used}/{total} ({100*used/total:.1f}%)")

# TMM state
print("\nTMM (Transition Mixture Model):")
if hasattr(carry['tmm_model'], 'used_mask'):
    used = carry['tmm_model'].used_mask.sum()
    total = len(carry['tmm_model'].used_mask)
    print(f"  - Used transitions: {used}/{total} ({100*used/total:.1f}%)")

# SMM state
print("\nSMM (Slot Mixture Model):")
if hasattr(carry['smm_model'], 'stats'):
    print(f"  - Image dimensions: {carry['smm_model'].stats.get('width', 'N/A')}x{carry['smm_model'].stats.get('height', 'N/A')}")
if hasattr(carry['smm_model'], 'prior'):
    print(f"  - Number of slots: {len(carry['smm_model'].prior.alpha)}")

In [None]:
# Final visualizations
fig, axes = plt.subplots(2, 2, figsize=(16, 14))

# 1. RMM with cluster coloring
try:
    rmm_img = vis.plot_rmm(
        carry['rmm_model'], 
        carry['imm_model'],
        colorize='cluster',
        return_ax=False
    )
    axes[0, 0].imshow(rmm_img)
    axes[0, 0].set_title('Final RMM State (Colored by Identity)')
    axes[0, 0].axis('off')
except Exception as e:
    axes[0, 0].text(0.5, 0.5, f'Error: {str(e)}', ha='center', va='center')
    axes[0, 0].set_title('RMM (Error)')

# 2. IMM
try:
    imm_img = vis.plot_identity_model(carry['imm_model'])
    axes[0, 1].imshow(imm_img)
    axes[0, 1].set_title('Final IMM State')
    axes[0, 1].axis('off')
except Exception as e:
    axes[0, 1].text(0.5, 0.5, f'Error: {str(e)}', ha='center', va='center')
    axes[0, 1].set_title('IMM (Error)')

# 3. Last observation
axes[1, 0].imshow(observations[-1])
axes[1, 0].set_title('Final Observation')
axes[1, 0].axis('off')

# 4. Action distribution
unique_actions, counts = np.unique(actions_list, return_counts=True)
axes[1, 1].bar(unique_actions, counts, alpha=0.7)
axes[1, 1].set_xlabel('Action')
axes[1, 1].set_ylabel('Count')
axes[1, 1].set_title('Action Distribution')
axes[1, 1].grid(True, alpha=0.3)

plt.tight_layout()
plt.suptitle('Final State Summary', fontsize=16, y=1.02)
plt.show()

---

## Summary

This notebook demonstrated the core functionality of AxiomCUDA:

### Key Components

1. **SMM (Slot Mixture Model)**: Discovers objects by learning slot-based representations
2. **TMM (Transition Mixture Model)**: Learns object dynamics and motion patterns
3. **RMM (Reward Mixture Model)**: Models rewards and object interactions
4. **IMM (Identity Mixture Model)**: Identifies object types from visual features

### Key Functions

- `ax.init()`: Initialize all models
- `ax.plan_fn()`: Plan actions using MPPI
- `ax.step_fn()`: Update models with new observations
- `ax.reduce_fn_rmm()`: Bayesian Model Reduction for component pruning

### Visualization Functions

- `vis.plot_rmm()`: Visualize reward model components
- `vis.plot_smm()`: Visualize slot assignments
- `vis.plot_identity_model()`: Visualize learned object identities
- `vis.plot_plan()`: Visualize planning trajectories
- `vis.plot_tmm()`: Visualize transition dynamics

### Further Reading

- Original AXIOM paper: https://github.com/VersesTech/axiom
- Active Inference framework
- Model Predictive Path Integral (MPPI) control

---

**Credits**: This notebook is based on the AXIOM framework by VERSES Research. AxiomCUDA provides a GPU-accelerated implementation while maintaining API compatibility with the original framework.