# SSM-MetaRL-TestCompute Demo

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/sunghunkwag/SSM-MetaRL-TestCompute/blob/main/demo.ipynb)

This notebook demonstrates the **SSM-MetaRL-TestCompute** framework, which combines:
- **State Space Models (SSM)** for efficient sequence modeling
- **Meta-Learning (MAML)** for fast adaptation
- **Test-Time Adaptation** for continual learning

## Features
- ✅ Meta-training with MetaMAML
- ✅ Test-time adaptation with gradient updates
- ✅ Gymnasium environment integration
- ✅ CPU-optimized for low-compute scenarios

## 1. Installation

Install required dependencies:

In [None]:
# Install dependencies
!pip install torch numpy gymnasium

# Clone the repository if not already cloned
import os
if not os.path.exists('SSM-MetaRL-TestCompute'):
    !git clone https://github.com/sunghunkwag/SSM-MetaRL-TestCompute.git
    os.chdir('SSM-MetaRL-TestCompute')
else:
    os.chdir('SSM-MetaRL-TestCompute')

print("✓ Installation complete!")

## 2. Import Libraries

In [None]:
import torch
import torch.nn as nn
import numpy as np
import gymnasium as gym
from collections import OrderedDict

# Import SSM-MetaRL components
from core.ssm import StateSpaceModel
from meta_rl.meta_maml import MetaMAML
from adaptation.test_time_adaptation import Adapter, AdaptationConfig
from env_runner.environment import Environment

print("✓ Libraries imported successfully!")
print(f"PyTorch version: {torch.__version__}")
print(f"Device: {'cuda' if torch.cuda.is_available() else 'cpu'}")

## 3. Configuration

Set up hyperparameters for the experiment:

In [None]:
# Configuration
class Config:
    # Environment
    env_name = 'CartPole-v1'
    batch_size = 1
    
    # Model architecture
    state_dim = 32
    hidden_dim = 64
    
    # Meta-training
    num_epochs = 10  # Reduced for demo
    episodes_per_task = 3
    inner_lr = 0.01
    outer_lr = 0.001
    
    # Test-time adaptation
    adapt_lr = 0.01
    num_adapt_steps = 20  # Reduced for demo

config = Config()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

print("Configuration:")
print(f"  Environment: {config.env_name}")
print(f"  Device: {device}")
print(f"  State dim: {config.state_dim}")
print(f"  Hidden dim: {config.hidden_dim}")
print(f"  Meta-training epochs: {config.num_epochs}")

## 4. Initialize Environment and Model

In [None]:
# Initialize environment
env = Environment(env_name=config.env_name, batch_size=config.batch_size)
obs_space = env.observation_space
action_space = env.action_space

input_dim = obs_space.shape[0] if isinstance(obs_space, gym.spaces.Box) else obs_space.n
output_dim = input_dim  # Predicting next observation

print(f"Environment initialized:")
print(f"  Observation space: {obs_space}")
print(f"  Action space: {action_space}")
print(f"  Input/Output dim: {input_dim}/{output_dim}")

# Initialize State Space Model
model = StateSpaceModel(
    state_dim=config.state_dim,
    input_dim=input_dim,
    output_dim=output_dim,
    hidden_dim=config.hidden_dim
).to(device)

print(f"\n✓ Model initialized with {sum(p.numel() for p in model.parameters())} parameters")

## 5. Data Collection Function

In [None]:
def collect_data(env, policy_model, num_episodes=3, max_steps_per_episode=50, device='cpu'):
    """
    Collects trajectory data from the environment.
    """
    all_obs, all_actions, all_rewards, all_next_obs, all_dones = [], [], [], [], []
    policy_model.eval()
    
    obs = env.reset()
    hidden_state = policy_model.init_hidden(batch_size=env.batch_size)
    
    for ep in range(num_episodes):
        steps_in_ep = 0
        done = False
        
        while not done and steps_in_ep < max_steps_per_episode:
            obs_tensor = torch.tensor(obs, dtype=torch.float32).unsqueeze(0).to(device)
            
            with torch.no_grad():
                action_logits, next_hidden_state = policy_model(obs_tensor, hidden_state)
                
                if isinstance(env.action_space, gym.spaces.Discrete):
                    n_actions = env.action_space.n
                    probs = torch.softmax(action_logits[:, :n_actions], dim=-1)
                    action = torch.multinomial(probs, 1).item()
                else:
                    action = action_logits.cpu().numpy().flatten()
            
            next_obs, reward, terminated, truncated, info = env.step(action)
            done = terminated or truncated
            
            all_obs.append(obs)
            all_actions.append(action)
            all_rewards.append(reward)
            all_next_obs.append(next_obs)
            all_dones.append(done)
            
            obs = next_obs
            hidden_state = next_hidden_state
            steps_in_ep += 1
        
        if done:
            obs = env.reset()
            hidden_state = policy_model.init_hidden(batch_size=env.batch_size)
    
    return {
        'observations': np.array(all_obs),
        'actions': np.array(all_actions),
        'rewards': np.array(all_rewards),
        'next_observations': np.array(all_next_obs),
        'dones': np.array(all_dones)
    }

