# Experiment 5

Lewis signaling game with simple reinforcement (reward 1 for both agents for correct action, reward 0 for incorrect action).

In [1]:
import random

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Categorical
from torch.optim import Adam
import seaborn as sns

random.seed(44)
torch.manual_seed(1773)
sns.set(rc={'figure.figsize':(15,20)})

In [2]:
class SpeakerNet(nn.Module):
    def __init__(self, input_size, alphabet_size):
        super().__init__()
        self.linear = nn.Embedding(input_size, alphabet_size)

    def forward(self, input):
        return F.softmax(self.linear(input), dim=1)
    
class ListenerNet(nn.Module):
    def __init__(self, alphabet_size, output_size):
        super().__init__()
        self.linear = nn.Embedding(alphabet_size, output_size)

    def forward(self, input):
        return F.softmax(self.linear(input), dim=1)

In [3]:
class Agent:
    def __init__(self, id, world_size, alphabet_size, learning_rate=10):
        self.id = id
        self.speak = SpeakerNet(world_size, alphabet_size)
        self.listen = ListenerNet(alphabet_size, world_size)
        self.speak_optimizer = Adam(params=self.speak.parameters(), lr=learning_rate)
        self.listen_optimizer = Adam(params=self.listen.parameters(), lr=learning_rate)
                
    def observe_and_speak(self, observation):
        probs = self.speak(observation)
        distribution = Categorical(probs)
        message = distribution.sample()
        self.last_sent_message, self.last_sent_message_probs = message, distribution.log_prob(message)
        return message
    
    def get_reward_for_speaking(self, reward):
        self.speak_optimizer.zero_grad()  
        loss = -(self.last_sent_message_probs * reward).sum()
        loss.backward()
        self.speak_optimizer.step()
        return loss
    
    def listen_and_predict(self, message):
        probs = self.listen(message)
        distribution = Categorical(probs)
        action = distribution.sample()
        self.last_action, self.last_action_probs = action, distribution.log_prob(action)
        return action
    
    def get_reward_for_listening(self, reward):
        self.listen_optimizer.zero_grad()  
        loss = -(self.last_action_probs * reward).sum()
        loss.backward()
        self.listen_optimizer.step()
        return loss

In [4]:
class World:
    def __init__(self, batch_size, world_size):
        self.batch_size, self.world_size = batch_size, world_size
        self.reset()
    
    def reset(self):
        self.hidden_state = torch.randint(
            high=self.world_size, 
            size=(self.batch_size,),
            dtype=torch.long
        )
    
    def receive_observation(self):
        return self.hidden_state
    
    def evaluate_prediction(self, pred):
        return (pred == self.hidden_state).float()

In [9]:
def rescale(rewards):
    return (rewards - rewards.mean()) / (rewards.std() + 1e-6)

def train(population_size=1):
    population = [Agent(id=i, world_size=10, alphabet_size=15, learning_rate=10)
                  for i in range(population_size)]
    world = World(batch_size=500, world_size=10)
    last_rewards, last_sender_loss, last_receiver_loss = 0, 0, 0
    for epoch in range(100_000):

        # Choose two agents (with replacement, with order) for each epoch
        sender, receiver = random.choices(population, k=2)
        observation = world.receive_observation()
        message = sender.observe_and_speak(observation)
        prediction = receiver.listen_and_predict(message)
        rewards = world.evaluate_prediction(prediction)
        rewards = rescale(rewards)
        receiver_loss = receiver.get_reward_for_listening(rewards)
        speaker_loss = sender.get_reward_for_speaking(rewards)
        world.reset()
        
        if epoch % 1000 == 0:
            print(f'Epoch {epoch}, agent {sender.id} talking to agent {receiver.id}')
            print(f'Reward {last_rewards}, speaker loss: {last_sender_loss/(1000):4f}, receiver loss: {last_receiver_loss/(1000):.4f}')
            print(observation[:5], message[:5], prediction[:5], rewards[:5])
            last_rewards, last_sender_loss, last_receiver_loss = 0, 0, 0
        

In [10]:
train(5)

