<!-- Centered layout with a university logo -->
<div align="center">

  <!-- University Logo -->
  <img src="https://cdn.freebiesupply.com/logos/large/2x/sharif-logo-png-transparent.png" width="180" height="180" style="margin-bottom: 10px;">
  
  <!-- Assignment Title -->
  <h1></h1>
  <h1 style="color:#0F5298; font-size: 40px; font-weight: bold; margin-bottom: 5px;">Deep Reinforcement Learning</h1>
  <h2 style="color:#0F5298; font-size: 32px; font-weight: normal; margin-top: 0px;">Assignment 10 - Multi-Agent Reinforcement Learning</h2>

  <!-- Department and University -->
  <h3 style="color:#696880; font-size: 24px; margin-top: 20px;">Computer Engineering Department</h3>
  <h3 style="color:#696880; font-size: 22px; margin-top: -5px;">Sharif University of Technology</h3>

  <!-- Semester -->
  <h3 style="color:#696880; font-size: 22px; margin-top: 20px;">Spring 2025</h3>

  <!-- Authors -->
  <h3 style="color:green; font-size: 22px; margin-top: 20px;">Full name: [FULL_NAME]</h3>
  <h3 style="color:green; font-size: 22px; margin-top: 20px;">Student ID: [STUDENT_ID]</h3>

  <!-- Horizontal Line for Separation -->
  <hr style="border: 1px solid #0F5298; width: 80%; margin-top: 30px;">

</div>


## Setup & Overview  
In this notebook, we explore Multi-Agent Reinforcement Learning (MARL) through various algorithms and environments.  
We implement and compare several approaches:
- **Independent Q-Learning** (IQL) - Each agent learns independently
- **QMIX** - Value decomposition for cooperative settings
- **MADDPG** - Multi-Agent Actor-Critic for mixed environments
- **Communication Protocols** - CommNet and TarMAC
- **Self-Play** - Training against past versions

We'll work with classic game theory environments like Prisoner's Dilemma and Coordination Games, then move to more complex multi-agent scenarios.

Follow the instructions carefully and complete the sections marked with **TODO**.


## Setup and Environment

In the upcoming cells, we import necessary libraries, set up utility functions for reproducibility and plotting, and define the basic components of our multi-agent experiments.


In [None]:
# %% [code]
import numpy as np
import random
import time
import os
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import Categorical
from collections import deque, defaultdict
import itertools
from dataclasses import dataclass
from copy import deepcopy
from tqdm.notebook import tqdm
import warnings
warnings.filterwarnings('ignore')

plt.rcParams['figure.dpi'] = 100
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)


In [None]:
# %% [code]
def plot_logs(df, x_key, y_key, legend_key, **kwargs):
    """Plot learning curves for multi-agent experiments"""
    num = len(df[legend_key].unique())
    pal = sns.color_palette("hls", num)
    if 'palette' not in kwargs:
        kwargs['palette'] = pal
    ax = sns.lineplot(x=x_key, y=y_key, data=df, hue=legend_key, **kwargs)
    return ax

def plot_game_matrix(matrix, title="Game Matrix", figsize=(8, 6)):
    """Plot a 2x2 game matrix"""
    fig, ax = plt.subplots(figsize=figsize)
    
    # Create heatmap
    im = ax.imshow(matrix, cmap='RdYlBu', aspect='auto')
    
    # Add text annotations
    for i in range(matrix.shape[0]):
        for j in range(matrix.shape[1]):
            text = ax.text(j, i, f'{matrix[i, j][0]:.1f}, {matrix[i, j][1]:.1f}',
                         ha="center", va="center", color="black", fontsize=12)
    
    # Set labels
    ax.set_xticks([0, 1])
    ax.set_yticks([0, 1])
    ax.set_xticklabels(['Action 0', 'Action 1'])
    ax.set_yticklabels(['Action 0', 'Action 1'])
    ax.set_xlabel('Agent 2 Action')
    ax.set_ylabel('Agent 1 Action')
    ax.set_title(title)
    
    # Add colorbar
    cbar = plt.colorbar(im)
    cbar.set_label('Payoff Value')
    
    plt.tight_layout()
    plt.show()

