<a href="https://colab.research.google.com/github/wbeard01/PCGD/blob/main/simple_adversary_pcgd.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Template / graphing from: https://github.com/tsmatz/reinforcement-learning-tutorials/tree/master

In [None]:
!pip install pettingzoo

In [None]:
from pettingzoo.mpe import simple_adversary_v3

In [None]:
import gym
import numpy as np
import torch
import torch.nn as nn
from torch.nn import functional as F
from collections import defaultdict
import matplotlib.pyplot as plt
from scipy.sparse.linalg import gmres
from scipy.sparse.linalg import LinearOperator
import pandas as pd
import time

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Policy network for agents / adversaries
# Fully-connected network with three hidden layers and ReLU activation
class PolicyPi(nn.Module):
    def __init__(self, input_dim, hidden_dim=256):
        super().__init__()

        self.firstHidden = nn.Linear(input_dim, hidden_dim)
        self.secondHidden = nn.Linear(hidden_dim, hidden_dim)
        self.thirdHidden = nn.Linear(hidden_dim, hidden_dim)
        self.classify = nn.Linear(hidden_dim, 5)

    # Forward pass
    def forward(self, s):
        outs = self.firstHidden(s)
        outs = F.relu(outs)
        outs = self.secondHidden(outs)
        outs = F.relu(outs)
        outs = self.thirdHidden(outs)
        outs = F.relu(outs)
        logits = self.classify(outs)
        return logits

In [None]:
# Simultaneous Gradient Descent (SimGD)
class SimGD:
    def __init__(self, policy, lr):
      self.policy = policy
      self.lr = lr

    # Empty gradients
    def zero_grad(self):
        for param in self.policy.parameters():
            if param.grad is not None:
                param.grad.detach()
                param.grad.zero_()

    # Update parameters using loss
    def step(self, loss):
        grads = torch.autograd.grad(loss, self.policy.parameters())
        for param, grad in zip(self.policy.parameters(), grads):
            param.data -= self.lr * grad

