In [89]:
import numpy as np
import random
from copy import deepcopy
from time import time

In [90]:
class Timer:
    def __init__(self, name=None):
        self.name = name if name else "Timer"

    def __enter__(self):
        self.start = time()
        return self

    def __exit__(self, *args):
        self.end = time()
        self.interval = self.end - self.start
        print(f"{self.name} time: {self.interval}")

# Board

In [91]:
# Gomoku
class Board():
    def __init__(self, size=19):
        self.size = size
        self.board = np.zeros((size, size))
        self.turn = 1
        self.winner = 0
        self.done = False
        self.move_count = 0
        self.move_history = []
        self.possible_neighbor_moves = []

        self.forward_diag_idx_x = []
        self.forward_diag_idx_y = []
        for i in range(self.size * 2 - 1):
            self.forward_diag_idx_x.append([])
            self.forward_diag_idx_y.append([])
            for j in range(i+1):
                if i < self.size and i-j < self.size:
                    self.forward_diag_idx_x[-1].append(j)
                    self.forward_diag_idx_y[-1].append(i-j)
        
        self.backward_diag_idx_x = []
        self.backward_diag_idx_y = []
        for i in range(-self.size + 1, self.size):
            self.backward_diag_idx_x.append([])
            self.backward_diag_idx_y.append([])
            for j in range(self.size):
                if i+j in range(self.size):
                    self.backward_diag_idx_x[-1].append(i+j)
                    self.backward_diag_idx_y[-1].append(j)

    def __str__(self):
        s = '  '
        for i in range(min(self.size, 10)):
            s += '  ' + str(i)
        for i in range(10, self.size):
            s += ' ' + str(i)
        s += '\n'
        # print markings
        for i in range(self.size):
            s += str(i) if i >= 10 else str(i) + ' '
            s += '|'
            for j in range(self.size):
                if self.board[i, j] == 0:
                    s += ' - '
                elif self.board[i, j] == 1:
                    s += ' X '
                else:
                    s += ' O '
            s += '|\n'
        return s

    def move(self, x, y):
        if self.board[x, y] == 0:
            self.board[x][y] = self.turn
            self.move_history.append((x, y))
            self.move_count += 1
            self.turn *= -1
            self.done, self.winner = self.check_done()
            return True
        else:
            return False

    # @property
    # def done(self):
    #     nb_tours = 0 #nombre de tours joués
    #     nb_streak = 0 #nombre de pions de la même couleur consécutifs
    #     prev = 0 #valeur du pion précédant


    #     #parcours par ligne et recherche etat partie
    #     for i in range(self.size - 4):
    #         for j in range(self.size - 4):
    #             if self.board[i, j] == 0:  # case est vide
    #                 nb_streak = 0  # reset le streak
    #                 prev = 0
    #             elif prev == self.board[i, j]:  # meme element sur la ligne
    #                 nb_streak += 1
    #                 nb_tours += 1
    #                 if nb_streak == 5:  # un joueur a gagne
    #                     return True, prev
    #                 if nb_tours == 120:  # nb maximum atteint
    #                     return True, prev
    #             else:
    #                 prev = self.board[i, j]  # le streak commence pour un joueur different
    #                 nb_streak = 1
    #                 nb_tours += 1
    #                 if nb_tours == 120: #nombre limite de tours atteints
    #                     return True, 0
    #         # changement de ligne
    #         nb_streak = 0
    #         prev = 0


    #     # pour les colonnes
    #     for j in range(self.size - 4):
    #         for i in range(self.size - 4):
    #             if self.board[i, j] == 0:  # case est vide
    #                 nb_streak = 0  # reset le streak
    #                 prev = 0
    #             elif prev == self.board[i, j]:  # meme element sur la colonne
    #                 nb_streak += 1
    #                 if nb_streak == 5:  # un joueur a gagné
    #                     return True, prev
    #             else:
    #                 prev = self.board[i, j]  # le streak commence pour un joueur different
    #                 nb_streak = 1
    #         # changement de colonne
    #         nb_streak = 0
    #         prev = 0


    #     # pour les diagonales sens \
    #     # triangle inferieur a diagonale '\'
    #     for i in range(self.size - 4 - 4):
    #         for k in range(self.size - 4 - i +1):
    #             if self.board[i + k, k] == 0:
    #                 nb_streak = 0
    #                 prev = 0
    #             elif prev == self.board[i + k, k]:
    #                 nb_streak += 1
    #                 if nb_streak == 5:
    #                     return True, prev
    #             else:
    #                 prev = self.board[i + k, k]
    #                 nb_streak = 1
    #         # changement de diagonale
    #         nb_streak = 0
    #         prev = 0

    #     # triangle superieur a diagonale '\'
    #     for j in range(1, self.size - 4 - 4):
    #         for k in range(self.size - 4 - j +1):
    #             if self.board[k, j + k] == 0:
    #                 nb_streak = 0
    #                 prev = 0
    #             elif prev == self.board[k, j + k]:
    #                 nb_streak += 1
    #                 if nb_streak == 5:
    #                     return True, prev
    #             else:
    #                 prev = self.board[k, j + k]
    #                 nb_streak = 1
    #         # changement de diagonale
    #         nb_streak = 0
    #         prev = 0


    #     # pour les diagonales sens /
    #     # triangle inferieur a diagonale '/'
    #     for i in range(self.size - 4 - 4):
    #         for k in range(self.size - 4 - i +1):
    #             if self.board[i + k, self.size - 4 - 1 - k] == 0:
    #                 nb_streak = 0
    #                 prev = 0
    #             elif prev == self.board[i + k, self.size - 4 - 1 - k]:
    #                 nb_streak += 1
    #                 if nb_streak == 5:
    #                     return True, prev
    #             else:
    #                 prev = self.board[i + k, self.size - 4 - 1 - k]
    #                 nb_streak = 1
    #         # changement de diagonale
    #         nb_streak = 0
    #         prev = 0

    #     # triangle superieur a diagonale '/'
    #     for j in range(4, self.size - 4 - 1):
    #         for k in range(j+1):
    #             if self.board[k, j - k] == 0:
    #                 nb_streak = 0
    #                 prev = 0
    #             elif prev == self.board[k, j - k]:
    #                 nb_streak += 1
    #                 if nb_streak == 5:
    #                     return True, prev
    #             else:
    #                 prev = self.board[k, j - k]
    #                 nb_streak = 1
    #         # changement de diagonale
    #         nb_streak = 0
    #         prev = 0

    #     return False, 0

    def check_done(self):
        #check lines
        line = self.board[self.move_history[-1][0], :]
        line = np.array2string(line)
        if line.count('1 1 1 1 1') == 1:
            return True, 1
        elif line.count('-1 -1 -1 -1 -1') == 1:
            return True, -1

        #check columns
        column = self.board[:, self.move_history[-1][1]]
        column = np.array2string(column)
        if column.count('1 1 1 1 1') == 1:
            return True, 1
        elif column.count('-1 -1 -1 -1 -1') == 1:
            return True, -1
        
        #check diagonals /
        for d_idxs_x, d_idxs_y in zip(self.forward_diag_idx_x, self.forward_diag_idx_y):
            for x, y in zip(d_idxs_x, d_idxs_y):
                if self.move_history[-1] == (x, y):
                    break
        diagonal = self.board[d_idxs_x, d_idxs_y]
        diagonal = np.array2string(diagonal)
        if diagonal.count('1 1 1 1 1') == 1:
            return True, 1
        elif diagonal.count('-1 -1 -1 -1 -1') == 1:
            return True, -1
        
        #check diagonals \
        for d_idxs_x, d_idxs_y in zip(self.backward_diag_idx_x, self.backward_diag_idx_y):
            for x, y in zip(d_idxs_x, d_idxs_y):
                if self.move_history[-1] == (x, y):
                    break
        diagonal = self.board[d_idxs_x, d_idxs_y]
        diagonal = np.array2string(diagonal)
        if diagonal.count('1 1 1 1 1') == 1:
            return True, 1
        elif diagonal.count('-1 -1 -1 -1 -1') == 1:
            return True, -1

        if self.move_count >= 120:
            return True, 0
        return False, 0

    def undo(self):
        if self.move_count > 0:
            self.move_count -= 1
            self.turn *= -1
            self.done, self.winner = self.check_winner()
            self.board[self.move_history[-1][0]][self.move_history[-1][1]] = 0
            self.move_history.pop()
            return True
        else:
            return False

    def reset(self):
        self.board = np.zeros((self.size, self.size))
        self.turn = 1
        self.winner = 0
        self.move_count = 0
        self.move_history = []
        self.done = False

    def get_possible_moves(self):
        possible_moves = []
        for i in range(self.size):
            for j in range(self.size):
                if self.board[i, j] == 0:
                    possible_moves.append((i, j))
        return possible_moves

    def get_possible_neighbor_moves(self):
        possible_moves = []
        if self.move_history == []:
            possible_moves = [(i, j) for i in range(self.size) for j in range(self.size)]
        else:
            for move in self.move_history:
                for i in range(-1, 2):
                    for j in range(-1, 2):
                        if i == 0 and j == 0:
                            continue
                        if (move[0] + i) in range(self.size) and (move[1] + j) in range(self.size) and self.board[move[0] + i, move[1] + j] == 0:
                            possible_moves.append((move[0] + i, move[1] + j))

        possible_moves = list(set(possible_moves))
        return possible_moves

