In [1]:
INPUT_SHAPE = (3,3,3)
LEARNING_RATE = 1e-3
EPOCHS = 100
C_PUCT = 4
MCTS_ITERATIONS = 10
DIRICHLET_ALPHA = 0.3
EPSILON = .75
TRAINING_SIZE = 64

In [2]:
import numpy as np
import random
from tqdm import tqdm
import math
import keras
from keras.layers import *
from keras.models import Model
from keras.optimizers import Adam
import keras.utils as utils
from keras.callbacks import LearningRateScheduler

In [3]:
class Game:
    def __init__(self):
        self.board = np.zeros((3,3))
        self.lastmove = None

    def no_legal_moves(self):
        for i in range(3):
            for j in range(3):
                if self.board[i][j] == 0:
                    return False
        return True
    
    def result(self):
        def cross_winner():
            cw = False
            if (self.board[0][0] == self.board[1][1] == self.board[2][2]) and (self.board[0][0] != 0):
                return self.board[0][0]
            if (self.board[0][2] == self.board[1][1] == self.board[2][0]) and (self.board[0][2] != 0):
                return self.board[0][2]
            return 0
        
        def hor_winner():
            for i in range(3):
                if (self.board[i][0] == self.board[i][1] == self.board[i][2]) and (self.board[i][0] != 0):
                    return self.board[i][0]
            return 0

        def ver_winner():
            for i in range(3):
                if (self.board[0][i] == self.board[1][i] == self.board[2][i]) and (self.board[0][i] != 0):
                    return self.board[0][i]
            return 0
        
        cw = cross_winner()
        if cw != 0:
            return cw
        
        hw = hor_winner()
        if hw != 0:
            return hw
        
        vw = ver_winner()
        if vw != 0:
            return vw

        return 0

    def is_game_over(self):
        return self.result() != 0 or self.no_legal_moves()

    def copy(self):
        g = Game()
        g.lastmove = self.lastmove
        for i in range(3):
            for j in range(3):
                g.board[i][j] = self.board[i][j]
        return g

    def legal_moves(self):
        lm = []
        for i in range(3):
            for j in range(3):
                if self.board[i][j] == 0:
                    lm.append((i, j))
        return lm

    def push(self, action, p):
        if action in self.legal_moves():
            r, c = action
            self.board[r][c] = p
            self.lastmove = p
        else:
            raise Exception("Illegal move exception")
    
    def turn(self):
        return -self.lastmove if self.lastmove != None else 1

In [4]:
def b_to_img(b):
    img = []
    for i in range(3):
        img.append(np.zeros((3,3)))
    
    img = np.asarray(img)

    turn = b.turn()

    for i in range(3):
        for j in range(3):
            if b.board[i][j] == 1:
                img[0][i][j] = 1
                            
            if b.board[i][j] == -1:
                img[1][i][j] = 1
                
            if turn == 1:
                img[2][i][j] = 1
    
    return img

In [5]:
def create_model(conv_size=256, conv_depth=1):
    input_layer = Input(shape=INPUT_SHAPE, name='BoardInput')
    
    x = Conv2D(filters=conv_size, kernel_size=1)(input_layer)
    for _ in range(conv_depth):
        previous = x
        x = Conv2D(filters=conv_size, kernel_size=1)(x)
        x = BatchNormalization(axis=3)(x)
        x = Activation('relu')(x)
        x = Conv2D(filters=conv_size, kernel_size=1)(x)
        x = BatchNormalization(axis=3)(x)
        x = Add()([x, previous])
        x = Activation('relu')(x)
    x = Flatten()(x)
    
    x = Dense(1028, activation='relu')(x)
    x = Dropout(.3)(x)
    x = Dense(512, activation='relu')(x)
    x = Dropout(.3)(x)

    policy = Dense(9, activation='softmax', name='policy')(x)
    value = Dense(1, activation='tanh', name='value')(x)

    model = Model(inputs=input_layer, outputs=[policy, value])
    
    model.compile(loss=['categorical_crossentropy', 'mean_squared_error'], optimizer=Adam(lr=LEARNING_RATE))

    return model

In [6]:
model = create_model()
#model = keras.models.load_model('model.h5')

