# Initialize a game

In [1]:
from ConnectN import ConnectN

game_setting = {'size':(6,6), 'N':4, 'pie_rule':True}
game = ConnectN(**game_setting)


In [2]:
% matplotlib notebook

from Play import Play


gameplay=Play(ConnectN(**game_setting), 
              player1=None, 
              player2=None)


<IPython.core.display.Javascript object>

# Define our policy

Please go ahead and define your own policy! See if you can train it under 1000 games and with only 1000 steps of exploration in each move.

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from math import *
import numpy as np

from ConnectN import ConnectN
game_setting = {'size':(6,6), 'N':4}
game = ConnectN(**game_setting)

class Policy(nn.Module):

    def __init__(self, game):
        super(Policy, self).__init__()

        # input = 6x6 board
        # convert to 5x5x8
        self.conv1 = nn.Conv2d(1, 16, kernel_size=2, stride=1, bias=False)
        # 5x5x16 to 3x3x32
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, bias=False)

        self.size=3*3*32
        
        # the part for actions
        self.fc_action1 = nn.Linear(self.size, self.size//4)
        self.fc_action2 = nn.Linear(self.size//4, 36)
        
        # the part for the value function
        self.fc_value1 = nn.Linear(self.size, self.size//6)
        self.fc_value2 = nn.Linear(self.size//6, 1)
        self.tanh_value = nn.Tanh()
        
    def forward(self, x):

        y = F.leaky_relu(self.conv1(x))
        y = F.leaky_relu(self.conv2(y))
        y = y.view(-1, self.size)
        
        # action head
        a = self.fc_action2(F.leaky_relu(self.fc_action1(y)))
        
        avail = (torch.abs(x.squeeze())!=1).type(torch.FloatTensor)
        avail = avail.view(-1, 36)
        maxa = torch.max(a)
        exp = avail*torch.exp(a-maxa)
        prob = exp/torch.sum(exp)
        
        # value head
        value = self.tanh_value(self.fc_value2(F.leaky_relu( self.fc_value1(y) )))
        return prob.view(6,6), value

policy = Policy(game)


# Define a MCTS player for Play

In [4]:
import MCTS

from copy import copy

def Policy_Player_MCTS(game):
    mytree = MCTS.Node(copy(game))
    for _ in range(1000):
        mytree.explore(policy)
       
    mytreenext, (v, nn_v, p, nn_p) = mytree.next(temperature=0.1)
    
    return mytreenext.game.last_move

import random

def Random_Player(game):
    return random.choice(game.available_moves())    


# Play a game against a random policy

In [5]:
% matplotlib notebook

from Play import Play


gameplay=Play(ConnectN(**game_setting), 
              player1=Policy_Player_MCTS, 
              player2=None)


<IPython.core.display.Javascript object>

# Training

In [6]:
# initialize our alphazero agent and optimizer
import torch.optim as optim

game=ConnectN(**game_setting)
policy = Policy(game)
optimizer = optim.Adam(policy.parameters(), lr=.01, weight_decay=1.e-5)

! pip install progressbar



Beware, training is **VERY VERY** slow!!

In [None]:
# train our agent

from collections import deque
import MCTS

# try a higher number
episodes = 2000

import progressbar as pb
widget = ['training loop: ', pb.Percentage(), ' ', 
          pb.Bar(), ' ', pb.ETA() ]
timer = pb.ProgressBar(widgets=widget, maxval=episodes).start()

outcomes = []
policy_loss = []

Nmax = 1000

for e in range(episodes):

    mytree = MCTS.Node(game)
    logterm = []
    vterm = []
    
    while mytree.outcome is None:
        for _ in range(Nmax):
            mytree.explore(policy)
            if mytree.N >= Nmax:
                break
            
        current_player = mytree.game.player
        mytree, (v, nn_v, p, nn_p) = mytree.next()
        mytree.detach_mother()
        
        loglist = torch.log(nn_p)*p
        constant = torch.where(p>0, p*torch.log(p),torch.tensor(0.))
        logterm.append(-torch.sum(loglist-constant))

        vterm.append(nn_v*current_player)
        
    # we compute the "policy_loss" for computing gradient
    outcome = mytree.outcome
    outcomes.append(outcome)
    
    loss = torch.sum( (torch.stack(vterm)-outcome)**2 + torch.stack(logterm) )
    optimizer.zero_grad()
    loss.backward()
    policy_loss.append(float(loss))

    optimizer.step()
    
    if e%10==0:
        print("game: ",e+1, ", mean loss: {:3.2f}".format(np.mean(policy_loss[-20:])),
              ", recent outcomes: ", outcomes[-10:])
    
    if e%500==0:
        torch.save(policy,'6-6-4-pie-{:d}.mypolicy'.format(e))
    del loss
    
    timer.update(e+1)
    
timer.finish()

  "type " + obj.__name__ + ". It won't be checked "
training loop:   0% |                                          | ETA:  19:12:56

game:  1 , mean loss: 30.47 , recent outcomes:  [1]


training loop:   0% |                                          | ETA:  12:59:58

game:  11 , mean loss: 28.04 , recent outcomes:  [1, -1, 1, 1, -1, 1, 1, 1, 1, 1]


training loop:   1% |                                           | ETA:  8:44:21

game:  21 , mean loss: 22.92 , recent outcomes:  [1, 1, 1, -1, 1, -1, 1, 1, 1, 1]


training loop:   1% |                                           | ETA:  7:33:28

game:  31 , mean loss: 17.35 , recent outcomes:  [1, -1, -1, 1, 1, 1, 1, 1, -1, 1]


training loop:   2% |                                           | ETA:  7:06:10

game:  41 , mean loss: 16.55 , recent outcomes:  [-1, 1, 1, 1, 1, 1, -1, 1, 1, -1]


training loop:   2% |#                                          | ETA:  6:54:36

game:  51 , mean loss: 16.71 , recent outcomes:  [1, 1, 1, 1, 1, 1, 1, 1, -1, 1]


training loop:   3% |#                                          | ETA:  6:41:38

game:  61 , mean loss: 13.40 , recent outcomes:  [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]


training loop:   3% |#                                          | ETA:  6:18:25

game:  71 , mean loss: 12.06 , recent outcomes:  [-1, 1, 1, 1, 1, 1, 1, 1, 1, -1]


training loop:   4% |#                                          | ETA:  6:03:43

game:  81 , mean loss: 11.43 , recent outcomes:  [1, 1, 1, 1, -1, 1, 1, 1, 1, 1]


training loop:   4% |#                                          | ETA:  5:55:39

game:  91 , mean loss: 11.11 , recent outcomes:  [1, 1, 1, 1, 1, -1, -1, 1, 1, 1]


training loop:   5% |##                                         | ETA:  5:56:48

game:  101 , mean loss: 15.18 , recent outcomes:  [-1, 1, 1, 1, 1, -1, 1, 1, 1, -1]


training loop:   5% |##                                         | ETA:  5:53:17

game:  111 , mean loss: 16.27 , recent outcomes:  [1, -1, 1, 1, 1, 1, -1, 1, -1, -1]


training loop:   6% |##                                         | ETA:  5:51:20

game:  121 , mean loss: 13.87 , recent outcomes:  [1, 1, 1, 1, -1, 1, 1, 1, 1, -1]


training loop:   6% |##                                         | ETA:  5:51:24

game:  131 , mean loss: 12.86 , recent outcomes:  [1, 1, 1, 1, -1, 1, 1, 1, -1, -1]


training loop:   7% |###                                        | ETA:  5:55:45

game:  141 , mean loss: 16.69 , recent outcomes:  [1, -1, 1, 1, -1, 1, -1, -1, -1, -1]


training loop:   7% |###                                        | ETA:  5:53:42

game:  151 , mean loss: 15.42 , recent outcomes:  [-1, 1, 1, 1, 1, 1, -1, -1, 1, 1]


training loop:   8% |###                                        | ETA:  6:02:03

game:  161 , mean loss: 15.31 , recent outcomes:  [1, -1, -1, -1, 1, 1, -1, -1, 1, -1]


training loop:   8% |###                                        | ETA:  6:04:49

game:  171 , mean loss: 18.40 , recent outcomes:  [1, 1, 1, 1, -1, -1, 1, 1, -1, 1]


training loop:   9% |###                                        | ETA:  6:08:01

game:  181 , mean loss: 19.31 , recent outcomes:  [-1, -1, 1, 1, 1, 1, -1, 1, -1, -1]


training loop:   9% |####                                       | ETA:  6:07:55

game:  191 , mean loss: 20.09 , recent outcomes:  [-1, 1, 1, 1, 1, -1, 1, 1, -1, -1]


training loop:  10% |####                                       | ETA:  6:10:13

game:  201 , mean loss: 18.96 , recent outcomes:  [-1, -1, -1, -1, 1, -1, -1, -1, 1, 1]


training loop:  10% |####                                       | ETA:  6:12:19

game:  211 , mean loss: 20.31 , recent outcomes:  [1, -1, 1, 1, 1, -1, -1, 1, -1, 1]


training loop:  11% |####                                       | ETA:  6:12:30

game:  221 , mean loss: 20.67 , recent outcomes:  [1, 1, -1, 1, 1, -1, -1, 1, 1, 1]


training loop:  11% |####                                       | ETA:  6:13:11

game:  231 , mean loss: 20.48 , recent outcomes:  [-1, -1, 1, 1, 1, 1, 1, 1, 1, 1]


training loop:  12% |#####                                      | ETA:  6:10:01

game:  241 , mean loss: 19.56 , recent outcomes:  [1, 1, -1, 1, -1, 1, 1, 1, 1, 1]


training loop:  12% |#####                                      | ETA:  6:10:25

game:  251 , mean loss: 20.47 , recent outcomes:  [1, 1, 0, -1, 1, -1, -1, -1, -1, -1]


training loop:  13% |#####                                      | ETA:  6:08:55

game:  261 , mean loss: 19.60 , recent outcomes:  [1, -1, 1, 1, -1, 1, 1, 1, 1, 1]


training loop:  13% |#####                                      | ETA:  6:10:27

game:  271 , mean loss: 21.19 , recent outcomes:  [1, -1, -1, 1, -1, 1, 1, 1, 1, -1]


training loop:  14% |######                                     | ETA:  6:11:29

game:  281 , mean loss: 22.46 , recent outcomes:  [1, 1, 1, 1, 1, 1, 1, 0, 1, 1]


training loop:  14% |######                                     | ETA:  6:07:47

game:  291 , mean loss: 17.84 , recent outcomes:  [1, 1, 1, 1, 1, 1, 1, -1, 1, 1]


training loop:  15% |######                                     | ETA:  6:03:22

game:  301 , mean loss: 18.96 , recent outcomes:  [-1, 1, 1, -1, 1, 1, -1, -1, 1, -1]


training loop:  15% |######                                     | ETA:  6:00:49

game:  311 , mean loss: 19.38 , recent outcomes:  [1, 1, 1, 1, 1, -1, 1, 1, -1, 1]


training loop:  16% |######                                     | ETA:  5:57:20

game:  321 , mean loss: 16.73 , recent outcomes:  [1, 1, 1, 1, 1, 1, -1, 1, 1, 1]


training loop:  16% |#######                                    | ETA:  5:53:37

game:  331 , mean loss: 14.56 , recent outcomes:  [1, 1, 1, 1, 1, 1, 1, 1, 1, -1]


training loop:  17% |#######                                    | ETA:  5:49:39

game:  341 , mean loss: 15.37 , recent outcomes:  [1, 1, 1, -1, 1, 1, 1, 1, 1, 1]


training loop:  17% |#######                                    | ETA:  5:47:42

game:  351 , mean loss: 20.55 , recent outcomes:  [1, 1, -1, -1, 1, 1, -1, 1, 1, 0]


training loop:  18% |#######                                    | ETA:  5:44:32

game:  361 , mean loss: 19.66 , recent outcomes:  [1, 1, -1, 1, 1, 1, 1, 1, 1, 1]


training loop:  18% |#######                                    | ETA:  5:41:42

game:  371 , mean loss: 12.97 , recent outcomes:  [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]


training loop:  19% |########                                   | ETA:  5:36:42

game:  381 , mean loss: 11.29 , recent outcomes:  [1, 1, 1, 1, 1, 1, 1, 1, -1, 1]


training loop:  19% |########                                   | ETA:  5:32:15

game:  391 , mean loss: 12.35 , recent outcomes:  [1, 1, 1, -1, 1, -1, 1, 1, 1, 1]


training loop:  20% |########                                   | ETA:  5:29:16

game:  401 , mean loss: 18.01 , recent outcomes:  [-1, -1, 1, 1, 1, 1, 1, -1, -1, 1]


training loop:  20% |########                                   | ETA:  5:25:12

game:  411 , mean loss: 17.43 , recent outcomes:  [1, 1, 1, 1, 1, 1, 1, 1, -1, 1]


training loop:  21% |#########                                  | ETA:  5:23:32

game:  421 , mean loss: 14.32 , recent outcomes:  [-1, 1, 1, 1, -1, 1, 1, 1, 1, 1]


training loop:  21% |#########                                  | ETA:  5:19:20

game:  431 , mean loss: 14.16 , recent outcomes:  [1, 1, 1, 1, 1, -1, 1, 1, 1, 1]


training loop:  22% |#########                                  | ETA:  5:13:42

game:  441 , mean loss: 9.58 , recent outcomes:  [1, 1, 1, 1, 1, 1, 1, 1, 1, -1]


training loop:  22% |#########                                  | ETA:  5:08:54

game:  451 , mean loss: 10.74 , recent outcomes:  [1, 1, 1, -1, 1, 1, 1, 1, 1, 1]


training loop:  23% |#########                                  | ETA:  5:04:32

game:  461 , mean loss: 10.70 , recent outcomes:  [1, -1, 1, 1, 1, 1, 1, 1, 1, 1]


training loop:  23% |##########                                 | ETA:  4:59:44

game:  471 , mean loss: 7.74 , recent outcomes:  [-1, 1, 1, 1, 1, 1, 1, 1, 1, 1]


training loop:  24% |##########                                 | ETA:  4:56:16

game:  481 , mean loss: 10.68 , recent outcomes:  [1, 1, -1, -1, 1, 1, 1, 1, 1, 1]


training loop:  24% |##########                                 | ETA:  4:52:37

game:  491 , mean loss: 9.69 , recent outcomes:  [1, -1, 1, 1, 1, 1, 1, 1, 1, 1]


training loop:  25% |##########                                 | ETA:  4:51:05

game:  501 , mean loss: 10.32 , recent outcomes:  [1, 1, -1, 1, 1, -1, 1, 1, 1, 1]


training loop:  25% |##########                                 | ETA:  4:48:43

game:  511 , mean loss: 13.23 , recent outcomes:  [1, 1, 1, 1, -1, 1, 1, 1, 1, 1]


training loop:  26% |###########                                | ETA:  4:44:44

game:  521 , mean loss: 8.46 , recent outcomes:  [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]


training loop:  26% |###########                                | ETA:  4:41:30

game:  531 , mean loss: 5.19 , recent outcomes:  [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]


training loop:  27% |###########                                | ETA:  4:37:39

game:  541 , mean loss: 4.65 , recent outcomes:  [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]


training loop:  27% |###########                                | ETA:  4:33:44

game:  551 , mean loss: 3.91 , recent outcomes:  [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]


training loop:  28% |############                               | ETA:  4:30:54

game:  561 , mean loss: 3.82 , recent outcomes:  [1, 1, 1, 1, 1, 1, 1, 1, -1, 1]


training loop:  28% |############                               | ETA:  4:27:33

game:  571 , mean loss: 3.61 , recent outcomes:  [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]


training loop:  29% |############                               | ETA:  4:24:53

# setup environment to pit your AI against the challenge policy '6-6-4-pie.policy'

In [7]:
challenge_policy = torch.load('6-6-4-pie.policy')

def Challenge_Player_MCTS(game):
    mytree = MCTS.Node(copy(game))
    for _ in range(1000):
        mytree.explore(challenge_policy)
       
    mytreenext, (v, nn_v, p, nn_p) = mytree.next(temperature=0.1)
    
    return mytreenext.game.last_move



# Let the game begin!

In [8]:
% matplotlib notebook
gameplay=Play(ConnectN(**game_setting), 
              player2=Policy_Player_MCTS, 
              player1=Challenge_Player_MCTS)

<IPython.core.display.Javascript object>