def set_seed(s):
    """Set random seeds for reproducibility"""
    np.random.seed(s)
    random.seed(s)
    torch.manual_seed(s)

set_seed(42)


## Game Theory Foundations

We start with classic game theory environments to understand multi-agent interactions:

### Prisoner's Dilemma
- **Cooperation (C)**: Both agents cooperate, get moderate reward
- **Defection (D)**: One agent defects while other cooperates, defector gets high reward
- **Mutual Defection**: Both defect, get low reward

### Coordination Game  
- **Pure Coordination**: Both agents must choose same action for high reward
- **Battle of Sexes**: Different preferences but coordination still beneficial


In [None]:
# %% [code]
# Define game matrices
PRISONERS_DILEMMA = np.array([
    [[3, 3], [0, 5]],  # Agent 1: C, Agent 2: C,D
    [[5, 0], [1, 1]]   # Agent 1: D, Agent 2: C,D
])

COORDINATION_GAME = np.array([
    [[2, 2], [0, 0]],   # Agent 1: A, Agent 2: A,B
    [[0, 0], [1, 1]]    # Agent 1: B, Agent 2: A,B
])

BATTLE_OF_SEXES = np.array([
    [[2, 1], [0, 0]],   # Agent 1: A, Agent 2: A,B
    [[0, 0], [1, 2]]    # Agent 1: B, Agent 2: A,B
])

print("Prisoner's Dilemma Matrix:")
print("Agent 1\\Agent 2 | Cooperate | Defect")
print("Cooperate        |   3, 3    |  0, 5")
print("Defect           |   5, 0    |  1, 1")
print("\nNash Equilibrium: (Defect, Defect)")
print("Pareto Optimal: (Cooperate, Cooperate)")

plot_game_matrix(PRISONERS_DILEMMA, "Prisoner's Dilemma")


**Q:** Why is the Nash Equilibrium (Defect, Defect) suboptimal in the Prisoner's Dilemma?

**A:** The Nash Equilibrium (Defect, Defect) is suboptimal because it represents a situation where both agents choose their individually rational strategy, but this leads to a worse outcome for both compared to mutual cooperation. Each agent defects because they fear being exploited if they cooperate while the other defects. However, if both could commit to cooperation, they would both be better off (3,3 vs 1,1). This illustrates the fundamental tension between individual rationality and collective welfare in competitive environments.


In [None]:
# %% [code]
plot_game_matrix(COORDINATION_GAME, "Coordination Game")
print("\nCoordination Game:")
print("Agent 1\\Agent 2 | Action A | Action B")
print("Action A         |   2, 2   |  0, 0")
print("Action B         |   0, 0   |  1, 1")
print("\nNash Equilibria: (A,A) and (B,B)")
print("Pareto Optimal: (A,A)")

plot_game_matrix(BATTLE_OF_SEXES, "Battle of Sexes")
print("\nBattle of Sexes:")
print("Agent 1\\Agent 2 | Action A | Action B")
print("Action A         |   2, 1   |  0, 0")
print("Action B         |   0, 0   |  1, 2")
print("\nNash Equilibria: (A,A) and (B,B)")
print("Agent 1 prefers (A,A), Agent 2 prefers (B,B)")


## Multi-Agent Environment

We'll create a flexible environment that can handle different game matrices and multiple agents.


