# Experiment 4

In [1]:
import random

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Categorical
from torch.optim import Adam
import seaborn as sns

random.seed(44)
torch.manual_seed(1773)
sns.set(rc={'figure.figsize':(15,20)})

In [3]:
class SpeakerNet(nn.Module):
    def __init__(self, input_size, alphabet_size):
        super().__init__()
        self.linear = nn.Embedding(input_size, alphabet_size)

    def forward(self, input):
        return F.softmax(self.linear(input), dim=1)
    
class ListenerNet(nn.Module):
    def __init__(self, alphabet_size, output_size):
        super().__init__()
        self.linear = nn.Embedding(alphabet_size, output_size)

    def forward(self, input):
        return F.softmax(self.linear(input), dim=1)

In [4]:
class Agent:
    def __init__(self, id, world_size, alphabet_size, learning_rate=0.1):
        self.id = id
        self.speak = SpeakerNet(world_size, alphabet_size)
        self.listen = ListenerNet(alphabet_size, world_size)
        self.speak_optimizer = Adam(params=self.speak.parameters(), lr=learning_rate)
        self.listen_optimizer = Adam(params=self.listen.parameters(), lr=learning_rate)
        
    def observe_and_speak(self, observation):
        probs = self.speak(observation)
        distribution = Categorical(probs)
        message = distribution.sample()
        self.last_message_probs, self.last_message = distribution.log_prob(message), message
        return message
    
    def get_reward_for_speaking(self, rewards):
        self.speak_optimizer.zero_grad()  
        loss = -(self.last_message_probs * rewards).sum()
        loss.backward()
        self.speak_optimizer.step()
        return loss
    
    def listen_and_predict(self, message):
        self.received_message = message
        return self.listen(message)
    
    def get_reward_for_listening(self, loss):
        self.listen_optimizer.zero_grad()  
        loss.backward()
        self.listen_optimizer.step()
        return self.listen.linear.weight.grad

In [5]:
class DifferentiableWorld:
    def __init__(self, batch_size, world_size):
        self.batch_size, self.world_size = batch_size, world_size
        self.reset()
    
    def reset(self):
        self.hidden_state = torch.randint(
            high=self.world_size, 
            size=(self.batch_size,),
            dtype=torch.long
        )
    
    def receive_observation(self):
        return self.hidden_state
    
    def evaluate_prediction(self, pred):
        return F.nll_loss(pred, self.hidden_state)

In [7]:
def train(population_size=1):
    population = [Agent(id=i, world_size=10, alphabet_size=15, learning_rate=1) for i in range(population_size)]
    world = DifferentiableWorld(batch_size=16, world_size=10)
    for epoch in range(1_000_000):
        # Choose two agents (with replacement, with order) for each epoch
        sender, receiver = random.choices(population, k=2)
        observation = world.receive_observation()
        message = sender.observe_and_speak(observation)
        prediction = receiver.listen_and_predict(message)
        receiver_loss = world.evaluate_prediction(prediction)
        rewards_for_sender = receiver.get_reward_for_listening(receiver_loss)
        speaker_loss = sender.get_reward_for_speaking(rewards_for_sender[message].sum(dim=1))
        world.reset()
        if epoch % 10000 == 0:
            print(f'Epoch {epoch}, agent {sender.id} talking to agent {receiver.id}')
            print(f'Speaker loss: {speaker_loss:.2f}, receiver loss: {receiver_loss:.2f}')
            print(observation[:5], message[:5], prediction.argmax(dim=1)[:5])


In [8]:
train(3)

Epoch 0, agent 2 talking to agent 0
Speaker loss: -0.00, receiver loss: -0.09
tensor([1, 1, 4, 4, 4]) tensor([10,  8,  4, 10,  5]) tensor([0, 3, 9, 0, 5])
Epoch 10000, agent 2 talking to agent 1
Speaker loss: 0.00, receiver loss: -0.56
tensor([5, 0, 5, 9, 3]) tensor([8, 2, 8, 2, 4]) tensor([5, 2, 5, 2, 3])
Epoch 20000, agent 0 talking to agent 2
Speaker loss: -0.00, receiver loss: -0.25
tensor([0, 1, 7, 3, 4]) tensor([9, 9, 9, 4, 9]) tensor([1, 1, 1, 3, 1])
Epoch 30000, agent 0 talking to agent 2
Speaker loss: 0.00, receiver loss: -0.12
tensor([6, 9, 5, 8, 9]) tensor([9, 9, 9, 9, 9]) tensor([1, 1, 1, 1, 1])
Epoch 40000, agent 1 talking to agent 0
Speaker loss: 0.00, receiver loss: -0.19
tensor([2, 2, 9, 8, 1]) tensor([2, 2, 2, 0, 0]) tensor([2, 2, 2, 1, 1])
Epoch 50000, agent 2 talking to agent 0
Speaker loss: 0.00, receiver loss: -0.50
tensor([7, 9, 8, 0, 1]) tensor([2, 2, 2, 2, 0]) tensor([2, 2, 2, 2, 1])
Epoch 60000, agent 1 talking to agent 0
Speaker loss: 0.00, receiver loss: -0.1

In [9]:
train(1)

Epoch 0, agent 0 talking to agent 0
Speaker loss: -0.00, receiver loss: -0.11
tensor([7, 4, 2, 6, 0]) tensor([ 8, 14,  8, 12, 14]) tensor([2, 1, 2, 5, 1])
Epoch 10000, agent 0 talking to agent 0
Speaker loss: 0.00, receiver loss: -0.56
tensor([5, 9, 7, 1, 8]) tensor([7, 1, 8, 7, 7]) tensor([1, 9, 7, 1, 1])
Epoch 20000, agent 0 talking to agent 0
Speaker loss: 0.00, receiver loss: -0.62
tensor([8, 9, 4, 0, 7]) tensor([ 7,  1,  1, 10,  8]) tensor([8, 9, 9, 0, 7])
Epoch 30000, agent 0 talking to agent 0
Speaker loss: 0.00, receiver loss: -0.38
tensor([7, 6, 4, 3, 7]) tensor([8, 7, 1, 7, 8]) tensor([7, 8, 9, 8, 7])
Epoch 40000, agent 0 talking to agent 0
Speaker loss: 0.00, receiver loss: -0.44
tensor([6, 0, 8, 9, 2]) tensor([ 7, 10,  7,  1,  7]) tensor([8, 0, 8, 9, 8])
Epoch 50000, agent 0 talking to agent 0
Speaker loss: 0.00, receiver loss: -0.31
tensor([9, 6, 2, 9, 7]) tensor([1, 7, 7, 1, 8]) tensor([9, 8, 8, 9, 7])
Epoch 60000, agent 0 talking to agent 0
Speaker loss: 0.00, receiver l