In [7]:
class MCTSNode:
    def __init__(self, state, prior_probability, action=None, parent=None):
        self.state = state
        self.prior_probability = prior_probability
        
        self.action = action
        self.parent = parent

        self.children = []

        self.N = 0
        self.W = 0
        self.Q = 0

    def is_root(self):
        return self.parent == None

    def is_leaf(self):
        return len(self.children) == 0

    def set_children(self, nnet=model):
        if not self.state.is_game_over():
            prior_probs = nnet.predict(b_to_img(self.state).reshape(1,3,3,3))[0][0].reshape(3,3)
            
            lm = self.state.legal_moves()
            
            for i in range(len(prior_probs)):
                for j in range(len(prior_probs[i])):
                    if (i, j) not in lm:
                        prior_probs[i][j] = 0
            prior_probs = prior_probs.reshape((9,))
            s = np.sum(prior_probs)
            for i in range(len(prior_probs)):
                prior_probs[i] = prior_probs[i] / s        
            prior_probs = prior_probs.reshape((3,3))
            
            noise = np.random.dirichlet(np.ones(len(lm))*DIRICHLET_ALPHA)
            
            for i in range(len(lm)):
                b = self.state.copy()
                b.push(lm[i], b.turn())
                r, c = lm[i][0], lm[i][1]
                pp = prior_probs[r][c]
                pp_n = EPSILON*pp + (1-EPSILON)*noise[i]
                act = lm[i]
                par = self
                self.children.append(MCTSNode(b, pp_n, action=act, parent=par))
    
    def get_ucb(self, nnet):
        return self.Q + C_PUCT * self.prior_probability * (math.sqrt(self.parent.N) / (1 + self.N))

    def select_child(self, nnet=model):
        if not self.state.is_game_over():
            scores = [i.get_ucb(nnet) for i in self.children]
            return self.children[np.argmax(scores)]
        else:
            return self

    def rollout(self, nnet=model):
        if not self.state.is_game_over():
            result = nnet.predict(b_to_img(self.state).reshape(1,3,3,3))[1][0][0]
        else:
            result = self.state.result()

        curr = self
        while True:
            curr.N += 1
            r = result * -curr.state.turn()
            curr.W += r
            curr.Q = curr.W / curr.N
            if curr.is_root():
                break
            else:
                curr = curr.parent