In [None]:
# %% [code]
@dataclass
class MultiAgentGame:
    """Multi-agent game environment"""
    payoff_matrix: np.ndarray
    num_agents: int = 2
    num_actions: int = 2
    
    def step(self, actions):
        """Execute actions and return rewards"""
        if self.num_agents == 2:
            return self.payoff_matrix[actions[0], actions[1]]
        else:
            # For more than 2 agents, we'll use a different structure
            raise NotImplementedError("Only 2-agent games implemented")
    
    def get_optimal_strategies(self):
        """Find Nash equilibria"""
        equilibria = []
        
        # Check all pure strategy combinations
        for a1 in range(self.num_actions):
            for a2 in range(self.num_actions):
                is_equilibrium = True
                
                # Check if agent 1 wants to deviate
                for a1_dev in range(self.num_actions):
                    if a1_dev != a1:
                        if self.payoff_matrix[a1_dev, a2][0] > self.payoff_matrix[a1, a2][0]:
                            is_equilibrium = False
                            break
                
                # Check if agent 2 wants to deviate
                if is_equilibrium:
                    for a2_dev in range(self.num_actions):
                        if a2_dev != a2:
                            if self.payoff_matrix[a1, a2_dev][1] > self.payoff_matrix[a1, a2][1]:
                                is_equilibrium = False
                                break
                
                if is_equilibrium:
                    equilibria.append((a1, a2))
        
        return equilibria

# Test the environment
pd_env = MultiAgentGame(PRISONERS_DILEMMA)
print("Prisoner's Dilemma Nash Equilibria:", pd_env.get_optimal_strategies())

coord_env = MultiAgentGame(COORDINATION_GAME)
print("Coordination Game Nash Equilibria:", coord_env.get_optimal_strategies())


## Independent Q-Learning (IQL)

The simplest approach to multi-agent RL: each agent learns independently, treating other agents as part of the environment.

**Key Characteristics:**
- Each agent maintains its own Q-table
- No communication between agents
- Non-stationarity: environment changes as other agents learn
- No convergence guarantees


In [None]:
# %% [code]
class IndependentQLearning:
    """Independent Q-Learning agent"""
    
    def __init__(self, num_actions, learning_rate=0.1, epsilon=0.1, gamma=0.9):
        self.num_actions = num_actions
        self.lr = learning_rate
        self.epsilon = epsilon
        self.gamma = gamma
        self.Q = np.zeros(num_actions)
        self.action_counts = np.zeros(num_actions)
        
    def get_action(self):
        """Epsilon-greedy action selection"""
        if np.random.random() < self.epsilon:
            return np.random.randint(self.num_actions)
        else:
            return np.argmax(self.Q)
    
    def update(self, action, reward):
        """Update Q-values"""
        self.action_counts[action] += 1
        self.Q[action] += self.lr * (reward - self.Q[action])
    
    def reset(self):
        """Reset agent state"""
        self.Q = np.zeros(self.num_actions)
        self.action_counts = np.zeros(self.num_actions)

class MultiAgentExperiment:
    """Run multi-agent experiments"""
    
    def __init__(self, env, agents, num_episodes=1000):
        self.env = env
        self.agents = agents
        self.num_episodes = num_episodes
        self.logs = []
        
    def run(self):
        """Run the experiment"""
        for episode in tqdm(range(self.num_episodes), desc="Training"):
            # Get actions from all agents
            actions = [agent.get_action() for agent in self.agents]
            
            # Execute actions and get rewards
            rewards = self.env.step(actions)
            
            # Update all agents
            for i, agent in enumerate(self.agents):
                agent.update(actions[i], rewards[i])
            
            # Log episode data
            self.logs.append({
                'episode': episode,
                'actions': actions.copy(),
                'rewards': rewards.copy(),
                'agent1_action': actions[0],
                'agent2_action': actions[1],
                'agent1_reward': rewards[0],
                'agent2_reward': rewards[1]
            })
        
        return pd.DataFrame(self.logs)


In [None]:
# %% [code]
# Test Independent Q-Learning on Prisoner's Dilemma
print("Testing Independent Q-Learning on Prisoner's Dilemma...")

# Create two IQL agents
agent1 = IndependentQLearning(num_actions=2, learning_rate=0.1, epsilon=0.1)
agent2 = IndependentQLearning(num_actions=2, learning_rate=0.1, epsilon=0.1)

# Run experiment
experiment = MultiAgentExperiment(pd_env, [agent1, agent2], num_episodes=2000)
logs = experiment.run()