Epoch 0, agent 1 talking to agent 2
Reward 0, speaker loss: 0.000000, receiver loss: 0.0000
tensor([0, 5, 2, 6, 5]) tensor([11,  1,  8,  9,  8]) tensor([2, 4, 6, 6, 0]) tensor([-0.3142, -0.3142, -0.3142,  3.1766, -0.3142])
Epoch 1000, agent 1 talking to agent 1
Reward 0, speaker loss: 0.000000, receiver loss: 0.0000
tensor([9, 5, 8, 1, 4]) tensor([ 4, 10,  8,  0,  1]) tensor([1, 8, 8, 9, 6]) tensor([-0.3689, -0.3689,  2.7053, -0.3689, -0.3689])
Epoch 2000, agent 2 talking to agent 2
Reward 0, speaker loss: 0.000000, receiver loss: 0.0000
tensor([1, 7, 4, 7, 0]) tensor([14, 11,  1, 11, 13]) tensor([9, 7, 4, 7, 0]) tensor([-0.8536,  1.1692,  1.1692,  1.1692,  1.1692])
Epoch 3000, agent 2 talking to agent 4
Reward 0, speaker loss: 0.000000, receiver loss: 0.0000
tensor([3, 3, 8, 8, 0]) tensor([ 8,  8,  8,  8, 13]) tensor([9, 9, 9, 9, 5]) tensor([-0.6509, -0.6509, -0.6509, -0.6509, -0.6509])
Epoch 4000, agent 4 talking to agent 4
Reward 0, speaker loss: 0.000000, receiver loss: 0.0000
tens

Epoch 37000, agent 0 talking to agent 3
Reward 0, speaker loss: 0.000000, receiver loss: 0.0000
tensor([2, 2, 0, 6, 5]) tensor([12, 12, 13,  9, 14]) tensor([4, 4, 5, 6, 5]) tensor([-0.8536, -0.8536, -0.8536,  1.1692,  1.1692])
Epoch 38000, agent 1 talking to agent 4
Reward 0, speaker loss: 0.000000, receiver loss: 0.0000
tensor([6, 8, 1, 5, 2]) tensor([ 9,  8,  0, 10,  5]) tensor([6, 9, 3, 8, 2]) tensor([ 1.5045, -0.6634, -0.6634, -0.6634,  1.5045])
Epoch 39000, agent 0 talking to agent 3
Reward 0, speaker loss: 0.000000, receiver loss: 0.0000
tensor([7, 4, 1, 9, 5]) tensor([11,  1,  0,  3, 14]) tensor([6, 4, 3, 1, 5]) tensor([-0.8536,  1.1692, -0.8536, -0.8536,  1.1692])
Epoch 40000, agent 3 talking to agent 1
Reward 0, speaker loss: 0.000000, receiver loss: 0.0000
tensor([3, 8, 2, 0, 4]) tensor([8, 6, 5, 5, 1]) tensor([8, 9, 0, 0, 6]) tensor([-0.6602, -0.6602, -0.6602,  1.5116, -0.6602])
Epoch 41000, agent 4 talking to agent 0
Reward 0, speaker loss: 0.000000, receiver loss: 0.0000
t

Epoch 74000, agent 4 talking to agent 0
Reward 0, speaker loss: 0.000000, receiver loss: 0.0000
tensor([2, 4, 8, 3, 1]) tensor([11, 12, 10,  0,  3]) tensor([7, 2, 5, 1, 9]) tensor([0., 0., 0., 0., 0.])
Epoch 75000, agent 4 talking to agent 2
Reward 0, speaker loss: 0.000000, receiver loss: 0.0000
tensor([6, 7, 2, 4, 8]) tensor([11,  2, 11, 12, 10]) tensor([7, 8, 7, 2, 5]) tensor([0., 0., 0., 0., 0.])
Epoch 76000, agent 1 talking to agent 1
Reward 0, speaker loss: 0.000000, receiver loss: 0.0000
tensor([2, 8, 8, 2, 9]) tensor([5, 8, 8, 5, 4]) tensor([0, 8, 8, 0, 1]) tensor([-0.3476,  2.8710,  2.8710, -0.3476, -0.3476])
Epoch 77000, agent 4 talking to agent 4
Reward 0, speaker loss: 0.000000, receiver loss: 0.0000
tensor([2, 1, 9, 6, 9]) tensor([11,  3,  6, 11,  6]) tensor([2, 1, 9, 2, 9]) tensor([ 0.4585,  0.4585,  0.4585, -2.1766,  0.4585])
Epoch 78000, agent 1 talking to agent 0
Reward 0, speaker loss: 0.000000, receiver loss: 0.0000
tensor([5, 1, 8, 7, 8]) tensor([10,  0,  8, 11,  8]