In [None]:
import numpy as np
import random
from collections import deque
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
from pettingzoo.mpe import simple_tag_v2
import supersuit as ss

In [None]:
import torch
import torch.nn as nn

class Actor(nn.Module):
  def __init__(self, state_size, action_size):
    super(Actor, self).__init__()
    self.fc1 = nn.Linear(state_size, 64)
    self.fc2 = nn.Linear(64, 64)
    # Give the desired size for the output layer
    self.fc3 = nn.Linear(64, action_size)

  def forward(self, state):
    x = torch.relu(self.fc1(torch.tensor(state)))
    x = torch.relu(self.fc2(x))
    # Obtain the action probabilities
    action_probs = torch.softmax(self.fc3(x), dim=-1)
    return action_probs


class Critic(nn.Module):
  def __init__(self, state_size):
    super(Critic, self).__init__()
    self.fc1 = nn.Linear(state_size, 64)
    # Fill in the desired dimensions
    self.fc2 = nn.Linear(64, 1)

  def forward(self, state):
    x = torch.relu(self.fc1(torch.tensor(state)))
    # Calculate the output value
    value = self.fc2(x)
    return value

In [None]:
import random
import numpy as np
from collections import deque

class ReplayBuffer:
  def __init__(self, capacity):
    self.buffer = deque(maxlen = capacity)

  def push(self, obs, actions, rewards, next_obs, done):
    data = (obs, actions, rewards, next_obs, done)
    self.buffer.append(data)

  def sample(self, batch_size):
    return random.sample(self.buffer, batch_size)

In [None]:
#DO ALL THE ENV SETUP HERE
#init actor and critic
#create target networks
#

for ep in range(episodes):
  obs, info = env.reset()
  dones = {agent: False for agent in env.agents}

  while not all(dones.values()):
    actions = {}
    for i, agent in enumerate(env.agents):
      state = obs[agent]
      with torch.no_grad():
          probs = actors[i](state) #computes prob distribution over actions
          action = torch.multinomial(probs, num_samples=1).item() #samples one action from the distribution
      actions[agent] = action

    next_obs, rewards, dones_2, _, _ = env.step(actions)

    # Store in replay buffer
    replay_buffer.push(
        [obs[agent] for agent in env.agents],
        [actions[agent] for agent in env.agents],
        [rewards[agent] for agent in env.agents],
        [next_obs[agent] for agent in env.agents],
        [dones[agent] for agent in env.agents]
    )

    obs = next_obs
    dones = dones_2

    for i in range(n_agents):
      obs_s, act_s, rew_s, next_obs_s, done_s = replay_buffer.sample(batch_size)
      #convert to lists of tensors
      obs_cat = torch.cat(obs_s, dim=1).float().to(device)
      act_cat = torch.cat(act_s, dim=1).float().to(device)
      next_obs_cat = torch.cat(next_obs_s, dim=1).float().to(device)
      with torch.no_grad():
        next_actions = [target_actors[j](next_obs_s[j]) for j in range(n_agents)]



# Main Training Loop

In [None]:
def train(maddpg: MADDPG, env, episodes=200):
    for ep in range(episodes):
        obs = env.reset()
        state = env.state()
        for a in maddpg.agents:
            maddpg.noises[a].reset()

        done = {a: False for a in maddpg.agents}
        episode_reward = {a: 0.0 for a in maddpg.agents}

        while not all(done.values()):

            acts = {}
            for a in maddpg.agents:
                obs_tensor = torch.FloatTensor(obs[a]).unsqueeze(0)
                action = maddpg.actors[a](obs_tensor).detach().numpy()[0]
                action += maddpg.noises[a].sample()
                acts[a] = np.clip(action, -1.0, 1.0)


            next_obs, rews, done, infos = env.step(acts)
            next_state = env.state()

            maddpg.buffer.push(state, obs, acts, rews, next_state, next_obs, done)

            obs, state = next_obs, next_state

            for a in maddpg.agents:
                episode_reward[a] += rews[a]

            maddpg.update()

        print(f"Episode {ep:3d} rewards:", episode_reward)

    env.close()

