# SSM-MetaRL-Unified Demo: Experience-Augmented Meta-RL

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/sunghunkwag/SSM-MetaRL-Unified/blob/main/demo.ipynb)

This notebook demonstrates the **SSM-MetaRL-Unified** framework, which integrates:
- **State Space Models (SSM)** for efficient sequence modeling
- **Meta-Learning (MAML)** for fast adaptation
- **Standard Test-Time Adaptation** (baseline)
- **Hybrid Test-Time Adaptation** with **Experience Replay** (novel approach)

## Key Innovation

The unified framework allows you to compare two adaptation strategies:
1. **Standard Mode**: Adapts using only current task data
2. **Hybrid Mode**: Augments adaptation with past experiences from a replay buffer

The hybrid approach can lead to more robust and sample-efficient learning!

## 1. Installation

Install required dependencies and clone the repository:

In [None]:
# Install dependencies
!pip install torch numpy gymnasium matplotlib -q

# Clone the unified repository
import os
if not os.path.exists('SSM-MetaRL-Unified'):
    !git clone https://github.com/sunghunkwag/SSM-MetaRL-Unified.git
    os.chdir('SSM-MetaRL-Unified')
else:
    os.chdir('SSM-MetaRL-Unified')

print("✓ Installation complete!")

## 2. Import Modules

Import all necessary components from the unified framework:

In [None]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import gymnasium as gym

# Import unified framework components
from core.ssm import StateSpaceModel
from meta_rl.meta_maml import MetaMAML
from adaptation import StandardAdapter, StandardAdaptationConfig
from adaptation import HybridAdapter, HybridAdaptationConfig
from experience.experience_buffer import ExperienceBuffer
from env_runner.environment import Environment

print("✓ Imports successful!")
print(f"PyTorch version: {torch.__version__}")
print(f"Device: {torch.device('cuda' if torch.cuda.is_available() else 'cpu')}")

## 3. Setup Environment and Model

Initialize the CartPole environment and create the SSM model:

In [None]:
# Configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
env_name = 'CartPole-v1'
state_dim = 32
hidden_dim = 64

# Initialize environment
env = Environment(env_name=env_name, batch_size=1)
input_dim = env.observation_space.shape[0]
output_dim = input_dim  # Predict next observation

print(f"Environment: {env_name}")
print(f"Input dimension: {input_dim}")
print(f"Output dimension: {output_dim}")

# Create SSM model
model = StateSpaceModel(
    state_dim=state_dim,
    input_dim=input_dim,
    output_dim=output_dim,
    hidden_dim=hidden_dim
).to(device)

print(f"\n✓ Model created with {sum(p.numel() for p in model.parameters()):,} parameters")

## 4. Meta-Training with MetaMAML

Train the model using MetaMAML to enable fast adaptation:

In [None]:
def collect_trajectory(env, model, num_steps=50, device='cpu'):
    """Collect a trajectory from the environment."""
    observations = []
    next_observations = []
    
    obs = env.reset()
    hidden_state = model.init_hidden(batch_size=1)
    
    for _ in range(num_steps):
        obs_tensor = torch.tensor(obs, dtype=torch.float32).unsqueeze(0).to(device)
        
        with torch.no_grad():
            output, hidden_state = model(obs_tensor, hidden_state)
            action = env.action_space.sample()
        
        next_obs, reward, done, info = env.step(action)
        
        observations.append(obs)
        next_observations.append(next_obs)
        
        obs = next_obs
        if done:
            obs = env.reset()
            hidden_state = model.init_hidden(batch_size=1)
    
    # Convert to tensors with shape (1, T, D)
    obs_seq = torch.tensor(np.array(observations), dtype=torch.float32).unsqueeze(0).to(device)
    next_obs_seq = torch.tensor(np.array(next_observations), dtype=torch.float32).unsqueeze(0).to(device)
    
    return obs_seq, next_obs_seq

# Meta-training
print("Starting meta-training...")
meta_learner = MetaMAML(model=model, inner_lr=0.01, outer_lr=0.001)
meta_losses = []

