In [1]:
import torch
from torch import nn
import torch.nn.functional as F
from torch.optim import Adam
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

CommNet

- Learning communication **not** with manually-specified protocols but instead using backpropagation from the RL signal
- Considers cooperative task between $J$ agents.
- Can be viewed as a single large model

Given a view of the state for all agents $s = \{s_1,...,s_J\}$, a controller maps states to actions $a = \Phi(s)$ where $a = \{a_1,...,a_J\}$ is a concatenation of discrete actions.

$\Phi$ is made of modules $f^i$, whic are multi-layer neural networks. $i \in \{0,..,K\}$ is the number of communication steps in the network.

Each $f^i$ takes two input vectors for each agent $j$: hidden state $h^i_j$ and communication $c^i_j$ and outputs a vector $h_j^{i+1}$. In the case that $f$ is a single-layer nn, $h^{i+1}_j = \sigma(C^i_j c^i_j + H^i_j h^i_j)$.

First layer of NN is an encoder function $h^0_j = r(s_j)$. For most tasks, they say it is a single-layer NN. 

$c^0_j$ = 0 ∀ j$.

Output of NN is a decoder $a_j \sim q(h^K_j)$ that outputs a distribution oer the action space. Single layer NN with softmax output. 

## Variations
- Only allow communication between agents within a certain range.
- Create skip connection between input encoding $h^0_j$ to various comm layers
- Temporal recurrence: 

## Componenets

In [2]:
class ReplayMemory:
    def __init__(self, capacity, obs_shape, action_shape, device):
        self.capacity = capacity
        self.obs = np.empty((capacity, *obs_shape), dtype=np.float32)
        self.obs_next = np.empty((capacity, *obs_shape), dtype=np.float32)
        self.actions = np.empty((capacity, *action_shape), dtype=np.float32)
        self.rewards = np.empty((capacity, 1), dtype=np.float32)
        self.not_dones = np.empty((capacity, 1), dtype=np.float32)
        
        self.device = device
        
        self.idx = 0
        self.full = False
    
    def __len__(self):
        return self.idx if not self.full else self.capacity
    
    def add(self, obs, actions, rewards, obs_next, not_done):
        np.copyto(self.obs[self.idx], obs)
        np.copyto(self.actions[self.idx], actions)
        np.copyto(self.rewards[self.idx], rewards)
        np.copyto(self.obs_next[self.idx], obs_next)
        np.copyto(self.not_dones[self.idx], not_done)
        
        self.idx = (self.idx + 1) % self.capacity
        self.full = self.full or self.idx == 0
        
    def sample(self, batch_size):
        idxs = np.random.randint(0,
                                 self.capacity if self.full else self.idx,
                                 size=batch_size)

        obs = self.obs[idxs]
        obs_next = self.obs_next[idxs]
        
        obs = torch.as_tensor(obs, device=self.device).float()
        obs_next = torch.as_tensor(obs_next, device=self.device).float()
        
        actions = torch.as_tensor(self.actions[idxs], device=self.device)
        rewards = torch.as_tensor(self.rewards[idxs], device=self.device)
        
        not_dones = torch.as_tensor(self.not_dones[idxs], device=self.device)
        
        return obs, actions, rewards, obs_next, not_dones

In [3]:
class Encoder(nn.Module):
    def __init__(self, state_dim, h_dim):
        super().__init__()
        self.net = nn.Linear(state_dim, h_dim)
        
    def forward(self, s):
        return F.relu(self.net(s))

In [190]:
class Decoder(nn.Module):
    def __init__(self, n_agents, h_dim, action_dim):
        super().__init__()
        
        self.h_dim = h_dim
        self.action_dim = action_dim
        
        self.q = nn.Embedding(n_agents, h_dim * action_dim)
    
    def forward(self, h, agent_idx):
        return torch.matmul(h, self.q(torch.tensor(agent_idx)).view(self.h_dim, self.action_dim))

In [191]:
# class Module(nn.Module):
#     """Each agent has its own separate module. Paper specifies tanh() nonlinearities on output"""
#     def __init__(self, h_dim, c_dim, n_layers, hidden_dim):
#         super().__init__()
        
#         self.H = nn.ModuleList([nn.Linear(h_dim, hidden_dim)])
#         self.C = nn.ModuleList([nn.Linear(c_dim, hidden_dim)])
        
#         for _ in range(n_layers - 1):
#             self.H.append(nn.ReLU())
#             self.C.append(nn.ReLU())
            
#             self.H.append(nn.Linear(hidden_dim, hidden_dim))
#             self.C.append(nn.Linear(hidden_dim, hidden_dim))
            
#         self.H.append(nn.Linear(hidden_dim, h_dim))
#         self.C.append(nn.Linear(hidden_dim, c_dim))
        
#     def forward(self, h, c):
#         for layer in self.H:
#             h = layer(h)
            
#         for layer in self.C:
#             c = layer(c)
            
#         return torch.tanh(c + h)

