# CA12: Multi-Agent Reinforcement Learning and Advanced Policy Methods

## Deep Reinforcement Learning - Session 12

**Multi-Agent Reinforcement Learning (MARL), Advanced Policy Gradient Methods, and Distributed Training**

This notebook explores advanced reinforcement learning topics including multi-agent systems, sophisticated policy gradient methods, distributed training techniques, and modern approaches to collaborative and competitive learning environments.

### Learning Objectives:
1. Understand multi-agent reinforcement learning fundamentals
2. Implement cooperative and competitive MARL algorithms
3. Master advanced policy gradient methods (PPO, TRPO, SAC variants)
4. Explore distributed training and asynchronous methods
5. Implement communication and coordination mechanisms
6. Understand game-theoretic foundations of MARL
7. Apply meta-learning and few-shot adaptation
8. Analyze emergent behaviors in multi-agent systems

### Notebook Structure:
1. **Multi-Agent Foundations** - Game theory and MARL basics
2. **Cooperative Multi-Agent Learning** - Centralized training, decentralized execution
3. **Competitive and Mixed-Motive Systems** - Self-play and adversarial training
4. **Advanced Policy Methods** - PPO variants, SAC improvements, TRPO
5. **Distributed Reinforcement Learning** - A3C, IMPALA, and modern distributed methods
6. **Communication and Coordination** - Message passing and emergent communication
7. **Meta-Learning in RL** - Few-shot adaptation and transfer learning
8. **Comprehensive Applications** - Real-world multi-agent scenarios

---

In [1]:
# Essential Imports and Advanced Setup for Multi-Agent RL
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.distributions import Normal, Categorical, MultivariateNormal, kl_divergence
import torch.multiprocessing as mp
import gymnasium as gym
from gymnasium import spaces
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
from collections import defaultdict, deque, namedtuple
import random
import pickle
import json
import copy
import time
import threading
from typing import Tuple, List, Dict, Optional, Union, NamedTuple, Any
import warnings
from dataclasses import dataclass, field
import math
from tqdm import tqdm
from abc import ABC, abstractmethod
import itertools
warnings.filterwarnings('ignore')

# Advanced imports for multi-agent systems
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.distributed as dist
from torch.nn import TransformerEncoder, TransformerEncoderLayer
from torch.utils.data import DataLoader, Dataset
import networkx as nx
from scipy.spatial.distance import pdist, squareform
from sklearn.cluster import KMeans
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA

# Game theory and optimization
from scipy.optimize import minimize, linprog
from scipy.special import softmax
import cvxpy as cp

# Set random seeds for reproducibility
SEED = 42
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
random.seed(SEED)

# Device configuration with multi-GPU support
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
n_gpus = torch.cuda.device_count()
if torch.cuda.is_available():
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

print(f"🤖 Multi-Agent Reinforcement Learning Environment Setup")
print(f"Device: {device}")
print(f"Available GPUs: {n_gpus}")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")

# Advanced plotting configuration
plt.style.use('seaborn-v0_8')
plt.rcParams['figure.figsize'] = (16, 10)
plt.rcParams['font.size'] = 12
plt.rcParams['axes.titlesize'] = 16
plt.rcParams['axes.labelsize'] = 14
plt.rcParams['xtick.labelsize'] = 11
plt.rcParams['ytick.labelsize'] = 11
plt.rcParams['legend.fontsize'] = 11

# Color schemes for multi-agent visualizations
agent_colors = sns.color_palette("Set2", 8)
performance_colors = sns.color_palette("viridis", 6)
sns.set_palette(agent_colors)

# Configuration classes for advanced RL
@dataclass
class MultiAgentConfig:
    """Configuration for multi-agent systems."""
    n_agents: int = 2
    state_dim: int = 10
    action_dim: int = 4
    hidden_dim: int = 128
    lr: float = 3e-4
    gamma: float = 0.99
    tau: float = 0.005
    batch_size: int = 256
    buffer_size: int = 100000
    update_freq: int = 10
    communication: bool = False
    message_dim: int = 32
    coordination_mechanism: str = "centralized"  # centralized, decentralized, mixed

@dataclass 
class PolicyConfig:
    """Configuration for advanced policy methods."""
    algorithm: str = "PPO"  # PPO, TRPO, SAC, DDPG, TD3
    clip_ratio: float = 0.2
    target_kl: float = 0.01
    entropy_coef: float = 0.01
    value_coef: float = 0.5
    max_grad_norm: float = 0.5
    n_epochs: int = 10
    minibatch_size: int = 64
    use_gae: bool = True
    gae_lambda: float = 0.95

# Global configurations
ma_config = MultiAgentConfig()
policy_config = PolicyConfig()

print("✅ Multi-Agent RL environment setup complete!")
print(f"🎯 Configuration: {ma_config.n_agents} agents, {ma_config.coordination_mechanism} coordination")
print("🚀 Ready for advanced multi-agent reinforcement learning!")

ModuleNotFoundError: No module named 'cvxpy'

# Section 1: Multi-Agent Foundations and Game Theory

## 1.1 Theoretical Foundation

### Multi-Agent Reinforcement Learning (MARL)

Multi-Agent Reinforcement Learning extends single-agent RL to environments with multiple learning agents. Key challenges include:

1. **Non-stationarity**: The environment appears non-stationary from each agent's perspective as other agents learn
2. **Partial observability**: Agents may have limited information about others' actions and observations
3. **Credit assignment**: Determining individual contributions to team rewards
4. **Scalability**: Computational complexity grows exponentially with number of agents
5. **Equilibrium concepts**: Finding stable solutions in multi-agent settings

### Game-Theoretic Foundations

**Nash Equilibrium**: A strategy profile where no agent can improve by unilaterally changing strategy.

For agents $i = 1, ..., n$ with strategy spaces $S_i$ and utility functions $u_i(s_1, ..., s_n)$:
$$s^* = (s_1^*, ..., s_n^*) \text{ is a Nash equilibrium if } \forall i, s_i: u_i(s_i^*, s_{-i}^*) \geq u_i(s_i, s_{-i}^*)$$

**Pareto Optimality**: A strategy profile is Pareto optimal if no other profile improves at least one agent's utility without decreasing another's.

**Stackelberg Equilibrium**: Leader-follower game structure where one agent commits to a strategy first.

### MARL Paradigms

1. **Independent Learning**: Each agent treats others as part of the environment
2. **Joint Action Learning**: Agents learn about others' actions and adapt accordingly  
3. **Multi-Agent Actor-Critic (MAAC)**: Centralized training with decentralized execution
4. **Communication-Based Learning**: Agents exchange information to coordinate

### Cooperation vs Competition Spectrum

- **Fully Cooperative**: Shared reward, common goal (e.g., team sports)
- **Fully Competitive**: Zero-sum game (e.g., adversarial settings)
- **Mixed-Motive**: Partially cooperative and competitive (e.g., resource sharing)

### Mathematical Formulation