num_epochs = 10
for epoch in range(num_epochs):
    # Collect support and query data
    obs_support, next_obs_support = collect_trajectory(env, model, num_steps=50, device=device)
    obs_query, next_obs_query = collect_trajectory(env, model, num_steps=50, device=device)
    
    # Create tasks list
    tasks = [(obs_support, next_obs_support, obs_query, next_obs_query)]
    
    # Meta-update
    initial_hidden = model.init_hidden(batch_size=1)
    loss = meta_learner.meta_update(
        tasks=tasks,
        initial_hidden_state=initial_hidden,
        loss_fn=nn.MSELoss()
    )
    
    meta_losses.append(loss)
    
    if (epoch + 1) % 2 == 0:
        print(f"Epoch {epoch+1}/{num_epochs}: Meta Loss = {loss:.4f}")

print("\n✓ Meta-training completed!")

# Plot meta-training loss
plt.figure(figsize=(10, 4))
plt.plot(meta_losses, marker='o')
plt.xlabel('Epoch')
plt.ylabel('Meta Loss')
plt.title('Meta-Training Progress')
plt.grid(True)
plt.show()

## 5. Test-Time Adaptation: Standard Mode

Test the standard adaptation approach (baseline):

In [None]:
print("Testing Standard Adaptation Mode...")

# Create standard adapter
standard_config = StandardAdaptationConfig(
    learning_rate=0.01,
    num_steps=5
)
standard_adapter = StandardAdapter(
    model=model,
    config=standard_config,
    device=device
)

# Test adaptation
obs = env.reset()
hidden_state = model.init_hidden(batch_size=1)
standard_losses = []

num_adapt_steps = 20
for step in range(num_adapt_steps):
    obs_tensor = torch.tensor(obs, dtype=torch.float32).unsqueeze(0).to(device)
    
    with torch.no_grad():
        output, next_hidden_state = model(obs_tensor, hidden_state)
        action = env.action_space.sample()
    
    next_obs, reward, done, info = env.step(action)
    next_obs_tensor = torch.tensor(next_obs, dtype=torch.float32).unsqueeze(0).to(device)
    
    # Perform adaptation
    loss_val, _ = standard_adapter.update_step(
        x=obs_tensor,
        y=next_obs_tensor,
        hidden_state=hidden_state
    )
    
    standard_losses.append(loss_val)
    
    obs = next_obs
    hidden_state = next_hidden_state
    
    if done:
        obs = env.reset()
        hidden_state = model.init_hidden(batch_size=1)

print(f"✓ Standard adaptation completed!")
print(f"  Initial loss: {standard_losses[0]:.4f}")
print(f"  Final loss: {standard_losses[-1]:.4f}")
print(f"  Improvement: {(1 - standard_losses[-1]/standard_losses[0])*100:.1f}%")

## 6. Test-Time Adaptation: Hybrid Mode (Experience-Augmented)

Now test the hybrid adaptation with experience replay:

In [None]:
print("Testing Hybrid Adaptation Mode (with Experience Replay)...")

# Initialize experience buffer
experience_buffer = ExperienceBuffer(max_size=1000, device=str(device))
print(f"Initialized ExperienceBuffer (max_size=1000)")

# Populate buffer with some initial experiences
print("Populating experience buffer...")
for _ in range(5):
    obs_traj, next_obs_traj = collect_trajectory(env, model, num_steps=20, device=device)
    # Add to buffer (flatten time dimension)
    for t in range(obs_traj.shape[1]):
        experience_buffer.add(
            obs_traj[:, t, :],
            next_obs_traj[:, t, :]
        )

print(f"Buffer populated with {len(experience_buffer)} experiences")

# Create hybrid adapter
hybrid_config = HybridAdaptationConfig(
    learning_rate=0.01,
    num_steps=5,
    experience_batch_size=16,
    experience_weight=0.1
)
hybrid_adapter = HybridAdapter(
    model=model,
    config=hybrid_config,
    experience_buffer=experience_buffer,
    device=device
)