# Analyze results
print(f"\nFinal Q-values:")
print(f"Agent 1 Q-values: {agent1.Q}")
print(f"Agent 2 Q-values: {agent2.Q}")

print(f"\nAction counts:")
print(f"Agent 1 action counts: {agent1.action_counts}")
print(f"Agent 2 action counts: {agent2.action_counts}")

# Plot learning curves
fig, axes = plt.subplots(1, 2, figsize=(15, 5))

# Plot rewards over time
axes[0].plot(logs['episode'], logs['agent1_reward'], alpha=0.3, label='Agent 1', color='blue')
axes[0].plot(logs['episode'], logs['agent2_reward'], alpha=0.3, label='Agent 2', color='red')

# Plot moving averages
window = 100
logs['agent1_ma'] = logs['agent1_reward'].rolling(window=window).mean()
logs['agent2_ma'] = logs['agent2_reward'].rolling(window=window).mean()

axes[0].plot(logs['episode'], logs['agent1_ma'], label='Agent 1 (MA)', color='blue', linewidth=2)
axes[0].plot(logs['episode'], logs['agent2_ma'], label='Agent 2 (MA)', color='red', linewidth=2)
axes[0].set_xlabel('Episode')
axes[0].set_ylabel('Reward')
axes[0].set_title('Reward Learning Curves')
axes[0].legend()
axes[0].grid(True)

# Plot action selection over time
axes[1].plot(logs['episode'], logs['agent1_action'], alpha=0.3, label='Agent 1', color='blue')
axes[1].plot(logs['episode'], logs['agent2_action'], alpha=0.3, label='Agent 2', color='red')
axes[1].set_xlabel('Episode')
axes[1].set_ylabel('Action')
axes[1].set_title('Action Selection Over Time')
axes[1].legend()
axes[1].grid(True)

plt.tight_layout()
plt.show()


**Q:** What do you observe about the convergence behavior of Independent Q-Learning in the Prisoner's Dilemma?

**A:** Independent Q-Learning typically converges to the Nash Equilibrium (Defect, Defect) because each agent learns independently and discovers that defecting gives higher individual rewards regardless of the other agent's action. The agents don't coordinate or communicate, so they can't escape the individual rationality trap that leads to the suboptimal Nash equilibrium. The learning curves show high variance initially due to exploration, then stabilize around the equilibrium rewards (1,1).


## QMIX: Value Decomposition for Cooperative MARL

QMIX addresses the credit assignment problem in cooperative multi-agent settings by decomposing the joint Q-function into individual Q-functions while ensuring monotonicity.

**Key Properties:**
- **Monotonicity**: ∂Q_tot/∂Q_i ≥ 0 for all agents i
- **Decentralized Execution**: Each agent can act independently using its local Q-function
- **Centralized Training**: Uses global state information during training


In [None]:
# %% [code]
class QMIXAgent(nn.Module):
    """Individual Q-network for QMIX"""
    
    def __init__(self, obs_dim, action_dim, hidden_dim=64):
        super().__init__()
        self.network = nn.Sequential(
            nn.Linear(obs_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, action_dim)
        )
    
    def forward(self, obs):
        return self.network(obs)

class QMIXMixer(nn.Module):
    """QMIX mixing network ensuring monotonicity"""
    
    def __init__(self, num_agents, state_dim, hidden_dim=32):
        super().__init__()
        self.num_agents = num_agents
        
        # Hypernetworks for mixing weights
        self.hyper_w1 = nn.Linear(state_dim, num_agents * hidden_dim)
        self.hyper_w2 = nn.Linear(state_dim, hidden_dim)
        self.hyper_b1 = nn.Linear(state_dim, hidden_dim)
        self.hyper_b2 = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )
    
    def forward(self, q_vals, state):
        """
        Mix individual Q-values into joint Q-value
        Args:
            q_vals: [batch_size, num_agents] individual Q-values
            state: [batch_size, state_dim] global state
        """
        batch_size = q_vals.size(0)
        
        # Generate mixing weights (ensure monotonicity with abs)
        w1 = torch.abs(self.hyper_w1(state))
        b1 = self.hyper_b1(state)
        w2 = torch.abs(self.hyper_w2(state))
        b2 = self.hyper_b2(state)
        
        # Reshape weights
        w1 = w1.view(batch_size, self.num_agents, -1)
        w2 = w2.view(batch_size, -1, 1)
        
        # Mixing computation
        hidden = F.elu(torch.bmm(q_vals.unsqueeze(1), w1) + b1.unsqueeze(1))
        q_tot = torch.bmm(hidden, w2) + b2.unsqueeze(1)
        
        return q_tot.squeeze(1)

