In [11]:
import sys
import numpy as np
from math import sqrt, log
import random

# Game board for tic-tac-toe
class Board:

    def __init__(self):
        self.state = np.zeros([3,3])
#         self.state = np.zeros([5,5])
        self.player = 0 

    def copy(self):
        copy = Board()
        copy.player = self.player
        copy.state = np.copy(self.state)
        return copy
    
    # Go to next state
    def move(self, move): 
        # If the postion is already occupied, jump out.
        if np.any(self.state[move[0],move[1]]): 
            print('Already occupied!')
            sys.exit()
        else:
            self.state[move[0],move[1]] = 8 if self.player == 0 else 7
            self.player ^= 1 

    def vacancy_on_board(self): 
        return np.argwhere(self.state == 0).tolist()
    

    def result(self): 
        if self.player^1 == 0:
            col_sum = np.any(np.sum(self.state,axis=0)== 24) 
            row_sum = np.any(np.sum(self.state,axis=1)==24)
            d1_sum  = np.any(np.trace(self.state)==24)
            d2_sum  = np.any(np.trace(np.flip(self.state,1))==24)
#             col_sum = np.any(np.sum(self.state,axis=0)== 40) 
#             row_sum = np.any(np.sum(self.state,axis=1)==40)
#             d1_sum  = np.any(np.trace(self.state)==40)
#             d2_sum  = np.any(np.trace(np.flip(self.state,1))==40)
        else:
            col_sum = np.any(np.sum(self.state,axis=0)== 21) 
            row_sum = np.any(np.sum(self.state,axis=1)==21)
            d1_sum  = np.any(np.trace(self.state)==21)
            d2_sum  = np.any(np.trace(np.flip(self.state,1))==21)
#             col_sum = np.any(np.sum(self.state,axis=0)== 35) 
#             row_sum = np.any(np.sum(self.state,axis=1)==35)
#             d1_sum  = np.any(np.trace(self.state)==35)
#             d2_sum  = np.any(np.trace(np.flip(self.state,1))==35)
        return col_sum or row_sum or d1_sum or d2_sum

    # Check if the game is terminated
    def is_terminal(self):
        if self.result():
            return True
        else:
            return False
        
# Track the current node
class Node:

    def __init__(self, parent=None, action=None, board=None):
        self.parent = parent
        self.board = board
        self.children = []
        self.wins = 0
        self.visits = 0
        self.unexplored = board.vacancy_on_board()
        self.action = action
        
    def expand(self, action, board):
        child = Node(parent=self, action=action, board=board)
        self.unexplored.remove(action)
        self.children.append(child)
        return child    
       
    def best_child(self):
        s = sorted(self.children, key=lambda c:c.wins/c.visits+0.2*sqrt(2*log(self.visits)/c.visits))
        return s[-1] 
    
    def update(self, result):
        self.visits += 1
        self.wins += result

# Main function of MC algorithm
def monte_carlo_tree_search(rootstate, iteration):
    root = Node(board=rootstate)
    for i in range(iteration):
        node = root
        board = rootstate.copy()
        node,board = tree_policy(node,board)
        board = default_policy(node,board)
        backup(node,board)
    s = sorted(root.children, key=lambda c:c.wins/c.visits)
    if len(s) == 0:
        print("Draw")
        sys.exit()
    else:
        return s[-1].action

# Search and expand function
def tree_policy(node,board):
    # if fully expanded and not leaf node, find the best child
    while node.unexplored == [] and node.children != []: 
        node = node.best_child()
        board.move(node.action)
    # if leaf, expand to new node
    if node.unexplored != []:
        a = random.choice(node.unexplored)
        board.move(a)
        node = node.expand(a, board.copy())
    return node, board

# Simulate function
def default_policy(node,board):
    while board.vacancy_on_board() != [] and not board.result():
        board.move(random.choice(board.vacancy_on_board()))
    return board

# Backpropagation function
def backup(node,board):      
    while node is not None:
        result = board.result()
        if result:
            if node.board.player==board.player:
                result = 1
            else: result = -1
        else: result = 0
        node.update(result)
        node = node.parent

In [12]:
b = Board() 
# Play with the computer. Me as 8, computer as 7
while b.vacancy_on_board() != [] and not b.is_terminal():
    line = input("It's your turn: ")
    my_turn = line[1:len(line)-1]
    my_turn = my_turn.split(',')
    my_turn = [int(i) for i in my_turn]
    b.move(my_turn)
    machine_turn = monte_carlo_tree_search(b,1000)  
    b.move(machine_turn)    
    print(b.state)   

It's your turn: [1,1]
[[7. 0. 0.]
 [0. 8. 0.]
 [0. 0. 0.]]
It's your turn: [0,1]
[[7. 8. 0.]
 [0. 8. 0.]
 [0. 7. 0.]]
It's your turn: [1,0]
[[7. 8. 0.]
 [8. 8. 7.]
 [0. 7. 0.]]
It's your turn: [2,0]
[[7. 8. 7.]
 [8. 8. 7.]
 [8. 7. 0.]]
It's your turn: [2,2]
Draw


SystemExit: 