# Advanced Multi-Agent RL Features Demo

This notebook demonstrates state-of-the-art features for multi-agent reinforcement learning:

1. **Attention-based Communication (TarMAC)** - Learned, targeted communication
2. **Graph Neural Networks (GNN)** - Scalable agent coordination
3. **Recurrent Policies (LSTM)** - Handling partial observability
4. **Intrinsic Curiosity Module (ICM)** - Exploration bonus

## Why These Features Matter

- **TarMAC**: Enables agents to learn WHAT and WHEN to communicate, improving coordination
- **GNN**: Scales to large numbers of agents by modeling relationships as graphs
- **LSTM**: Essential for partially observable environments (real-world scenarios)
- **ICM**: Improves exploration in sparse reward environments

## Academic/Interview Relevance

These are cutting-edge techniques from top-tier conferences (NeurIPS, ICML, ICLR).

## Setup and Imports

In [None]:
import sys
import os
sys.path.append('..')

import numpy as np
import torch
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm

# Import agents
from src.marl.agents import (
    DQNAgent,
    AttentionDQNAgent,
    GNNDQNAgent,
    LSTMDQNAgent
)

# Import environments
from src.marl.environments import MultiAgentGridWorld

# Import utilities
from src.marl.utils import attention, graph_networks, curiosity

# Set random seeds for reproducibility
np.random.seed(42)
torch.manual_seed(42)

# Set device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

## 1. Attention-based Communication (TarMAC)

### Key Concepts:
- Agents generate **messages** from their observations
- **Attention mechanism** determines which agents' messages to focus on
- **Gated integration** controls how much to use communicated information

### Reference:
Das et al., "TarMAC: Targeted Multi-Agent Communication" (ICML 2019)

In [None]:
# Create environment
env = MultiAgentGridWorld(
    grid_size=(8, 8),
    n_agents=3,
    n_targets=3,
    max_steps=50
)

# Create attention-based agents
attention_agents = [
    AttentionDQNAgent(
        agent_id=i,
        observation_space=env.observation_space,
        action_space=env.action_space,
        message_dim=32,
        hidden_dim=64,
        num_heads=4,
        use_communication=True,
        device=device
    )
    for i in range(3)
]

print("✓ Created 3 agents with TarMAC communication")
print(f"  - Message dimension: 32")
print(f"  - Number of attention heads: 4")
print(f"  - Hidden dimension: 64")

### Training Loop with Communication

In [None]:
def train_with_communication(agents, env, n_episodes=100):
    """
    Train agents with attention-based communication.
    """
    episode_rewards = []
    
    for episode in tqdm(range(n_episodes), desc="Training with TarMAC"):
        observations, _ = env.reset()
        episode_reward = 0
        done = False
        step = 0
        
        # Store messages from all agents
        messages = {i: None for i in range(len(agents))}
        
        while not done and step < 50:
            actions = {}
            new_messages = {}
            
            # Each agent selects action and generates message
            for i, agent in enumerate(agents):
                # Collect messages from OTHER agents
                other_messages = []
                for j in range(len(agents)):
                    if j != i and messages[j] is not None:
                        other_messages.append(messages[j])
                
                # Stack messages if available
                if len(other_messages) > 0:
                    other_messages_tensor = torch.stack(other_messages)
                else:
                    other_messages_tensor = None
                
                # Get action and message
                action, message = agent.get_action(
                    observations[i],
                    other_messages_tensor,
                    training=True
                )
                
                actions[i] = action
                new_messages[i] = message
            
            # Update messages
            messages = new_messages
            
            # Step environment
            next_observations, rewards, terminated, truncated, _ = env.step(actions)
            
            # Store experiences and update agents
            for i, agent in enumerate(agents):
                agent.store_experience(
                    observations[i],
                    actions[i],
                    rewards[i],
                    next_observations[i],
                    terminated[i]
                )
                agent.update()
            
            observations = next_observations
            episode_reward += sum(rewards.values())
            done = all(terminated.values())
            step += 1
        
        episode_rewards.append(episode_reward)
    
    return episode_rewards

# Train agents
rewards = train_with_communication(attention_agents, env, n_episodes=100)

# Plot results
plt.figure(figsize=(10, 5))
plt.plot(rewards, alpha=0.3)
plt.plot(np.convolve(rewards, np.ones(10)/10, mode='valid'), linewidth=2)
plt.xlabel('Episode')
plt.ylabel('Total Reward')
plt.title('Training with Attention-based Communication (TarMAC)')
plt.grid(True, alpha=0.3)
plt.show()

print(f"\n✓ Training complete!")
print(f"  Average reward (last 20 episodes): {np.mean(rewards[-20:]):.2f}")

### Visualizing Attention Weights

Let's see which agents each agent pays attention to:

In [None]:
# Run one episode and collect attention weights
observations, _ = env.reset()
messages = {i: None for i in range(len(attention_agents))}
attention_history = []

