# Experiment 8



In [166]:
import random

import numpy as np
import torch
import torch.nn as nn
import torch.utils.data as data
import torch.nn.functional as F
from egg import core
from egg.zoo.simple_autoenc.features import OneHotLoader
from egg.zoo.simple_autoenc.archs import Sender
from egg.zoo.simple_autoenc.train import get_params

In [185]:
opts = core.init(params=['--random_seed=13', 
                         '--n_epochs=50',
                         '--batch_size=1'])
opts.n_features = 10
opts.batches_per_epoch = 1000
opts.sender_entropy_coeff = 0.01
opts.receiver_entropy_coeff = 0.01
opts.executive_sender_entropy_coeff = 0.01
opts.alphabet_size = 8
opts.sender_population_size = 3
opts.receiver_population_size = 3
opts.lr = 1e-2

In [186]:
def loss(sender_input, _message, _receiver_input, receiver_output, _labels):
    acc = (receiver_output == sender_input.argmax(dim=1)).detach().float().mean(dim=0)
    return -acc, {'acc': acc.item()}

train_loader = OneHotLoader(n_features=opts.n_features, batch_size=opts.batch_size,
                            batches_per_epoch=opts.batches_per_epoch)
test_loader = OneHotLoader(n_features=opts.n_features, batch_size=opts.batch_size,
                            batches_per_epoch=opts.batches_per_epoch, seed=7)

class Receiver(nn.Module):
    def __init__(self, n_hidden, n_features):
        super(Receiver, self).__init__()
        self.output = core.RelaxedEmbedding(n_hidden, n_features)

    def forward(self, x, _input):
        return self.output(x)

senders = [core.ReinforceWrapper(Sender(opts.alphabet_size, opts.n_features)) 
           for _ in range(opts.sender_population_size)]
executive_senders = [core.ReinforceWrapper(Sender(opts.receiver_population_size, 1)) 
                     for _ in range(opts.sender_population_size)]
receivers = [core.ReinforceWrapper(Receiver(opts.n_features, opts.alphabet_size)) 
             for _ in range(opts.receiver_population_size)]

In [187]:
class MultiAgentGame(nn.Module):

    def __init__(self, senders, receivers, executive_senders, loss, 
                 sender_entropy_coeff=opts.sender_entropy_coeff, 
                 receiver_entropy_coeff=opts.sender_entropy_coeff,
                 executive_sender_entropy_coeff=opts.executive_sender_entropy_coeff):
        super(MultiAgentGame, self).__init__()
        self.senders = senders
        self.receivers = receivers
        self.executive_senders = executive_senders
        self.loss = loss

        self.receiver_entropy_coeff = receiver_entropy_coeff
        self.sender_entropy_coeff = sender_entropy_coeff
        self.executive_sender_entropy_coeff = executive_sender_entropy_coeff

        self.mean_baseline = 0.0
        self.n_points = 0.0

    def forward(self, sender_input, labels, receiver_input=None):
        idx = np.random.choice(len(self.senders))       
        executive_sender = self.executive_senders[idx]
        sender = self.senders[idx]
        receiver_id, executive_sender_log_prob, executive_sender_entropy = executive_sender(torch.ones(1, 1))
        receiver = self.receivers[receiver_id.item()]
        message, sender_log_prob, sender_entropy = sender(sender_input)
        receiver_output, receiver_log_prob, receiver_entropy = receiver(message, receiver_input)

        loss, rest_info = self.loss(sender_input, message, receiver_input, receiver_output, labels)
        advantage = (loss.detach() - self.mean_baseline) 
        sender_loss = advantage * (sender_log_prob + receiver_log_prob)
        exec_sender_loss = advantage * (executive_sender_log_prob + receiver_log_prob)
        receiver_loss = advantage * receiver_log_prob
        policy_loss = (sender_loss + receiver_loss + exec_sender_loss).mean()
        
        entropy_loss = -(sender_entropy.mean() * self.sender_entropy_coeff + 
                         receiver_entropy.mean() * self.receiver_entropy_coeff +
                         executive_sender_entropy.mean() * self.executive_sender_entropy_coeff)

        if self.training:
            self.n_points += 1.0
            self.mean_baseline += (loss.detach().mean().item() -
                                   self.mean_baseline) / self.n_points

        full_loss = policy_loss + entropy_loss

        rest_info['baseline'] = self.mean_baseline
        rest_info['loss'] = loss.mean().item()
        rest_info['sender_entropy'] = sender_entropy.mean()
        rest_info['receiver_entropy'] = receiver_entropy.mean()
        rest_info['executive_sender_entropy'] = executive_sender_entropy.mean()
        return full_loss, rest_info

In [188]:
import seaborn as sns
import matplotlib.pyplot as plt
import neptune
from neptune.experiments import Experiment


class NeptuneMonitor:
        
    def __init__(self, experiment: Experiment = None):
        self.experiment = experiment if experiment else neptune
    
    def log(self, mode, epoch, loss, rest):
        self.experiment.send_metric(f'{mode}_loss', loss)
        for metric, value in rest.items():
            self.experiment.send_metric(f'{mode}_{metric}', value)

            