In [None]:
import numpy as np
import random
from collections import deque
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
from torch.distributions import Categorical
from pettingzoo.mpe import simple_tag_v2
import supersuit as ss

# ----------------------------
# 1) Replay Buffer
# ----------------------------
class ReplayBuffer:
    def __init__(self, max_size=int(1e6)):
        self.buffer = deque(maxlen=max_size)

    def push(self, state, obs, actions, rewards, next_state, next_obs, dones):
        """Store a full multi-agent transition."""
        self.buffer.append((state, obs, actions, rewards, next_state, next_obs, dones))

    def sample(self, batch_size):
        batch = random.sample(self.buffer, batch_size)
        return map(list, zip(*batch))

    def __len__(self):
        return len(self.buffer)


# ----------------------------
# 2) Actor & Critic for Discrete Actions
# ----------------------------
class Actor(nn.Module):
    def __init__(self, obs_dim, act_dim, hidden_size=128):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(obs_dim, hidden_size), nn.ReLU(),
            nn.Linear(hidden_size, hidden_size), nn.ReLU(),
            nn.Linear(hidden_size, act_dim)
        )

    def forward(self, obs):
        logits = self.net(obs)  # (batch, act_dim)
        return F.softmax(logits, dim=-1)  # probabilities


class Critic(nn.Module):
    def __init__(self, state_dim, total_act_dim, hidden_size=256):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(state_dim + total_act_dim, hidden_size), nn.ReLU(),
            nn.Linear(hidden_size, hidden_size), nn.ReLU(),
            nn.Linear(hidden_size, 1)
        )

    def forward(self, state, joint_actions_onehot):
        x = torch.cat([state, joint_actions_onehot], dim=-1)
        return self.net(x)  # (batch, 1)


