# Experiment 9



In [2]:
import random

import numpy as np
import torch
import torch.nn as nn
import torch.utils.data as data
import torch.nn.functional as F
from egg import core
from egg.zoo.simple_autoenc.features import OneHotLoader
from egg.zoo.simple_autoenc.archs import Sender
from egg.zoo.simple_autoenc.train import get_params

In [3]:
opts = core.init(params=['--random_seed=13', 
                         '--n_epochs=50',
                         '--batch_size=1'])
opts.n_features = 10
opts.batches_per_epoch = 1000
opts.sender_entropy_coeff = 0.01
opts.receiver_entropy_coeff = 0.01
opts.executive_sender_entropy_coeff = 0.01
opts.alphabet_size = 8
opts.sender_population_size = 1
opts.receiver_population_size = 3
opts.lr = 1e-1

In [4]:
def loss(sender_input, _message, _receiver_input, receiver_output, _labels):
    acc = (receiver_output == sender_input.argmax(dim=1)).detach().float().mean(dim=0)
    return -acc, {'acc': acc.item()}

train_loader = OneHotLoader(n_features=opts.n_features, batch_size=opts.batch_size,
                            batches_per_epoch=opts.batches_per_epoch)
test_loader = OneHotLoader(n_features=opts.n_features, batch_size=opts.batch_size,
                            batches_per_epoch=opts.batches_per_epoch, seed=7)

class Receiver(nn.Module):
    def __init__(self, n_hidden, n_features):
        super(Receiver, self).__init__()
        self.output = core.RelaxedEmbedding(n_hidden, n_features)

    def forward(self, x, _input):
        return self.output(x)

senders = [core.ReinforceWrapper(Sender(opts.alphabet_size, opts.n_features)) 
           for _ in range(opts.sender_population_size)]
executive_senders = [core.ReinforceWrapper(Sender(opts.receiver_population_size, 1)) 
                     for _ in range(opts.sender_population_size)]
receivers = [core.ReinforceWrapper(Receiver(opts.n_features, opts.alphabet_size)) 
             for _ in range(opts.receiver_population_size)]

In [5]:
class MultiAgentGame(nn.Module):

    def __init__(self, senders, receivers, executive_senders, loss, 
                 sender_entropy_coeff=opts.sender_entropy_coeff, 
                 receiver_entropy_coeff=opts.sender_entropy_coeff,
                 executive_sender_entropy_coeff=opts.executive_sender_entropy_coeff):
        super(MultiAgentGame, self).__init__()
        self.senders = senders
        self.receivers = receivers
        self.executive_senders = executive_senders
        self.loss = loss

        self.receiver_entropy_coeff = receiver_entropy_coeff
        self.sender_entropy_coeff = sender_entropy_coeff
        self.executive_sender_entropy_coeff = executive_sender_entropy_coeff

        self.mean_baseline = 0.0
        self.n_points = 0.0

    def forward(self, sender_input, labels, receiver_input=None):
        idx = np.random.choice(len(self.senders))       
        executive_sender = self.executive_senders[idx]
        sender = self.senders[idx]
        receiver_id, executive_sender_log_prob, executive_sender_entropy = executive_sender(torch.ones(1, 1))
        receiver = self.receivers[receiver_id.item()]
        message, sender_log_prob, sender_entropy = sender(sender_input)
        receiver_output, receiver_log_prob, receiver_entropy = receiver(message, receiver_input)

        loss, rest_info = self.loss(sender_input, message, receiver_input, receiver_output, labels)
        advantage = (loss.detach() - self.mean_baseline) 
        sender_loss = advantage * (sender_log_prob + receiver_log_prob)
        exec_sender_loss = advantage * (executive_sender_log_prob + receiver_log_prob)
        receiver_loss = advantage * receiver_log_prob
        policy_loss = (sender_loss + receiver_loss + exec_sender_loss).mean()
        
        entropy_loss = -(sender_entropy.mean() * self.sender_entropy_coeff + 
                         receiver_entropy.mean() * self.receiver_entropy_coeff +
                         executive_sender_entropy.mean() * self.executive_sender_entropy_coeff)

        if self.training:
            self.n_points += 1.0
            self.mean_baseline += (loss.detach().mean().item() -
                                   self.mean_baseline) / self.n_points

        full_loss = policy_loss + entropy_loss

        rest_info['baseline'] = self.mean_baseline
        rest_info['loss'] = loss.mean().item()
        rest_info['sender_entropy'] = sender_entropy.mean()
        rest_info['receiver_entropy'] = receiver_entropy.mean()
        rest_info['executive_sender_entropy'] = executive_sender_entropy.mean()
        return full_loss, rest_info