**Multi-Agent MDP (MMDP)**:
- State space: $\mathcal{S}$
- Joint action space: $\mathcal{A} = \mathcal{A}_1 \times ... \times \mathcal{A}_n$
- Transition dynamics: $P(s'|s, a_1, ..., a_n)$
- Reward functions: $R_i(s, a_1, ..., a_n, s')$ for each agent $i$
- Discount factor: $\gamma \in [0, 1)$

**Policy Gradient in MARL**:
$$\nabla_{\theta_i} J_i(\theta_i) = \mathbb{E}_{\tau \sim \pi_{\theta}}[\sum_{t=0}^T \nabla_{\theta_i} \log \pi_{\theta_i}(a_{i,t}|o_{i,t}) A_i^t]$$

Where $A_i^t$ is agent $i$'s advantage at time $t$, which can be computed using various methods including multi-agent value functions.

---

In [None]:
# Game Theory Utilities and Basic Multi-Agent Framework

class GameTheoryUtils:
    """Utility class for game-theoretic analysis."""
    
    @staticmethod
    def find_nash_equilibria(payoff_matrices):
        """
        Find pure strategy Nash equilibria for n-player games.
        
        Args:
            payoff_matrices: List of payoff matrices, one per player
        Returns:
            List of Nash equilibrium strategy profiles
        """
        n_players = len(payoff_matrices)
        if n_players != 2:
            raise NotImplementedError("Only 2-player games supported")
            
        matrix_a, matrix_b = payoff_matrices[0], payoff_matrices[1]
        nash_equilibria = []
        
        rows, cols = matrix_a.shape
        
        for i in range(rows):
            for j in range(cols):
                # Check if (i,j) is a Nash equilibrium
                is_nash = True
                
                # Check if player 1 wants to deviate
                for i_prime in range(rows):
                    if matrix_a[i_prime, j] > matrix_a[i, j]:
                        is_nash = False
                        break
                
                # Check if player 2 wants to deviate
                if is_nash:
                    for j_prime in range(cols):
                        if matrix_b[i, j_prime] > matrix_b[i, j]:
                            is_nash = False
                            break
                
                if is_nash:
                    nash_equilibria.append((i, j))
        
        return nash_equilibria
    
    @staticmethod
    def is_pareto_optimal(payoff_matrices, strategy_profile):
        """Check if a strategy profile is Pareto optimal."""
        current_payoffs = [matrix[strategy_profile] for matrix in payoff_matrices]
        
        # Check all other strategy profiles
        for profile in itertools.product(*[range(matrix.shape[i]) for i, matrix in enumerate(payoff_matrices)]):
            if profile == strategy_profile:
                continue
                
            candidate_payoffs = [matrix[profile] for matrix in payoff_matrices]
            
            # Check if candidate dominates current
            dominates = True
            strictly_better = False
            
            for i in range(len(current_payoffs)):
                if candidate_payoffs[i] < current_payoffs[i]:
                    dominates = False
                    break
                elif candidate_payoffs[i] > current_payoffs[i]:
                    strictly_better = True
            
            if dominates and strictly_better:
                return False
        
        return True
    
    @staticmethod
    def compute_best_response(payoff_matrix, opponent_strategy):
        """Compute best response to opponent's mixed strategy."""
        expected_payoffs = payoff_matrix @ opponent_strategy
        return np.zeros_like(expected_payoffs).at[np.argmax(expected_payoffs)].set(1.0)

class MultiAgentEnvironment:
    """Base class for multi-agent environments."""
    
    def __init__(self, n_agents, state_dim, action_dim, cooperative=True):
        self.n_agents = n_agents
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.cooperative = cooperative
        self.state = None
        self.step_count = 0
        self.max_steps = 200
        
    def reset(self):
        """Reset environment to initial state."""
        self.state = np.random.randn(self.state_dim)
        self.step_count = 0
        return [self.state.copy() for _ in range(self.n_agents)]
    
    def step(self, actions):
        """Execute joint action and return next states, rewards, dones."""
        self.step_count += 1
        
        # Simple dynamics: state evolves based on joint action
        joint_action = np.mean(actions, axis=0)
        noise = np.random.randn(self.state_dim) * 0.1
        self.state = 0.9 * self.state + 0.1 * joint_action[:self.state_dim] + noise
        
        # Compute rewards
        if self.cooperative:
            # Cooperative: shared reward based on coordination
            coordination_bonus = -np.mean([np.linalg.norm(actions[i] - joint_action) for i in range(self.n_agents)])
            base_reward = -np.linalg.norm(self.state)  # Drive state to origin
            rewards = [base_reward + coordination_bonus] * self.n_agents
        else:
            # Competitive: individual rewards with competition
            rewards = []
            for i in range(self.n_agents):
                individual_reward = -np.linalg.norm(self.state - actions[i][:self.state_dim])
                competition_penalty = sum([np.linalg.norm(actions[i] - actions[j]) 
                                         for j in range(self.n_agents) if j != i]) * 0.1
                rewards.append(individual_reward - competition_penalty)
        
        done = self.step_count >= self.max_steps
        next_states = [self.state.copy() for _ in range(self.n_agents)]
        
        return next_states, rewards, done
    
    def render(self):
        """Visualize current environment state."""
        pass

# Demonstration of game theory concepts
def demonstrate_game_theory():
    """Demonstrate basic game theory concepts."""
    print("🎯 Game Theory Analysis Demo")
    
    # Prisoner's Dilemma
    print("\n1. Prisoner's Dilemma:")
    # Player 1's payoff matrix (rows: cooperate, defect)
    prisoner_a = np.array([[-1, -3], [0, -2]])  # (cooperate, defect) vs (cooperate, defect)
    # Player 2's payoff matrix 
    prisoner_b = np.array([[-1, 0], [-3, -2]])
    
    print("Player 1 payoff matrix:")
    print(prisoner_a)
    print("Player 2 payoff matrix:")
    print(prisoner_b)
    
    nash_eq = GameTheoryUtils.find_nash_equilibria([prisoner_a, prisoner_b])
    print(f"Nash equilibria: {nash_eq}")
    
    for eq in nash_eq:
        is_pareto = GameTheoryUtils.is_pareto_optimal([prisoner_a, prisoner_b], eq)
        print(f"Strategy {eq}: Pareto optimal = {is_pareto}")
    
    # Coordination Game
    print("\n2. Coordination Game:")
    coord_a = np.array([[2, 0], [0, 1]])
    coord_b = np.array([[2, 0], [0, 1]])
    
    print("Coordination game (both players have same payoffs):")
    print(coord_a)
    
    nash_eq = GameTheoryUtils.find_nash_equilibria([coord_a, coord_b])
    print(f"Nash equilibria: {nash_eq}")
    
    return prisoner_a, prisoner_b, coord_a, coord_b

# Test multi-agent environment
def test_multi_agent_env():
    """Test the basic multi-agent environment."""
    print("\n🤖 Multi-Agent Environment Test")
    
    # Cooperative environment
    print("Testing cooperative environment:")
    coop_env = MultiAgentEnvironment(n_agents=3, state_dim=4, action_dim=4, cooperative=True)
    states = coop_env.reset()
    print(f"Initial states shape: {[s.shape for s in states]}")
    
    # Random actions
    actions = [np.random.randn(coop_env.action_dim) for _ in range(coop_env.n_agents)]
    next_states, rewards, done = coop_env.step(actions)
    
    print(f"Rewards (cooperative): {rewards}")
    print(f"All agents get same reward: {len(set(rewards)) == 1}")
    
    # Competitive environment  
    print("\nTesting competitive environment:")
    comp_env = MultiAgentEnvironment(n_agents=3, state_dim=4, action_dim=4, cooperative=False)
    states = comp_env.reset()
    next_states, rewards, done = comp_env.step(actions)
    
    print(f"Rewards (competitive): {rewards}")
    print(f"Agents get different rewards: {len(set(rewards)) > 1}")
    
    return coop_env, comp_env

# Run demonstrations
game_matrices = demonstrate_game_theory()
environments = test_multi_agent_env()

print("\n✅ Game theory and multi-agent foundations implemented successfully!")

# Section 2: Cooperative Multi-Agent Learning

## 2.1 Centralized Training, Decentralized Execution (CTDE)

The CTDE paradigm is fundamental to modern cooperative MARL:

**Training Phase**: 
- Central coordinator has access to global information
- Can compute joint value functions and coordinate policy updates
- Addresses non-stationarity through centralized critic

**Execution Phase**:
- Each agent acts based on local observations only
- No communication required during deployment
- Maintains scalability and robustness

### Multi-Agent Actor-Critic (MAAC)

**Centralized Critic**: Estimates joint action-value function $Q(s, a_1, ..., a_n)$

**Actor Update**: Each agent $i$ updates policy using centralized critic:
$$\nabla_{\theta_i} J_i = \mathbb{E}[\nabla_{\theta_i} \log \pi_{\theta_i}(a_i|o_i) \cdot Q^{\pi}(s, a_1, ..., a_n)]$$

**Critic Update**: Minimize joint TD error:
$$L(\phi) = \mathbb{E}[(Q_{\phi}(s, a_1, ..., a_n) - y)^2]$$
$$y = r + \gamma Q_{\phi'}(s', \pi_{\theta_1'}(o_1'), ..., \pi_{\theta_n'}(o_n'))$$

### Multi-Agent Deep Deterministic Policy Gradient (MADDPG)

Extension of DDPG to multi-agent settings:

1. **Centralized Critics**: Each agent maintains its own critic that uses global information
2. **Experience Replay**: Shared replay buffer with transitions $(s, a_1, ..., a_n, r_1, ..., r_n, s')$
3. **Target Networks**: Slow-updating target networks for stability

**Critic Loss for Agent $i$**:
$$L_i(\phi_i) = \mathbb{E}[(Q_{\phi_i}(s, a_1, ..., a_n) - y_i)^2]$$
$$y_i = r_i + \gamma Q_{\phi_i'}(s', \mu_{\theta_1'}(o_1'), ..., \mu_{\theta_n'}(o_n'))$$

**Actor Loss for Agent $i$**:
$$L_i(\theta_i) = -\mathbb{E}[Q_{\phi_i}(s, a_1|_{a_i=\mu_{\theta_i}(o_i)}, ..., a_n)]$$

### Counterfactual Multi-Agent Policy Gradients (COMA)

Uses counterfactual reasoning for credit assignment:

**Counterfactual Baseline**:
$$A_i(s, a) = Q(s, a) - \sum_{a_i'} \pi_i(a_i'|o_i) Q(s, a_{-i}, a_i')$$

This baseline removes the effect of agent $i$'s action, isolating its contribution to the team reward.

### Value Decomposition Networks (VDN)

Decomposes team value function into individual components:
$$Q_{tot}(s, a) = \sum_{i=1}^n Q_i(o_i, a_i)$$

**Advantages**:
- Individual value functions can be learned independently
- Naturally handles partial observability
- Maintains convergence guarantees under certain conditions

**Limitations**:
- Additivity assumption may be too restrictive
- Cannot represent complex coordination patterns

---

In [None]:
# Multi-Agent Deep Deterministic Policy Gradient (MADDPG) Implementation

class Actor(nn.Module):
    """Actor network for MADDPG."""
    
    def __init__(self, obs_dim, action_dim, hidden_dim=128, max_action=1.0):
        super(Actor, self).__init__()
        self.max_action = max_action
        
        self.net = nn.Sequential(
            nn.Linear(obs_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, action_dim),
            nn.Tanh()
        )
    
    def forward(self, obs):
        return self.max_action * self.net(obs)

class Critic(nn.Module):
    """Centralized critic for MADDPG."""
    
    def __init__(self, total_obs_dim, total_action_dim, hidden_dim=128):
        super(Critic, self).__init__()
        
        self.net = nn.Sequential(
            nn.Linear(total_obs_dim + total_action_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )
    
    def forward(self, obs, actions):
        return self.net(torch.cat([obs, actions], dim=-1))

class MADDPGAgent:
    """Single agent in MADDPG framework."""
    
    def __init__(self, agent_id, obs_dim, action_dim, total_obs_dim, total_action_dim,
                 lr_actor=1e-4, lr_critic=1e-3, gamma=0.99, tau=0.005):
        self.agent_id = agent_id
        self.gamma = gamma
        self.tau = tau
        
        # Networks
        self.actor = Actor(obs_dim, action_dim).to(device)
        self.critic = Critic(total_obs_dim, total_action_dim).to(device)
        self.target_actor = Actor(obs_dim, action_dim).to(device)
        self.target_critic = Critic(total_obs_dim, total_action_dim).to(device)
        
        # Copy parameters to target networks
        self.target_actor.load_state_dict(self.actor.state_dict())
        self.target_critic.load_state_dict(self.critic.state_dict())
        
        # Optimizers
        self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=lr_actor)
        self.critic_optimizer = optim.Adam(self.critic.parameters(), lr=lr_critic)
        
        # Noise for exploration
        self.noise_scale = 0.1
        self.noise_decay = 0.9999
    
    def act(self, obs, add_noise=True):
        """Select action given observation."""
        obs = torch.FloatTensor(obs).to(device)
        action = self.actor(obs).cpu().data.numpy()
        
        if add_noise:
            noise = np.random.normal(0, self.noise_scale, size=action.shape)
            action += noise
            self.noise_scale *= self.noise_decay
        
        return np.clip(action, -1, 1)
    
    def update_critic(self, obs, actions, rewards, next_obs, next_actions, dones):
        """Update critic network."""
        obs = torch.FloatTensor(obs).to(device)
        actions = torch.FloatTensor(actions).to(device)
        rewards = torch.FloatTensor(rewards).to(device)
        next_obs = torch.FloatTensor(next_obs).to(device)
        next_actions = torch.FloatTensor(next_actions).to(device)
        dones = torch.BoolTensor(dones).to(device)
        
        # Current Q-values
        current_q = self.critic(obs, actions).squeeze()
        
        # Target Q-values
        with torch.no_grad():
            target_q = self.target_critic(next_obs, next_actions).squeeze()
            target_q = rewards + self.gamma * target_q * ~dones
        
        # Critic loss
        critic_loss = F.mse_loss(current_q, target_q)
        
        self.critic_optimizer.zero_grad()
        critic_loss.backward()
        torch.nn.utils.clip_grad_norm_(self.critic.parameters(), 1.0)
        self.critic_optimizer.step()
        
        return critic_loss.item()
    
    def update_actor(self, obs, actions):
        """Update actor network."""
        obs = torch.FloatTensor(obs).to(device)
        actions = torch.FloatTensor(actions).to(device)
        
        # Replace this agent's action with current policy
        actions_pred = actions.clone()
        agent_obs = obs[:, self.agent_id]  # This agent's observations
        actions_pred[:, self.agent_id] = self.actor(agent_obs)
        
        # Actor loss: maximize Q-value
        actor_loss = -self.critic(obs.view(obs.size(0), -1), 
                                 actions_pred.view(actions_pred.size(0), -1)).mean()
        
        self.actor_optimizer.zero_grad()
        actor_loss.backward()
        torch.nn.utils.clip_grad_norm_(self.actor.parameters(), 1.0)
        self.actor_optimizer.step()
        
        return actor_loss.item()
    
    def soft_update(self):
        """Soft update of target networks."""
        for target, source in zip(self.target_actor.parameters(), self.actor.parameters()):
            target.data.copy_(self.tau * source.data + (1.0 - self.tau) * target.data)
        
        for target, source in zip(self.target_critic.parameters(), self.critic.parameters()):
            target.data.copy_(self.tau * source.data + (1.0 - self.tau) * target.data)

class MADDPG:
    """Multi-Agent Deep Deterministic Policy Gradient."""
    
    def __init__(self, n_agents, obs_dim, action_dim, buffer_size=100000):
        self.n_agents = n_agents
        self.obs_dim = obs_dim
        self.action_dim = action_dim
        
        total_obs_dim = n_agents * obs_dim
        total_action_dim = n_agents * action_dim
        
        # Create agents
        self.agents = [
            MADDPGAgent(i, obs_dim, action_dim, total_obs_dim, total_action_dim)
            for i in range(n_agents)
        ]
        
        # Replay buffer
        self.replay_buffer = ReplayBuffer(buffer_size)
        
    def act(self, observations, add_noise=True):
        """Get actions from all agents."""
        actions = []
        for i, agent in enumerate(self.agents):
            action = agent.act(observations[i], add_noise)
            actions.append(action)
        return actions
    
    def step(self, states, actions, rewards, next_states, dones):
        """Store experience and update agents."""
        # Store experience
        self.replay_buffer.push(states, actions, rewards, next_states, dones)
        
        # Update agents if enough samples
        if len(self.replay_buffer) > ma_config.batch_size:
            self.update()
    
    def update(self):
        """Update all agents."""
        batch = self.replay_buffer.sample(ma_config.batch_size)
        states, actions, rewards, next_states, dones = batch
        
        # Prepare data for centralized training
        states_flat = np.array(states).reshape(len(states), -1)
        actions_flat = np.array(actions).reshape(len(actions), -1)
        next_states_flat = np.array(next_states).reshape(len(next_states), -1)
        
        # Get next actions from target actors
        next_actions = []
        for i, agent in enumerate(self.agents):
            next_obs = torch.FloatTensor(next_states).to(device)[:, i]
            next_action = agent.target_actor(next_obs)
            next_actions.append(next_action)
        
        next_actions_flat = torch.cat(next_actions, dim=-1).cpu().data.numpy()
        
        # Update each agent
        losses = {'actor': [], 'critic': []}
        for i, agent in enumerate(self.agents):
            agent_rewards = np.array(rewards)[:, i]
            agent_dones = np.array(dones)
            
            # Update critic
            critic_loss = agent.update_critic(
                states_flat, actions_flat, agent_rewards,
                next_states_flat, next_actions_flat, agent_dones
            )
            losses['critic'].append(critic_loss)
            
            # Update actor
            actor_loss = agent.update_actor(states, actions)
            losses['actor'].append(actor_loss)
            
            # Soft update target networks
            agent.soft_update()
        
        return losses

class ReplayBuffer:
    """Replay buffer for multi-agent experiences."""
    
    def __init__(self, capacity):
        self.buffer = deque(maxlen=capacity)
    
    def push(self, states, actions, rewards, next_states, dones):
        """Store a transition."""
        self.buffer.append((states, actions, rewards, next_states, dones))
    
    def sample(self, batch_size):
        """Sample a batch of transitions."""
        batch = random.sample(self.buffer, batch_size)
        states, actions, rewards, next_states, dones = zip(*batch)
        return states, actions, rewards, next_states, dones
    
    def __len__(self):
        return len(self.buffer)

# Value Decomposition Network (VDN) Implementation
class VDNAgent(nn.Module):
    """Individual agent network for VDN."""
    
    def __init__(self, obs_dim, action_dim, hidden_dim=64):
        super(VDNAgent, self).__init__()
        
        self.net = nn.Sequential(
            nn.Linear(obs_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, action_dim)
        )
    
    def forward(self, obs):
        return self.net(obs)

class VDN:
    """Value Decomposition Network for cooperative MARL."""
    
    def __init__(self, n_agents, obs_dim, action_dim, lr=1e-3):
        self.n_agents = n_agents
        self.agents = [VDNAgent(obs_dim, action_dim).to(device) for _ in range(n_agents)]
        self.target_agents = [VDNAgent(obs_dim, action_dim).to(device) for _ in range(n_agents)]
        
        # Copy parameters
        for agent, target in zip(self.agents, self.target_agents):
            target.load_state_dict(agent.state_dict())
        
        self.optimizers = [optim.Adam(agent.parameters(), lr=lr) for agent in self.agents]
        self.replay_buffer = ReplayBuffer(10000)
        
    def act(self, observations, epsilon=0.1):
        """Epsilon-greedy action selection."""
        actions = []
        for i, agent in enumerate(self.agents):
            if np.random.random() < epsilon:
                action = np.random.randint(agent.net[-1].out_features)
            else:
                obs = torch.FloatTensor(observations[i]).to(device)
                q_values = agent(obs)
                action = q_values.argmax().item()
            actions.append(action)
        return actions
    
    def update(self, batch_size=32):
        """Update VDN agents."""
        if len(self.replay_buffer) < batch_size:
            return
        
        batch = self.replay_buffer.sample(batch_size)
        states, actions, rewards, next_states, dones = batch
        
        total_loss = 0
        
        # Convert to tensors
        team_rewards = torch.FloatTensor([sum(r) for r in rewards]).to(device)
        team_dones = torch.BoolTensor([any(d) for d in dones]).to(device)
        
        for i, (agent, target_agent, optimizer) in enumerate(zip(self.agents, self.target_agents, self.optimizers)):
            agent_states = torch.FloatTensor([s[i] for s in states]).to(device)
            agent_actions = torch.LongTensor([a[i] for a in actions]).to(device)
            agent_next_states = torch.FloatTensor([s[i] for s in next_states]).to(device)
            
            # Current Q-values
            q_values = agent(agent_states)
            q_values = q_values.gather(1, agent_actions.unsqueeze(1)).squeeze()
            
            # Target Q-values
            with torch.no_grad():
                next_q_values = target_agent(agent_next_states).max(1)[0]
                target_q = team_rewards + 0.99 * next_q_values * ~team_dones
            
            loss = F.mse_loss(q_values, target_q)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
        
        # Soft update target networks
        tau = 0.005
        for agent, target_agent in zip(self.agents, self.target_agents):
            for param, target_param in zip(agent.parameters(), target_agent.parameters()):
                target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data)
        
        return total_loss / self.n_agents

print("🤖 Cooperative multi-agent algorithms implemented successfully!")
print("✅ MADDPG, VDN, and supporting utilities ready for training!")

# Section 3: Advanced Policy Gradient Methods

## 3.1 Proximal Policy Optimization (PPO)

PPO addresses the challenge of step size in policy gradient methods through clipped objective functions.

### PPO-Clip Objective

**Probability Ratio**:
$$r_t(\theta) = \frac{\pi_\theta(a_t|s_t)}{\pi_{\theta_{old}}(a_t|s_t)}$$

**Clipped Objective**:
$$L^{CLIP}(\theta) = \hat{\mathbb{E}}_t[\min(r_t(\theta)A_t, \text{clip}(r_t(\theta), 1-\epsilon, 1+\epsilon)A_t)]$$

Where $\epsilon$ is the clipping parameter (typically 0.1-0.3) and $A_t$ is the advantage estimate.

### Trust Region Policy Optimization (TRPO)

TRPO constrains policy updates to stay within a trust region:

**Objective**:
$$\max_\theta \hat{\mathbb{E}}_t[\frac{\pi_\theta(a_t|s_t)}{\pi_{\theta_{old}}(a_t|s_t)}A_t]$$

**Subject to**:
$$\hat{\mathbb{E}}_t[KL[\pi_{\theta_{old}}(\cdot|s_t), \pi_\theta(\cdot|s_t)]] \leq \delta$$

**Conjugate Gradient Solution**:
TRPO uses conjugate gradient to solve the constrained optimization problem:
$$g = \nabla_\theta L(\theta_{old})$$
$$H = \nabla_\theta^2 KL[\pi_{\theta_{old}}, \pi_\theta]$$
$$\theta_{new} = \theta_{old} + \sqrt{\frac{2\delta}{g^T H^{-1} g}} H^{-1} g$$

### Soft Actor-Critic (SAC)

SAC maximizes both expected return and entropy for better exploration:

**Objective**:
$$J(\pi) = \sum_{t=0}^T \mathbb{E}_{(s_t, a_t) \sim \rho_\pi}[r(s_t, a_t) + \alpha \mathcal{H}(\pi(\cdot|s_t))]$$

Where $\alpha$ is the temperature parameter controlling exploration-exploitation trade-off.

**Soft Q-Function Updates**:
$$J_Q(\phi) = \mathbb{E}_{(s_t, a_t, r_t, s_{t+1}) \sim \mathcal{D}}[\frac{1}{2}(Q_\phi(s_t, a_t) - y_t)^2]$$
$$y_t = r_t + \gamma \mathbb{E}_{a_{t+1} \sim \pi}[Q_{\phi'}(s_{t+1}, a_{t+1}) - \alpha \log \pi(a_{t+1}|s_{t+1})]$$

**Policy Updates**:
$$J_\pi(\theta) = \mathbb{E}_{s_t \sim \mathcal{D}, a_t \sim \pi_\theta}[\alpha \log \pi_\theta(a_t|s_t) - Q_\phi(s_t, a_t)]$$

### Advanced Advantage Estimation

**Generalized Advantage Estimation (GAE)**:
$$A_t^{GAE(\gamma, \lambda)} = \sum_{l=0}^\infty (\gamma\lambda)^l \delta_{t+l}^V$$

Where $\delta_t^V = r_t + \gamma V(s_{t+1}) - V(s_t)$ is the TD error.

GAE balances bias and variance:
- $\lambda = 0$: Low variance, high bias (TD error)
- $\lambda = 1$: High variance, low bias (Monte Carlo)

### Multi-Agent Policy Gradient Extensions

**Multi-Agent PPO (MAPPO)**:
- Centralized value function: $V(s_1, ..., s_n)$
- Individual actor updates with shared value baseline
- Addresses non-stationarity through centralized training

**Multi-Agent SAC (MASAC)**:
- Individual entropy regularization per agent
- Shared experience replay buffer
- Independent policy and Q-function updates

---

In [None]:
# Advanced Policy Gradient Methods Implementation

class PPONetwork(nn.Module):
    """Combined actor-critic network for PPO."""
    
    def __init__(self, obs_dim, action_dim, hidden_dim=64, discrete=True):
        super(PPONetwork, self).__init__()
        self.discrete = discrete
        
        # Shared layers
        self.shared = nn.Sequential(
            nn.Linear(obs_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU()
        )
        
        # Actor head
        if discrete:
            self.actor = nn.Linear(hidden_dim, action_dim)
        else:
            self.actor_mean = nn.Linear(hidden_dim, action_dim)
            self.actor_logstd = nn.Parameter(torch.zeros(1, action_dim))
        
        # Critic head
        self.critic = nn.Linear(hidden_dim, 1)
    
    def forward(self, obs):
        shared_features = self.shared(obs)
        value = self.critic(shared_features)
        
        if self.discrete:
            action_logits = self.actor(shared_features)
            return action_logits, value
        else:
            action_mean = self.actor_mean(shared_features)
            action_std = torch.exp(self.actor_logstd.expand_as(action_mean))
            return (action_mean, action_std), value
    
    def get_action_and_value(self, obs, action=None):
        if self.discrete:
            logits, value = self.forward(obs)
            probs = Categorical(logits=logits)
            if action is None:
                action = probs.sample()
            return action, probs.log_prob(action), probs.entropy(), value
        else:
            (mean, std), value = self.forward(obs)
            probs = Normal(mean, std)
            if action is None:
                action = probs.sample()
            return action, probs.log_prob(action).sum(-1), probs.entropy().sum(-1), value

class PPOAgent:
    """Proximal Policy Optimization agent."""
    
    def __init__(self, obs_dim, action_dim, lr=3e-4, discrete=True):
        self.network = PPONetwork(obs_dim, action_dim, discrete=discrete).to(device)
        self.optimizer = optim.Adam(self.network.parameters(), lr=lr, eps=1e-5)
        self.discrete = discrete
        
        # PPO hyperparameters
        self.clip_coef = 0.2
        self.ent_coef = 0.01
        self.vf_coef = 0.5
        self.max_grad_norm = 0.5
        self.target_kl = 0.01
        
    def get_action_and_value(self, obs, action=None):
        return self.network.get_action_and_value(obs, action)
    
    def update(self, rollouts, n_epochs=10, minibatch_size=64):
        """Update PPO using clipped objective."""
        obs, actions, logprobs, returns, values, advantages = rollouts
        
        # Normalize advantages
        advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
        
        clipfracs = []
        total_losses = []
        
        for epoch in range(n_epochs):
            # Random minibatches
            indices = torch.randperm(len(obs))
            
            for start in range(0, len(obs), minibatch_size):
                end = start + minibatch_size
                mb_indices = indices[start:end]
                
                mb_obs = obs[mb_indices]
                mb_actions = actions[mb_indices]
                mb_logprobs = logprobs[mb_indices]
                mb_returns = returns[mb_indices]
                mb_values = values[mb_indices]
                mb_advantages = advantages[mb_indices]
                
                # Forward pass
                _, newlogprob, entropy, newvalue = self.get_action_and_value(mb_obs, mb_actions)
                
                # Policy loss
                logratio = newlogprob - mb_logprobs
                ratio = logratio.exp()
                
                with torch.no_grad():
                    # Calculate approximate KL divergence
                    approx_kl = ((ratio - 1) - logratio).mean()
                    clipfracs.append(((ratio - 1.0).abs() > self.clip_coef).float().mean().item())
                
                # Clipped surrogate objective
                pg_loss1 = -mb_advantages * ratio
                pg_loss2 = -mb_advantages * torch.clamp(ratio, 1 - self.clip_coef, 1 + self.clip_coef)
                pg_loss = torch.max(pg_loss1, pg_loss2).mean()
                
                # Value loss
                v_loss = F.mse_loss(newvalue.squeeze(), mb_returns)
                
                # Entropy loss
                entropy_loss = entropy.mean()
                
                # Total loss
                loss = pg_loss - self.ent_coef * entropy_loss + v_loss * self.vf_coef
                
                self.optimizer.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(self.network.parameters(), self.max_grad_norm)
                self.optimizer.step()
                
                total_losses.append(loss.item())
            
            # Early stopping based on KL divergence
            if approx_kl > self.target_kl:
                break
        
        return {
            'total_loss': np.mean(total_losses),
            'policy_loss': pg_loss.item(),
            'value_loss': v_loss.item(),
            'entropy_loss': entropy_loss.item(),
            'approx_kl': approx_kl.item(),
            'clipfrac': np.mean(clipfracs)
        }

class SACAgent:
    """Soft Actor-Critic agent."""
    
    def __init__(self, obs_dim, action_dim, lr=3e-4, alpha=0.2, tau=0.005):
        # Actor network
        self.actor = nn.Sequential(
            nn.Linear(obs_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU()
        ).to(device)
        
        self.actor_mean = nn.Linear(256, action_dim).to(device)
        self.actor_logstd = nn.Linear(256, action_dim).to(device)
        
        # Q networks
        self.q1 = nn.Sequential(
            nn.Linear(obs_dim + action_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, 1)
        ).to(device)
        
        self.q2 = nn.Sequential(
            nn.Linear(obs_dim + action_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, 1)
        ).to(device)
        
        # Target Q networks
        self.target_q1 = copy.deepcopy(self.q1)
        self.target_q2 = copy.deepcopy(self.q2)
        
        # Optimizers
        self.actor_optimizer = optim.Adam(list(self.actor.parameters()) + 
                                        list(self.actor_mean.parameters()) + 
                                        list(self.actor_logstd.parameters()), lr=lr)
        self.q1_optimizer = optim.Adam(self.q1.parameters(), lr=lr)
        self.q2_optimizer = optim.Adam(self.q2.parameters(), lr=lr)
        
        # Hyperparameters
        self.alpha = alpha
        self.tau = tau
        self.gamma = 0.99
        
        # Automatic entropy tuning
        self.target_entropy = -action_dim
        self.log_alpha = torch.zeros(1, requires_grad=True, device=device)
        self.alpha_optimizer = optim.Adam([self.log_alpha], lr=lr)
    
    def get_action(self, obs, deterministic=False):
        """Sample action from policy."""
        obs = torch.FloatTensor(obs).to(device)
        
        # Forward pass through actor
        features = self.actor(obs)
        mean = self.actor_mean(features)
        log_std = self.actor_logstd(features)
        log_std = torch.clamp(log_std, -20, 2)
        std = torch.exp(log_std)
        
        if deterministic:
            action = torch.tanh(mean)
        else:
            # Sample from Normal distribution
            normal = Normal(mean, std)
            x = normal.rsample()  # Reparameterization trick
            action = torch.tanh(x)
            
            # Compute log probability
            log_prob = normal.log_prob(x)
            # Enforcing action bounds
            log_prob -= torch.log(1 - action.pow(2) + 1e-6)
            log_prob = log_prob.sum(1, keepdim=True)
        
        return action.cpu().data.numpy(), log_prob if not deterministic else None
    
    def update(self, batch):
        """Update SAC networks."""
        states, actions, rewards, next_states, dones = batch
        
        states = torch.FloatTensor(states).to(device)
        actions = torch.FloatTensor(actions).to(device)
        rewards = torch.FloatTensor(rewards).to(device)
        next_states = torch.FloatTensor(next_states).to(device)
        dones = torch.BoolTensor(dones).to(device)
        
        with torch.no_grad():
            # Get next actions and log probabilities
            next_actions, next_log_probs = self.get_action(next_states)
            next_actions = torch.FloatTensor(next_actions).to(device)
            
            # Target Q-values
            target_q1 = self.target_q1(torch.cat([next_states, next_actions], dim=1))
            target_q2 = self.target_q2(torch.cat([next_states, next_actions], dim=1))
            target_q = torch.min(target_q1, target_q2) - self.alpha * next_log_probs
            target_q = rewards + self.gamma * (1 - dones.float()) * target_q
        
        # Current Q-values
        current_q1 = self.q1(torch.cat([states, actions], dim=1))
        current_q2 = self.q2(torch.cat([states, actions], dim=1))
        
        # Q-function losses
        q1_loss = F.mse_loss(current_q1, target_q)
        q2_loss = F.mse_loss(current_q2, target_q)
        
        # Update Q-functions
        self.q1_optimizer.zero_grad()
        q1_loss.backward()
        self.q1_optimizer.step()
        
        self.q2_optimizer.zero_grad()
        q2_loss.backward()
        self.q2_optimizer.step()
        
        # Update policy
        new_actions, log_probs = self.get_action(states)
        new_actions = torch.FloatTensor(new_actions).to(device)
        
        q1_new = self.q1(torch.cat([states, new_actions], dim=1))
        q2_new = self.q2(torch.cat([states, new_actions], dim=1))
        q_new = torch.min(q1_new, q2_new)
        
        actor_loss = (self.alpha * log_probs - q_new).mean()
        
        self.actor_optimizer.zero_grad()
        actor_loss.backward()
        self.actor_optimizer.step()
        
        # Update alpha (temperature parameter)
        alpha_loss = (-self.log_alpha * (log_probs + self.target_entropy).detach()).mean()
        
        self.alpha_optimizer.zero_grad()
        alpha_loss.backward()
        self.alpha_optimizer.step()
        
        self.alpha = self.log_alpha.exp().item()
        
        # Soft update target networks
        self.soft_update()
        
        return {
            'q1_loss': q1_loss.item(),
            'q2_loss': q2_loss.item(),
            'actor_loss': actor_loss.item(),
            'alpha_loss': alpha_loss.item(),
            'alpha': self.alpha
        }
    
    def soft_update(self):
        """Soft update target networks."""
        for target_param, param in zip(self.target_q1.parameters(), self.q1.parameters()):
            target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
        
        for target_param, param in zip(self.target_q2.parameters(), self.q2.parameters()):
            target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)

class GAEBuffer:
    """Buffer for collecting trajectories and computing GAE."""
    
    def __init__(self, size, obs_dim, action_dim, gamma=0.99, gae_lambda=0.95):
        self.size = size
        self.gamma = gamma
        self.gae_lambda = gae_lambda
        
        self.obs = np.zeros((size, obs_dim), dtype=np.float32)
        self.actions = np.zeros((size, action_dim), dtype=np.float32)
        self.rewards = np.zeros(size, dtype=np.float32)
        self.values = np.zeros(size, dtype=np.float32)
        self.logprobs = np.zeros(size, dtype=np.float32)
        self.dones = np.zeros(size, dtype=np.float32)
        
        self.ptr = 0
        self.max_size = size
    
    def store(self, obs, action, reward, value, logprob, done):
        """Store a single transition."""
        self.obs[self.ptr] = obs
        self.actions[self.ptr] = action
        self.rewards[self.ptr] = reward
        self.values[self.ptr] = value
        self.logprobs[self.ptr] = logprob
        self.dones[self.ptr] = done
        
        self.ptr = (self.ptr + 1) % self.max_size
    
    def compute_gae(self, last_value=0):
        """Compute GAE advantages and returns."""
        advantages = np.zeros_like(self.rewards)
        returns = np.zeros_like(self.rewards)
        
        last_gae = 0
        for t in reversed(range(self.size)):
            if t == self.size - 1:
                next_nonterminal = 1.0 - self.dones[t]
                next_value = last_value
            else:
                next_nonterminal = 1.0 - self.dones[t+1]
                next_value = self.values[t+1]
            
            delta = self.rewards[t] + self.gamma * next_value * next_nonterminal - self.values[t]
            advantages[t] = last_gae = delta + self.gamma * self.gae_lambda * next_nonterminal * last_gae
        
        returns = advantages + self.values
        return advantages, returns
    
    def get_batch(self):
        """Get all stored data as tensors."""
        return {
            'obs': torch.FloatTensor(self.obs).to(device),
            'actions': torch.FloatTensor(self.actions).to(device),
            'rewards': torch.FloatTensor(self.rewards).to(device),
            'values': torch.FloatTensor(self.values).to(device),
            'logprobs': torch.FloatTensor(self.logprobs).to(device),
            'dones': torch.FloatTensor(self.dones).to(device)
        }

# Demonstration function
def demonstrate_advanced_policies():
    """Demonstrate advanced policy methods on a simple environment."""
    print("🎯 Advanced Policy Methods Demo")
    
    # Create simple continuous control environment
    obs_dim, action_dim = 4, 2
    
    # PPO demonstration
    print("\n1. PPO Agent:")
    ppo_agent = PPOAgent(obs_dim, action_dim, discrete=False)
    obs = torch.randn(1, obs_dim)
    action, logprob, entropy, value = ppo_agent.get_action_and_value(obs)
    print(f"PPO Action shape: {action.shape}, Value: {value.item():.3f}")
    
    # SAC demonstration
    print("\n2. SAC Agent:")
    sac_agent = SACAgent(obs_dim, action_dim)
    action, log_prob = sac_agent.get_action(obs.numpy()[0])
    print(f"SAC Action: {action}, Log Prob: {log_prob.item():.3f}")
    
    print("\n✅ Advanced policy methods demonstrated successfully!")

# Run demonstration
demonstrate_advanced_policies()

print("🚀 Advanced policy gradient methods implemented successfully!")
print("✅ PPO, SAC, and GAE utilities ready for multi-agent training!")

# Section 4: Distributed Reinforcement Learning

## 4.1 Asynchronous Methods

Distributed RL enables parallel learning across multiple environments and workers, significantly improving sample efficiency and wall-clock training time.

### Asynchronous Advantage Actor-Critic (A3C)

A3C runs multiple actor-learners in parallel, each interacting with a separate environment instance:

**Global Network Update**:
$$\theta_{global} \leftarrow \theta_{global} + \alpha \sum_{i=1}^{n_{workers}} \nabla \theta_i$$

**Local Gradient Accumulation**:
Each worker $i$ accumulates gradients over $t_{max}$ steps:
$$\nabla \theta_i = \sum_{t=1}^{t_{max}} \nabla \log \pi_{\theta_i}(a_t|s_t) A_t + \beta \nabla H(\pi_{\theta_i}(s_t))$$

Where $A_t$ is computed using n-step returns or GAE.

### IMPALA (Importance Weighted Actor-Learner Architecture)

IMPALA addresses the off-policy nature of distributed learning through importance sampling:

**V-trace Target**:
$$v_s = V(s_t) + \sum_{i=0}^{n-1} \gamma^i \prod_{j=0}^{i} c_{t+j} [r_{t+i} + \gamma V(s_{t+i+1}) - V(s_{t+i})]$$

**Importance Weights**:
$$\rho_t = \min(\bar{\rho}, \frac{\pi(a_t|s_t)}{\mu(a_t|s_t)})$$
$$c_t = \min(\bar{c}, \frac{\pi(a_t|s_t)}{\mu(a_t|s_t)})$$

Where $\mu$ is the behavior policy and $\pi$ is the target policy.

### Distributed PPO (D-PPO)

Scales PPO to distributed settings while maintaining policy gradient guarantees:

1. **Rollout Collection**: Workers collect experience in parallel
2. **Gradient Aggregation**: Central server aggregates gradients
3. **Synchronized Updates**: Global policy update after each epoch

**Gradient Synchronization**:
$$g_{global} = \frac{1}{N} \sum_{i=1}^{N} g_i$$

Where $g_i$ is the gradient from worker $i$.

## 4.2 Evolutionary Strategies (ES) in RL

ES provides gradient-free optimization for RL policies:

**Population-Based Update**:
$$\theta_{t+1} = \theta_t + \alpha \frac{1}{\sigma \lambda} \sum_{i=1}^{\lambda} R_i \epsilon_i$$

Where:
- $\epsilon_i \sim \mathcal{N}(0, I)$ are random perturbations
- $R_i$ is the return achieved by perturbed policy $\theta_t + \sigma \epsilon_i$
- $\lambda$ is the population size

### Advantages of ES:
1. **Parallelizable**: Each worker evaluates different policy perturbation
2. **Gradient-free**: Works with non-differentiable rewards
3. **Robust**: Less sensitive to hyperparameters
4. **Communication efficient**: Only needs to share scalars (returns)

## 4.3 Multi-Agent Distributed Learning

### Centralized Training Distributed Execution (CTDE) at Scale

**Hierarchical Coordination**:
- **Global Coordinator**: Manages high-level strategy
- **Local Coordinators**: Handle subgroup coordination
- **Individual Agents**: Execute local policies

**Communication Patterns**:
1. **Broadcast**: Central coordinator broadcasts information to all agents
2. **Reduce**: Agents send information to central coordinator
3. **All-reduce**: All agents receive aggregated information from all others
4. **Ring**: Information flows in a circular pattern

### Parameter Server Architecture

**Parameter Server**: Maintains global model parameters
**Workers**: Pull parameters, compute gradients, push updates

**Asynchronous Updates**:
$$\theta_{t+1} = \theta_t - \alpha \sum_{i \in \text{available}} \nabla_i$$

**Advantages**:
- Fault tolerance through redundancy
- Scalable to thousands of workers
- Flexible resource allocation

---

In [None]:
# Distributed Reinforcement Learning Implementation

import multiprocessing as mp
from multiprocessing import Process, Queue, Value, Array
import queue
import threading
from threading import Lock
import time

class ParameterServer:
    """Parameter server for distributed RL."""
    
    def __init__(self, model_state_dict):
        self.params = {k: v.clone().share_memory_() for k, v in model_state_dict.items()}
        self.lock = Lock()
        self.version = Value('i', 0)
        self.update_count = Value('i', 0)
    
    def get_parameters(self):
        """Get current parameters."""
        with self.lock:
            return {k: v.clone() for k, v in self.params.items()}, self.version.value
    
    def update_parameters(self, gradients, lr=1e-4):
        """Update parameters with gradients."""
        with self.lock:
            for key, grad in gradients.items():
                if key in self.params:
                    self.params[key] -= lr * grad
            
            self.version.value += 1
            self.update_count.value += 1
    
    def get_stats(self):
        """Get server statistics."""
        return {
            'version': self.version.value,
            'updates': self.update_count.value
        }

class A3CWorker:
    """A3C worker for distributed training."""
    
    def __init__(self, worker_id, global_model, local_model, env_fn, gamma=0.99, n_steps=5):
        self.worker_id = worker_id
        self.global_model = global_model
        self.local_model = local_model
        self.env = env_fn()
        self.gamma = gamma
        self.n_steps = n_steps
        self.optimizer = optim.Adam(global_model.parameters(), lr=1e-4)
        
    def compute_n_step_returns(self, rewards, values, next_value, dones):
        """Compute n-step returns."""
        returns = []
        R = next_value
        
        for i in reversed(range(len(rewards))):
            R = rewards[i] + self.gamma * R * (1 - dones[i])
            returns.insert(0, R)
        
        return returns
    
    def train_step(self):
        """Single training step for A3C worker."""
        # Sync local model with global model
        self.local_model.load_state_dict(self.global_model.state_dict())
        
        states, actions, rewards, values, log_probs, dones = [], [], [], [], [], []
        
        state = self.env.reset()
        for _ in range(self.n_steps):
            state_tensor = torch.FloatTensor(state).unsqueeze(0)
            
            with torch.no_grad():
                logits, value = self.local_model(state_tensor)
                probs = F.softmax(logits, dim=-1)
                dist = Categorical(probs)
                action = dist.sample()
                log_prob = dist.log_prob(action)
            
            next_state, reward, done, _ = self.env.step(action.item())
            
            states.append(state)
            actions.append(action)
            rewards.append(reward)
            values.append(value.item())
            log_probs.append(log_prob)
            dones.append(done)
            
            state = next_state if not done else self.env.reset()
            
            if done:
                break
        
        # Compute returns
        with torch.no_grad():
            if done:
                next_value = 0
            else:
                state_tensor = torch.FloatTensor(state).unsqueeze(0)
                _, next_value = self.local_model(state_tensor)
                next_value = next_value.item()
        
        returns = self.compute_n_step_returns(rewards, values, next_value, dones)
        
        # Convert to tensors
        states = torch.FloatTensor(states)
        actions = torch.LongTensor(actions)
        returns = torch.FloatTensor(returns)
        values = torch.FloatTensor(values)
        log_probs = torch.stack(log_probs)
        
        # Compute losses
        advantages = returns - values
        
        # Actor loss
        actor_loss = -(log_probs * advantages.detach()).mean()
        
        # Critic loss
        critic_loss = F.mse_loss(values, returns)
        
        # Entropy bonus
        logits, _ = self.local_model(states)
        probs = F.softmax(logits, dim=-1)
        entropy = -(probs * torch.log(probs + 1e-8)).sum(-1).mean()
        
        total_loss = actor_loss + 0.5 * critic_loss - 0.01 * entropy
        
        # Compute gradients
        self.optimizer.zero_grad()
        total_loss.backward()
        
        # Clip gradients
        torch.nn.utils.clip_grad_norm_(self.local_model.parameters(), 40)
        
        # Update global model
        for global_param, local_param in zip(self.global_model.parameters(), 
                                           self.local_model.parameters()):
            if global_param.grad is not None:
                global_param.grad = local_param.grad
            else:
                global_param.grad = local_param.grad.clone()
        
        self.optimizer.step()
        
        return {
            'total_loss': total_loss.item(),
            'actor_loss': actor_loss.item(),
            'critic_loss': critic_loss.item(),
            'entropy': entropy.item()
        }

class IMPALALearner:
    """IMPALA learner with V-trace correction."""
    
    def __init__(self, model, lr=1e-4, rho_bar=1.0, c_bar=1.0):
        self.model = model
        self.optimizer = optim.Adam(model.parameters(), lr=lr)
        self.rho_bar = rho_bar  # Importance sampling clipping for policy gradient
        self.c_bar = c_bar      # Importance sampling clipping for value function
        
    def vtrace(self, rewards, values, behavior_log_probs, target_log_probs, bootstrap_value, gamma=0.99):
        """Compute V-trace targets."""
        # Importance sampling ratios
        rhos = torch.exp(target_log_probs - behavior_log_probs)
        clipped_rhos = torch.clamp(rhos, max=self.rho_bar)
        clipped_cs = torch.clamp(rhos, max=self.c_bar)
        
        # V-trace computation
        values_t_plus_1 = torch.cat([values[1:], bootstrap_value.unsqueeze(0)])
        deltas = clipped_rhos * (rewards + gamma * values_t_plus_1 - values)
        
        # Compute V-trace targets
        vs = []
        v_s = values[-1] + deltas[-1]
        vs.append(v_s)
        
        for i in reversed(range(len(deltas) - 1)):
            v_s = values[i] + deltas[i] + gamma * clipped_cs[i] * (v_s - values_t_plus_1[i])
            vs.append(v_s)
        
        vs.reverse()
        return torch.stack(vs)
    
    def update(self, batch):
        """Update IMPALA learner."""
        states, actions, rewards, behavior_log_probs, bootstrap_value = batch
        
        # Forward pass
        logits, values = self.model(states)
        
        # Current policy log probabilities
        target_log_probs = F.log_softmax(logits, dim=-1).gather(1, actions.unsqueeze(-1)).squeeze(-1)
        
        # V-trace targets
        vtrace_targets = self.vtrace(rewards, values.squeeze(), behavior_log_probs, 
                                   target_log_probs, bootstrap_value)
        
        # Advantages for policy gradient
        advantages = vtrace_targets - values.squeeze()
        
        # Losses
        policy_loss = -(target_log_probs * advantages.detach()).mean()
        value_loss = F.mse_loss(values.squeeze(), vtrace_targets.detach())
        
        # Entropy regularization
        entropy = -(F.softmax(logits, dim=-1) * F.log_softmax(logits, dim=-1)).sum(-1).mean()
        
        total_loss = policy_loss + 0.5 * value_loss - 0.01 * entropy
        
        # Update
        self.optimizer.zero_grad()
        total_loss.backward()
        torch.nn.utils.clip_grad_norm_(self.model.parameters(), 40)
        self.optimizer.step()
        
        return {
            'policy_loss': policy_loss.item(),
            'value_loss': value_loss.item(),
            'entropy': entropy.item(),
            'total_loss': total_loss.item()
        }

class DistributedPPOCoordinator:
    """Coordinator for distributed PPO training."""
    
    def __init__(self, n_workers, obs_dim, action_dim, lr=3e-4):
        self.n_workers = n_workers
        self.obs_dim = obs_dim
        self.action_dim = action_dim
        
        # Global model
        self.global_model = PPONetwork(obs_dim, action_dim, discrete=True)
        self.optimizer = optim.Adam(self.global_model.parameters(), lr=lr)
        
        # Communication queues
        self.task_queues = [Queue() for _ in range(n_workers)]
        self.result_queue = Queue()
        
        # Training statistics
        self.episode_rewards = []
        self.losses = []
    
    def collect_rollouts(self, n_steps=128):
        """Coordinate rollout collection across workers."""
        # Send collection tasks to workers
        for i in range(self.n_workers):
            self.task_queues[i].put(('collect', n_steps))
        
        # Collect results
        all_rollouts = []
        for _ in range(self.n_workers):
            rollouts = self.result_queue.get()
            all_rollouts.append(rollouts)
        
        return all_rollouts
    
    def aggregate_rollouts(self, rollouts_list):
        """Aggregate rollouts from all workers."""
        aggregated = {
            'obs': [],
            'actions': [],
            'rewards': [],
            'values': [],
            'log_probs': [],
            'advantages': [],
            'returns': []
        }
        
        for rollouts in rollouts_list:
            for key in aggregated:
                aggregated[key].extend(rollouts[key])
        
        # Convert to tensors
        for key in aggregated:
            aggregated[key] = torch.FloatTensor(aggregated[key])
        
        return aggregated
    
    def update_global_model(self, rollouts):
        """Update global model using aggregated rollouts."""
        ppo_agent = PPOAgent(self.obs_dim, self.action_dim)
        ppo_agent.network = self.global_model
        ppo_agent.optimizer = self.optimizer
        
        # Prepare rollouts for PPO update
        obs = rollouts['obs']
        actions = rollouts['actions']
        log_probs = rollouts['log_probs']
        returns = rollouts['returns']
        values = rollouts['values']
        advantages = rollouts['advantages']
        
        ppo_rollouts = (obs, actions, log_probs, returns, values, advantages)
        losses = ppo_agent.update(ppo_rollouts)
        
        return losses
    
    def broadcast_parameters(self):
        """Send updated parameters to all workers."""
        state_dict = self.global_model.state_dict()
        for i in range(self.n_workers):
            self.task_queues[i].put(('update_params', state_dict))

class EvolutionaryStrategy:
    """Simple evolutionary strategy for RL."""
    
    def __init__(self, model, population_size=50, sigma=0.1, lr=0.01):
        self.model = model
        self.population_size = population_size
        self.sigma = sigma
        self.lr = lr
        
        # Get parameter shapes
        self.param_shapes = []
        self.param_sizes = []
        for param in model.parameters():
            self.param_shapes.append(param.shape)
            self.param_sizes.append(param.numel())
        
        self.total_params = sum(self.param_sizes)
    
    def generate_population(self):
        """Generate population of parameter perturbations."""
        return [np.random.randn(self.total_params) for _ in range(self.population_size)]
    
    def set_parameters(self, flat_params):
        """Set model parameters from flattened array."""
        idx = 0
        with torch.no_grad():
            for param, size, shape in zip(self.model.parameters(), self.param_sizes, self.param_shapes):
                param_values = flat_params[idx:idx+size].reshape(shape)
                param.copy_(torch.FloatTensor(param_values))
                idx += size
    
    def get_parameters(self):
        """Get flattened model parameters."""
        params = []
        for param in self.model.parameters():
            params.append(param.detach().cpu().numpy().flatten())
        return np.concatenate(params)
    
    def update(self, rewards, perturbations):
        """Update parameters using ES."""
        # Normalize rewards
        rewards = np.array(rewards)
        rewards = (rewards - np.mean(rewards)) / (np.std(rewards) + 1e-8)
        
        # Compute parameter update
        current_params = self.get_parameters()
        param_update = np.zeros_like(current_params)
        
        for reward, perturbation in zip(rewards, perturbations):
            param_update += reward * perturbation
        
        param_update = self.lr * param_update / (self.population_size * self.sigma)
        
        # Update parameters
        new_params = current_params + param_update
        self.set_parameters(new_params)
        
        return param_update

# Demonstration functions
def demonstrate_parameter_server():
    """Demonstrate parameter server functionality."""
    print("🖥️  Parameter Server Demo")
    
    # Create dummy model
    model = nn.Sequential(nn.Linear(4, 32), nn.ReLU(), nn.Linear(32, 2))
    
    # Initialize parameter server
    param_server = ParameterServer(model.state_dict())
    
    print(f"Initial version: {param_server.get_stats()['version']}")
    
    # Simulate gradient update
    dummy_gradients = {name: torch.randn_like(param) for name, param in model.named_parameters()}
    param_server.update_parameters(dummy_gradients)
    
    print(f"After update: {param_server.get_stats()}")
    
    return param_server

def demonstrate_evolutionary_strategy():
    """Demonstrate evolutionary strategy."""
    print("\n🧬 Evolutionary Strategy Demo")
    
    # Create simple model
    model = nn.Sequential(nn.Linear(4, 8), nn.ReLU(), nn.Linear(8, 2))
    es = EvolutionaryStrategy(model, population_size=10, sigma=0.1)
    
    # Generate population
    population = es.generate_population()
    print(f"Generated population of size: {len(population)}")
    print(f"Parameter dimensionality: {es.total_params}")
    
    # Simulate fitness evaluation
    rewards = np.random.randn(len(population))
    es.update(rewards, population)
    
    print("✅ ES update completed")
    
    return es

# Run demonstrations
print("🌐 Distributed Reinforcement Learning Systems")
param_server_demo = demonstrate_parameter_server()
es_demo = demonstrate_evolutionary_strategy()

print("\n🚀 Distributed RL implementations ready!")
print("✅ Parameter server, A3C, IMPALA, and ES components implemented!")

# Section 5: Communication and Coordination in Multi-Agent Systems

## 5.1 Communication Protocols

Multi-agent systems often require sophisticated communication mechanisms to achieve coordination and share information effectively. This section explores various communication paradigms and their implementation in reinforcement learning contexts.

### Communication Types:
1. **Direct Communication**: Explicit message passing between agents
2. **Emergent Communication**: Learned communication protocols through RL
3. **Indirect Communication**: Environment-mediated information sharing
4. **Broadcast vs. Targeted**: Communication scope and recipients

### Mathematical Framework:
For agent $i$ sending message $m_i^t$ at time $t$:
$$m_i^t = \text{CommPolicy}_i(s_i^t, h_i^t)$$

Where $h_i^t$ is the communication history and the message influences other agents:
$$\pi_j(a_j^t | s_j^t, \{m_k^t\}_{k \neq j})$$

### Key Challenges:
- **Communication Overhead**: Balancing information sharing with computational cost
- **Partial Observability**: Deciding what information to communicate
- **Communication Noise**: Handling unreliable communication channels
- **Scalability**: Maintaining efficiency as the number of agents increases

## 5.2 Coordination Mechanisms

### Centralized Coordination:
- Global coordinator makes joint decisions
- Optimal but not scalable
- Single point of failure

### Decentralized Coordination:
- Agents coordinate through local interactions
- Scalable and robust
- May lead to suboptimal solutions

### Hierarchical Coordination:
- Multi-level coordination structure
- Combines benefits of centralized and decentralized approaches
- Natural for many real-world scenarios

### Market-Based Coordination:
- Agents bid for tasks or resources
- Economically motivated coordination
- Natural load balancing

In [None]:
# Communication and Coordination Implementation

class CommunicationChannel:
    """Communication channel for multi-agent systems."""
    
    def __init__(self, n_agents, message_dim=16, noise_std=0.1):
        self.n_agents = n_agents
        self.message_dim = message_dim
        self.noise_std = noise_std
        self.message_history = []
        
    def send_message(self, sender_id, message, recipients=None):
        """Send message from one agent to others."""
        if recipients is None:
            recipients = list(range(self.n_agents))
            recipients.remove(sender_id)
        
        # Add noise to simulate real-world communication
        noisy_message = message + torch.randn_like(message) * self.noise_std
        
        comm_event = {
            'sender': sender_id,
            'recipients': recipients,
            'message': noisy_message,
            'timestamp': len(self.message_history)
        }
        
        self.message_history.append(comm_event)
        return comm_event
    
    def get_messages_for_agent(self, agent_id, last_n=5):
        """Get recent messages for a specific agent."""
        relevant_messages = []
        for event in self.message_history[-last_n:]:
            if agent_id in event['recipients']:
                relevant_messages.append({
                    'sender': event['sender'],
                    'message': event['message'],
                    'timestamp': event['timestamp']
                })
        return relevant_messages
    
    def clear_history(self):
        """Clear communication history."""
        self.message_history = []

class AttentionCommunication(nn.Module):
    """Attention-based communication mechanism."""
    
    def __init__(self, obs_dim, message_dim=16, n_heads=4):
        super().__init__()
        self.obs_dim = obs_dim
        self.message_dim = message_dim
        self.n_heads = n_heads
        
        # Message encoding
        self.message_encoder = nn.Sequential(
            nn.Linear(obs_dim, message_dim),
            nn.ReLU(),
            nn.Linear(message_dim, message_dim)
        )
        
        # Attention mechanism
        self.attention = nn.MultiheadAttention(message_dim, n_heads, batch_first=True)
        
        # Message processing
        self.message_processor = nn.Sequential(
            nn.Linear(message_dim, message_dim),
            nn.ReLU(),
            nn.Linear(message_dim, message_dim)
        )
    
    def forward(self, observations, messages=None):
        """
        Args:
            observations: [batch_size, n_agents, obs_dim]
            messages: [batch_size, n_agents, message_dim] or None
        """
        batch_size, n_agents, _ = observations.shape
        
        # Encode observations into messages
        encoded_messages = self.message_encoder(observations)  # [batch, n_agents, message_dim]
        
        if messages is not None:
            # Combine with previous messages
            combined_messages = encoded_messages + messages
        else:
            combined_messages = encoded_messages
        
        # Apply attention across agents
        attended_messages, attention_weights = self.attention(
            combined_messages, combined_messages, combined_messages
        )
        
        # Process attended messages
        processed_messages = self.message_processor(attended_messages)
        
        return processed_messages, attention_weights

class CoordinationMechanism:
    """Base class for coordination mechanisms."""
    
    def __init__(self, n_agents):
        self.n_agents = n_agents
        self.coordination_history = []
    
    def coordinate(self, agent_states, task_requirements):
        """Coordinate agents based on states and task requirements."""
        raise NotImplementedError
    
    def evaluate_coordination(self, joint_actions, outcomes):
        """Evaluate the quality of coordination."""
        raise NotImplementedError

class MarketBasedCoordination(CoordinationMechanism):
    """Market-based coordination using auction mechanisms."""
    
    def __init__(self, n_agents, n_tasks=5):
        super().__init__(n_agents)
        self.n_tasks = n_tasks
        self.task_values = torch.rand(n_tasks) * 10  # Task values
        
    def conduct_auction(self, agent_bids):
        """
        Conduct first-price sealed-bid auction.
        
        Args:
            agent_bids: [n_agents, n_tasks] - bid matrix
        
        Returns:
            task_assignments: [n_tasks] - winning agent for each task
            winning_bids: [n_tasks] - winning bid amounts
        """
        winning_agents = torch.argmax(agent_bids, dim=0)
        winning_bids = torch.max(agent_bids, dim=0).values
        
        return winning_agents, winning_bids
    
    def coordinate(self, agent_capabilities, task_requirements):
        """Coordinate using market mechanism."""
        # Generate bids based on capabilities and task requirements
        agent_bids = torch.zeros(self.n_agents, self.n_tasks)
        
        for i in range(self.n_agents):
            for j in range(self.n_tasks):
                # Simple bidding strategy: capability match * task value - cost
                capability_match = torch.dot(agent_capabilities[i], task_requirements[j])
                cost = torch.norm(agent_capabilities[i] - task_requirements[j])
                agent_bids[i, j] = capability_match * self.task_values[j] - cost
        
        # Conduct auction
        assignments, winning_bids = self.conduct_auction(agent_bids)
        
        coordination_result = {
            'assignments': assignments,
            'bids': agent_bids,
            'winning_bids': winning_bids,
            'total_value': torch.sum(winning_bids)
        }
        
        self.coordination_history.append(coordination_result)
        return coordination_result

class HierarchicalCoordination(CoordinationMechanism):
    """Hierarchical coordination with multiple levels."""
    
    def __init__(self, n_agents, hierarchy_levels=2):
        super().__init__(n_agents)
        self.hierarchy_levels = hierarchy_levels
        self.create_hierarchy()
    
    def create_hierarchy(self):
        """Create hierarchical structure."""
        self.hierarchy = {}
        agents_per_level = [self.n_agents]
        
        for level in range(self.hierarchy_levels):
            agents_at_level = max(1, agents_per_level[-1] // 2)
            agents_per_level.append(agents_at_level)
            
            self.hierarchy[level] = {
                'coordinators': list(range(agents_at_level)),
                'subordinates': list(range(agents_per_level[level]))
            }
    
    def coordinate_level(self, level, agent_states):
        """Coordinate agents at specific hierarchy level."""
        if level >= self.hierarchy_levels:
            return agent_states
        
        coordinators = self.hierarchy[level]['coordinators']
        subordinates = self.hierarchy[level]['subordinates']
        
        # High-level coordination decisions
        coordination_decisions = []
        for coordinator_id in coordinators:
            # Simple coordination: average subordinate states
            subordinate_indices = subordinates[coordinator_id::len(coordinators)]
            if subordinate_indices:
                avg_state = torch.mean(agent_states[subordinate_indices], dim=0)
                coordination_decisions.append(avg_state)
            else:
                coordination_decisions.append(torch.zeros_like(agent_states[0]))
        
        return torch.stack(coordination_decisions)
    
    def coordinate(self, agent_states, global_objective):
        """Hierarchical coordination process."""
        current_states = agent_states
        coordination_trace = []
        
        for level in range(self.hierarchy_levels):
            level_decisions = self.coordinate_level(level, current_states)
            coordination_trace.append(level_decisions)
            current_states = level_decisions
        
        # Final global decision
        global_decision = torch.mean(current_states, dim=0)
        
        return {
            'global_decision': global_decision,
            'level_decisions': coordination_trace,
            'hierarchy': self.hierarchy
        }

class EmergentCommunicationAgent(nn.Module):
    """Agent that learns to communicate through RL."""
    
    def __init__(self, obs_dim, action_dim, message_dim=8, vocab_size=16):
        super().__init__()
        self.obs_dim = obs_dim
        self.action_dim = action_dim
        self.message_dim = message_dim
        self.vocab_size = vocab_size
        
        # Observation encoding
        self.obs_encoder = nn.Sequential(
            nn.Linear(obs_dim, 64),
            nn.ReLU(),
            nn.Linear(64, 32)
        )
        
        # Message generation
        self.message_generator = nn.Sequential(
            nn.Linear(32, message_dim),
            nn.ReLU(),
            nn.Linear(message_dim, vocab_size)
        )
        
        # Message interpretation
        self.message_interpreter = nn.Sequential(
            nn.Linear(vocab_size, message_dim),
            nn.ReLU(),
            nn.Linear(message_dim, 16)
        )
        
        # Action policy (considering messages)
        self.action_policy = nn.Sequential(
            nn.Linear(32 + 16, 64),  # obs_encoding + message_interpretation
            nn.ReLU(),
            nn.Linear(64, action_dim)
        )
        
        # Value function
        self.value_function = nn.Sequential(
            nn.Linear(32 + 16, 64),
            nn.ReLU(),
            nn.Linear(64, 1)
        )
    
    def generate_message(self, obs):
        """Generate message based on observation."""
        obs_encoding = self.obs_encoder(obs)
        message_logits = self.message_generator(obs_encoding)
        
        # Sample message from categorical distribution
        message_dist = Categorical(logits=message_logits)
        message = message_dist.sample()
        message_log_prob = message_dist.log_prob(message)
        
        return message, message_log_prob
    
    def interpret_messages(self, messages):
        """Interpret received messages."""
        # Convert discrete messages to one-hot
        one_hot_messages = F.one_hot(messages, self.vocab_size).float()
        
        # Average messages from multiple agents
        if len(one_hot_messages.shape) > 1:
            avg_message = torch.mean(one_hot_messages, dim=0)
        else:
            avg_message = one_hot_messages
        
        return self.message_interpreter(avg_message)
    
    def forward(self, obs, received_messages=None):
        """Forward pass considering observations and messages."""
        obs_encoding = self.obs_encoder(obs)
        
        if received_messages is not None:
            message_info = self.interpret_messages(received_messages)
            combined_input = torch.cat([obs_encoding, message_info], dim=-1)
        else:
            message_info = torch.zeros(16)
            combined_input = torch.cat([obs_encoding, message_info], dim=-1)
        
        # Generate action probabilities
        action_logits = self.action_policy(combined_input)
        action_probs = F.softmax(action_logits, dim=-1)
        
        # Generate value estimate
        value = self.value_function(combined_input)
        
        return action_probs, value

# Demonstration functions
def demonstrate_communication():
    """Demonstrate communication mechanisms."""
    print("📡 Communication Mechanisms Demo")
    
    # Initialize communication channel
    comm_channel = CommunicationChannel(n_agents=4, message_dim=8)
    
    # Simulate message exchange
    message = torch.randn(8)
    comm_event = comm_channel.send_message(sender_id=0, message=message, recipients=[1, 2, 3])
    
    print(f"Message sent from agent 0 to agents {comm_event['recipients']}")
    print(f"Message shape: {comm_event['message'].shape}")
    
    # Get messages for specific agent
    messages = comm_channel.get_messages_for_agent(agent_id=1)
    print(f"Agent 1 received {len(messages)} messages")
    
    return comm_channel

def demonstrate_coordination():
    """Demonstrate coordination mechanisms."""
    print("\n🤝 Coordination Mechanisms Demo")
    
    # Market-based coordination
    market_coord = MarketBasedCoordination(n_agents=4, n_tasks=3)
    
    # Generate random agent capabilities and task requirements
    agent_capabilities = torch.randn(4, 5)
    task_requirements = torch.randn(3, 5)
    
    coordination_result = market_coord.coordinate(agent_capabilities, task_requirements)
    
    print("Market-based coordination result:")
    print(f"Task assignments: {coordination_result['assignments']}")
    print(f"Total value: {coordination_result['total_value']:.2f}")
    
    # Hierarchical coordination
    hierarchical_coord = HierarchicalCoordination(n_agents=8, hierarchy_levels=2)
    agent_states = torch.randn(8, 6)
    
    hierarchy_result = hierarchical_coord.coordinate(agent_states, global_objective=None)
    print(f"\nHierarchical coordination levels: {len(hierarchy_result['level_decisions'])}")
    print(f"Global decision shape: {hierarchy_result['global_decision'].shape}")
    
    return market_coord, hierarchical_coord

def demonstrate_emergent_communication():
    """Demonstrate emergent communication."""
    print("\n🗣️  Emergent Communication Demo")
    
    # Create emergent communication agent
    agent = EmergentCommunicationAgent(obs_dim=10, action_dim=4, message_dim=8, vocab_size=16)
    
    # Generate observation
    obs = torch.randn(10)
    
    # Generate message
    message, message_log_prob = agent.generate_message(obs)
    print(f"Generated message: {message.item()}, log prob: {message_log_prob.item():.3f}")
    
    # Forward pass with message
    action_probs, value = agent(obs, received_messages=torch.tensor([message]))
    print(f"Action probabilities shape: {action_probs.shape}")
    print(f"Value estimate: {value.item():.3f}")
    
    return agent

# Run demonstrations
print("🌐 Communication and Coordination Systems")
comm_demo = demonstrate_communication()
coord_demo = demonstrate_coordination()
emergent_demo = demonstrate_emergent_communication()

print("\n🚀 Communication and coordination implementations ready!")
print("✅ Multi-agent communication, coordination, and emergent protocols implemented!")

# Section 6: Meta-Learning and Adaptation in Multi-Agent Systems

## 6.1 Meta-Learning Foundations

Meta-learning, or "learning to learn," is particularly important in multi-agent systems where agents must quickly adapt to:
- New opponent strategies
- Changing team compositions  
- Novel task distributions
- Dynamic environment conditions

### Mathematical Framework:
Given a distribution of tasks $\mathcal{T}$, meta-learning aims to find parameters $\theta$ such that:
$$\theta^* = \arg\min_\theta \mathbb{E}_{\tau \sim \mathcal{T}} \left[ \mathcal{L}_\tau(\theta - \alpha \nabla_\theta \mathcal{L}_\tau(\theta)) \right]$$

Where $\alpha$ is the inner learning rate and $\mathcal{L}_\tau$ is the loss on task $\tau$.

## 6.2 Model-Agnostic Meta-Learning (MAML) for Multi-Agent Systems

MAML can be extended to multi-agent settings where agents must quickly adapt their policies to new scenarios:

### Multi-Agent MAML Objective:
$$\min_{\theta_1, ..., \theta_n} \sum_{i=1}^n \mathbb{E}_{\tau \sim \mathcal{T}} \left[ \mathcal{L}_{\tau,i}(\phi_{i,\tau}) \right]$$

Where $\phi_{i,\tau} = \theta_i - \alpha_i \nabla_{\theta_i} \mathcal{L}_{\tau,i}(\theta_i)$

## 6.3 Few-Shot Learning in Multi-Agent Contexts

### Key Challenges:
1. **Opponent Modeling**: Quickly learning opponent behavior patterns
2. **Team Formation**: Adapting to new team compositions
3. **Strategy Transfer**: Applying learned strategies to new scenarios
4. **Communication Adaptation**: Adjusting communication protocols

### Applications:
- **Multi-Agent Navigation**: Adapting to new environments with different agents
- **Competitive Games**: Quickly learning counter-strategies
- **Cooperative Tasks**: Forming effective teams with unknown agents

## 6.4 Continual Learning in Dynamic Multi-Agent Environments

### Catastrophic Forgetting Problem:
In multi-agent systems, agents may forget how to handle previously encountered opponents or scenarios when learning new ones.

### Solutions:
1. **Elastic Weight Consolidation (EWC)**: Protect important parameters
2. **Progressive Networks**: Expand capacity for new tasks
3. **Memory-Augmented Networks**: Store and replay important experiences
4. **Meta-Learning**: Learn how to quickly adapt without forgetting

## 6.5 Self-Play and Population-Based Training

### Self-Play Evolution:
Agents improve by playing against previous versions of themselves or a diverse population of strategies.

### Population Diversity:
$$\text{Diversity} = \mathbb{E}_{\pi_i, \pi_j \sim P} [D(\pi_i, \pi_j)]$$

Where $P$ is the population and $D$ measures strategic distance between policies.

### Benefits:
- Robust strategy development
- Automatic curriculum generation
- Exploration of diverse play styles
- Prevention of exploitation vulnerabilities

In [None]:
# Meta-Learning and Adaptation Implementation

import copy
from collections import defaultdict

class MAMLAgent(nn.Module):
    """Multi-Agent Model-Agnostic Meta-Learning Agent."""
    
    def __init__(self, obs_dim, action_dim, hidden_dim=128, meta_lr=1e-3, inner_lr=1e-2):
        super().__init__()
        self.obs_dim = obs_dim
        self.action_dim = action_dim
        self.meta_lr = meta_lr
        self.inner_lr = inner_lr
        
        # Policy network
        self.policy_net = nn.Sequential(
            nn.Linear(obs_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, action_dim)
        )
        
        # Value network
        self.value_net = nn.Sequential(
            nn.Linear(obs_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim), 
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )
        
        # Meta-optimizer
        self.meta_optimizer = optim.Adam(self.parameters(), lr=meta_lr)
        
    def forward(self, obs):
        """Forward pass."""
        policy_logits = self.policy_net(obs)
        value = self.value_net(obs)
        return F.softmax(policy_logits, dim=-1), value
    
    def inner_update(self, support_batch, num_steps=5):
        """Perform inner loop adaptation."""
        # Create temporary model copy for adaptation
        adapted_model = copy.deepcopy(self)
        inner_optimizer = optim.SGD(adapted_model.parameters(), lr=self.inner_lr)
        
        for _ in range(num_steps):
            obs, actions, rewards, next_obs, dones = support_batch
            
            # Compute policy and value predictions
            action_probs, values = adapted_model(obs)
            next_values = adapted_model(next_obs)[1]
            
            # Compute targets
            targets = rewards + 0.99 * next_values * (1 - dones)
            
            # Compute losses
            policy_loss = -torch.log(action_probs.gather(1, actions.unsqueeze(1))).squeeze() * (targets - values).detach()
            value_loss = F.mse_loss(values.squeeze(), targets.detach())
            
            total_loss = policy_loss.mean() + 0.5 * value_loss
            
            # Inner update
            inner_optimizer.zero_grad()
            total_loss.backward()
            inner_optimizer.step()
        
        return adapted_model
    
    def meta_update(self, tasks_batch):
        """Perform meta-update using multiple tasks."""
        meta_losses = []
        
        for task_data in tasks_batch:
            support_batch, query_batch = task_data
            
            # Inner adaptation
            adapted_model = self.inner_update(support_batch)
            
            # Evaluate on query set
            obs, actions, rewards, next_obs, dones = query_batch
            action_probs, values = adapted_model(obs)
            next_values = adapted_model(next_obs)[1]
            
            targets = rewards + 0.99 * next_values * (1 - dones)
            
            # Meta-loss
            policy_loss = -torch.log(action_probs.gather(1, actions.unsqueeze(1))).squeeze() * (targets - values).detach()
            value_loss = F.mse_loss(values.squeeze(), targets.detach())
            meta_loss = policy_loss.mean() + 0.5 * value_loss
            
            meta_losses.append(meta_loss)
        
        # Meta-gradient update
        total_meta_loss = torch.stack(meta_losses).mean()
        
        self.meta_optimizer.zero_grad()
        total_meta_loss.backward()
        self.meta_optimizer.step()
        
        return total_meta_loss.item()

class OpponentModel(nn.Module):
    """Model for predicting opponent behavior."""
    
    def __init__(self, obs_dim, action_dim, opponent_action_dim, hidden_dim=64):
        super().__init__()
        self.obs_dim = obs_dim
        self.action_dim = action_dim
        self.opponent_action_dim = opponent_action_dim
        
        # Opponent policy predictor
        self.opponent_predictor = nn.Sequential(
            nn.Linear(obs_dim + action_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, opponent_action_dim)
        )
        
        # Confidence estimator
        self.confidence_net = nn.Sequential(
            nn.Linear(obs_dim + action_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1),
            nn.Sigmoid()
        )
        
        self.optimizer = optim.Adam(self.parameters(), lr=1e-3)
        self.history = []
    
    def predict_opponent_action(self, obs, my_action):
        """Predict opponent action given observation and my action."""
        input_tensor = torch.cat([obs, F.one_hot(my_action, self.action_dim).float()], dim=-1)
        opponent_logits = self.opponent_predictor(input_tensor)
        confidence = self.confidence_net(input_tensor)
        
        return F.softmax(opponent_logits, dim=-1), confidence
    
    def update_model(self, obs, my_action, opponent_action):
        """Update opponent model with observed behavior."""
        input_tensor = torch.cat([obs, F.one_hot(my_action, self.action_dim).float()], dim=-1)
        predicted_logits = self.opponent_predictor(input_tensor)
        
        # Cross-entropy loss for opponent action prediction
        loss = F.cross_entropy(predicted_logits, opponent_action)
        
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        
        # Store in history
        self.history.append({
            'obs': obs.detach(),
            'my_action': my_action,
            'opponent_action': opponent_action,
            'loss': loss.item()
        })
        
        return loss.item()
    
    def get_adaptation_speed(self):
        """Compute how quickly the model is adapting."""
        if len(self.history) < 10:
            return 0.0
        
        recent_losses = [h['loss'] for h in self.history[-10:]]
        early_losses = [h['loss'] for h in self.history[-20:-10]] if len(self.history) >= 20 else recent_losses
        
        return max(0, np.mean(early_losses) - np.mean(recent_losses))

class PopulationBasedTraining:
    """Population-based training for multi-agent systems."""
    
    def __init__(self, agent_class, population_size=8, mutation_rate=0.1):
        self.agent_class = agent_class
        self.population_size = population_size
        self.mutation_rate = mutation_rate
        self.generation = 0
        
        # Initialize population
        self.population = []
        self.fitness_scores = []
        self.diversity_scores = []
        
        for i in range(population_size):
            agent = agent_class()
            self.population.append(agent)
            self.fitness_scores.append(0.0)
            self.diversity_scores.append(0.0)
    
    def evaluate_fitness(self, agent_idx, opponents, n_games=10):
        """Evaluate agent fitness against opponents."""
        agent = self.population[agent_idx]
        total_reward = 0
        
        for _ in range(n_games):
            # Simple evaluation: random game outcome influenced by agent capability
            # In practice, this would be actual game playing
            game_reward = torch.randn(1).item() + agent_idx * 0.1  # Placeholder
            total_reward += game_reward
        
        avg_fitness = total_reward / n_games
        self.fitness_scores[agent_idx] = avg_fitness
        
        return avg_fitness
    
    def compute_diversity(self, agent_idx):
        """Compute diversity of agent compared to population."""
        agent = self.population[agent_idx]
        diversity_sum = 0
        
        for other_idx, other_agent in enumerate(self.population):
            if other_idx != agent_idx:
                # Simple diversity metric: parameter distance
                param_distance = 0
                for p1, p2 in zip(agent.parameters(), other_agent.parameters()):
                    param_distance += torch.norm(p1 - p2).item()
                diversity_sum += param_distance
        
        avg_diversity = diversity_sum / (self.population_size - 1)
        self.diversity_scores[agent_idx] = avg_diversity
        
        return avg_diversity
    
    def select_parents(self, selection_pressure=0.7):
        """Select parents for next generation."""
        # Combine fitness and diversity scores
        combined_scores = []
        for i in range(self.population_size):
            score = selection_pressure * self.fitness_scores[i] + (1 - selection_pressure) * self.diversity_scores[i]
            combined_scores.append(score)
        
        # Tournament selection
        parents = []
        for _ in range(self.population_size // 2):
            tournament_size = 3
            tournament_indices = np.random.choice(self.population_size, tournament_size, replace=False)
            winner = tournament_indices[np.argmax([combined_scores[i] for i in tournament_indices])]
            parents.append(winner)
        
        return parents
    
    def mutate_agent(self, agent):
        """Mutate agent parameters."""
        mutated_agent = copy.deepcopy(agent)
        
        for param in mutated_agent.parameters():
            if torch.rand(1).item() < self.mutation_rate:
                noise = torch.randn_like(param) * 0.1
                param.data += noise
        
        return mutated_agent
    
    def evolve_generation(self):
        """Evolve population for one generation."""
        # Evaluate all agents
        for i in range(self.population_size):
            self.evaluate_fitness(i, opponents=list(range(self.population_size)))
            self.compute_diversity(i)
        
        # Select parents
        parent_indices = self.select_parents()
        
        # Create next generation
        new_population = []
        
        # Keep top performers
        top_performers = sorted(range(self.population_size), 
                              key=lambda x: self.fitness_scores[x], reverse=True)[:2]
        
        for idx in top_performers:
            new_population.append(copy.deepcopy(self.population[idx]))
        
        # Generate offspring
        while len(new_population) < self.population_size:
            parent_idx = np.random.choice(parent_indices)
            parent = self.population[parent_idx]
            offspring = self.mutate_agent(parent)
            new_population.append(offspring)
        
        # Update population
        self.population = new_population
        self.generation += 1
        
        return {
            'generation': self.generation,
            'avg_fitness': np.mean(self.fitness_scores),
            'max_fitness': np.max(self.fitness_scores),
            'avg_diversity': np.mean(self.diversity_scores)
        }

class SelfPlayTraining:
    """Self-play training system."""
    
    def __init__(self, agent, env, save_frequency=10):
        self.agent = agent
        self.env = env
        self.save_frequency = save_frequency
        
        # Historical opponents (checkpoints)
        self.historical_opponents = []
        self.training_iteration = 0
        
    def add_checkpoint(self):
        """Add current agent as historical opponent."""
        checkpoint = copy.deepcopy(self.agent)
        self.historical_opponents.append({
            'agent': checkpoint,
            'iteration': self.training_iteration,
            'performance': 0.0
        })
        
        # Limit number of historical opponents
        if len(self.historical_opponents) > 20:
            self.historical_opponents.pop(0)
    
    def select_opponent(self, strategy='diverse'):
        """Select opponent for training."""
        if not self.historical_opponents:
            return copy.deepcopy(self.agent)  # Self-play
        
        if strategy == 'diverse':
            # Select diverse set of opponents
            return np.random.choice(self.historical_opponents)['agent']
        
        elif strategy == 'recent':
            # Focus on recent opponents
            recent_opponents = self.historical_opponents[-5:]
            return np.random.choice(recent_opponents)['agent']
        
        elif strategy == 'strongest':
            # Play against strongest opponents
            strongest = max(self.historical_opponents, key=lambda x: x['performance'])
            return strongest['agent']
        
        else:
            return np.random.choice(self.historical_opponents)['agent']
    
    def train_step(self, opponent_strategy='diverse'):
        """Single self-play training step."""
        opponent = self.select_opponent(opponent_strategy)
        
        # Play game against opponent (simplified)
        state = self.env.reset()
        total_reward = 0
        
        for step in range(100):  # Max episode length
            # Agent action
            with torch.no_grad():
                action_probs, _ = self.agent(torch.FloatTensor(state))
                action = Categorical(action_probs).sample().item()
            
            # Opponent action (simplified)
            with torch.no_grad():
                opp_action_probs, _ = opponent(torch.FloatTensor(state))
                opp_action = Categorical(opp_action_probs).sample().item()
            
            # Environment step (placeholder)
            next_state, reward, done, _ = self.env.step([action, opp_action])
            total_reward += reward
            
            if done:
                break
            
            state = next_state
        
        self.training_iteration += 1
        
        # Periodically save checkpoint
        if self.training_iteration % self.save_frequency == 0:
            self.add_checkpoint()
        
        return total_reward

# Demonstration functions
def demonstrate_maml():
    """Demonstrate MAML for multi-agent learning."""
    print("🧠 Meta-Learning (MAML) Demo")
    
    # Create MAML agent
    maml_agent = MAMLAgent(obs_dim=8, action_dim=4, hidden_dim=64)
    
    # Create dummy task batch
    tasks_batch = []
    for _ in range(3):  # 3 tasks
        # Support set
        support_obs = torch.randn(10, 8)
        support_actions = torch.randint(0, 4, (10,))
        support_rewards = torch.randn(10)
        support_next_obs = torch.randn(10, 8)
        support_dones = torch.zeros(10)
        
        support_batch = (support_obs, support_actions, support_rewards, support_next_obs, support_dones)
        
        # Query set
        query_obs = torch.randn(5, 8)
        query_actions = torch.randint(0, 4, (5,))
        query_rewards = torch.randn(5)
        query_next_obs = torch.randn(5, 8)
        query_dones = torch.zeros(5)
        
        query_batch = (query_obs, query_actions, query_rewards, query_next_obs, query_dones)
        
        tasks_batch.append((support_batch, query_batch))
    
    # Perform meta-update
    meta_loss = maml_agent.meta_update(tasks_batch)
    print(f"Meta-loss: {meta_loss:.4f}")
    
    return maml_agent

def demonstrate_opponent_modeling():
    """Demonstrate opponent modeling."""
    print("\n🎯 Opponent Modeling Demo")
    
    opponent_model = OpponentModel(obs_dim=8, action_dim=4, opponent_action_dim=4)
    
    # Simulate opponent interactions
    for _ in range(20):
        obs = torch.randn(8)
        my_action = torch.randint(0, 4, (1,)).item()
        opponent_action = torch.randint(0, 4, (1,))
        
        loss = opponent_model.update_model(obs, my_action, opponent_action)
    
    adaptation_speed = opponent_model.get_adaptation_speed()
    print(f"Adaptation speed: {adaptation_speed:.4f}")
    
    # Test prediction
    test_obs = torch.randn(8)
    test_action = 0
    pred_action_probs, confidence = opponent_model.predict_opponent_action(test_obs, test_action)
    
    print(f"Predicted opponent action probabilities: {pred_action_probs}")
    print(f"Prediction confidence: {confidence.item():.3f}")
    
    return opponent_model

def demonstrate_population_training():
    """Demonstrate population-based training."""
    print("\n🧬 Population-Based Training Demo")
    
    # Define simple agent class for demo
    class SimpleAgent(nn.Module):
        def __init__(self):
            super().__init__()
            self.policy = nn.Sequential(nn.Linear(4, 8), nn.ReLU(), nn.Linear(8, 2))
    
    pbt = PopulationBasedTraining(SimpleAgent, population_size=6)
    
    # Evolve for a few generations
    for generation in range(3):
        stats = pbt.evolve_generation()
        print(f"Generation {stats['generation']}: "
              f"Avg Fitness: {stats['avg_fitness']:.3f}, "
              f"Max Fitness: {stats['max_fitness']:.3f}")
    
    return pbt

# Run demonstrations
print("🎓 Meta-Learning and Adaptation Systems")
maml_demo = demonstrate_maml()
opponent_demo = demonstrate_opponent_modeling()
population_demo = demonstrate_population_training()

print("\n🚀 Meta-learning and adaptation implementations ready!")
print("✅ MAML, opponent modeling, and population-based training implemented!")

# Section 7: Comprehensive Applications and Case Studies

## 7.1 Multi-Agent Resource Allocation

Resource allocation is a fundamental problem in multi-agent systems where agents must efficiently distribute limited resources while considering individual objectives and system-wide constraints.

### Problem Formulation:
- **Agents**: $\mathcal{A} = \{1, 2, ..., n\}$
- **Resources**: $\mathcal{R} = \{r_1, r_2, ..., r_m\}$ with quantities $\{q_1, q_2, ..., q_m\}$
- **Allocations**: $x_{i,j}$ = amount of resource $j$ allocated to agent $i$
- **Constraints**: $\sum_{i=1}^n x_{i,j} \leq q_j$ for all $j$

### Objective Functions:
1. **Utilitarian**: $\max \sum_{i=1}^n U_i(x_i)$
2. **Egalitarian**: $\max \min_i U_i(x_i)$
3. **Nash Social Welfare**: $\max \prod_{i=1}^n U_i(x_i)$

## 7.2 Autonomous Vehicle Coordination

Multi-agent reinforcement learning applications in autonomous vehicle systems present unique challenges in safety, efficiency, and scalability.

### Key Components:
- **Vehicle Agents**: Each vehicle as an independent learning agent
- **Communication**: V2V (Vehicle-to-Vehicle) and V2I (Vehicle-to-Infrastructure)
- **Objectives**: Safety, traffic flow optimization, fuel efficiency
- **Constraints**: Traffic rules, physical limitations, safety margins

### Coordination Challenges:
1. **Intersection Management**: Distributed traffic light control
2. **Highway Merging**: Cooperative lane changing and merging
3. **Platooning**: Formation and maintenance of vehicle platoons
4. **Emergency Response**: Coordinated response to accidents or hazards

## 7.3 Smart Grid Management

The smart grid represents a complex multi-agent system where various entities must coordinate for efficient energy distribution and consumption.

### Agent Types:
- **Producers**: Power plants, renewable energy sources
- **Consumers**: Residential, commercial, industrial users
- **Storage**: Battery systems, pumped hydro storage
- **Grid Operators**: Transmission and distribution system operators

### Challenges:
- **Demand Response**: Dynamic pricing and consumption adjustment
- **Load Balancing**: Real-time supply-demand matching
- **Renewable Integration**: Managing intermittent energy sources
- **Market Mechanisms**: Automated bidding and trading

## 7.4 Robotics Swarm Coordination

Swarm robotics involves coordinating large numbers of simple robots to achieve complex collective behaviors.

### Applications:
- **Search and Rescue**: Coordinated search patterns
- **Environmental Monitoring**: Distributed sensor networks
- **Construction**: Collaborative building and assembly
- **Military/Defense**: Autonomous drone swarms

### Technical Challenges:
- **Scalability**: Algorithms that work with hundreds or thousands of agents
- **Fault Tolerance**: Graceful degradation when agents fail
- **Communication Limits**: Bandwidth and range constraints
- **Real-time Coordination**: Fast decision making in dynamic environments

## 7.5 Financial Trading Systems

Multi-agent systems in financial markets involve multiple trading agents with different strategies and objectives.

### Agent Categories:
- **Market Makers**: Provide liquidity
- **Arbitrageurs**: Exploit price differences
- **Trend Followers**: Follow market momentum
- **Mean Reversion**: Bet on price corrections

### Market Dynamics:
- **Price Discovery**: Collective determination of asset values
- **Liquidity Provision**: Ensuring tradeable markets
- **Risk Management**: Controlling exposure and volatility
- **Regulatory Compliance**: Following trading rules and regulations

## 7.6 Game-Theoretic Analysis Framework

### Nash Equilibrium in Multi-Agent RL:
For policies $\pi = (\pi_1, ..., \pi_n)$, a Nash equilibrium satisfies:
$$J_i(\pi_i^*, \pi_{-i}^*) \geq J_i(\pi_i, \pi_{-i}^*) \quad \forall \pi_i, \forall i$$

### Stackelberg Games:
Leader-follower dynamics where one agent commits to a strategy first:
$$\max_{\pi_L} J_L(\pi_L, \pi_F^*(\pi_L))$$
$$\text{s.t. } \pi_F^*(\pi_L) = \arg\max_{\pi_F} J_F(\pi_L, \pi_F)$$

### Cooperative Game Theory:
- **Shapley Value**: Fair allocation of cooperative gains
- **Core**: Stable coalition structures
- **Nucleolus**: Solution concept for transferable utility games

In [None]:
# Comprehensive Applications and Case Studies Implementation

class ResourceAllocationEnvironment:
    """Multi-agent resource allocation environment."""
    
    def __init__(self, n_agents=4, n_resources=3, resource_capacities=None):
        self.n_agents = n_agents
        self.n_resources = n_resources
        
        if resource_capacities is None:
            self.resource_capacities = torch.ones(n_resources) * 10.0
        else:
            self.resource_capacities = torch.tensor(resource_capacities)
        
        # Agent utility functions (random for demo)
        self.agent_utilities = []
        for _ in range(n_agents):
            utility_weights = torch.rand(n_resources) * 2  # Random utility weights
            self.agent_utilities.append(utility_weights)
        
        self.reset()
    
    def reset(self):
        """Reset environment."""
        self.current_allocations = torch.zeros(self.n_agents, self.n_resources)
        self.remaining_resources = self.resource_capacities.clone()
        self.time_step = 0
        
        return self.get_state()
    
    def get_state(self):
        """Get current state for all agents."""
        states = []
        for i in range(self.n_agents):
            # State includes current allocation and remaining resources
            agent_state = torch.cat([
                self.current_allocations[i],  # Own allocation
                self.remaining_resources,     # Remaining resources
                self.current_allocations.sum(0)  # Total allocated
            ])
            states.append(agent_state)
        
        return torch.stack(states)
    
    def step(self, actions):
        """
        Execute actions for all agents.
        Actions: [n_agents, n_resources] - requested allocation amounts
        """
        actions = torch.tensor(actions).float()
        
        # Ensure actions are non-negative and within limits
        actions = torch.clamp(actions, 0, 1)  # Normalized requests
        
        # Scale actions based on remaining resources
        scaled_actions = actions * self.remaining_resources.unsqueeze(0)
        
        # Resolve conflicts using proportional allocation
        total_requests = scaled_actions.sum(0)
        allocation_ratios = torch.ones_like(total_requests)
        
        # Apply capacity constraints
        over_capacity = total_requests > self.remaining_resources
        allocation_ratios[over_capacity] = (self.remaining_resources[over_capacity] / 
                                          total_requests[over_capacity])
        
        # Compute actual allocations
        actual_allocations = scaled_actions * allocation_ratios.unsqueeze(0)
        
        # Update state
        self.current_allocations += actual_allocations
        self.remaining_resources -= actual_allocations.sum(0)
        
        # Compute rewards (utility gained)
        rewards = []
        for i in range(self.n_agents):
            utility = torch.dot(actual_allocations[i], self.agent_utilities[i])
            rewards.append(utility.item())
        
        self.time_step += 1
        done = self.time_step >= 20 or torch.all(self.remaining_resources <= 0.1)
        
        return self.get_state(), rewards, done, {}
    
    def compute_social_welfare(self):
        """Compute total social welfare."""
        total_welfare = 0
        for i in range(self.n_agents):
            agent_welfare = torch.dot(self.current_allocations[i], self.agent_utilities[i])
            total_welfare += agent_welfare.item()
        return total_welfare

class AutonomousVehicleEnvironment:
    """Simplified autonomous vehicle coordination environment."""
    
    def __init__(self, n_vehicles=4, road_length=100):
        self.n_vehicles = n_vehicles
        self.road_length = road_length
        
        self.reset()
    
    def reset(self):
        """Reset environment."""
        # Vehicle positions (random start)
        self.positions = torch.rand(self.n_vehicles) * self.road_length * 0.3
        
        # Vehicle velocities (start slow)
        self.velocities = torch.ones(self.n_vehicles) * 5.0
        
        # Target velocities (desired speed)
        self.target_velocities = torch.rand(self.n_vehicles) * 10 + 10  # 10-20 m/s
        
        self.time_step = 0
        
        return self.get_state()
    
    def get_state(self):
        """Get state for all vehicles."""
        states = []
        for i in range(self.n_vehicles):
            # Find nearest neighbors
            distances = torch.abs(self.positions - self.positions[i])
            distances[i] = float('inf')  # Exclude self
            
            # Get nearest vehicle info
            nearest_idx = torch.argmin(distances)
            relative_pos = self.positions[nearest_idx] - self.positions[i]
            relative_vel = self.velocities[nearest_idx] - self.velocities[i]
            
            vehicle_state = torch.tensor([
                self.positions[i] / self.road_length,  # Normalized position
                self.velocities[i] / 20.0,             # Normalized velocity
                self.target_velocities[i] / 20.0,      # Normalized target velocity
                relative_pos / self.road_length,       # Relative position to nearest
                relative_vel / 20.0,                   # Relative velocity to nearest
                distances.min() / 20.0                 # Distance to nearest vehicle
            ])
            
            states.append(vehicle_state)
        
        return torch.stack(states)
    
    def step(self, actions):
        """
        Execute actions (acceleration commands).
        Actions: [n_vehicles] - acceleration values (-1 to 1)
        """
        actions = torch.tensor(actions).float()
        actions = torch.clamp(actions, -1, 1)
        
        dt = 0.1  # Time step
        max_accel = 3.0  # m/s^2
        
        # Update velocities
        accelerations = actions * max_accel
        self.velocities += accelerations * dt
        self.velocities = torch.clamp(self.velocities, 0, 25)  # Speed limits
        
        # Update positions
        self.positions += self.velocities * dt
        
        # Compute rewards
        rewards = []
        for i in range(self.n_vehicles):
            # Reward components
            speed_reward = -torch.abs(self.velocities[i] - self.target_velocities[i]) * 0.1
            
            # Safety reward (maintain distance)
            distances = torch.abs(self.positions - self.positions[i])
            distances[i] = float('inf')
            min_distance = distances.min()
            safety_reward = -10.0 if min_distance < 2.0 else 0.0
            
            # Efficiency reward (progress)
            progress_reward = self.velocities[i] * 0.05
            
            total_reward = speed_reward + safety_reward + progress_reward
            rewards.append(total_reward.item())
        
        self.time_step += 1
        done = self.time_step >= 100 or torch.any(self.positions >= self.road_length)
        
        return self.get_state(), rewards, done, {}

class SmartGridEnvironment:
    """Smart grid multi-agent environment."""
    
    def __init__(self, n_producers=2, n_consumers=3, n_storage=1):
        self.n_producers = n_producers
        self.n_consumers = n_consumers
        self.n_storage = n_storage
        self.n_agents = n_producers + n_consumers + n_storage
        
        # Production capacities and costs
        self.production_capacities = torch.rand(n_producers) * 50 + 20  # 20-70 MW
        self.production_costs = torch.rand(n_producers) * 0.1 + 0.05   # $0.05-0.15/MWh
        
        # Consumer demands
        self.base_demands = torch.rand(n_consumers) * 30 + 10  # 10-40 MW
        
        # Storage capacities
        self.storage_capacities = torch.ones(n_storage) * 100  # 100 MWh
        
        self.reset()
    
    def reset(self):
        """Reset environment."""
        self.current_storage = self.storage_capacities * 0.5  # Start half-full
        self.time_step = 0
        
        # Random demand fluctuation
        self.current_demands = self.base_demands * (0.8 + 0.4 * torch.rand(self.n_consumers))
        
        # Random renewable production (solar/wind variability)
        self.renewable_factor = torch.rand(1).item() * 0.5 + 0.5  # 0.5-1.0
        
        return self.get_state()
    
    def get_state(self):
        """Get state for all agents."""
        states = []
        
        # Producer states
        for i in range(self.n_producers):
            producer_state = torch.tensor([
                self.production_capacities[i] / 100,  # Normalized capacity
                self.production_costs[i] * 10,        # Scaled cost
                self.renewable_factor,                # Renewable availability
                self.current_demands.sum() / 100,     # Total demand
                self.time_step / 24.0                 # Time of day (normalized)
            ])
            states.append(producer_state)
        
        # Consumer states  
        for i in range(self.n_consumers):
            consumer_state = torch.tensor([
                self.current_demands[i] / 50,         # Normalized demand
                self.base_demands[i] / 50,            # Base demand
                torch.sin(self.time_step * 2 * np.pi / 24),  # Time of day cycle
                (self.current_demands.sum() - self.current_demands[i]) / 100,  # Other demand
                self.renewable_factor                 # Renewable availability
            ])
            states.append(consumer_state)
        
        # Storage states
        for i in range(self.n_storage):
            storage_state = torch.tensor([
                self.current_storage[i] / self.storage_capacities[i],  # Charge level
                self.storage_capacities[i] / 100,     # Capacity
                self.current_demands.sum() / 100,     # Total demand
                self.renewable_factor,                # Renewable availability
                self.time_step / 24.0                 # Time of day
            ])
            states.append(storage_state)
        
        return torch.stack(states)
    
    def step(self, actions):
        """
        Execute actions for all agents.
        Actions: [n_agents] - normalized action values
        """
        actions = torch.tensor(actions).float()
        actions = torch.clamp(actions, -1, 1)
        
        # Parse actions
        producer_actions = actions[:self.n_producers]  # Production levels
        consumer_actions = actions[self.n_producers:self.n_producers + self.n_consumers]  # Demand response
        storage_actions = actions[self.n_producers + self.n_consumers:]  # Charge/discharge
        
        # Compute actual production
        production = producer_actions * self.production_capacities * self.renewable_factor
        production = torch.clamp(production, 0, self.production_capacities)
        
        # Compute adjusted demand (demand response)
        adjusted_demands = self.current_demands * (1 + consumer_actions * 0.3)
        adjusted_demands = torch.clamp(adjusted_demands, self.current_demands * 0.7, 
                                     self.current_demands * 1.3)
        
        # Storage actions (positive = discharge, negative = charge)
        storage_power = storage_actions * 20  # Max 20 MW charge/discharge rate
        
        # Update storage levels
        self.current_storage -= storage_power * 0.1  # 0.1 hour time step
        self.current_storage = torch.clamp(self.current_storage, 0, self.storage_capacities)
        
        # Balance supply and demand
        total_supply = production.sum() + storage_power.sum()
        total_demand = adjusted_demands.sum()
        imbalance = total_supply - total_demand
        
        # Compute rewards
        rewards = []
        
        # Producer rewards (profit - penalty for imbalance)
        for i in range(self.n_producers):
            revenue = production[i] * 0.1  # $0.1/MWh base price
            cost = production[i] * self.production_costs[i]
            imbalance_penalty = abs(imbalance) * 0.01  # Penalty for grid imbalance
            producer_reward = revenue - cost - imbalance_penalty
            rewards.append(producer_reward.item())
        
        # Consumer rewards (savings from demand response - inconvenience)
        for i in range(self.n_consumers):
            base_cost = self.current_demands[i] * 0.1
            actual_cost = adjusted_demands[i] * 0.1
            inconvenience = abs(consumer_actions[i]) * 2.0  # Cost of changing demand
            consumer_reward = base_cost - actual_cost - inconvenience
            rewards.append(consumer_reward.item())
        
        # Storage rewards (arbitrage opportunities - degradation)
        for i in range(self.n_storage):
            arbitrage_reward = storage_power[i] * 0.02  # Profit from price differences
            degradation_cost = abs(storage_power[i]) * 0.001  # Battery wear
            storage_reward = arbitrage_reward - degradation_cost
            rewards.append(storage_reward.item())
        
        # Update time and demand
        self.time_step += 1
        if self.time_step % 6 == 0:  # Update demand every 6 hours
            self.current_demands = self.base_demands * (0.8 + 0.4 * torch.rand(self.n_consumers))
        
        done = self.time_step >= 24  # One day
        
        info = {
            'total_supply': total_supply.item(),
            'total_demand': total_demand.item(),
            'imbalance': imbalance.item(),
            'renewable_factor': self.renewable_factor
        }
        
        return self.get_state(), rewards, done, info

class MultiAgentGameTheoryAnalyzer:
    """Analyzer for game-theoretic properties of multi-agent systems."""
    
    def __init__(self, n_agents, n_actions):
        self.n_agents = n_agents
        self.n_actions = n_actions
        
    def compute_payoff_matrix(self, agents, env, n_episodes=100):
        """Compute payoff matrix for all agent strategy combinations."""
        payoffs = np.zeros([self.n_actions] * self.n_agents + [self.n_agents])
        
        for action_profile in itertools.product(range(self.n_actions), repeat=self.n_agents):
            total_rewards = np.zeros(self.n_agents)
            
            for episode in range(n_episodes):
                state = env.reset()
                episode_rewards = np.zeros(self.n_agents)
                
                for step in range(100):  # Max episode length
                    actions = list(action_profile)
                    next_state, rewards, done, _ = env.step(actions)
                    
                    episode_rewards += np.array(rewards)
                    
                    if done:
                        break
                    
                    state = next_state
                
                total_rewards += episode_rewards
            
            avg_rewards = total_rewards / n_episodes
            payoffs[action_profile] = avg_rewards
        
        return payoffs
    
    def find_nash_equilibria(self, payoff_matrix):
        """Find pure strategy Nash equilibria."""
        nash_equilibria = []
        
        for action_profile in itertools.product(range(self.n_actions), repeat=self.n_agents):
            is_nash = True
            
            for agent in range(self.n_agents):
                current_payoff = payoff_matrix[action_profile][agent]
                
                # Check if agent can improve by changing strategy
                for alt_action in range(self.n_actions):
                    if alt_action == action_profile[agent]:
                        continue
                    
                    alt_profile = list(action_profile)
                    alt_profile[agent] = alt_action
                    alt_payoff = payoff_matrix[tuple(alt_profile)][agent]
                    
                    if alt_payoff > current_payoff:
                        is_nash = False
                        break
                
                if not is_nash:
                    break
            
            if is_nash:
                nash_equilibria.append(action_profile)
        
        return nash_equilibria
    
    def compute_social_welfare(self, payoff_matrix, action_profile):
        """Compute social welfare for given action profile."""
        return np.sum(payoff_matrix[action_profile])
    
    def find_social_optimum(self, payoff_matrix):
        """Find action profile that maximizes social welfare."""
        best_welfare = float('-inf')
        best_profile = None
        
        for action_profile in itertools.product(range(self.n_actions), repeat=self.n_agents):
            welfare = self.compute_social_welfare(payoff_matrix, action_profile)
            if welfare > best_welfare:
                best_welfare = welfare
                best_profile = action_profile
        
        return best_profile, best_welfare

# Demonstration functions
def demonstrate_resource_allocation():
    """Demonstrate resource allocation environment."""
    print("🏭 Resource Allocation Demo")
    
    env = ResourceAllocationEnvironment(n_agents=3, n_resources=2, 
                                      resource_capacities=[20.0, 15.0])
    
    # Random policy simulation
    state = env.reset()
    total_rewards = np.zeros(3)
    
    for step in range(10):
        actions = torch.rand(3, 2) * 0.3  # Random allocation requests
        next_state, rewards, done, _ = env.step(actions)
        
        total_rewards += np.array(rewards)
        
        if done:
            break
        
        state = next_state
    
    social_welfare = env.compute_social_welfare()
    print(f"Final allocations: {env.current_allocations}")
    print(f"Social welfare: {social_welfare:.2f}")
    print(f"Individual rewards: {total_rewards}")
    
    return env

def demonstrate_autonomous_vehicles():
    """Demonstrate autonomous vehicle coordination."""
    print("\n🚗 Autonomous Vehicle Coordination Demo")
    
    env = AutonomousVehicleEnvironment(n_vehicles=4, road_length=100)
    
    state = env.reset()
    print(f"Initial positions: {env.positions}")
    print(f"Target velocities: {env.target_velocities}")
    
    # Simple coordination: maintain spacing
    for step in range(20):
        actions = []
        for i in range(env.n_vehicles):
            # Simple controller: match target speed, avoid collisions
            speed_error = env.target_velocities[i] - env.velocities[i]
            action = speed_error * 0.1
            
            # Collision avoidance
            distances = torch.abs(env.positions - env.positions[i])
            distances[i] = float('inf')
            min_distance = distances.min()
            
            if min_distance < 5.0:  # Too close
                action = -0.5  # Brake
            
            actions.append(action)
        
        next_state, rewards, done, _ = env.step(actions)
        
        if step % 5 == 0:
            print(f"Step {step}: Positions: {env.positions.round(1).tolist()}")
        
        if done:
            break
        
        state = next_state
    
    return env

def demonstrate_smart_grid():
    """Demonstrate smart grid coordination."""
    print("\n⚡ Smart Grid Management Demo")
    
    env = SmartGridEnvironment(n_producers=2, n_consumers=2, n_storage=1)
    
    state = env.reset()
    print(f"Production capacities: {env.production_capacities.round(1)}")
    print(f"Base demands: {env.base_demands.round(1)}")
    
    total_rewards = np.zeros(5)  # 2 producers + 2 consumers + 1 storage
    
    for step in range(12):  # Half day simulation
        # Simple coordination strategies
        actions = []
        
        # Producers: produce based on demand
        total_demand = env.current_demands.sum()
        for i in range(env.n_producers):
            production_ratio = min(1.0, total_demand / env.production_capacities.sum())
            actions.append(production_ratio)
        
        # Consumers: slight demand response
        for i in range(env.n_consumers):
            demand_response = 0.1 * (torch.randn(1).item())
            actions.append(demand_response)
        
        # Storage: charge during low demand, discharge during high demand
        if total_demand > env.base_demands.sum():
            actions.append(0.5)  # Discharge
        else:
            actions.append(-0.3)  # Charge
        
        next_state, rewards, done, info = env.step(actions)
        total_rewards += np.array(rewards)
        
        if step % 3 == 0:
            print(f"Hour {step*2}: Supply={info['total_supply']:.1f}, "
                  f"Demand={info['total_demand']:.1f}, "
                  f"Imbalance={info['imbalance']:.1f}")
        
        if done:
            break
        
        state = next_state
    
    print(f"Total rewards: {total_rewards.round(2)}")
    
    return env

# Run comprehensive demonstrations
print("🌟 Comprehensive Multi-Agent Applications")
resource_env = demonstrate_resource_allocation()
vehicle_env = demonstrate_autonomous_vehicles()
grid_env = demonstrate_smart_grid()

print("\n🚀 All comprehensive applications implemented!")
print("✅ Resource allocation, autonomous vehicles, and smart grid systems ready!")

In [None]:
# Comprehensive Evaluation and Training Framework

class MultiAgentTrainingOrchestrator:
    """Orchestrator for comprehensive multi-agent training and evaluation."""
    
    def __init__(self, config):
        self.config = config
        self.training_history = []
        self.evaluation_results = []
        
        # Initialize components based on config
        self.setup_environment()
        self.setup_agents()
        self.setup_evaluation_metrics()
    
    def setup_environment(self):
        """Setup environment based on configuration."""
        env_type = self.config.get('environment', 'resource_allocation')
        
        if env_type == 'resource_allocation':
            self.env = ResourceAllocationEnvironment(
                n_agents=self.config.get('n_agents', 4),
                n_resources=self.config.get('n_resources', 3)
            )
        elif env_type == 'autonomous_vehicles':
            self.env = AutonomousVehicleEnvironment(
                n_vehicles=self.config.get('n_agents', 4),
                road_length=self.config.get('road_length', 100)
            )
        elif env_type == 'smart_grid':
            self.env = SmartGridEnvironment(
                n_producers=self.config.get('n_producers', 2),
                n_consumers=self.config.get('n_consumers', 3),
                n_storage=self.config.get('n_storage', 1)
            )
        else:
            self.env = MultiAgentEnvironment(
                n_agents=self.config.get('n_agents', 4),
                state_dim=self.config.get('state_dim', 10),
                action_dim=self.config.get('action_dim', 4)
            )
    
    def setup_agents(self):
        """Setup agents based on configuration."""
        algorithm = self.config.get('algorithm', 'MADDPG')
        n_agents = self.config.get('n_agents', 4)
        
        self.agents = []
        
        if algorithm == 'MADDPG':
            obs_dim = self.config.get('obs_dim', 8)
            action_dim = self.config.get('action_dim', 4)
            
            for i in range(n_agents):
                agent = MADDPGAgent(
                    agent_id=i,
                    obs_dim=obs_dim,
                    action_dim=action_dim,
                    n_agents=n_agents,
                    lr_actor=self.config.get('lr_actor', 1e-3),
                    lr_critic=self.config.get('lr_critic', 1e-3)
                )
                self.agents.append(agent)
        
        elif algorithm == 'VDN':
            for i in range(n_agents):
                agent = VDNAgent(
                    agent_id=i,
                    obs_dim=self.config.get('obs_dim', 8),
                    action_dim=self.config.get('action_dim', 4),
                    lr=self.config.get('lr', 1e-3)
                )
                self.agents.append(agent)
        
        elif algorithm == 'PPO':
            for i in range(n_agents):
                agent = PPOAgent(
                    obs_dim=self.config.get('obs_dim', 8),
                    action_dim=self.config.get('action_dim', 4),
                    lr=self.config.get('lr', 3e-4)
                )
                self.agents.append(agent)
        
        # Initialize communication if enabled
        if self.config.get('enable_communication', False):
            self.comm_channel = CommunicationChannel(
                n_agents=n_agents,
                message_dim=self.config.get('message_dim', 16)
            )
        else:
            self.comm_channel = None
    
    def setup_evaluation_metrics(self):
        """Setup evaluation metrics."""
        self.metrics = {
            'individual_rewards': [],
            'social_welfare': [],
            'cooperation_score': [],
            'communication_efficiency': [],
            'convergence_rate': [],
            'nash_equilibrium_distance': []
        }
    
    def train_episode(self, episode_idx):
        """Train agents for one episode."""
        state = self.env.reset()
        episode_rewards = np.zeros(len(self.agents))
        episode_length = 0
        
        # Episode-specific metrics
        cooperation_events = 0
        communication_events = 0
        
        while episode_length < self.config.get('max_episode_length', 100):
            actions = []
            
            # Get actions from all agents
            for i, agent in enumerate(self.agents):
                if hasattr(agent, 'get_action'):
                    if self.comm_channel:
                        # Include communication
                        messages = self.comm_channel.get_messages_for_agent(i)
                        action = agent.get_action(state[i], messages)
                    else:
                        action = agent.get_action(state[i])
                else:
                    # Simple policy for baseline
                    action = torch.randint(0, self.config.get('action_dim', 4), (1,)).item()
                
                actions.append(action)
            
            # Execute actions
            next_state, rewards, done, info = self.env.step(actions)
            
            # Store experiences and update agents
            for i, agent in enumerate(self.agents):
                if hasattr(agent, 'store_experience'):
                    agent.store_experience(state[i], actions[i], rewards[i], next_state[i], done)
                
                if hasattr(agent, 'update') and episode_idx % self.config.get('update_freq', 1) == 0:
                    agent.update()
            
            # Handle communication
            if self.comm_channel:
                for i, agent in enumerate(self.agents):
                    if hasattr(agent, 'generate_message') and np.random.rand() < 0.1:
                        message = agent.generate_message(state[i])
                        self.comm_channel.send_message(i, message)
                        communication_events += 1
            
            episode_rewards += np.array(rewards)
            episode_length += 1
            
            if done:
                break
            
            state = next_state
        
        # Compute cooperation score (placeholder)
        cooperation_score = np.std(episode_rewards) / (np.mean(episode_rewards) + 1e-8)
        cooperation_score = 1.0 / (1.0 + cooperation_score)  # Higher is more cooperative
        
        # Store episode results
        episode_result = {
            'episode': episode_idx,
            'individual_rewards': episode_rewards,
            'social_welfare': np.sum(episode_rewards),
            'cooperation_score': cooperation_score,
            'communication_events': communication_events,
            'episode_length': episode_length
        }
        
        self.training_history.append(episode_result)
        
        return episode_result
    
    def evaluate_agents(self, n_episodes=10):
        """Comprehensive evaluation of trained agents."""
        print(f"🔍 Evaluating agents over {n_episodes} episodes...")
        
        evaluation_rewards = []
        social_welfares = []
        cooperation_scores = []
        
        for eval_episode in range(n_episodes):
            state = self.env.reset()
            episode_rewards = np.zeros(len(self.agents))
            episode_length = 0
            
            while episode_length < self.config.get('max_episode_length', 100):
                actions = []
                
                # Use deterministic policies for evaluation
                for i, agent in enumerate(self.agents):
                    with torch.no_grad():
                        if hasattr(agent, 'get_action'):
                            if self.comm_channel:
                                messages = self.comm_channel.get_messages_for_agent(i)
                                action = agent.get_action(state[i], messages, deterministic=True)
                            else:
                                action = agent.get_action(state[i], deterministic=True)
                        else:
                            action = torch.randint(0, self.config.get('action_dim', 4), (1,)).item()
                    
                    actions.append(action)
                
                next_state, rewards, done, info = self.env.step(actions)
                episode_rewards += np.array(rewards)
                episode_length += 1
                
                if done:
                    break
                
                state = next_state
            
            evaluation_rewards.append(episode_rewards)
            social_welfares.append(np.sum(episode_rewards))
            
            # Compute cooperation score
            cooperation_score = np.std(episode_rewards) / (np.mean(episode_rewards) + 1e-8)
            cooperation_score = 1.0 / (1.0 + cooperation_score)
            cooperation_scores.append(cooperation_score)
        
        # Aggregate results
        evaluation_result = {
            'mean_individual_rewards': np.mean(evaluation_rewards, axis=0),
            'std_individual_rewards': np.std(evaluation_rewards, axis=0),
            'mean_social_welfare': np.mean(social_welfares),
            'std_social_welfare': np.std(social_welfares),
            'mean_cooperation_score': np.mean(cooperation_scores),
            'std_cooperation_score': np.std(cooperation_scores)
        }
        
        self.evaluation_results.append(evaluation_result)
        
        return evaluation_result
    
    def run_training(self):
        """Run complete training procedure."""
        n_episodes = self.config.get('n_episodes', 1000)
        eval_freq = self.config.get('eval_freq', 100)
        
        print(f"🚀 Starting training for {n_episodes} episodes...")
        print(f"📊 Algorithm: {self.config.get('algorithm', 'MADDPG')}")
        print(f"🤖 Number of agents: {len(self.agents)}")
        print(f"🌍 Environment: {self.config.get('environment', 'multi_agent')}")
        
        for episode in range(n_episodes):
            # Training episode
            episode_result = self.train_episode(episode)
            
            # Periodic evaluation
            if episode % eval_freq == 0:
                eval_result = self.evaluate_agents()
                
                print(f"\n📈 Episode {episode} Results:")
                print(f"   Training Social Welfare: {episode_result['social_welfare']:.2f}")
                print(f"   Evaluation Social Welfare: {eval_result['mean_social_welfare']:.2f} ± {eval_result['std_social_welfare']:.2f}")
                print(f"   Cooperation Score: {eval_result['mean_cooperation_score']:.3f}")
                
                # Early stopping check
                if len(self.evaluation_results) > 3:
                    recent_performance = [r['mean_social_welfare'] for r in self.evaluation_results[-3:]]
                    if np.std(recent_performance) < 0.1:  # Converged
                        print(f"🎯 Training converged at episode {episode}")
                        break
        
        print("✅ Training completed!")
        
        # Final comprehensive evaluation
        final_evaluation = self.evaluate_agents(n_episodes=50)
        
        return {
            'training_history': self.training_history,
            'evaluation_results': self.evaluation_results,
            'final_evaluation': final_evaluation
        }
    
    def visualize_results(self):
        """Visualize training and evaluation results."""
        if not self.training_history:
            print("❌ No training history to visualize")
            return
        
        plt.figure(figsize=(15, 10))
        
        # Social welfare over training
        plt.subplot(2, 3, 1)
        social_welfares = [result['social_welfare'] for result in self.training_history]
        plt.plot(social_welfares)
        plt.title('Social Welfare During Training')
        plt.xlabel('Episode')
        plt.ylabel('Social Welfare')
        
        # Individual rewards over training
        plt.subplot(2, 3, 2)
        if len(self.training_history) > 0:
            n_agents = len(self.training_history[0]['individual_rewards'])
            for agent_id in range(n_agents):
                agent_rewards = [result['individual_rewards'][agent_id] for result in self.training_history]
                plt.plot(agent_rewards, label=f'Agent {agent_id}')
        plt.title('Individual Rewards During Training')
        plt.xlabel('Episode')
        plt.ylabel('Reward')
        plt.legend()
        
        # Cooperation scores
        plt.subplot(2, 3, 3)
        cooperation_scores = [result['cooperation_score'] for result in self.training_history]
        plt.plot(cooperation_scores)
        plt.title('Cooperation Score During Training')
        plt.xlabel('Episode')
        plt.ylabel('Cooperation Score')
        
        # Evaluation results
        if self.evaluation_results:
            plt.subplot(2, 3, 4)
            eval_welfare_means = [result['mean_social_welfare'] for result in self.evaluation_results]
            eval_welfare_stds = [result['std_social_welfare'] for result in self.evaluation_results]
            episodes = range(0, len(self.evaluation_results) * self.config.get('eval_freq', 100), 
                           self.config.get('eval_freq', 100))
            
            plt.errorbar(episodes, eval_welfare_means, yerr=eval_welfare_stds, capsize=5)
            plt.title('Evaluation Social Welfare')
            plt.xlabel('Episode')
            plt.ylabel('Social Welfare')
            
            # Final individual performance comparison
            plt.subplot(2, 3, 5)
            if self.evaluation_results:
                final_result = self.evaluation_results[-1]
                agent_means = final_result['mean_individual_rewards']
                agent_stds = final_result['std_individual_rewards']
                agents = range(len(agent_means))
                
                plt.bar(agents, agent_means, yerr=agent_stds, capsize=5)
                plt.title('Final Individual Agent Performance')
                plt.xlabel('Agent ID')
                plt.ylabel('Mean Reward')
        
        # Algorithm comparison (if multiple runs)
        plt.subplot(2, 3, 6)
        plt.text(0.5, 0.5, f"Algorithm: {self.config.get('algorithm', 'Unknown')}\n"
                            f"Environment: {self.config.get('environment', 'Unknown')}\n"
                            f"Agents: {len(self.agents)}\n"
                            f"Episodes: {len(self.training_history)}",
                 horizontalalignment='center', verticalalignment='center',
                 transform=plt.gca().transAxes, fontsize=12,
                 bbox=dict(boxstyle='round', facecolor='lightblue'))
        plt.title('Configuration Summary')
        plt.axis('off')
        
        plt.tight_layout()
        plt.show()

# Demonstration of comprehensive training
def run_comprehensive_demo():
    """Run comprehensive multi-agent RL demonstration."""
    print("🌟 Comprehensive Multi-Agent RL Training Demo")
    
    # Configuration for different scenarios
    configs = [
        {
            'name': 'MADDPG Resource Allocation',
            'algorithm': 'MADDPG',
            'environment': 'resource_allocation',
            'n_agents': 3,
            'n_resources': 2,
            'obs_dim': 7,  # own_allocation + remaining + total_allocated
            'action_dim': 2,  # allocation for each resource
            'n_episodes': 200,
            'eval_freq': 50,
            'lr_actor': 1e-3,
            'lr_critic': 1e-3
        },
        {
            'name': 'PPO Autonomous Vehicles',
            'algorithm': 'PPO',
            'environment': 'autonomous_vehicles',
            'n_agents': 3,
            'road_length': 100,
            'obs_dim': 6,  # position, velocity, target_velocity, relative_pos, relative_vel, distance
            'action_dim': 3,  # discrete acceleration actions
            'n_episodes': 300,
            'eval_freq': 75,
            'lr': 3e-4,
            'enable_communication': True,
            'message_dim': 8
        }
    ]
    
    results = {}
    
    for config in configs:
        print(f"\n🎯 Running: {config['name']}")
        print("=" * 50)
        
        # Create and run orchestrator
        orchestrator = MultiAgentTrainingOrchestrator(config)
        training_results = orchestrator.run_training()
        
        # Store results
        results[config['name']] = {
            'config': config,
            'results': training_results,
            'orchestrator': orchestrator
        }
        
        # Visualize results
        orchestrator.visualize_results()
        
        print(f"✅ Completed: {config['name']}")
        
        # Print final performance summary
        if training_results['evaluation_results']:
            final_eval = training_results['final_evaluation']
            print(f"📊 Final Performance Summary:")
            print(f"   Social Welfare: {final_eval['mean_social_welfare']:.2f} ± {final_eval['std_social_welfare']:.2f}")
            print(f"   Individual Rewards: {final_eval['mean_individual_rewards'].round(2)}")
            print(f"   Cooperation Score: {final_eval['mean_cooperation_score']:.3f}")
    
    return results

# Run the comprehensive demonstration
print("🚀 Starting Comprehensive Multi-Agent RL Demonstration")
print("This will train and evaluate multiple algorithms on different environments...")

# Note: This would be a full training run - for demo purposes, we'll show the structure
print("📋 Demo Structure:")
print("1. MADDPG on Resource Allocation")
print("2. PPO on Autonomous Vehicle Coordination")  
print("3. Comprehensive evaluation and visualization")
print("\n⚠️  Full training would take significant time - structure demonstrated above")

print("\n🎉 Comprehensive Multi-Agent RL Framework Complete!")
print("✅ Training orchestrator, evaluation framework, and visualization ready!")
print("✅ All advanced multi-agent RL concepts implemented!")
print("\n📚 Notebook Summary:")
print("• Multi-Agent Foundations & Game Theory")
print("• Cooperative Learning (MADDPG, VDN)")
print("• Advanced Policy Methods (PPO, SAC)")
print("• Distributed RL (A3C, IMPALA)")
print("• Communication & Coordination")
print("• Meta-Learning & Adaptation")  
print("• Comprehensive Applications & Case Studies")
print("• Complete Training & Evaluation Framework")

In [None]:
# CA12: Multi-Agent Reinforcement Learning and Advanced Policy Methods

## Deep Reinforcement Learning - Session 12

**Multi-Agent Reinforcement Learning (MARL), Advanced Policy Gradient Methods, and Distributed Training**

This notebook explores advanced reinforcement learning topics including multi-agent systems, sophisticated policy gradient methods, distributed training techniques, and modern approaches to collaborative and competitive learning environments.

### Learning Objectives:
1. Understand multi-agent reinforcement learning fundamentals
2. Implement cooperative and competitive MARL algorithms
3. Master advanced policy gradient methods (PPO, TRPO, SAC variants)
4. Explore distributed training and asynchronous methods
5. Implement communication and coordination mechanisms
6. Understand game-theoretic foundations of MARL
7. Apply meta-learning and few-shot adaptation
8. Analyze emergent behaviors in multi-agent systems

### Notebook Structure:
1. **Multi-Agent Foundations** - Game theory and MARL basics
2. **Cooperative Multi-Agent Learning** - Centralized training, decentralized execution
3. **Competitive and Mixed-Motive Systems** - Self-play and adversarial training
4. **Advanced Policy Methods** - PPO variants, SAC improvements, TRPO
5. **Distributed Reinforcement Learning** - A3C, IMPALA, and modern distributed methods
6. **Communication and Coordination** - Message passing and emergent communication
7. **Meta-Learning in RL** - Few-shot adaptation and transfer learning
8. **Comprehensive Applications** - Real-world multi-agent scenarios

---