# Node

In [92]:
class Node():
    def __init__(self, turn, is_root=False, root_moves=None):
        self.children = []
        self.parent = None
        self.value = 0
        self.num_visits = 0
        self.move = None
        self.turn = turn
        self.is_root = is_root
        self.root_moves = root_moves
        self.is_terminal = False

    def get_move_sequence(self):
        if self.is_root:
            return []
        else:
            return self.parent.get_move_sequence() + [self.move]

    def get_root_moves(self):
        if self.is_root:
            return self.root_moves
        else:
            return self.parent.get_root_moves()

    @property
    def possible_moves(self):
        if self.is_root:
            children_moves = [child.move for child in self.children]
            return [move for move in self.root_moves if move not in children_moves]
        else:
            used_moves = self.get_move_sequence()
            used_moves += [child.move for child in self.children]
            return [move for move in self.get_root_moves() if move not in used_moves]
        

    def add_child(self, child):
        self.children.append(child)
        child.parent = self

    def set_move(self, move):
        self.move = move
    
    def set_value(self, value):
        self.value = value

    def set_num_visits(self, num_visits):
        self.num_visits = num_visits

    def set_parent(self, parent):
        self.parent = parent

    def set_is_root(self, is_root):
        self.is_root = is_root
    
    

# Selection

