In [12]:
%%writefile environment.py
import numpy as np

def two_hot(n, a, b):
    ret = np.zeros(2 * n)
    ret[[a, n + b]] = 1
    return ret

class Game(object):
    def __init__(self, n):
        self.n = n
        self.target = None
        self.flip = None
        self.x = None
        self.y = None
        self.reset()
    
    def reset(self):
        self.flip = np.random.randint(2)
    
    def sender_input(self):
        x, y = np.random.choice(range(self.n), 2, replace=False)
        self.x, self.y = x, y
        self.target = 0 if x > y else 1
        return two_hot(self.n, x, y)
        
    def receiver_input(self, val):
        assert -1e-8 < val < 1 + 1e-8
        if self.flip == 1: 
            self.target = 1 - self.target
            self.x, self.y = self.y, self.x
        return np.concatenate((two_hot(self.n, self.x, self.y), [val]))
    
    def reward(self, out):
        assert self.target is not None
        if out == self.target:
            return 1
        else:
            return 0    

if __name__ == '__main__':
    g = Game(10)
    print(g.sender_input())
    print(g.receiver_input(0.5))
    g.reset()
    print(g.sender_input())
    print(g.receiver_input(0.5))

Overwriting environment.py


In [23]:
%%writefile train.py

from environment import Game
from model import *
from itertools import count
import torch.optim as optim
import logging

bound = [0., 1.]

# Hyper parameters
explore_sigma = 0.05
n_r = 10 

n_numbers = 10
n_hidden = 50
n_games = 32
game_pool = []
for i in range(n_games):
    game_pool.append(Game(n_numbers))

R = Receiver(n_numbers, n_hidden)
SA = SenderActor(n_numbers, n_hidden)
SC = SenderCritic(n_numbers, n_hidden)

optim_r = optim.SGD(R.parameters(), lr=1e-4)
optim_sa = optim.SGD(SA.parameters(), lr=1e-4)
optim_sc = optim.SGD(SC.parameters(), lr=1e-4)

running_succ_rate = 0
for epoch in count():
    succ_rate = 0
    for i in range(n_games):
        game_pool[i].reset()
    
    
    
    
    running_succ_rate = running_succ_rate * 0.95 + succ_rate * 0.05
    print('successful_rate = {}'.format(running_succ_rate))

Overwriting train.py


In [21]:
%%writefile model.py

import torch 
import torch.nn as nn
import torch.nn.functional as F

"""
Format:
input vector: (2n + 1)
output vector(prob): (2)
"""
class Receiver(nn.Module):
    def __init__(self, n, hid):
        super(Receiver, self).__init__()
        self.n = n
        self.hid = hid
        self.net = nn.Sequential(
            nn.Linear(2 * n + 1, hid),
            nn.SELU(),
            nn.Linear(hid, 2),
            nn.Softmax()
        )
    
    def forward(self, input):
        input_size = input.size(1)
        assert input_size == self.n * 2 + 1
        return self.net(input)

"""
Format:
input vector: (2n)
output vector(real number range from 0 to 1): (1)
"""
class SenderActor(nn.Module):
    def __init__(self, n, hid):
        super(SenderActor, self).__init__()
        self.n = n
        self.hid = hid
        self.net = nn.Sequential(
            nn.Linear(2 * n, hid),
            nn.SELU(),
            nn.Linear(hid, 1)
        )
    
    def forward(self, input):
        input_size = input.size(1)
        assert input_size == self.n * 2
        return self.net(input)

"""
Format:
input vector: (2n + 1)
output vector(Q value): (1)
"""
class SenderCritic(nn.Module):
    def __init__(self, n, hid):
        super(SenderCritic, self).__init__()
        self.n = n
        self.hid = hid
        self.net = nn.Sequential(
            nn.Linear(2 * n, hid),
            nn.SELU(),
            nn.Linear(hid, 1)
        )
    
    def forward(self, input):
        input_size = input.size(1)
        assert input_size == self.n * 2 + 1
        return self.net(input)


Overwriting model.py