for step in range(10):  # Run for 10 steps
    actions = {}
    new_messages = {}
    step_attention = []
    
    for i, agent in enumerate(attention_agents):
        other_messages = [messages[j] for j in range(len(attention_agents)) if j != i and messages[j] is not None]
        
        if len(other_messages) > 0:
            other_messages_tensor = torch.stack(other_messages)
        else:
            other_messages_tensor = None
        
        action, message = agent.get_action(observations[i], other_messages_tensor, training=False)
        actions[i] = action
        new_messages[i] = message
        
        # Get attention weights
        attn_weights = agent.get_attention_weights()
        if attn_weights is not None:
            step_attention.append(attn_weights.cpu().numpy())
    
    messages = new_messages
    observations, _, terminated, truncated, _ = env.step(actions)
    
    if len(step_attention) > 0:
        attention_history.append(step_attention)
    
    if all(terminated.values()):
        break

print("\n✓ Collected attention weights during execution")
print(f"  Steps recorded: {len(attention_history)}")
print("\nAttention weights show which agents each agent focuses on during communication.")

## 2. Graph Neural Networks (GNN)

### Key Concepts:
- Model agents as **nodes** in a graph
- Communication as **message passing** on edges
- **Dynamic graph construction** based on proximity
- Scalable to 100+ agents

### Reference:
Velickovic et al., "Graph Attention Networks" (ICLR 2018)

In [None]:
# Create GNN-based agents
gnn_agents = [
    GNNDQNAgent(
        agent_id=i,
        observation_space=env.observation_space,
        action_space=env.action_space,
        n_agents=3,
        gnn_type="gat",  # Graph Attention Network
        num_gnn_layers=2,
        num_heads=4,
        device=device
    )
    for i in range(3)
]

print("✓ Created 3 agents with Graph Neural Networks")
print(f"  - GNN type: Graph Attention Network (GAT)")
print(f"  - Number of GNN layers: 2")
print(f"  - Number of attention heads: 4")
print("\nGNNs enable scalable communication by modeling agent relationships as graphs.")

### Demonstrating GNN Architecture

In [None]:
# Show GNN architecture
print("\nGNN Q-Network Architecture:")
print("="*50)
print(gnn_agents[0].q_network)
print("="*50)
print("\nKey Components:")
print("1. Observation Encoder - Processes individual observations")
print("2. GNN Layers - Message passing between agents")
print("3. Q-Value Head - Outputs Q-values for actions")

## 3. Recurrent Policies (LSTM)

### Key Concepts:
- **Memory** of past observations
- Handle **partial observability** (POMDPs)
- **Temporal reasoning** over sequences

### Reference:
Hausknecht & Stone, "Deep Recurrent Q-Learning for Partially Observable MDPs" (AAAI 2015)

In [None]:
# Create LSTM-based agents
lstm_agents = [
    LSTMDQNAgent(
        agent_id=i,
        observation_space=env.observation_space,
        action_space=env.action_space,
        lstm_hidden_dim=64,
        num_lstm_layers=1,
        sequence_length=8,
        device=device
    )
    for i in range(3)
]

print("✓ Created 3 agents with LSTM memory")
print(f"  - LSTM hidden dimension: 64")
print(f"  - Number of LSTM layers: 1")
print(f"  - Sequence length for training: 8")
print("\nLSTM enables agents to remember past observations and make better decisions.")

## 4. Intrinsic Curiosity Module (ICM)

### Key Concepts:
- **Intrinsic reward** based on prediction error
- Encourages exploration of novel states
- Learns task-relevant **features** automatically
- Combines **inverse dynamics** (predict action) and **forward dynamics** (predict next state)

### Reference:
Pathak et al., "Curiosity-driven Exploration by Self-supervised Prediction" (ICML 2017)

In [None]:
# Create ICM module
icm_module = curiosity.IntrinsicCuriosityModule(
    obs_dim=env.observation_space.shape[0],
    action_dim=env.action_space.n,
    feature_dim=32,
    hidden_dim=64,
    beta=0.2,
    eta=0.5
).to(device)

print("✓ Created Intrinsic Curiosity Module")
print(f"  - Feature dimension: 32")
print(f"  - Beta (inverse model weight): 0.2")
print(f"  - Eta (intrinsic reward scale): 0.5")
print("\nICM Architecture:")
print("1. Feature Encoder - Learns task-relevant representations")
print("2. Inverse Model - Predicts action from state transition")
print("3. Forward Model - Predicts next state features (error = curiosity)")

### Demonstrating Intrinsic Rewards

In [None]:
# Collect intrinsic rewards during exploration
observations, _ = env.reset()
intrinsic_rewards_history = []

for step in range(20):
    # Random actions for exploration
    actions = {i: env.action_space.sample() for i in range(3)}
    next_observations, rewards, terminated, truncated, _ = env.step(actions)
    
    # Compute intrinsic rewards for each agent
    step_intrinsic_rewards = []
    for i in range(3):
        obs_tensor = torch.FloatTensor(observations[i]).unsqueeze(0).to(device)
        next_obs_tensor = torch.FloatTensor(next_observations[i]).unsqueeze(0).to(device)
        action_tensor = torch.LongTensor([actions[i]]).to(device)
        
        intrinsic_reward, losses = icm_module(obs_tensor, next_obs_tensor, action_tensor)
        step_intrinsic_rewards.append(intrinsic_reward.item())
    
    intrinsic_rewards_history.append(step_intrinsic_rewards)
    observations = next_observations
    
    if all(terminated.values()):
        break