In [None]:
# Polymatrix Competitive Gradient Descent (PCGD)
class PCGD:
    def __init__(self, policies, eta):
        self.policies = policies
        self.eta = eta
        self.agents = ["adversary_0", "agent_0", "agent_1"]

    # Empty gradients
    def zero_grad(self):
      for agent in self.agents:
          for param in self.policies[agent].parameters():
              if param.grad is not None:
                  param.grad.detach()
                  param.grad.zero_()

    # Flatten tensors into 1D array
    def custom_flatten(self, gp):
        flattened = []
        for g in gp:
            flattened.append(g.flatten())
        return torch.concat(flattened)

    # Generate matrix of losses for PCGD Hessian
    def loss_matrix(self, log_probs, cum_rewards):
        agents = self.agents
        losses = torch.zeros((len(agents), len(agents)))
        for row in range(len(agents)):
            for col in range(len(agents)):
                losses[row, col] = (cum_rewards[agents[row]] * log_probs[agents[col]] * log_probs[agents[row]]).mean()
        return losses

    # Vector of first derivatives of loss (recouping SimGD)
    def zeta(self, log_probs, cum_rewards):
        self.zero_grad()
        agents = self.agents
        zeta = []
        for row in agents:
            reward = (log_probs[row] * cum_rewards[row]).mean()
            grads = torch.autograd.grad(reward, self.policies[row].parameters(), retain_graph=True, create_graph=True)
            zeta.append(self.custom_flatten(grads))
        return torch.concat(zeta)

    # Matrix-vector product for A = (I + H_o) for solving the linear system in
    # PCGD numerically using Krylov subspace methods
    def mvp(self, loss_mat, vec):
        self.zero_grad()
        vec = vec.reshape(-1, 1)
        agents = self.agents
        split = sum(p.numel() for p in self.policies[agents[0]].parameters())
        split2 = sum(p.numel() for p in self.policies[agents[1]].parameters())
        blocks = [vec[:split], vec[split:split+split2], vec[split+split2:]]
        new_blocks = []
        for row in range(len(agents)):
            acc = blocks[row].clone()
            for col in range(len(agents)):
                if row != col:
                    reward = loss_mat[row, col]
                    grads = self.custom_flatten(torch.autograd.grad(reward, self.policies[agents[col]].parameters(), retain_graph=True, create_graph=True)).reshape(-1, 1)
                    vjp = self.custom_flatten(torch.autograd.grad(grads, self.policies[agents[row]].parameters(), [blocks[col]], retain_graph=True)).reshape(-1, 1)
                    acc += self.eta * vjp
            new_blocks.append(acc)
        return torch.concat(new_blocks)

    # Solve for PCGD parameter update
    def compute_loss_mat_update_iterative(self, loss_mat, zeta):
        mv = lambda v: self.mvp(loss_mat, torch.tensor(v)).detach().numpy()
        A = LinearOperator((zeta.shape[0], zeta.shape[0]), matvec=mv)
        b = zeta.detach().numpy()
        return self.eta * torch.tensor(gmres(A, b)[0])

    # Magic to safely update parameters
    def update_parameters(self, update):
        agents = self.agents
        split = sum(p.numel() for p in self.policies[agents[0]].parameters())
        split2 = sum(p.numel() for p in self.policies[agents[1]].parameters())
        first = update[:split]
        second = update[split:split+split2]
        third = update[split+split2:]
        grad_like_policy = []
        idx = 0
        for param in self.policies[agents[0]].parameters():
            grad_like_policy.append(first[idx : idx + torch.numel(param)].reshape(param.shape))
            idx += torch.numel(param)
        for param, grad in zip(self.policies[agents[0]].parameters(), grad_like_policy):
            param.data += grad

        grad_like_policy = []
        idx = 0
        for param in self.policies[agents[1]].parameters():
            grad_like_policy.append(second[idx : idx + torch.numel(param)].reshape(param.shape))
            idx += torch.numel(param)
        for param, grad in zip(self.policies[agents[1]].parameters(), grad_like_policy):
            param.data += grad

        grad_like_policy = []
        idx = 0
        for param in self.policies[agents[2]].parameters():
            grad_like_policy.append(third[idx : idx + torch.numel(param)].reshape(param.shape))
            idx += torch.numel(param)
        for param, grad in zip(self.policies[agents[2]].parameters(), grad_like_policy):
            param.data += grad

In [None]:
# Create and display learning curve for all agents
def show_pretty_learning_graph():
    fig, axs = plt.subplots(1, 3)
    fig.suptitle('Simple Adversary Reward (PCGD)')
    plt.xlabel("training step")
    plt.ylabel("cumulative reward")
    i = 0
    for agent in env.agents:
        average_reward = []
        std = []
        for idx in range(len(reward_records[agent])):
            avg_list = np.empty(shape=(1,), dtype=int)
            if idx < 5:
                avg_list = reward_records[agent][:idx+1]
            else:
                avg_list = reward_records[agent][idx-4:idx+1]
            average_reward.append(np.average(avg_list))
            std.append(np.std(avg_list))
        axs[i].set_title(agent)
        axs[i].plot(average_reward, label="average reward (last 5 steps)")
        axs[i].fill_between(range(len(reward_records[agent])), np.array(average_reward) - np.array(std),
                        np.array(average_reward) + np.array(std), alpha=0.2)
        i += 1
    plt.legend(loc="lower right")
    fig.set_size_inches(22, 10)
    plt.show()

# Define agents' policies
policy_pi = {"adversary_0": PolicyPi(8).to(device),
            "agent_0": PolicyPi(10).to(device),
             "agent_1": PolicyPi(10).to(device)}

# Save optimizers on policies
opts = {"adversary_0": SimGD(policy_pi["adversary_0"], lr=0.01),
            "agent_0": SimGD(policy_pi["agent_0"], lr=0.01),
            "agent_1": SimGD(policy_pi["agent_1"], lr=0.01)}
pcgd = PCGD(policy_pi, 0.6)