class QMIX:
    """QMIX algorithm implementation"""
    
    def __init__(self, num_agents, obs_dim, action_dim, state_dim, lr=0.0005):
        self.num_agents = num_agents
        self.agents = nn.ModuleList([
            QMIXAgent(obs_dim, action_dim) for _ in range(num_agents)
        ])
        self.mixer = QMIXMixer(num_agents, state_dim)
        
        # Target networks
        self.target_agents = nn.ModuleList([
            QMIXAgent(obs_dim, action_dim) for _ in range(num_agents)
        ])
        self.target_mixer = QMIXMixer(num_agents, state_dim)
        
        # Copy parameters to target networks
        self.update_target_networks()
        
        # Optimizers
        self.optimizer = optim.Adam(
            list(self.agents.parameters()) + list(self.mixer.parameters()),
            lr=lr
        )
        
        self.gamma = 0.99
        self.tau = 0.005  # Soft update parameter
    
    def update_target_networks(self):
        """Copy parameters to target networks"""
        for i in range(self.num_agents):
            self.target_agents[i].load_state_dict(self.agents[i].state_dict())
        self.target_mixer.load_state_dict(self.mixer.state_dict())
    
    def soft_update_target_networks(self):
        """Soft update target networks"""
        for i in range(self.num_agents):
            for param, target_param in zip(self.agents[i].parameters(), 
                                        self.target_agents[i].parameters()):
                target_param.data.copy_(self.tau * param.data + 
                                      (1 - self.tau) * target_param.data)
        
        for param, target_param in zip(self.mixer.parameters(), 
                                     self.target_mixer.parameters()):
            target_param.data.copy_(self.tau * param.data + 
                                  (1 - self.tau) * target_param.data)
    
    def get_actions(self, observations, epsilon=0.0):
        """Get actions for all agents"""
        actions = []
        for i, obs in enumerate(observations):
            if np.random.random() < epsilon:
                actions.append(np.random.randint(self.agents[i].network[-1].out_features))
            else:
                with torch.no_grad():
                    q_vals = self.agents[i](obs)
                    actions.append(q_vals.argmax().item())
        return actions
    
    def update(self, batch):
        """Update QMIX networks"""
        obs_batch, action_batch, reward_batch, next_obs_batch, state_batch, next_state_batch, done_batch = batch
        
        # Current Q-values
        current_q_vals = []
        for i in range(self.num_agents):
            q_vals = self.agents[i](obs_batch[i])
            current_q_vals.append(q_vals.gather(1, action_batch[i].unsqueeze(1)).squeeze(1))
        
        current_q_vals = torch.stack(current_q_vals, dim=1)
        current_q_tot = self.mixer(current_q_vals, state_batch)
        
        # Target Q-values
        with torch.no_grad():
            next_q_vals = []
            for i in range(self.num_agents):
                next_q_vals.append(self.target_agents[i](next_obs_batch[i]).max(1)[0])
            
            next_q_vals = torch.stack(next_q_vals, dim=1)
            next_q_tot = self.target_mixer(next_q_vals, next_state_batch)
            target_q_tot = reward_batch + self.gamma * next_q_tot * (1 - done_batch)
        
        # Compute loss
        loss = F.mse_loss(current_q_tot, target_q_tot)
        
        # Update networks
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        
        # Soft update target networks
        self.soft_update_target_networks()
        
        return loss.item()