def save_sender_codebook(experiment, senders, epoch, label):
    figure, axes = plt.subplots(1, len(receivers),sharey=True, figsize=(20,5))
    figure.suptitle(f'Epoch {epoch}')
    for i, (sender, ax) in enumerate(zip(senders, axes)):
        g = sns.heatmap(F.softmax(sender.agent.fc1.weight.detach(), dim=1).numpy(), annot=True, fmt='.2f', ax=ax)
        g.set_title(f'{label} {i}')
    figure.savefig('fig.jpg')
    experiment.log_image(f'{label}s', 'fig.jpg')
    plt.close()

In [189]:
class CustomTrainer(core.Trainer):

    def train(self, n_epochs):
        while self.epoch < n_epochs:
            train_loss, train_rest = self.train_epoch()
            for i, ex_s in enumerate(self.game.executive_senders):
                self.monitor.experiment.send_metric(f'ex sender {i} grad', ex_s.agent.fc1.weight.grad.sum())
            for i, s in enumerate(self.game.senders):
                self.monitor.experiment.send_metric(f'sender {i} grad', s.agent.fc1.weight.grad.sum())

            self.epoch += 1

            self.monitor.log('train', self.epoch, train_loss, train_rest)
            

            if self.validation_data is not None and self.validation_freq > 0 and self.epoch % self.validation_freq == 0:
                validation_loss, rest = self.eval()
                self.monitor.log('validation', self.epoch, validation_loss, rest)
                print(f'validation: epoch {self.epoch}, loss {validation_loss},  {rest}', flush=True)
                save_sender_codebook(self.monitor.experiment, self.game.senders, self.epoch, 'Sender')
                save_sender_codebook(self.monitor.experiment, self.game.executive_senders, self.epoch, 'Executive sender')


                if self.early_stopping:
                    self.early_stopping.update_values(validation_loss, rest, train_loss, rest, self.epoch)
                    if self.early_stopping.should_stop(): break

In [190]:
game = MultiAgentGame(senders, receivers, executive_senders, loss)
sender_params = [{'params': sender.parameters(), 'lr': opts.lr} for sender in senders]
executive_senders_params = [{'params': ex_sender.parameters(), 'lr': opts.lr*0.01} for ex_sender in executive_senders]
receiver_params = [{'params': receiver.parameters(), 'lr': opts.lr} for receiver in receivers]
optimizer = torch.optim.Adam(sender_params+receiver_params+executive_senders_params)

neptune.init('tomekkorbak/EGG')
trainer = CustomTrainer(game=game, optimizer=optimizer, train_data=train_loader, validation_data=test_loader)
experiment = neptune.create_experiment(name='first-egg-experiment', tags=['egg', 'executive-senders'], params=vars(opts))
trainer.monitor = NeptuneMonitor(experiment=experiment)
trainer.train(n_epochs=opts.n_epochs)
experiment.stop()

EGG-64
https://ui.neptune.ml/tomekkorbak/egg/e/EGG-64
validation: epoch 1, loss -0.006105861626565456,  {'acc': 0.115, 'baseline': -0.1009999999999993, 'loss': -0.115, 'sender_entropy': tensor(1.9241), 'receiver_entropy': tensor(1.7247), 'executive_sender_entropy': tensor(1.0272)}
validation: epoch 2, loss -0.06182699277997017,  {'acc': 0.148, 'baseline': -0.11899999999999979, 'loss': -0.148, 'sender_entropy': tensor(1.7910), 'receiver_entropy': tensor(1.5497), 'executive_sender_entropy': tensor(1.0266)}
validation: epoch 3, loss -0.04432927817106247,  {'acc': 0.186, 'baseline': -0.13766666666666397, 'loss': -0.186, 'sender_entropy': tensor(1.5320), 'receiver_entropy': tensor(1.3243), 'executive_sender_entropy': tensor(1.0242)}
validation: epoch 4, loss -0.19516055285930634,  {'acc': 0.244, 'baseline': -0.1542499999999999, 'loss': -0.244, 'sender_entropy': tensor(1.3201), 'receiver_entropy': tensor(1.0149), 'executive_sender_entropy': tensor(1.0231)}
validation: epoch 5, loss -0.150091

validation: epoch 37, loss 0.014040905050933361,  {'acc': 0.796, 'baseline': -0.6088378378378433, 'loss': -0.796, 'sender_entropy': tensor(0.2425), 'receiver_entropy': tensor(0.0102), 'executive_sender_entropy': tensor(0.9847)}
validation: epoch 38, loss 0.03705231100320816,  {'acc': 0.797, 'baseline': -0.613500000000006, 'loss': -0.797, 'sender_entropy': tensor(0.2197), 'receiver_entropy': tensor(0.0079), 'executive_sender_entropy': tensor(0.9847)}
validation: epoch 39, loss -0.001573710236698389,  {'acc': 0.797, 'baseline': -0.6178974358974395, 'loss': -0.797, 'sender_entropy': tensor(0.2603), 'receiver_entropy': tensor(0.0065), 'executive_sender_entropy': tensor(0.9836)}
validation: epoch 40, loss 0.014994163066148758,  {'acc': 0.796, 'baseline': -0.6229250000000087, 'loss': -0.796, 'sender_entropy': tensor(0.2191), 'receiver_entropy': tensor(0.0052), 'executive_sender_entropy': tensor(0.9829)}
validation: epoch 41, loss -0.0018142202170565724,  {'acc': 0.796, 'baseline': -0.6270731