# Training setup
env = simple_adversary_v3.env(render_mode="rgb_array")
reward_records = defaultdict(lambda : [])
batch_size = 2 ** 11
epochs = 1001
gamma = 0.99

# Training metadata
df = pd.DataFrame({"epoch": [],
                  "trajectories": [],
                  "adversary reward": [],
                  "agent reward": []})

# Choose action from policy network
def pick_sample(s, agent):
    with torch.no_grad():
        s_batch = np.expand_dims(s, axis=0)
        s_batch = torch.tensor(s_batch, dtype=torch.float).to(device)
        logits = policy_pi[agent](s_batch)
        logits = logits.squeeze(dim=0)
        probs = F.softmax(logits, dim=-1)
        a = torch.multinomial(probs, num_samples=1)
        return a.tolist()[0]

# Run epochs
for i in range(epochs):

    # Gradient information for each sample in the epoch
    zetas = []
    loss_mats = []

    # Run batch
    for j in range(batch_size):

        # Sample from environment
        done = False
        states = defaultdict(lambda : [])
        actions = defaultdict(lambda : [])
        rewards = defaultdict(lambda : [])
        env.reset(seed=(j * epochs + i))
        ss = {}
        for agent in env.agents:
            env.agent_selection = agent
            ss[agent] = env.last()[0]
        while not done:
            t_actions = {}
            for agent in env.agents:
                states[agent].append(ss[agent].tolist())
                t_actions[agent] = pick_sample(ss[agent], agent)
            for agent in env.agents:
                env.agent_selection = agent
                env.step(t_actions[agent])
            for agent in env.agents:
                env.agent_selection = agent
                s, r, term, trunc, _ = env.last()
                ss[agent] = s
                done = term or trunc
                actions[agent].append(t_actions[agent])
                rewards[agent].append(r)

        # Save optimization information
        pcgd_log_probs = {}
        pcgd_cum_rewards = {}
        for agent in env.agents:
            cum_rewards = np.zeros_like(rewards[agent])
            reward_len = len(rewards[agent])
            for j in reversed(range(reward_len)):
                cum_rewards[j] = rewards[agent][j] + (cum_rewards[j+1]*gamma if j+1 < reward_len else 0)

            t_states = torch.tensor(states[agent], dtype=torch.float).to(device)
            t_actions = torch.tensor(actions[agent], dtype=torch.int64).to(device)
            cum_rewards = torch.tensor(cum_rewards, dtype=torch.float).to(device)
            logits = policy_pi[agent](t_states).to(device)
            log_probs = -F.cross_entropy(logits, t_actions, reduction="none")
            loss = -log_probs * cum_rewards

            pcgd_log_probs[agent] = log_probs
            pcgd_cum_rewards[agent] = cum_rewards

        # Record losses for PCGD
        loss_mat = pcgd.loss_matrix(pcgd_log_probs, pcgd_cum_rewards)
        zeta = pcgd.zeta(pcgd_log_probs, pcgd_cum_rewards)
        loss_mats.append(loss_mat)
        zetas.append(zeta)

    # Perform PCGD update
    batch_zeta = torch.stack(zetas, dim=0).mean(dim=0)
    batch_loss_mat = torch.stack(loss_mats, dim=0).mean(dim=0)
    update = pcgd.compute_loss_mat_update_iterative(batch_loss_mat, batch_zeta)
    pcgd.update_parameters(update)

    # Save details / models as appropriate
    for agent in env.agents:
        print("Run epoch{} with rewards {}".format(i, sum(rewards[agent])))
        if agent == "agent_0":
            ad = pcgd_cum_rewards["adversary_0"][0].detach().numpy()
            ag = pcgd_cum_rewards["agent_0"][0].detach().numpy()
            df.loc[len(df.index)] = [i, i * batch_size, ad, ag]
            df.to_csv(f"pcgd_training_metadata_{i}.csv", index=False)
        torch.save(policy_pi[agent], f"pcgd_{agent}_{i}.model")
        print("MODEL SAVED")
        reward_records[agent].append(sum(rewards[agent]))
        show_pretty_learning_graph()

print("\nDone")
env.close()