# Test adaptation
obs = env.reset()
hidden_state = model.init_hidden(batch_size=1)
hybrid_losses = []

num_adapt_steps = 20
for step in range(num_adapt_steps):
    obs_tensor = torch.tensor(obs, dtype=torch.float32).unsqueeze(0).to(device)
    
    with torch.no_grad():
        output, next_hidden_state = model(obs_tensor, hidden_state)
        action = env.action_space.sample()
    
    next_obs, reward, done, info = env.step(action)
    next_obs_tensor = torch.tensor(next_obs, dtype=torch.float32).unsqueeze(0).to(device)
    
    # Perform hybrid adaptation (uses experience replay)
    loss_val, _ = hybrid_adapter.update_step(
        x_current=obs_tensor,
        y_current=next_obs_tensor,
        hidden_state_current=hidden_state
    )
    
    hybrid_losses.append(loss_val)
    
    obs = next_obs
    hidden_state = next_hidden_state
    
    if done:
        obs = env.reset()
        hidden_state = model.init_hidden(batch_size=1)

print(f"✓ Hybrid adaptation completed!")
print(f"  Initial loss: {hybrid_losses[0]:.4f}")
print(f"  Final loss: {hybrid_losses[-1]:.4f}")
print(f"  Improvement: {(1 - hybrid_losses[-1]/hybrid_losses[0])*100:.1f}%")

env.close()

## 7. Comparison: Standard vs. Hybrid Adaptation

Visualize the difference between the two adaptation strategies:

In [None]:
plt.figure(figsize=(12, 5))

plt.subplot(1, 2, 1)
plt.plot(standard_losses, marker='o', label='Standard', color='blue', alpha=0.7)
plt.plot(hybrid_losses, marker='s', label='Hybrid (Experience-Augmented)', color='red', alpha=0.7)
plt.xlabel('Adaptation Step')
plt.ylabel('Loss')
plt.title('Adaptation Loss Comparison')
plt.legend()
plt.grid(True)

plt.subplot(1, 2, 2)
methods = ['Standard', 'Hybrid']
initial_losses = [standard_losses[0], hybrid_losses[0]]
final_losses = [standard_losses[-1], hybrid_losses[-1]]

x = np.arange(len(methods))
width = 0.35

plt.bar(x - width/2, initial_losses, width, label='Initial Loss', alpha=0.7)
plt.bar(x + width/2, final_losses, width, label='Final Loss', alpha=0.7)
plt.xlabel('Adaptation Mode')
plt.ylabel('Loss')
plt.title('Initial vs. Final Loss')
plt.xticks(x, methods)
plt.legend()
plt.grid(True, axis='y')

plt.tight_layout()
plt.show()

print("\n" + "="*60)
print("COMPARISON SUMMARY")
print("="*60)
print(f"Standard Mode:")
print(f"  Loss reduction: {(1 - standard_losses[-1]/standard_losses[0])*100:.1f}%")
print(f"\nHybrid Mode (Experience-Augmented):")
print(f"  Loss reduction: {(1 - hybrid_losses[-1]/hybrid_losses[0])*100:.1f}%")
print(f"  Buffer size: {len(experience_buffer)} experiences")
print("="*60)

## 8. Conclusion

This demo showcased the **SSM-MetaRL-Unified** framework's dual adaptation capabilities:

1. **Standard Adaptation**: Uses only current task data for adaptation (baseline approach)
2. **Hybrid Adaptation**: Augments adaptation with past experiences from a replay buffer (novel approach)

### Key Takeaways

- The **ExperienceBuffer** enables the model to learn from past trajectories during adaptation
- The **hybrid loss** combines current data with sampled experiences for more robust learning
- This approach can be particularly beneficial in:
  - Sparse reward environments
  - Non-stationary tasks
  - Sample-limited scenarios

### Next Steps

- Try different buffer sizes and experience weights
- Test on more complex environments
- Run SOTA benchmarks on MuJoCo tasks
- Compare against LSTM, GRU, and Transformer baselines

For more information, visit the [GitHub repository](https://github.com/sunghunkwag/SSM-MetaRL-Unified)!