print("✓ Data collection function defined")

## 6. Meta-Training with MetaMAML

Train the model using Model-Agnostic Meta-Learning (MAML):

In [None]:
# Initialize MetaMAML
meta_maml = MetaMAML(
    model=model,
    inner_lr=config.inner_lr,
    outer_lr=config.outer_lr
)

print("Starting meta-training...\n")

for epoch in range(config.num_epochs):
    # Collect data for this meta-task
    task_data = collect_data(
        env, 
        model, 
        num_episodes=config.episodes_per_task,
        max_steps_per_episode=50,
        device=device
    )
    
    # Prepare tensors
    obs_tensor = torch.tensor(task_data['observations'], dtype=torch.float32).to(device)
    next_obs_tensor = torch.tensor(task_data['next_observations'], dtype=torch.float32).to(device)
    
    # Meta-update
    meta_loss = meta_maml.meta_update(
        support_data=(obs_tensor, next_obs_tensor),
        query_data=(obs_tensor, next_obs_tensor),
        num_inner_steps=5
    )
    
    if (epoch + 1) % 2 == 0:
        print(f"Epoch {epoch+1}/{config.num_epochs} - Meta Loss: {meta_loss:.4f}")

print("\n✓ Meta-training completed!")

## 7. Test-Time Adaptation

Demonstrate adaptation to a new task:

In [None]:
# Save meta-learned parameters
meta_learned_params = OrderedDict(
    (name, param.clone().detach()) for name, param in model.named_parameters()
)

# Initialize adapter
adapt_config = AdaptationConfig(
    learning_rate=config.adapt_lr,
    num_steps=config.num_adapt_steps,
    loss_fn=nn.MSELoss()
)

adapter = Adapter(model=model, config=adapt_config)

print("Starting test-time adaptation...\n")

# Collect test data
test_data = collect_data(
    env,
    model,
    num_episodes=2,
    max_steps_per_episode=30,
    device=device
)

obs_test = torch.tensor(test_data['observations'], dtype=torch.float32).to(device)
next_obs_test = torch.tensor(test_data['next_observations'], dtype=torch.float32).to(device)

# Perform adaptation
adaptation_losses = adapter.adapt(
    observations=obs_test,
    targets=next_obs_test
)

print(f"\nAdaptation complete!")
print(f"  Initial loss: {adaptation_losses[0]:.4f}")
print(f"  Final loss: {adaptation_losses[-1]:.4f}")
print(f"  Improvement: {adaptation_losses[0] - adaptation_losses[-1]:.4f}")

## 8. Visualize Adaptation Progress

In [None]:
import matplotlib.pyplot as plt

plt.figure(figsize=(10, 5))
plt.plot(adaptation_losses, linewidth=2)
plt.xlabel('Adaptation Step', fontsize=12)
plt.ylabel('Loss (MSE)', fontsize=12)
plt.title('Test-Time Adaptation Progress', fontsize=14, fontweight='bold')
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

print("✓ Visualization complete!")

## 9. Evaluate Adapted Model

In [None]:
# Evaluate the adapted model
model.eval()
test_episodes = 5
total_rewards = []

for ep in range(test_episodes):
    obs = env.reset()
    hidden_state = model.init_hidden(batch_size=1)
    episode_reward = 0
    done = False
    steps = 0
    max_steps = 200
    
    while not done and steps < max_steps:
        obs_tensor = torch.tensor(obs, dtype=torch.float32).unsqueeze(0).to(device)
        
        with torch.no_grad():
            action_logits, hidden_state = model(obs_tensor, hidden_state)
            
            if isinstance(env.action_space, gym.spaces.Discrete):
                n_actions = env.action_space.n
                probs = torch.softmax(action_logits[:, :n_actions], dim=-1)
                action = torch.multinomial(probs, 1).item()
            else:
                action = action_logits.cpu().numpy().flatten()
        
        obs, reward, terminated, truncated, _ = env.step(action)
        done = terminated or truncated
        episode_reward += reward
        steps += 1
    
    total_rewards.append(episode_reward)
    print(f"Episode {ep+1}: Reward = {episode_reward:.2f}, Steps = {steps}")

print(f"\nAverage Reward: {np.mean(total_rewards):.2f} ± {np.std(total_rewards):.2f}")
print("✓ Evaluation complete!")

## 10. Summary

This demo showcased:

1. ✅ **Meta-Training**: Trained a State Space Model using MetaMAML
2. ✅ **Test-Time Adaptation**: Adapted the model to new data with gradient updates
3. ✅ **Evaluation**: Tested the adapted model on the environment

### Key Takeaways

- **SSM**: Efficient sequence modeling with linear complexity
- **MetaMAML**: Fast adaptation from meta-learned initialization
- **Adaptation**: Continual learning at test time

### Next Steps

- Try different environments (e.g., `'MountainCar-v0'`, `'Acrobot-v1'`)
- Adjust hyperparameters for better performance
- Extend to more complex tasks

---

**Repository**: [SSM-MetaRL-TestCompute](https://github.com/sunghunkwag/SSM-MetaRL-TestCompute)

**License**: MIT