# ----------------------------
# 3) MADDPG for Discrete Actions
# ----------------------------
class MADDPGDiscrete:
    def __init__(self, env, gamma=0.95, tau=0.01,
                 actor_lr=1e-3, critic_lr=1e-3,
                 batch_size=1024, buffer_size=int(1e6)):
        self.env = env
        self.agents = env.possible_agents
        self.n_agents = len(self.agents)
        self.gamma = gamma
        self.tau = tau
        self.batch_size = batch_size

        # Dimensions
        example_obs = env.reset()
        state_dim = env.state().shape[0]
        obs_dims = {a: example_obs[a].shape[0] for a in self.agents}
        act_dims = {a: env.action_spaces[a].n for a in self.agents}
        total_act_dim = sum(act_dims.values())

        # Networks & targets
        self.actors = {}
        self.targ_actors = {}
        self.actor_optim = {}
        for a in self.agents:
            self.actors[a] = Actor(obs_dims[a], act_dims[a])
            self.targ_actors[a] = Actor(obs_dims[a], act_dims[a])
            self.targ_actors[a].load_state_dict(self.actors[a].state_dict())
            self.actor_optim[a] = Adam(self.actors[a].parameters(), lr=actor_lr)

        self.critic = Critic(state_dim, total_act_dim)
        self.targ_critic = Critic(state_dim, total_act_dim)
        self.targ_critic.load_state_dict(self.critic.state_dict())
        self.critic_optim = Adam(self.critic.parameters(), lr=critic_lr)

        # Replay buffer
        self.buffer = ReplayBuffer(buffer_size)

    def _onehot_actions(self, actions):
        """Convert dict of discrete actions to a joint one-hot tensor."""
        onehots = []
        for a in self.agents:
            batch_actions = actions[a]  # torch tensor (batch,)
            onehots.append(F.one_hot(batch_actions, num_classes=self.env.action_spaces[a].n).float())
        return torch.cat(onehots, dim=-1)  # (batch, total_act_dim)

    def select_action(self, obs):
        """Given a single observation dict, select discrete actions."""
        actions = {}
        for a in self.agents:
            o = torch.FloatTensor(obs[a]).unsqueeze(0)  # (1, obs_dim)
            probs = self.actors[a](o).detach().squeeze(0).numpy()  # (act_dim,)
            m = Categorical(probs)
            actions[a] = m.sample().item()
        return actions

    def update(self):
        if len(self.buffer) < self.batch_size:
            return

        # Sample batch
        states, obs_b, acts_b, rews_b, next_states, next_obs_b, dones_b = \
            self.buffer.sample(self.batch_size)
        states = torch.FloatTensor(states)
        next_states = torch.FloatTensor(next_states)

        # Build per-agent tensors
        obs_tensor = {a: torch.FloatTensor(np.vstack([ob[a] for ob in obs_b])) for a in self.agents}
        action_tensor = {a: torch.LongTensor([ac[a] for ac in acts_b]) for a in self.agents}
        rew_tensor = {a: torch.FloatTensor([[r[a]] for r in rews_b]) for a in self.agents}
        done_tensor = {a: torch.FloatTensor([[d[a]] for d in dones_b]) for a in self.agents}
        next_obs_tensor = {a: torch.FloatTensor(np.vstack([no[a] for no in next_obs_b])) for a in self.agents}

        # One-hot joint actions
        all_actions_onehot = self._onehot_actions(action_tensor)

        # Next actions from target actors => one-hot
        next_action_probs = {a: self.targ_actors[a](next_obs_tensor[a]) for a in self.agents}
        next_action_samples = {a: Categorical(next_action_probs[a]).sample() for a in self.agents}
        next_actions_onehot = self._onehot_actions(next_action_samples)

        # Critic update
        with torch.no_grad():
            target_q = {}
            q_targ_vals = self.targ_critic(next_states, next_actions_onehot)
            for a in self.agents:
                target_q[a] = rew_tensor[a] + self.gamma * (1 - done_tensor[a]) * q_targ_vals

        current_q = self.critic(states, all_actions_onehot)
        critic_loss = 0
        for a in self.agents:
            critic_loss += F.mse_loss(current_q, target_q[a])
        critic_loss /= self.n_agents

        self.critic_optim.zero_grad()
        critic_loss.backward()
        self.critic_optim.step()

        # Actor update
        for a in self.agents:
            # re-sample actions for current obs
            probs = self.actors[a](obs_tensor[a])
            dist = Categorical(probs)
            sampled = dist.sample()
            log_prob = dist.log_prob(sampled).unsqueeze(1)  # (B,1)

            # joint actions: replace agent a's actions with sampled
            acts = {b: action_tensor[b] if b != a else sampled for b in self.agents}
            joint_onehot = self._onehot_actions(acts)

            # Actor loss: maximize Q => minimize -Q * log_prob
            q_val = self.critic(states, joint_onehot)
            actor_loss = -(log_prob * q_val).mean()

            self.actor_optim[a].zero_grad()
            actor_loss.backward()
            self.actor_optim[a].step()

        # Soft updates
        for a in self.agents:
            for p, target_p in zip(self.actors[a].parameters(),
                                    self.targ_actors[a].parameters()):
                target_p.data.copy_(self.tau * p.data + (1 - self.tau) * target_p.data)
        for p, target_p in zip(self.critic.parameters(), self.targ_critic.parameters()):
            target_p.data.copy_(self.tau * p.data + (1 - self.tau) * target_p.data)


# ----------------------------
# 4) Training Loop
# ----------------------------
def train(maddpg, env, episodes=500):
    for ep in range(episodes):
        obs = env.reset()
        state = env.state()
        done = {a: False for a in maddpg.agents}
        ep_rewards = {a: 0.0 for a in maddpg.agents}

        while not all(done.values()):
            actions = maddpg.select_action(obs)
            next_obs, rewards, done, _ = env.step(actions)
            next_state = env.state()

            # store transition
            maddpg.buffer.push(state, obs, actions, rewards, next_state, next_obs, done)

            obs, state = next_obs, next_state
            for a in maddpg.agents:
                ep_rewards[a] += rewards[a]

            # update networks
            maddpg.update()

        print(f"Episode {ep} rewards:", ep_rewards)

    env.close()


if __name__ == "__main__":
    base_env = simple_tag_v2.env(max_cycles=100, continuous_actions=False)
    env = ss.pad_observations_v0(base_env)
    maddpg = MADDPGDiscrete(env)
    train(maddpg, env)