## MADDPG: Multi-Agent Deep Deterministic Policy Gradient

MADDPG extends DDPG to multi-agent settings using centralized training with decentralized execution (CTDE).

**Key Features:**
- **Centralized Critics**: Each agent's critic sees all observations and actions
- **Decentralized Actors**: Each agent's actor only sees its own observations
- **Handles Mixed Environments**: Works for both cooperative and competitive settings
- **Addresses Non-stationarity**: Critics use global information during training


In [None]:
# %% [code]
class Actor(nn.Module):
    """Actor network for MADDPG"""
    
    def __init__(self, obs_dim, action_dim, hidden_dim=64):
        super().__init__()
        self.network = nn.Sequential(
            nn.Linear(obs_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, action_dim),
            nn.Tanh()  # Actions in [-1, 1]
        )
    
    def forward(self, obs):
        return self.network(obs)

class Critic(nn.Module):
    """Centralized critic network for MADDPG"""
    
    def __init__(self, obs_dim, action_dim, num_agents, hidden_dim=64):
        super().__init__()
        self.num_agents = num_agents
        total_obs_dim = obs_dim * num_agents
        total_action_dim = action_dim * num_agents
        
        self.network = nn.Sequential(
            nn.Linear(total_obs_dim + total_action_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )
    
    def forward(self, obs, actions):
        """
        Args:
            obs: [batch_size, num_agents * obs_dim] concatenated observations
            actions: [batch_size, num_agents * action_dim] concatenated actions
        """
        x = torch.cat([obs, actions], dim=-1)
        return self.network(x)

class MADDPGAgent:
    """Individual MADDPG agent"""
    
    def __init__(self, agent_id, obs_dim, action_dim, num_agents, lr=0.001, tau=0.005):
        self.agent_id = agent_id
        self.obs_dim = obs_dim
        self.action_dim = action_dim
        self.num_agents = num_agents
        
        # Networks
        self.actor = Actor(obs_dim, action_dim)
        self.critic = Critic(obs_dim, action_dim, num_agents)
        self.target_actor = Actor(obs_dim, action_dim)
        self.target_critic = Critic(obs_dim, action_dim, num_agents)
        
        # Copy parameters to target networks
        self.target_actor.load_state_dict(self.actor.state_dict())
        self.target_critic.load_state_dict(self.critic.state_dict())
        
        # Optimizers
        self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=lr)
        self.critic_optimizer = optim.Adam(self.critic.parameters(), lr=lr)
        
        self.tau = tau
        self.gamma = 0.99
    
    def act(self, obs, noise=0.0):
        """Get action with optional noise"""
        with torch.no_grad():
            action = self.actor(obs)
            if noise > 0:
                action += torch.randn_like(action) * noise
            return torch.clamp(action, -1, 1)
    
    def update(self, batch, other_agents):
        """Update actor and critic networks"""
        obs_batch, action_batch, reward_batch, next_obs_batch, done_batch = batch
        
        # Prepare centralized inputs
        all_obs = torch.cat([obs_batch[i] for i in range(self.num_agents)], dim=1)
        all_actions = torch.cat([action_batch[i] for i in range(self.num_agents)], dim=1)
        all_next_obs = torch.cat([next_obs_batch[i] for i in range(self.num_agents)], dim=1)
        
        # Current Q-value
        current_q = self.critic(all_obs, all_actions)
        
        # Target Q-value
        with torch.no_grad():
            next_actions = []
            for i in range(self.num_agents):
                if i == self.agent_id:
                    next_actions.append(self.target_actor(next_obs_batch[i]))
                else:
                    next_actions.append(other_agents[i].target_actor(next_obs_batch[i]))
            
            all_next_actions = torch.cat(next_actions, dim=1)
            target_q = reward_batch[self.agent_id] + self.gamma * self.target_critic(all_next_obs, all_next_actions) * (1 - done_batch)
        
        # Critic loss
        critic_loss = F.mse_loss(current_q, target_q)
        
        # Update critic
        self.critic_optimizer.zero_grad()
        critic_loss.backward()
        self.critic_optimizer.step()
        
        # Actor loss
        current_actions = []
        for i in range(self.num_agents):
            if i == self.agent_id:
                current_actions.append(self.actor(obs_batch[i]))
            else:
                current_actions.append(action_batch[i])
        
        all_current_actions = torch.cat(current_actions, dim=1)
        actor_loss = -self.critic(all_obs, all_current_actions).mean()
        
        # Update actor
        self.actor_optimizer.zero_grad()
        actor_loss.backward()
        self.actor_optimizer.step()
        
        # Soft update target networks
        self.soft_update_target_networks()
        
        return critic_loss.item(), actor_loss.item()
    
    def soft_update_target_networks(self):
        """Soft update target networks"""
        for param, target_param in zip(self.actor.parameters(), self.target_actor.parameters()):
            target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
        
        for param, target_param in zip(self.critic.parameters(), self.target_critic.parameters()):
            target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)