In [93]:
def get_all_MCTSleaf_nodes(node):
    #leaf nodes are nodes with possible moves
    if node.is_terminal:
        return []
    elif node.possible_moves == []:
        l = []
        for child in node.children:
            l.extend(get_all_MCTSleaf_nodes(child))
        return l
    else:
        l = [node]
        for child in node.children:
            l.extend(get_all_MCTSleaf_nodes(child))
        return l

def selection(root):
    # gather all leaf nodes
    leaf_nodes = get_all_MCTSleaf_nodes(root)

    # get max score leaf node
    score = -float('inf')
    for l in leaf_nodes:
        num_visits = l.num_visits
        if num_visits == 0:
            num_visits += 1
        if l.is_root:
            parent_visits = num_visits
        else:
            parent_visits = l.parent.num_visits
        l_score = l.value / num_visits + 1.41 * np.sqrt(np.log(parent_visits) / num_visits)
        if l_score > score:
            score = l_score
            max_leaf = l
        elif l_score == score:
            max_leaf = random.choice([l, max_leaf])
    return max_leaf

# Expansion

In [94]:
def expansion(board, max_leaf):
    # Uniformly select a move from the possible moves
    move = random.choice(max_leaf.possible_moves)
    # move = max_leaf.possible_moves[0]

    # Create a new node with the move
    new_node = Node(max_leaf.turn * -1)
    new_node.set_move(move)

    # Check if the new node is terminal
    board.move(*move)
    if board.done:
        new_node.is_terminal = True
    
    # Add the new node to the parent's children
    max_leaf.add_child(new_node)
    return new_node

# Simulation

In [95]:
def simulation(board, leaf, policyplayer, policyadv):
    # Get move sequence
    move_sequence = leaf.get_move_sequence()
    
    for move in move_sequence:
        board.move(*move)
    while not board.done:
        if board.turn == 1:
            move = policyplayer(board)
            board.move(*move)
        else:
            move = policyadv(board)
            board.move(*move)
    return board.winner


# Backpropagation

In [96]:
def backpropagation(leaf, winner):
    leaf.set_value(leaf.value + winner * leaf.turn)
    leaf.set_num_visits(leaf.num_visits + 1)
    if leaf.parent != None:
        backpropagation(leaf.parent, winner)

In [97]:
def policyPlayer(board):
    return random.choice(board.get_possible_neighbor_moves())
    # return board.get_possible_moves()[0]

def policyAdv(board):
    return random.choice(board.get_possible_neighbor_moves())
    # return board.get_possible_moves()[0]

In [98]:
def MCTS(board, tree, policyplayer, policyadv):
    # Selection
    with Timer('Selection'):
        leaf = selection(tree)
    
    # Expansion
    with Timer('Expansion'):
        leaf = expansion(deepcopy(board), leaf)
    
    # Simulation
    with Timer('Simulation'):
        winner = simulation(deepcopy(board), leaf, policyplayer, policyadv)

    # Backpropagation
    with Timer('Backpropagation'):
        backpropagation(leaf, winner)

    return tree
    