In [192]:
class Module(nn.Module):
    """h_j^{i+1} = sigma(H^i h_j^i + C^i c_j^i)
    
        Each module contains all f_i for each agent
    """
    def __init__(self, n_agents, h_dim, n_layers, hidden_dim):
        super().__init__()
        
        self.h_dim = h_dim
        self.hidden_dim = hidden_dim
        
        self.H = nn.Embedding(num_embeddings=n_agents, embedding_dim=h_dim * hidden_dim)
        self.C = nn.Embedding(num_embeddings=n_agents, embedding_dim=h_dim * hidden_dim)
        
        self.net = nn.ModuleList([nn.ReLU()])
        
        for _ in range(n_layers - 2):
            self.net.append(nn.Linear(hidden_dim, hidden_dim))
            self.net.append(nn.ReLU())
            
        self.net.append(nn.Linear(hidden_dim, h_dim))
    
    def grab_emb(self, agent_idx):
        H = self.H(agent_idx)
        C = self.C(agent_idx)
        return H, C
    
    def forward(self, h, c, agent_idx):
        H, C = self.grab_emb(torch.tensor(agent_idx))
        
        h = torch.matmul(h, H.view(self.h_dim, self.hidden_dim))
        c = torch.matmul(c, C.view(self.h_dim, self.hidden_dim))
        
        hidden = h + c
        
        for layer in self.net:
            hidden = layer(hidden)
            
        return hidden

In [204]:
class CommNetBase:
    def __init__(self, state_dim, h_dim, c_dim, action_dim, hidden_dim, n_layers, n_agents,
                 n_comm_steps, lr, batch_size, device):
        
        self.encoder = Encoder(state_dim, h_dim * n_agents).to(device)
#         self.decoder = Decoder(h_dim * n_agents, action_dim * n_agents).to(device)
        self.decoder = Decoder(n_agents, h_dim, action_dim).to(device)
        
        self.modules = []
        
        #f^i is shared across all agents
        for k in range(n_comm_steps):
            self.modules.append(
                Module(n_agents, h_dim, n_layers, hidden_dim).to(device)
            )

        self.optim = Adam(
            list(self.encoder.parameters()) + list(self.decoder.parameters()),
            lr=lr
        )
        
        for m in self.modules:
            self.optim.add_param_group({'params': list(m.parameters())})
            
        self.n_comm_steps = n_comm_steps
        self.n_agents = n_agents
        self.c_dim = c_dim
        self.batch_size = batch_size
        
    def act(self, s, agent_idxs):
        c_0 = torch.zeros(1, self.c_dim * self.n_agents).chunk(self.n_agents, dim=1)
        h_0 = self.encoder(s).chunk(self.n_agents, dim=1)

        all_k_hs = []
        
        for k in range(self.n_comm_steps):
            step_k_hs = []
            
            for j in range(len(agent_idxs)):
                if k == 0:
                    h_j = self.modules[k](
                        h_0[j], c_0[j], agent_idxs[j]
                    )
                    
                    step_k_hs.append(h_j)
                    
                else:
                    h_j = self.modules[k](
                        all_k_hs[k - 1][j], comm_vectors[j], agent_idxs[j]
                    )
                    
                    step_k_hs.append(h_j)
            
            all_k_hs.append(step_k_hs)
            
            comm_vectors = []
            for j in range(len(agent_idxs)):
                comm_vectors.append(self.h_to_c(step_k_hs, j, 1, len(agent_idxs)))

#             for j in range(self.n_agents):
#                 if k == 0:
#                     h_j = self.modules[k](h_0[j], c_0[j])
#                     step_k_hs.append(h_j)
                    
#                 else:
#                     h_j = self.modules[k](all_k_hs[k - 1][j], comm_vectors[j])
#                     step_k_hs.append(h_j)
            
#             all_k_hs.append(step_k_hs)
            
#             comm_vectors = []
#             for j in range(self.n_agents):
#                 comm_vectors.append(self.h_to_c(step_k_hs, j, 1))
        
        action_logits = []
        for j in range(len(agent_idxs)):
            action_logits.append(
                self.decoder(all_k_hs[-1][j], agent_idxs[j])
            )
#         action_logits, q = self.decoder(torch.cat(all_k_hs[self.n_comm_steps - 1], dim=1))
#         action_logits = action_logits.chunk(self.n_agents, dim=1)
        actions = []
        
        for j in range(len(agent_idxs)):
            actions.append(F.gumbel_softmax(action_logits[j], hard=True))
        
        return action_logits, actions
            
    def h_to_c(self, h, j, batch_size, n_agents):
        """h = [[batch_size, hidden_dim * 2], ... J]
            j = what agent we care about
        """
        sum_cntr = torch.zeros(batch_size, self.c_dim)

        for agent in range(n_agents):
            if agent == j:
                continue
            else:
                sum_cntr += h[agent]
                
        sum_cntr /= (self.n_agents - 1)
        
        return sum_cntr
        

In [205]:
learner = CommNetBase(state_dim=64, h_dim=16, c_dim=16, action_dim=3, hidden_dim=64,
            n_layers=3, n_agents=500, n_comm_steps=4, lr=0.01, batch_size=1, device='cpu')

In [206]:
learner.act(torch.rand(1, 64), [1, 2, 22])

([tensor([[0.0091, 0.3712, 0.2957]], grad_fn=<MmBackward>),
  tensor([[ 1.2547, -1.2044, -0.5666]], grad_fn=<MmBackward>),
  tensor([[ 0.8637,  0.2336, -0.0886]], grad_fn=<MmBackward>)],
 [tensor([[0., 1., 0.]], grad_fn=<AddBackward0>),
  tensor([[1., 0., 0.]], grad_fn=<AddBackward0>),
  tensor([[1., 0., 0.]], grad_fn=<AddBackward0>)])