class MADDPG:
    """MADDPG algorithm coordinator"""
    
    def __init__(self, num_agents, obs_dim, action_dim, lr=0.001, tau=0.005):
        self.num_agents = num_agents
        self.agents = [
            MADDPGAgent(i, obs_dim, action_dim, num_agents, lr, tau)
            for i in range(num_agents)
        ]
    
    def act(self, observations, noise=0.0):
        """Get actions for all agents"""
        actions = []
        for i, obs in enumerate(observations):
            action = self.agents[i].act(obs, noise)
            actions.append(action)
        return actions
    
    def update(self, batch):
        """Update all agents"""
        losses = []
        for i in range(self.num_agents):
            other_agents = [self.agents[j] for j in range(self.num_agents) if j != i]
            critic_loss, actor_loss = self.agents[i].update(batch, other_agents)
            losses.append((critic_loss, actor_loss))
        return losses


## Communication Protocols

Communication is crucial for coordination in multi-agent systems. We implement two key approaches:

### CommNet: Emergent Communication
- Agents learn to communicate through continuous messages
- No predefined communication protocol
- Communication emerges from the task requirements

### TarMAC: Targeted Multi-Agent Communication
- Attention-based communication mechanism
- Agents can selectively attend to messages from other agents
- More sophisticated than simple averaging


In [None]:
# %% [code]
class CommNetAgent(nn.Module):
    """CommNet agent with emergent communication"""
    
    def __init__(self, obs_dim, action_dim, comm_dim=32, hidden_dim=64):
        super().__init__()
        self.comm_dim = comm_dim
        
        # Observation encoder
        self.obs_encoder = nn.Sequential(
            nn.Linear(obs_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, comm_dim)
        )
        
        # Communication GRU
        self.comm_gru = nn.GRU(comm_dim, comm_dim, batch_first=True)
        
        # Action decoder
        self.action_decoder = nn.Sequential(
            nn.Linear(comm_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, action_dim)
        )
    
    def forward(self, obs, comm_hidden=None, num_comm_rounds=2):
        """
        Args:
            obs: [batch_size, obs_dim] agent observation
            comm_hidden: [batch_size, comm_dim] communication hidden state
            num_comm_rounds: number of communication rounds
        """
        # Encode observation
        encoded_obs = self.obs_encoder(obs)
        
        if comm_hidden is None:
            comm_hidden = encoded_obs
        
        # Communication rounds
        for _ in range(num_comm_rounds):
            # Average communication (simplified CommNet)
            comm_input = comm_hidden.mean(dim=0, keepdim=True).expand_as(comm_hidden)
            comm_hidden, _ = self.comm_gru(comm_input.unsqueeze(1), comm_hidden.unsqueeze(0))
            comm_hidden = comm_hidden.squeeze(0)
        
        # Decode to action
        action_logits = self.action_decoder(comm_hidden)
        return action_logits, comm_hidden

