# CPT-Aware Multi-Agent Implementation with PettingZoo

This notebook implements a CPT-aware version of the DDPG agent in a competitive multi-agent environment using PettingZoo's `simple_tag_v3` environment. In this version, a CPT transformation is applied to rewards—making the agent risk-sensitive—before computing the Q-target in the update step. The notebook retains the original PettingZoo integration, metrics logging, and visualization code.

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from pettingzoo.mpe import simple_tag_v3

# --- CPT Transformation Function ---
# Applies a CPT-like transformation to rewards: uses a power-law for gains and applies an extra penalty for losses.
def cpt_transform_tensor(rewards, alpha=0.88, beta=0.88, lambda_=2.25):
    rewards = rewards.float()
    pos = torch.pow(torch.clamp(rewards, min=0), alpha)
    neg = -lambda_ * torch.pow(-torch.clamp(rewards, max=0), beta)
    return pos + neg

# --- Actor Network ---
class Actor(nn.Module):
    def __init__(self, state_dim, action_dim, hidden_dim=256):
        super(Actor, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, action_dim),
            nn.Tanh()  # Assuming actions are in [-1, 1]
        )

    def forward(self, state):
        return self.net(state)

# --- Critic Network ---
class Critic(nn.Module):
    def __init__(self, state_dim, action_dim, hidden_dim=256):
        super(Critic, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(state_dim + action_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )

    def forward(self, state, action):
        x = torch.cat([state, action], dim=-1)
        return self.net(x)

# --- CPT-DDPG Agent ---
class CPTDDPG:
    def __init__(self, state_dim, action_dim, actor_lr=1e-3, critic_lr=1e-3, gamma=0.99, tau=0.005):
        self.gamma = gamma
        self.tau = tau

        # Actor network and target
        self.actor = Actor(state_dim, action_dim)
        self.actor_target = Actor(state_dim, action_dim)
        self.actor_target.load_state_dict(self.actor.state_dict())

        # Critic network and target
        self.critic = Critic(state_dim, action_dim)
        self.critic_target = Critic(state_dim, action_dim)
        self.critic_target.load_state_dict(self.critic.state_dict())

        self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=actor_lr)
        self.critic_optimizer = optim.Adam(self.critic.parameters(), lr=critic_lr)

    def select_action(self, state):
        # Expects state as a NumPy array
        state = torch.FloatTensor(state).unsqueeze(0)
        action = self.actor(state)
        return action.detach().cpu().numpy()[0]

    def update(self, replay_buffer, batch_size=64):
        state, action, reward, next_state, done = replay_buffer.sample(batch_size)

        # Apply the CPT transformation to the rewards
        reward = cpt_transform_tensor(reward)

        with torch.no_grad():
            next_action = self.actor_target(next_state)
            target_Q = self.critic_target(next_state, next_action)
            target_Q = reward + (1 - done) * self.gamma * target_Q

        current_Q = self.critic(state, action)
        critic_loss = F.mse_loss(current_Q, target_Q)

        self.critic_optimizer.zero_grad()
        critic_loss.backward()
        self.critic_optimizer.step()

        actor_loss = -self.critic(state, self.actor(state)).mean()
        self.actor_optimizer.zero_grad()
        actor_loss.backward()
        self.actor_optimizer.step()

        # Soft update target networks
        for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):
            target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
        for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()):
            target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)

        return critic_loss.item(), actor_loss.item()

# --- Minimal Replay Buffer ---
class ReplayBuffer:
    def __init__(self, capacity):
        self.capacity = capacity
        self.buffer = []
        self.position = 0

    def add(self, state, action, reward, next_state, done):
        if len(self.buffer) < self.capacity:
            self.buffer.append(None)
        self.buffer[self.position] = (state, action, reward, next_state, done)
        self.position = (self.position + 1) % self.capacity

    def sample(self, batch_size):
        indices = np.random.randint(0, len(self.buffer), size=batch_size)
        batch = [self.buffer[i] for i in indices]
        state, action, reward, next_state, done = zip(*batch)
        state = torch.stack(state)
        action = torch.stack(action)
        reward = torch.stack(reward).squeeze()
        next_state = torch.stack(next_state)
        done = torch.stack(done).squeeze()
        return state, action, reward, next_state, done

    def size(self):
        return len(self.buffer)


## Training Loop with PettingZoo Integration

The cell below runs the training loop using PettingZoo's `simple_tag_v3` environment. Each agent in the environment is assigned its own CPT-DDPG agent and replay buffer. Rewards are transformed using the CPT function before updating the networks. Metrics are logged per episode.

In [None]:
# Initialize the PettingZoo environment
env = simple_tag_v3.env()
env.reset()

# Create dictionaries for agents and replay buffers
agents = {}
replay_buffers = {}
for agent_id in env.agents:
    state_dim = env.observation_space(agent_id).shape[0]
    action_dim = env.action_space(agent_id).shape[0]
    agents[agent_id] = CPTDDPG(state_dim, action_dim)
    replay_buffers[agent_id] = ReplayBuffer(capacity=100000)

num_episodes = 50
min_buffer_size = 64

# Create a dictionary to log episode rewards per agent
episode_rewards = {agent_id: [] for agent_id in env.agents}

for episode in range(num_episodes):
    env.reset()
    total_rewards = {agent_id: 0.0 for agent_id in env.agents}
    
    # PettingZoo's agent_iter loop
    for agent_id in env.agent_iter():
        obs, reward, done, info = env.last()
        
        if done:
            action = None
        else:
            obs_np = np.array(obs, dtype=np.float32)
            action = agents[agent_id].select_action(obs_np)
            # Ensure action is within the valid range
            action = np.clip(action, env.action_space(agent_id).low, env.action_space(agent_id).high)
        
        env.step(action)
        
        # Get new observation for the current agent
        new_obs = env.observe(agent_id)
        
        # Convert to tensors
        state_tensor = torch.tensor(np.array(obs, dtype=np.float32))
        action_tensor = torch.tensor(action, dtype=torch.float32) if action is not None else torch.zeros(env.action_space(agent_id).shape, dtype=torch.float32)
        reward_tensor = torch.tensor([reward], dtype=torch.float32)
        next_state_tensor = torch.tensor(np.array(new_obs, dtype=np.float32))
        done_tensor = torch.tensor([float(done)], dtype=torch.float32)
        
        # Store transition in the replay buffer
        replay_buffers[agent_id].add(state_tensor, action_tensor, reward_tensor, next_state_tensor, done_tensor)
        total_rewards[agent_id] += reward
        
        # Update the agent if the replay buffer has enough samples
        if replay_buffers[agent_id].size() >= min_buffer_size:
            critic_loss, actor_loss = agents[agent_id].update(replay_buffers[agent_id], batch_size=64)
            # (Optional) Log the losses if desired
    
    # Log and print the rewards for this episode
    for agent_id in env.agents:
        episode_rewards[agent_id].append(total_rewards[agent_id])
        print(f"Episode {episode} - Agent {agent_id}: Total Reward = {total_rewards[agent_id]:.2f}")


## Visualizations

The cell below plots the total rewards per episode for each agent.

In [None]:
plt.figure(figsize=(12, 6))
for agent_id, rewards in episode_rewards.items():
    plt.plot(rewards, label=f'Agent {agent_id}')
plt.xlabel('Episode')
plt.ylabel('Total Reward')
plt.title('Episode Rewards per Agent')
plt.legend()
plt.show()