# Visualize intrinsic rewards
intrinsic_rewards_history = np.array(intrinsic_rewards_history)

plt.figure(figsize=(12, 4))
for i in range(3):
    plt.plot(intrinsic_rewards_history[:, i], label=f'Agent {i}', marker='o')
plt.xlabel('Step')
plt.ylabel('Intrinsic Reward')
plt.title('Intrinsic Curiosity Rewards During Exploration')
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()

print("\n✓ Intrinsic rewards computed successfully!")
print("Higher rewards indicate more novel/surprising states.")

## 5. Comparing All Approaches

Let's compare the performance of different agent architectures:

In [None]:
# This is a simplified comparison - full training would take longer
print("Feature Comparison:")
print("="*80)
print(f"{'Feature':<30} {'Baseline DQN':<15} {'TarMAC':<15} {'GNN':<15} {'LSTM':<15}")
print("="*80)
print(f"{'Communication':<30} {'No':<15} {'Yes (Attn)':<15} {'Yes (Graph)':<15} {'No':<15}")
print(f"{'Memory':<30} {'No':<15} {'No':<15} {'No':<15} {'Yes (LSTM)':<15}")
print(f"{'Scalability (# agents)':<30} {'<10':<15} {'<20':<15} {'100+':<15} {'<10':<15}")
print(f"{'Partial Observability':<30} {'Poor':<15} {'Poor':<15} {'Poor':<15} {'Good':<15}")
print(f"{'Parameters':<30} {'Low':<15} {'Medium':<15} {'Medium':<15} {'Medium':<15}")
print("="*80)
print("\nKey Insights:")
print("• TarMAC: Best for tasks requiring selective communication")
print("• GNN: Best for large-scale multi-agent systems")
print("• LSTM: Best for partially observable environments")
print("• ICM: Useful as add-on for sparse reward environments")

## 6. Implementation Highlights for Interviews

### Key Technical Points to Discuss:

#### 1. **Attention Mechanism**
```python
# Multi-head attention for communication
Q = self.q_proj(query)  # What I'm looking for
K = self.k_proj(key)    # What others are broadcasting
V = self.v_proj(value)  # The actual message content

attention_weights = softmax(Q @ K.T / sqrt(d_k))
output = attention_weights @ V
```

#### 2. **Graph Neural Networks**
```python
# Message passing on agent graph
for layer in gnn_layers:
    # Aggregate neighbor features
    messages = adjacency @ node_features
    # Update node features
    node_features = update_fn(messages)
```

#### 3. **Intrinsic Curiosity**
```python
# Forward model prediction error as curiosity
predicted_next_state = forward_model(state, action)
intrinsic_reward = ||predicted_next_state - actual_next_state||^2
total_reward = extrinsic_reward + eta * intrinsic_reward
```

#### 4. **LSTM for Memory**
```python
# Maintain hidden state across steps
h_t, c_t = lstm(observation_t, (h_{t-1}, c_{t-1}))
q_values = q_network(h_t)
```

## 7. Research Papers and Citations

### Implemented Features:

1. **TarMAC (Targeted Multi-Agent Communication)**
   - Das et al., ICML 2019
   - "TarMAC: Targeted Multi-Agent Communication"

2. **Graph Attention Networks**
   - Veličković et al., ICLR 2018
   - "Graph Attention Networks"

3. **Intrinsic Curiosity Module**
   - Pathak et al., ICML 2017
   - "Curiosity-driven Exploration by Self-supervised Prediction"

4. **Deep Recurrent Q-Learning**
   - Hausknecht & Stone, AAAI 2015
   - "Deep Recurrent Q-Learning for Partially Observable MDPs"

5. **Random Network Distillation**
   - Burda et al., ICLR 2019
   - "Exploration by Random Network Distillation"

### Additional References:
- QMIX (Rashid et al., ICML 2018)
- CommNet (Sukhbaatar et al., NeurIPS 2016)
- MADDPG (Lowe et al., NeurIPS 2017)

## Summary

This notebook demonstrated four cutting-edge features for multi-agent RL:

✅ **Attention-based Communication (TarMAC)** - Learned selective communication

✅ **Graph Neural Networks** - Scalable coordination for many agents

✅ **Recurrent Policies (LSTM)** - Memory for partial observability

✅ **Intrinsic Curiosity Module** - Exploration in sparse reward environments

### Next Steps:
1. Train agents for more episodes to see full convergence
2. Try different hyperparameters
3. Test on other environments (Cooperative Navigation, Predator-Prey)
4. Combine features (e.g., GNN + ICM)
5. Implement additional features (Meta-learning, Hierarchical policies)

### For Your Project/Interview:
- Understand the **mathematical foundations** of each technique
- Be able to explain **when and why** to use each approach
- Discuss **trade-offs** (computation, scalability, performance)
- Mention relevant **research papers** and cite properly