class TarMACAgent(nn.Module):
    """TarMAC agent with attention-based communication"""
    
    def __init__(self, obs_dim, action_dim, comm_dim=32, hidden_dim=64):
        super().__init__()
        self.comm_dim = comm_dim
        
        # Observation encoder
        self.obs_encoder = nn.Sequential(
            nn.Linear(obs_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, comm_dim)
        )
        
        # Attention components
        self.query_net = nn.Linear(comm_dim, comm_dim)
        self.key_net = nn.Linear(comm_dim, comm_dim)
        self.value_net = nn.Linear(comm_dim, comm_dim)
        
        # Action decoder
        self.action_decoder = nn.Sequential(
            nn.Linear(comm_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, action_dim)
        )
    
    def forward(self, obs, other_messages=None):
        """
        Args:
            obs: [batch_size, obs_dim] agent observation
            other_messages: [batch_size, num_others, comm_dim] messages from other agents
        """
        # Encode observation
        encoded_obs = self.obs_encoder(obs)
        
        if other_messages is None:
            # No communication, use only own observation
            attended_message = encoded_obs
        else:
            # Compute attention
            query = self.query_net(encoded_obs)  # [batch_size, comm_dim]
            keys = self.key_net(other_messages)  # [batch_size, num_others, comm_dim]
            values = self.value_net(other_messages)  # [batch_size, num_others, comm_dim]
            
            # Attention weights
            scores = torch.bmm(query.unsqueeze(1), keys.transpose(1, 2))  # [batch_size, 1, num_others]
            attention_weights = F.softmax(scores, dim=-1)  # [batch_size, 1, num_others]
            
            # Weighted sum of values
            attended_message = torch.bmm(attention_weights, values).squeeze(1)  # [batch_size, comm_dim]
            
            # Combine with own observation
            attended_message = attended_message + encoded_obs
        
        # Decode to action
        action_logits = self.action_decoder(attended_message)
        return action_logits

class CommunicationExperiment:
    """Experiment framework for communication protocols"""
    
    def __init__(self, env, agents, num_episodes=1000):
        self.env = env
        self.agents = agents
        self.num_episodes = num_episodes
        self.logs = []
    
    def run_commnet_experiment(self):
        """Run CommNet experiment"""
        print("Running CommNet experiment...")
        
        for episode in tqdm(range(self.num_episodes), desc="CommNet Training"):
            # Initialize communication hidden states
            comm_hidden = None
            
            # Get actions with communication
            actions = []
            for i, agent in enumerate(self.agents):
                obs = torch.randn(1, 2)  # Simplified observation
                action_logits, comm_hidden = agent(obs, comm_hidden)
                action = Categorical(logits=action_logits).sample().item()
                actions.append(action)
            
            # Execute actions
            rewards = self.env.step(actions)
            
            # Log episode
            self.logs.append({
                'episode': episode,
                'actions': actions,
                'rewards': rewards,
                'method': 'CommNet'
            })
        
        return pd.DataFrame(self.logs)
    
    def run_tarmac_experiment(self):
        """Run TarMAC experiment"""
        print("Running TarMAC experiment...")
        
        for episode in tqdm(range(self.num_episodes), desc="TarMAC Training"):
            # Collect messages from all agents
            messages = []
            for agent in self.agents:
                obs = torch.randn(1, 2)  # Simplified observation
                encoded_obs = agent.obs_encoder(obs)
                messages.append(encoded_obs)
            
            # Get actions with attention
            actions = []
            for i, agent in enumerate(self.agents):
                obs = torch.randn(1, 2)
                other_messages = torch.stack([m for j, m in enumerate(messages) if j != i])
                other_messages = other_messages.unsqueeze(0)  # Add batch dimension
                
                action_logits = agent(obs, other_messages)
                action = Categorical(logits=action_logits).sample().item()
                actions.append(action)
            
            # Execute actions
            rewards = self.env.step(actions)
            
            # Log episode
            self.logs.append({
                'episode': episode,
                'actions': actions,
                'rewards': rewards,
                'method': 'TarMAC'
            })
        
        return pd.DataFrame(self.logs)