In [6]:
import seaborn as sns
import matplotlib.pyplot as plt
import neptune
from neptune.experiments import Experiment


class NeptuneMonitor:
        
    def __init__(self, experiment: Experiment = None):
        self.experiment = experiment if experiment else neptune
    
    def log(self, mode, epoch, loss, rest):
        self.experiment.send_metric(f'{mode}_loss', loss)
        for metric, value in rest.items():
            self.experiment.send_metric(f'{mode}_{metric}', value)

            
def save_sender_codebook(experiment, senders, epoch, label):
    figure, axes = plt.subplots(1, len(receivers),sharey=True, figsize=(20,5))
    figure.suptitle(f'Epoch {epoch}')
    for i, (sender, ax) in enumerate(zip(senders, axes)):
        g = sns.heatmap(F.softmax(sender.agent.fc1.weight.detach(), dim=1).numpy(), annot=True, fmt='.2f', ax=ax)
        g.set_title(f'{label} {i}')
    figure.savefig('fig.jpg')
    experiment.log_image(f'{label}s', 'fig.jpg')
    plt.close()
    
def save_exec_sender_codebook(experiment, senders, epoch, label):
    figure, axes = plt.subplots(1, len(receivers),sharey=True, figsize=(20,5))
    figure.suptitle(f'Epoch {epoch}')
    for i, (sender, ax) in enumerate(zip(senders, axes)):
        g = sns.heatmap(F.softmax(sender.agent.fc1.weight.detach(), dim=1).numpy(), annot=True, fmt='.2f', ax=ax)
        g.set_title(f'{label} {i}')
    figure.savefig('fig.jpg')
    experiment.log_image(f'{label}s', 'fig.jpg')
    plt.close()

    
def is_codebook_shared(senders):
    from itertools import product
    return all(shared_codebooks(sender1, sender2) 
               for sender1, sender2 in product(senders, senders))


def shared_codebooks(sender1, sender2):
    return bool(((sender1.agent.fc1.weight.detach()).argmax(dim=1) == \
               (sender2.agent.fc1.weight.detach()).argmax(dim=1)).all())
        

In [7]:
class CustomTrainer(core.Trainer):

    def train(self, n_epochs):
        while self.epoch < n_epochs:
            train_loss, train_rest = self.train_epoch()
            for i, ex_s in enumerate(self.game.executive_senders):
                self.monitor.experiment.send_metric(f'ex sender {i} grad', ex_s.agent.fc1.weight.grad.sum())
            for i, s in enumerate(self.game.senders):
                self.monitor.experiment.send_metric(f'sender {i} grad', s.agent.fc1.weight.grad.sum())

            self.epoch += 1

            self.monitor.log('train', self.epoch, train_loss, train_rest)
            self.monitor.experiment.send_metric('shared_codebook', is_codebook_shared(self.game.senders))


            if self.validation_data is not None and self.validation_freq > 0 and self.epoch % self.validation_freq == 0:
                validation_loss, rest = self.eval()
                self.monitor.log('validation', self.epoch, validation_loss, rest)
                print(f'validation: epoch {self.epoch}, loss {validation_loss},  {rest}', flush=True)
                save_sender_codebook(self.monitor.experiment, self.game.senders, self.epoch, 'Sender')
                save_exec_sender_codebook(self.monitor.experiment, self.game.executive_senders, self.epoch, 'Executive sender')


                if self.early_stopping:
                    self.early_stopping.update_values(validation_loss, rest, train_loss, rest, self.epoch)
                    if self.early_stopping.should_stop(): break
        
        
             