# Inference

In [99]:
def inference(tree):
    score = float('inf')
    for node in tree.children:
        if node.value / node.num_visits < score:
            score = node.value / node.num_visits
            max_node = node
    return max_node

# Init board

In [100]:
board = Board(15)
tree = Node(1, is_root=True, root_moves=board.get_possible_neighbor_moves())
len(tree.possible_moves)

225

In [101]:
board.move(0,0)

True

# Run MCTS

In [102]:
delay = 4
simulations = 0
start = time()
while time() - start < delay:
    MCTS(board, tree, policyPlayer, policyAdv)
    simulations += 1
print(f'Num simulations: {simulations}')

#get depth
def get_max_depth(root):
    if root.children == []:
        return 0
    else:
        max_depth = 0
        for child in root.children:
            max_depth = max(max_depth, get_max_depth(child))
        return max_depth + 1
print(f'Max depth: {get_max_depth(tree)}')

print(f'root: value: {tree.value}, num_visits: {tree.num_visits}, num_children: {len(tree.children)}')

for child in tree.children:
    print(f'child #1: value: {child.value}, num_visits: {child.num_visits}, num_children: {len(child.children)}, move: {child.move}')
    for child2 in child.children:
            print(f'child #2: value: {child2.value}, num_visits: {child2.num_visits}, num_children: {len(child2.children)}, move: {child2.move}')

Selection time: 0.00015354156494140625
Expansion time: 0.008122444152832031
Simulation time: 0.21791863441467285
Backpropagation time: 8.106231689453125e-06
Selection time: 0.0001373291015625
Expansion time: 0.0014908313751220703
Simulation time: 0.15798735618591309
Backpropagation time: 5.245208740234375e-06
Selection time: 0.000118255615234375
Expansion time: 0.0012128353118896484
Simulation time: 0.15755462646484375
Backpropagation time: 7.3909759521484375e-06
Selection time: 0.0001747608184814453
Expansion time: 0.0018901824951171875
Simulation time: 0.16501951217651367
Backpropagation time: 6.9141387939453125e-06
Selection time: 0.0002338886260986328
Expansion time: 0.0015604496002197266
Simulation time: 0.16466689109802246
Backpropagation time: 8.58306884765625e-06
Selection time: 0.00027251243591308594
Expansion time: 0.0015952587127685547
Simulation time: 0.16525721549987793
Backpropagation time: 1.3589859008789062e-05
Selection time: 0.0004260540008544922
Expansion time: 0.001

In [103]:
best_node = inference(tree)
print(best_node.move)
board.move(*best_node.move)

print(board)

(12, 3)
    0  1  2  3  4  5  6  7  8  9 10 11 12 13 14
0 | X  -  -  -  -  -  -  -  -  -  -  -  -  -  - |
1 | -  -  -  -  -  -  -  -  -  -  -  -  -  -  - |
2 | -  -  -  -  -  -  -  -  -  -  -  -  -  -  - |
3 | -  -  -  -  -  -  -  -  -  -  -  -  -  -  - |
4 | -  -  -  -  -  -  -  -  -  -  -  -  -  -  - |
5 | -  -  -  -  -  -  -  -  -  -  -  -  -  -  - |
6 | -  -  -  -  -  -  -  -  -  -  -  -  -  -  - |
7 | -  -  -  -  -  -  -  -  -  -  -  -  -  -  - |
8 | -  -  -  -  -  -  -  -  -  -  -  -  -  -  - |
9 | -  -  -  -  -  -  -  -  -  -  -  -  -  -  - |
10| -  -  -  -  -  -  -  -  -  -  -  -  -  -  - |
11| -  -  -  -  -  -  -  -  -  -  -  -  -  -  - |
12| -  -  -  O  -  -  -  -  -  -  -  -  -  -  - |
13| -  -  -  -  -  -  -  -  -  -  -  -  -  -  - |
14| -  -  -  -  -  -  -  -  -  -  -  -  -  -  - |



In [104]:
moved = False
while not moved:
    player_move = (int(input('Enter row: ')), int(input('Enter column: ')))
    moved = board.move(*player_move)
    if not moved:
        print('Invalid move')

print(board)

ValueError: invalid literal for int() with base 10: ''

In [None]:
is_new_child = True
for child in best_node.children:
    if child.move == player_move:
        tree = child
        tree.is_root = True
        tree.root_moves = board.get_possible_neighbor_moves()
        is_new_child = False
        break

if is_new_child:
    tree = Node(1, is_root=True, root_moves=board.get_possible_neighbor_moves())