In [8]:
def model_vs_model(nnet1, nnet2, matches=10):
    p1 = 0
    p2 = 0
    d =  0
    
    nnet1_player = 1
    nnet2_player = -1
    
    for i in range(matches):
        player = 1
        game = Game()
        while not game.is_game_over():
            if player == nnet1_player:
                root = MCTSNode(game, 1)
                for i in range(MCTS_ITERATIONS // 2):
                    iteration_done = False
                    current = root
                    while not iteration_done:
                        if current.is_leaf():
                            if current.N == 0:
                                current.rollout()
                                iteration_done = True
                            elif current.N > 0:
                                current.set_children()
                                current = current.select_child()
                                current.rollout()
                                iteration_done = True
                        else:
                            current = current.select_child()

                visits = [i.N for i in root.children]
                s = sum(visits)
                scores = [i / s for i in visits]
                child = np.random.choice(root.children, p=scores)
                row, col = child.action
                assert(player == game.turn())
                game.push(child.action, player)
        
                player *= -1
            
            elif player == nnet2_player:
                root = MCTSNode(game, 1)
                for i in range(MCTS_ITERATIONS // 2):
                    iteration_done = False
                    current = root
                    while not iteration_done:
                        if current.is_leaf():
                            if current.N == 0:
                                current.rollout()
                                iteration_done = True
                            elif current.N > 0:
                                current.set_children()
                                current = current.select_child()
                                current.rollout()
                                iteration_done = True
                        else:
                            current = current.select_child()

                visits = [i.N for i in root.children]
                s = sum(visits)
                scores = [i / s for i in visits]
                child = np.random.choice(root.children, p=scores)
                row, col = child.action
                assert(player == game.turn())
                game.push(child.action, player)
        
                player *= -1
        
        if game.result() == nnet1_player:
            p1 += 1
        elif game.result() == nnet2_player:
            p2 += 1
        else:
            d += 1
            
        nnet1_player *= -1
        nnet2_player *= -1
        
    print(f'player 1 wins: {p1}')
    print(f'player 2 wins: {p2}')
    print(f'draws: {d}')
        
    if p2 >= matches * 60 / 100:
        return True
    else:
        return False

In [9]:
match_memory = []
results = []

train = True
root = None
x_train = []
policy_train = []
value_train = []

for epoch in tqdm(range(EPOCHS)):
    game = Game()
    root = MCTSNode(game, 1)
    while not game.is_game_over():
        for i in range(MCTS_ITERATIONS):
            iteration_done = False
            current = root
            while not iteration_done:
                if current.is_leaf():
                    if current.N == 0:
                        current.rollout()
                        iteration_done = True
                    elif current.N > 0:
                        current.set_children()
                        current = current.select_child()
                        current.rollout()
                        iteration_done = True
                else:
                    current = current.select_child()

        visits = [i.N for i in root.children]
        s = sum(visits)
        scores = [i/s for i in visits]
        child = np.random.choice(root.children, p=scores)
        root = child
        row, col = child.action
        p = np.zeros((3,3)); p[row][col] = 1; p = p.reshape((9,))
        game.push(child.action, game.turn())
        match_memory.append([b_to_img(root.state), p, None])

    final_result = game.result()
    results.append(final_result)
    for i in range(len(match_memory)):
        if match_memory[i][2] == None:
            match_memory[i][2] = final_result
            
    if len(match_memory) > TRAINING_SIZE:
        random.shuffle(match_memory)
        if train:
            del x_train
            del policy_train
            del value_train
            x_train = []
            policy_train = []
            value_train = []

        for i in range(TRAINING_SIZE):
            x_train.append((match_memory[i][0]))
            policy_train.append(match_memory[i][1])
            value_train.append(match_memory[i][2])

        x_train = np.asarray(x_train)
        policy_train = np.asarray(policy_train)
        value_train = np.asarray(value_train)

        old_weights = model.get_weights()
        old_model = create_model()
        old_model.set_weights(old_weights)

        model.fit(x=x_train, y=[policy_train, value_train], 
                batch_size=32, verbose=1, epochs=4)

        if not model_vs_model(old_model, model):
            train = False
            x_train = list(x_train)
            policy_train = list(policy_train)
            value_train = list(value_train)
            del model
            model = old_model
            print("OLD MODEL WIN")
        else:
            train = True
            del old_weights
            del old_model
            del match_memory
            match_memory = []
            print("NEW MODEL WIN/DRAW")
        print(f'len match memory: {len(match_memory)}')

 10%|█         | 10/100 [00:15<02:12,  1.48s/it]

Epoch 1/4
Epoch 2/4
Epoch 3/4
Epoch 4/4


 11%|█         | 11/100 [00:28<07:13,  4.87s/it]

player 1 wins: 6
player 2 wins: 2
draws: 2
OLD MODEL WIN
len match memory: 70
Epoch 1/4
Epoch 2/4
Epoch 3/4
Epoch 4/4


 12%|█▏        | 12/100 [00:40<10:23,  7.08s/it]

player 1 wins: 5
player 2 wins: 5
draws: 0
OLD MODEL WIN
len match memory: 77
Epoch 1/4
Epoch 2/4
Epoch 3/4
Epoch 4/4


 13%|█▎        | 13/100 [00:55<13:40,  9.43s/it]

player 1 wins: 6
player 2 wins: 3
draws: 1
OLD MODEL WIN
len match memory: 84
Epoch 1/4
Epoch 2/4
Epoch 3/4
Epoch 4/4


 14%|█▍        | 14/100 [01:09<15:29, 10.81s/it]

player 1 wins: 5
player 2 wins: 3
draws: 2
OLD MODEL WIN
len match memory: 91
Epoch 1/4
Epoch 2/4
Epoch 3/4
Epoch 4/4


 15%|█▌        | 15/100 [01:23<16:45, 11.83s/it]

player 1 wins: 5
player 2 wins: 4
draws: 1
OLD MODEL WIN
len match memory: 97
Epoch 1/4
Epoch 2/4
Epoch 3/4
Epoch 4/4


 16%|█▌        | 16/100 [01:38<17:56, 12.82s/it]

player 1 wins: 4
player 2 wins: 5
draws: 1
OLD MODEL WIN
len match memory: 106
Epoch 1/4
Epoch 2/4
Epoch 3/4
Epoch 4/4


 17%|█▋        | 17/100 [01:52<18:18, 13.24s/it]

player 1 wins: 5
player 2 wins: 5
draws: 0
OLD MODEL WIN
len match memory: 111
Epoch 1/4
Epoch 2/4
Epoch 3/4
Epoch 4/4


 18%|█▊        | 18/100 [02:10<20:03, 14.68s/it]

player 1 wins: 3
player 2 wins: 5
draws: 2
OLD MODEL WIN
len match memory: 119
Epoch 1/4
Epoch 2/4
Epoch 3/4
Epoch 4/4


 19%|█▉        | 19/100 [02:27<20:44, 15.37s/it]

player 1 wins: 7
player 2 wins: 3
draws: 0
OLD MODEL WIN
len match memory: 126
Epoch 1/4
Epoch 2/4
Epoch 3/4
Epoch 4/4


 20%|██        | 20/100 [02:44<20:54, 15.68s/it]

player 1 wins: 6
player 2 wins: 2
draws: 2
OLD MODEL WIN
len match memory: 135
Epoch 1/4
Epoch 2/4
Epoch 3/4
Epoch 4/4


 21%|██        | 21/100 [03:01<21:08, 16.06s/it]

player 1 wins: 5
player 2 wins: 4
draws: 1
OLD MODEL WIN
len match memory: 144
Epoch 1/4
Epoch 2/4
Epoch 3/4
Epoch 4/4


 22%|██▏       | 22/100 [03:17<21:06, 16.24s/it]

player 1 wins: 1
player 2 wins: 8
draws: 1
NEW MODEL WIN/DRAW
len match memory: 0


 31%|███       | 31/100 [03:32<02:24,  2.09s/it]

Epoch 1/4
Epoch 2/4
Epoch 3/4
Epoch 4/4


 32%|███▏      | 32/100 [03:44<05:25,  4.79s/it]

player 1 wins: 6
player 2 wins: 4
draws: 0
OLD MODEL WIN
len match memory: 68
Epoch 1/4
Epoch 2/4
Epoch 3/4
Epoch 4/4


 33%|███▎      | 33/100 [03:58<08:34,  7.68s/it]

player 1 wins: 4
player 2 wins: 2
draws: 4
OLD MODEL WIN
len match memory: 76
Epoch 1/4
Epoch 2/4
Epoch 3/4
Epoch 4/4


 34%|███▍      | 34/100 [04:11<10:10,  9.25s/it]

player 1 wins: 5
player 2 wins: 4
draws: 1
OLD MODEL WIN
len match memory: 82
Epoch 1/4
Epoch 2/4
Epoch 3/4
Epoch 4/4


 35%|███▌      | 35/100 [04:25<11:29, 10.61s/it]

player 1 wins: 5
player 2 wins: 5
draws: 0
OLD MODEL WIN
len match memory: 89
Epoch 1/4
Epoch 2/4
Epoch 3/4
Epoch 4/4


 36%|███▌      | 36/100 [04:40<12:44, 11.95s/it]

player 1 wins: 4
player 2 wins: 5
draws: 1
OLD MODEL WIN
len match memory: 95
Epoch 1/4
Epoch 2/4
Epoch 3/4
Epoch 4/4


 37%|███▋      | 37/100 [04:54<13:10, 12.55s/it]

player 1 wins: 2
player 2 wins: 7
draws: 1
NEW MODEL WIN/DRAW
len match memory: 0


 46%|████▌     | 46/100 [05:09<01:54,  2.11s/it]

Epoch 1/4
Epoch 2/4
Epoch 3/4
Epoch 4/4


 47%|████▋     | 47/100 [05:21<04:35,  5.20s/it]

player 1 wins: 2
player 2 wins: 5
draws: 3
OLD MODEL WIN
len match memory: 65
Epoch 1/4
Epoch 2/4
Epoch 3/4
Epoch 4/4


 48%|████▊     | 48/100 [05:34<06:35,  7.60s/it]

player 1 wins: 4
player 2 wins: 4
draws: 2
OLD MODEL WIN
len match memory: 71
Epoch 1/4
Epoch 2/4
Epoch 3/4
Epoch 4/4


 49%|████▉     | 49/100 [05:48<08:01,  9.44s/it]

player 1 wins: 5
player 2 wins: 5
draws: 0
OLD MODEL WIN
len match memory: 78
Epoch 1/4
Epoch 2/4
Epoch 3/4
Epoch 4/4


 50%|█████     | 50/100 [06:02<09:07, 10.95s/it]

player 1 wins: 4
player 2 wins: 5
draws: 1
OLD MODEL WIN
len match memory: 87
Epoch 1/4
Epoch 2/4
Epoch 3/4
Epoch 4/4


 51%|█████     | 51/100 [06:19<10:19, 12.63s/it]

player 1 wins: 4
player 2 wins: 4
draws: 2
OLD MODEL WIN
len match memory: 96
Epoch 1/4
Epoch 2/4
Epoch 3/4
Epoch 4/4


 52%|█████▏    | 52/100 [06:33<10:28, 13.08s/it]

player 1 wins: 7
player 2 wins: 2
draws: 1
OLD MODEL WIN
len match memory: 104
Epoch 1/4
Epoch 2/4
Epoch 3/4
Epoch 4/4


 53%|█████▎    | 53/100 [06:48<10:47, 13.77s/it]

player 1 wins: 4
player 2 wins: 4
draws: 2
OLD MODEL WIN
len match memory: 113
Epoch 1/4
Epoch 2/4
Epoch 3/4
Epoch 4/4


 54%|█████▍    | 54/100 [07:02<10:31, 13.74s/it]

player 1 wins: 4
player 2 wins: 5
draws: 1
OLD MODEL WIN
len match memory: 118
Epoch 1/4
Epoch 2/4
Epoch 3/4
Epoch 4/4


 55%|█████▌    | 55/100 [07:16<10:25, 13.91s/it]

player 1 wins: 5
player 2 wins: 5
draws: 0
OLD MODEL WIN
len match memory: 126
Epoch 1/4
Epoch 2/4
Epoch 3/4
Epoch 4/4


 56%|█████▌    | 56/100 [07:32<10:30, 14.32s/it]

player 1 wins: 5
player 2 wins: 4
draws: 1
OLD MODEL WIN
len match memory: 135
Epoch 1/4
Epoch 2/4
Epoch 3/4
Epoch 4/4


 57%|█████▋    | 57/100 [07:47<10:32, 14.71s/it]

player 1 wins: 5
player 2 wins: 3
draws: 2
OLD MODEL WIN
len match memory: 143
Epoch 1/4
Epoch 2/4
Epoch 3/4
Epoch 4/4


 58%|█████▊    | 58/100 [08:03<10:23, 14.84s/it]

player 1 wins: 7
player 2 wins: 2
draws: 1
OLD MODEL WIN
len match memory: 152
Epoch 1/4
Epoch 2/4
Epoch 3/4
Epoch 4/4


 59%|█████▉    | 59/100 [08:19<10:25, 15.26s/it]

player 1 wins: 3
player 2 wins: 6
draws: 1
NEW MODEL WIN/DRAW
len match memory: 0


 68%|██████▊   | 68/100 [08:34<01:10,  2.21s/it]

Epoch 1/4
Epoch 2/4
Epoch 3/4
Epoch 4/4


 69%|██████▉   | 69/100 [08:45<02:30,  4.84s/it]

player 1 wins: 7
player 2 wins: 2
draws: 1
OLD MODEL WIN
len match memory: 68
Epoch 1/4
Epoch 2/4
Epoch 3/4
Epoch 4/4


 70%|███████   | 70/100 [08:56<03:26,  6.87s/it]

player 1 wins: 5
player 2 wins: 3
draws: 2
OLD MODEL WIN
len match memory: 73
Epoch 1/4
Epoch 2/4
Epoch 3/4
Epoch 4/4


 71%|███████   | 71/100 [09:08<04:04,  8.44s/it]

player 1 wins: 5
player 2 wins: 4
draws: 1
OLD MODEL WIN
len match memory: 80
Epoch 1/4
Epoch 2/4
Epoch 3/4
Epoch 4/4


 72%|███████▏  | 72/100 [09:21<04:31,  9.71s/it]

player 1 wins: 6
player 2 wins: 4
draws: 0
OLD MODEL WIN
len match memory: 85
Epoch 1/4
Epoch 2/4
Epoch 3/4
Epoch 4/4


 73%|███████▎  | 73/100 [09:35<04:56, 10.97s/it]

player 1 wins: 4
player 2 wins: 4
draws: 2
OLD MODEL WIN
len match memory: 90
Epoch 1/4
Epoch 2/4
Epoch 3/4
Epoch 4/4


 74%|███████▍  | 74/100 [09:47<04:56, 11.40s/it]

player 1 wins: 5
player 2 wins: 5
draws: 0
OLD MODEL WIN
len match memory: 95
Epoch 1/4
Epoch 2/4
Epoch 3/4
Epoch 4/4


 75%|███████▌  | 75/100 [10:00<04:50, 11.63s/it]

player 1 wins: 5
player 2 wins: 5
draws: 0
OLD MODEL WIN
len match memory: 100
Epoch 1/4
Epoch 2/4
Epoch 3/4
Epoch 4/4


 76%|███████▌  | 76/100 [10:13<04:50, 12.12s/it]

player 1 wins: 4
player 2 wins: 6
draws: 0
NEW MODEL WIN/DRAW
len match memory: 0


 84%|████████▍ | 84/100 [10:27<00:36,  2.28s/it]

Epoch 1/4
Epoch 2/4
Epoch 3/4
Epoch 4/4


 85%|████████▌ | 85/100 [10:39<01:19,  5.27s/it]

player 1 wins: 3
player 2 wins: 6
draws: 1
NEW MODEL WIN/DRAW
len match memory: 0


 94%|█████████▍| 94/100 [10:53<00:09,  1.61s/it]

Epoch 1/4
Epoch 2/4
Epoch 3/4
Epoch 4/4


 95%|█████████▌| 95/100 [11:04<00:22,  4.57s/it]

player 1 wins: 2
player 2 wins: 7
draws: 1
NEW MODEL WIN/DRAW
len match memory: 0


100%|██████████| 100/100 [11:13<00:00,  6.74s/it]


In [10]:
model.save('model.h5')

In [14]:
# vs random player
nnet_player = 1
player = 1
game = Game()
while not game.is_game_over():
    print(game.board)
    print('\n')
    if player == nnet_player:
        root = MCTSNode(game, 1)
        for i in range(300):
            iteration_done = False
            current = root
            while not iteration_done:
                if current.is_leaf():
                    if current.N == 0:
                        current.rollout()
                        iteration_done = True
                    elif current.N > 0:
                        current.set_children()
                        current = current.select_child()
                        current.rollout()
                        iteration_done = True
                else:
                    current = current.select_child()

        visits = [i.N for i in root.children]
        s = sum(visits)
        scores = [i / s for i in visits]
        child = root.children[np.argmax(visits)]
        row, col = child.action
        assert(player == game.turn())
        game.push(child.action, player)
        player *= -1        
    else:
        game.push(random.choice(game.legal_moves()), player)
        player *= -1

print(game.board)
print('\n')
print('result:', game.result())

[[0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]]


[[0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 1.]]


[[ 0.  0.  0.]
 [ 0. -1.  0.]
 [ 0.  0.  1.]]


[[ 0.  0.  1.]
 [ 0. -1.  0.]
 [ 0.  0.  1.]]


[[-1.  0.  1.]
 [ 0. -1.  0.]
 [ 0.  0.  1.]]


[[-1.  0.  1.]
 [ 0. -1.  1.]
 [ 0.  0.  1.]]


result: 1.0


In [12]:
# vs human playing first
nnet_player = 1
player = 1
game = Game()
while not game.is_game_over():
    print(game.board)
    print('\n')
    if player == nnet_player:
        root = MCTSNode(game, 1)
        for i in range(300):
            iteration_done = False
            current = root
            while not iteration_done:
                if current.is_leaf():
                    if current.N == 0:
                        current.rollout()
                        iteration_done = True
                    elif current.N > 0:
                        current.set_children()
                        current = current.select_child()
                        current.rollout()
                        iteration_done = True
                else:
                    current = current.select_child()

        visits = [i.N for i in root.children]
        s = sum(visits)
        scores = [i / s for i in visits]
        child = root.children[np.argmax(visits)]
        row, col = child.action
        assert(player == game.turn())
        game.push(child.action, player)
        player *= -1        
    else:
        r = int(input('insert a row: '))
        c = int(input('insert a column: '))
        print('\n')
        assert((r,c) in game.legal_moves())
        game.push((r,c), player)
        player *= -1

print(game.board)
print('\n')
print('result:', game.result())

[[0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]]


[[0. 0. 0.]
 [0. 0. 1.]
 [0. 0. 0.]]


insert a row: 0
insert a column: 0


[[-1.  0.  0.]
 [ 0.  0.  1.]
 [ 0.  0.  0.]]


[[-1.  0.  0.]
 [ 1.  0.  1.]
 [ 0.  0.  0.]]


insert a row: 1
insert a column: 1


[[-1.  0.  0.]
 [ 1. -1.  1.]
 [ 0.  0.  0.]]


[[-1.  0.  0.]
 [ 1. -1.  1.]
 [ 0.  0.  1.]]


insert a row: 0
insert a column: 2


[[-1.  0. -1.]
 [ 1. -1.  1.]
 [ 0.  0.  1.]]


[[-1.  0. -1.]
 [ 1. -1.  1.]
 [ 1.  0.  1.]]


insert a row: 2
insert a column: 1


[[-1.  0. -1.]
 [ 1. -1.  1.]
 [ 1. -1.  1.]]


[[-1.  1. -1.]
 [ 1. -1.  1.]
 [ 1. -1.  1.]]


result: 0


In [13]:
# vs human playing second
nnet_player = -1
player = 1
game = Game()
while not game.is_game_over():
    print(game.board)
    print('\n')
    if player == nnet_player:
        root = MCTSNode(game, 1)
        for i in range(300):
            iteration_done = False
            current = root
            while not iteration_done:
                if current.is_leaf():
                    if current.N == 0:
                        current.rollout()
                        iteration_done = True
                    elif current.N > 0:
                        current.set_children()
                        current = current.select_child()
                        current.rollout()
                        iteration_done = True
                else:
                    current = current.select_child()

        visits = [i.N for i in root.children]
        s = sum(visits)
        scores = [i / s for i in visits]
        child = root.children[np.argmax(visits)]
        row, col = child.action
        assert(player == game.turn())
        game.push(child.action, player)
        player *= -1        
    else:
        r = int(input('insert a row: '))
        c = int(input('insert a column: '))
        print('\n')
        assert((r,c) in game.legal_moves())
        game.push((r,c), player)
        player *= -1

print(game.board)
print('\n')
print('result:', game.result())

[[0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]]


insert a row: 0
insert a column: 0


[[1. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]]


[[ 1.  0.  0.]
 [ 0. -1.  0.]
 [ 0.  0.  0.]]


insert a row: 0
insert a column: 2


[[ 1.  0.  1.]
 [ 0. -1.  0.]
 [ 0.  0.  0.]]


[[ 1. -1.  1.]
 [ 0. -1.  0.]
 [ 0.  0.  0.]]


insert a row: 2
insert a column: 1


[[ 1. -1.  1.]
 [ 0. -1.  0.]
 [ 0.  1.  0.]]


[[ 1. -1.  1.]
 [ 0. -1.  0.]
 [-1.  1.  0.]]


insert a row: 1
insert a column: 2


[[ 1. -1.  1.]
 [ 0. -1.  1.]
 [-1.  1.  0.]]


[[ 1. -1.  1.]
 [ 0. -1.  1.]
 [-1.  1. -1.]]


insert a row: 1
insert a column: 0


[[ 1. -1.  1.]
 [ 1. -1.  1.]
 [-1.  1. -1.]]


result: 0