In [9]:
game = MultiAgentGame(senders, receivers, executive_senders, loss)
sender_params = [{'params': sender.parameters(), 'lr': opts.lr} for sender in senders]
executive_senders_params = [{'params': ex_sender.parameters(), 'lr': opts.lr} for ex_sender in executive_senders]
receiver_params = [{'params': receiver.parameters(), 'lr': opts.lr} for receiver in receivers]
optimizer = torch.optim.Adam(sender_params+receiver_params+executive_senders_params)

neptune.init('tomekkorbak/EGG')
trainer = CustomTrainer(game=game, optimizer=optimizer, train_data=train_loader, validation_data=test_loader)
experiment = neptune.create_experiment(name='first-egg-experiment', tags=['egg', 'executive-senders'], params=vars(opts))
trainer.monitor = NeptuneMonitor(experiment=experiment)
trainer.train(n_epochs=opts.n_epochs)
experiment.stop()

EGG-84
https://ui.neptune.ml/tomekkorbak/egg/e/EGG-84
validation: epoch 1, loss -0.025697851553559303,  {'acc': 0.39, 'baseline': -0.33400000000000285, 'loss': -0.39, 'sender_entropy': tensor(0.0437), 'receiver_entropy': tensor(0.0453), 'executive_sender_entropy': tensor(0.0142)}
validation: epoch 2, loss -0.03087226301431656,  {'acc': 0.496, 'baseline': -0.4180000000000053, 'loss': -0.496, 'sender_entropy': tensor(0.0500), 'receiver_entropy': tensor(0.0082), 'executive_sender_entropy': tensor(0.0012)}
validation: epoch 3, loss -0.013852248899638653,  {'acc': 0.497, 'baseline': -0.43766666666665943, 'loss': -0.497, 'sender_entropy': tensor(0.0193), 'receiver_entropy': tensor(0.0082), 'executive_sender_entropy': tensor(0.0012)}
validation: epoch 4, loss -0.014215932227671146,  {'acc': 0.497, 'baseline': -0.45125000000000365, 'loss': -0.497, 'sender_entropy': tensor(0.0275), 'receiver_entropy': tensor(0.0037), 'executive_sender_entropy': tensor(0.0014)}
validation: epoch 5, loss -0.15761

validation: epoch 36, loss -0.012845792807638645,  {'acc': 0.5, 'baseline': -0.49708333333332216, 'loss': -0.5, 'sender_entropy': tensor(2.2438e-07), 'receiver_entropy': tensor(0.0017), 'executive_sender_entropy': tensor(0.0001)}
validation: epoch 37, loss -0.0412503182888031,  {'acc': 0.498, 'baseline': -0.4974594594594496, 'loss': -0.498, 'sender_entropy': tensor(2.2372e-07), 'receiver_entropy': tensor(0.0287), 'executive_sender_entropy': tensor(1.3268e-08)}
validation: epoch 38, loss -0.015549964271485806,  {'acc': 0.5, 'baseline': -0.49697368421051674, 'loss': -0.5, 'sender_entropy': tensor(2.2535e-07), 'receiver_entropy': tensor(0.0004), 'executive_sender_entropy': tensor(1.3267e-08)}
validation: epoch 39, loss -4.0327948227059096e-05,  {'acc': 0.5, 'baseline': -0.4960256410256504, 'loss': -0.5, 'sender_entropy': tensor(2.2938e-07), 'receiver_entropy': tensor(0.0004), 'executive_sender_entropy': tensor(1.3267e-08)}
validation: epoch 40, loss -6.206531543284655e-05,  {'acc': 0.5, '

In [None]:
bool((torch.Tensor([[1, 1,], [2, 2]]) == torch.Tensor([[1, 1,], [2, 2]])).all())