In [1]:
import time
import numpy as np
import risk_ext

import torch
import torch.nn as nn
from torch.distributions.categorical import Categorical
from torch.optim import Adam
torch.manual_seed(0)

def vec_to_matrix(game, v):
    return v.reshape((game.n_max_territories, 1 + game.n_max_players))

def get_state_dim():
    tmp = risk_ext.start_game(1, 1, 0)
    return tmp.n_max_territories * (1 + tmp.n_max_players)
    
def mlp(sizes, activation=nn.Tanh, output_activation=nn.Identity):
    # Build a feedforward neural network.
    layers = []
    for j in range(len(sizes)-1):
        act = activation if j < len(sizes)-2 else output_activation
        layers += [nn.Linear(sizes[j], sizes[j+1]), act()]
    return nn.Sequential(*layers)

class NNPlayer:
    def __init__(self):
        self.state_dim = get_state_dim()
        self.n_actions = 3
        self.lr = 0.01

        self.logits_net = mlp(sizes=[self.state_dim]+[32]+[self.n_actions])
        self.optimizer = Adam(self.logits_net.parameters(), lr=self.lr)

    def get_policy(self, obs):
        logits = self.logits_net(obs)
        return Categorical(logits=logits)

    def get_action(self, obs):
        return self.get_policy(obs).sample().item()

    def compute_loss(self, obs, act, weights):
        logp = self.get_policy(obs).log_prob(act)
        return -(logp * weights).mean()
    
    def act(self, game, state_vec):
        owner_col = vec_to_matrix(game, state_vec)[:, 1]
        attack_from = (owner_col == 1).argmax()
        attack_to = self.get_action(torch.as_tensor(state_vec, dtype=torch.float32))
        return attack_from, attack_to
    
    def learn(self, obs, actions, weights):
        self.optimizer.zero_grad()
        batch_loss = self.compute_loss(obs=torch.as_tensor(obs, dtype=torch.float32),
                                  act=torch.as_tensor(actions, dtype=torch.int32),
                                  weights=torch.as_tensor(weights, dtype=torch.float32)
                                  )
        batch_loss.backward()
        self.optimizer.step()
        return batch_loss

class DumbPlayer:
    def act(self, game, state_vec):
        owner_col = vec_to_matrix(game, state_vec)[:, game.player_idx + 1]
        attack_from = (owner_col == 1).argmax()
        attack_to = (owner_col != 1).argmax()
        return attack_from, attack_to

    def learn(self, obs, actions, weights):
        pass

In [2]:
n_players = 3
n_territories = 3

In [3]:
class Env:
    def __init__(self):
        self.n_games = 0
        
    def reset(self, seed=None):
        self.n_games += 1
        if seed is None:
            seed = self.n_games
        self.game = risk_ext.start_game(n_players, n_territories, seed)
        return self.game.board_state
    
    def step(self, action):
        self.game.step(*action)
        done = self.game.phase == 3
        return self.game.board_state, float(done), done

In [4]:
class PlayerBatch:
    def __init__(self):
        self.obs = []
        self.actions = []
        self.weights = []
        self.returns = []
        self.lengths = []
        self.start_episode()
    
    def record(self, obs, action):      
        self.obs.append(obs.copy())
        self.actions.append(action[1])
        self.ep_length += 1
    
    def finish_episode(self, reward):
        self.ep_returns = reward
        self.returns.append(self.ep_returns)
        self.lengths.append(self.ep_length)
        self.weights += [self.ep_returns] * self.ep_length
        self.start_episode()
    
    def start_episode(self):
        self.ep_length = 0
        self.ep_returns = 0

In [5]:
env = Env()
players = [NNPlayer(), DumbPlayer(), DumbPlayer()]

In [6]:
def train_one_epoch(batch_size):
    batches = [PlayerBatch() for p in players]
    obs = env.reset()

    go = True
    while go:
        #print(env.game.turn_idx, env.game.player_idx)
        action = players[env.game.player_idx].act(env.game, obs)
        batches[env.game.player_idx].record(obs, action)
        obs, reward, done = env.step(action)
        if done:
            for i in range(len(players)):
                player_reward = reward if i == env.game.player_idx else 0
                b = batches[i]
                b.finish_episode(player_reward)
                if len(b.obs) > batch_size:
                    go = False
            obs = env.reset()
            

    loss = [
        players[i].learn(
            batches[i].obs, 
            batches[i].actions, 
            batches[i].weights
        ) for i in range(len(players))
    ]
    return loss, [b.returns for b in batches], [b.lengths for b in batches]

In [7]:
for i in range(50):
    loss, rets, lens = train_one_epoch(5000)
    print('epoch: %3d \t loss0: %.3f \t return0: %.3f \t ep_len: %.3f'%
            (i, loss[0], np.mean(rets[0]), np.mean(lens[0])))

  attack_to = self.get_action(torch.as_tensor(state_vec, dtype=torch.float32))


epoch:   0 	 loss0: 0.571 	 return0: 0.460 	 ep_len: 16.583
epoch:   1 	 loss0: 0.568 	 return0: 0.483 	 ep_len: 13.203
epoch:   2 	 loss0: 0.577 	 return0: 0.510 	 ep_len: 11.289
epoch:   3 	 loss0: 0.494 	 return0: 0.489 	 ep_len: 10.662
epoch:   4 	 loss0: 0.566 	 return0: 0.561 	 ep_len: 9.891
epoch:   5 	 loss0: 0.524 	 return0: 0.562 	 ep_len: 9.935
epoch:   6 	 loss0: 0.588 	 return0: 0.610 	 ep_len: 11.694
epoch:   7 	 loss0: 0.463 	 return0: 0.578 	 ep_len: 10.993
epoch:   8 	 loss0: 0.524 	 return0: 0.597 	 ep_len: 15.176
epoch:   9 	 loss0: 0.415 	 return0: 0.592 	 ep_len: 19.620
epoch:  10 	 loss0: 0.327 	 return0: 0.570 	 ep_len: 19.388
epoch:  11 	 loss0: 0.355 	 return0: 0.569 	 ep_len: 23.319
epoch:  12 	 loss0: 0.204 	 return0: 0.596 	 ep_len: 38.051
epoch:  13 	 loss0: 0.169 	 return0: 0.555 	 ep_len: 39.133
epoch:  14 	 loss0: 0.199 	 return0: 0.586 	 ep_len: 37.872
epoch:  15 	 loss0: 0.224 	 return0: 0.581 	 ep_len: 67.689
epoch:  16 	 loss0: 0.311 	 return0: 0.627

In [7]:
def play_game(players, verbose=False, max_turns=None):    
    env.reset()
    while True:
        action = players[env.game.player_idx].act(env.game, env.game.board_state)
        if verbose:
            print(env.game.turn_idx, env.game.player_idx, env.game.phase, action, env.game.board_state)
        _,_,done = env.step(action)
        if done:
            break
        if max_turns is not None:
            if env.game.turn_idx > max_turns:
                break
    return env.game.player_idx

def faceoff(players, n_games):
    winners = np.empty(n_games, dtype = np.int32)
    for i in range(n_games):
        winners[i] = play_game(players)
    return winners

In [8]:
np.sum(faceoff(players, 500) == 0)

  attack_to = self.get_action(torch.as_tensor(state_vec, dtype=torch.float32))


KeyboardInterrupt: 

In [13]:
np.sum(faceoff([DumbPlayer(), DumbPlayer()], 500) == 0)

IndexError: list index out of range

In [8]:
play_game(players, verbose=True, max_turns=10)

0 0 1 (0, 1) [7. 1. 0. 0. 6. 0. 1. 0. 6. 0. 0. 1.]
0 1 1 (1, 0) [7. 1. 0. 0. 5. 0. 1. 0. 6. 0. 0. 1.]
0 2 1 (2, 0) [7. 1. 0. 0. 3. 0. 1. 0. 7. 0. 0. 1.]
1 0 1 (0, 1) [7. 1. 0. 0. 3. 0. 1. 0. 6. 0. 0. 1.]
1 1 1 (1, 0) [6. 1. 0. 0. 3. 0. 1. 0. 6. 0. 0. 1.]
1 2 1 (2, 0) [6. 1. 0. 0. 1. 0. 1. 0. 7. 0. 0. 1.]
2 0 1 (0, 2) [5. 1. 0. 0. 1. 0. 1. 0. 7. 0. 0. 1.]
2 1 1 (1, 0) [3. 1. 0. 0. 2. 0. 1. 0. 7. 0. 0. 1.]
2 2 1 (2, 0) [3. 1. 0. 0. 1. 0. 1. 0. 8. 0. 0. 1.]
3 0 1 (0, 0) [3. 1. 0. 0. 1. 0. 1. 0. 7. 0. 0. 1.]
3 1 1 (1, 0) [3. 1. 0. 0. 2. 0. 1. 0. 7. 0. 0. 1.]
3 2 1 (2, 0) [3. 1. 0. 0. 1. 0. 1. 0. 8. 0. 0. 1.]
4 0 1 (0, 0) [2. 1. 0. 0. 1. 0. 1. 0. 8. 0. 0. 1.]
4 1 1 (1, 0) [2. 1. 0. 0. 2. 0. 1. 0. 8. 0. 0. 1.]
4 2 1 (2, 0) [2. 1. 0. 0. 1. 0. 1. 0. 9. 0. 0. 1.]
5 1 1 (1, 0) [8. 0. 0. 1. 2. 0. 1. 0. 1. 0. 0. 1.]
5 2 1 (0, 1) [9. 0. 0. 1. 1. 0. 1. 0. 1. 0. 0. 1.]


  attack_to = self.get_action(torch.as_tensor(state_vec, dtype=torch.float32))


2