In [1]:
import random 
import math

#### Below is an implementation of Multi Agent Monte Carlo Tree Search (MCTS) for the game of tic tac toe. The game is turn based and the search tree levels alternate between player and opponent nodes, i.e. all sucessors of a player node are opponent nodes and vice versa. The goal of the algorithm is to iteratively build up a search tree and at the end of the iterations, choose the best possible actions based on the values of the successor nodes of the root node. There are 4 main steps in the algorithm, always starting from the root node:

1) `Selection`: This step involves traversing down the tree until a leaf node is found (a leaf node is a node which has not been expanded before, i.e. it has no children). The traversal is done using an exploration-exploitation strategy (UCT) which balances nodes with higher value with random exploration. (The value of a node is the accumulated reward for the player at that node in proportion to the number of times that node has been visited/traversed.)

2) `Expansion`: The selcted leaf node is expanded (i.e. all of its children are generated), and the one of its children is picked at random and it's action is executed.

3) `Simulation`: Then a simulation is run from the game state resulting from executing the action of the random child. 

4) `Backpropagation`: Rewards from the simulation results are accumulated on every node along the path from the random child to the root. In this case the reward from the simuation is 1 for the winning player and 0 for the loser or 0 for both if it's a draw.  

After many iterations, we can then extract the best action by choosing the action of the child node of the root which has the best value, i.e. best accuumulated reward to number of visitc ratio. This type of strategy is called a "self-play" tree policy because the player and opponent both share the same state-action space, so the opponentn can be thought of as the player playing against itself.


In [None]:
class TicTacToe:
    def __init__(self) -> None:
        self.board = [' '] * 9 # initially empty board
        self.current_player = 'X' # player X gets first turn

    # returns list of empty positions on the board
    def get_legal_moves(self):    
        return [i for i, cell in enumerate(self.board) if cell == ' ']
    

    # player move, marks a position on the board, then switches turn
    def make_move(self, move):
        self.board[move] = self.current_player
        # switch turn
        self.current_player = 'O' if self.current_player == 'X' else 'X'

    # checks if game is over
    def is_terminal(self):
        winning_combinations = [(0,1,2), (3,4,5), (6,7,8),  # rows
                                (0,3,6), (1,4,7), (2,5,8),  # cols
                                (0,4,8), (2,4,6)            # diagonals 
                                ]                  
        
        for combo in winning_combinations:
            if self.board[combo[0]] == self.board[combo[1]] == self.board[combo[2]] != ' ':
                return True 

    # returns winner is there is one
    def gets_winner(self):
        winning_combinations = [(0,1,2), (3,4,5), (6,7,8),  # rows
                                (0,3,6), (1,4,7), (2,5,8),  # cols
                                (0,4,8), (2,4,6)            # diagonals 
                                ]                  
        
        for combo in winning_combinations:
            if self.board[combo[0]] == self.board[combo[1]] == self.board[combo[2]] != ' ':
                return self.board[combo[0]] 
            
        return ' '    
 
    # display the game board
    def print_board(self):
        print('---------')
        for i in range(0, 9, 3):
            print(self.board[i], '|', self.board[i+1], '|', self.board[i+2])
        print('---------')



class Node:

    next_node_id = 0

    def __init__(self, move=None, parent=None) -> None:
        self.move = move
        self.parent = parent
        self.children = []
        self.wins = 0
        self.visits = 0
        self.id = Node.next_node_id
        Node.next_node_id += 1

    def add_child(self, child):
        self.children.append(child)    

    def update(self, win):
        # update visit stats
        self.visits += 1
        # accumulate rewards
        self.wins += win



class MultiAgentMCTS:
    def __init__(self, exploration_constant=1.4, iterations=10) -> None:
        self.exploration_constant = exploration_constant
        self.iterations = iterations


    # UCB selection of successor node
    def select_child(self, node):
        total_visits = node.visits    # sum(child.visits for child in node.children)
        log_total_visits = math.log(total_visits)

        best_score = float("-inf")
        best_child = None

        # find child node with highest score
        for child in node.children:
            exploit_term = child.wins/child.visits
            explore_term = self.exploration_constant * math.sqrt(2.0*log_total_visits/child.visits) 
            score = exploit_term + explore_term
            if score > best_score:
                best_score = score
                best_child = child

        return best_child    

    # traverses down the tree and selects an unexplored/unexpanded child node which is not a termninal state
    def select(self, node, state):
        while node.children and not state.is_terminal():
            # select best child according to UCB bandit
            child = self.select_child(node)
            # execute it's move
            state.make_move(child.move)
            node = child
        
        return node, state    
    

    '''
    def expand_partial(self, selected_node, state):
        legal_moves = state.get_legal_moves()
        unexplored_moves = [move for move in legal_moves if not any(child.move == move for child in selected_node.children)]
        if unexplored_moves:
            # randomly pick one of the unexplored actions available to selected node and generate a child/successor node from it
            move = random.choice(unexplored_moves) 
            state.make_move(move)
            new_node = Node(move, node)
            node = node.add_child(new_node)

        return new_node, state
    '''


    # generates all children of a node given the game state represented by that node
    def expand(self, node, game_state):
        # get all available actions for this node
        legal_moves = game_state.get_legal_moves()

        if len(legal_moves) == 0:
            game_state.print_board()
            raise RuntimeError("Error! No legal moves found from this state!")

        # generate all successors
        for move in legal_moves:
            new_child = Node(move=move, parent=node)
            node.add_child(new_child)    

        return node


    # random/monte carlo simulation to terminal state
    def simulate(self, state):
        while not state.is_terminal():
            legal_moves = state.get_legal_moves()
            move = random.choice(legal_moves)
            state.make_move(move)

        return state    


    # backpropagate the simulation rewards up to root node
    def backpropagate(self, node, state):
        winner =  state.get_winner()
        while node is not None: 
            win = 1 if winner == node.move else 0
            # update node stats
            node.update(win)
            node = node.parent   


    def search(self, game_state):
        # create a root node
        root = Node()
        # expand the root node
        root = self.expand(root, game_state)

        # run MCTS iterations
        for _ in range(self.iterations):
            node = root
            # make a copy of the initial game state
            state = TicTacToe()
            state.board = list(game_state.board)

            # select leaf node
            selected_node, state = self.select(node, state)
        
            # expand leaf node and pick one of its successors at random and execute it's move
            expanded_node, state = self.expand(selected_node, state)
            random_child = random.choice(expanded_node.children)
            state.make_move(random_child.move)

            # simulation 
            state = self.simulate(state)

            # backpropagation
            self.backpropagate(random_child, state)

        # after iterations are done, return the best move (i.e. the child with highest value)
        best_move = max(root.children, key = lambda child: child.wins/child.visits).move 
        
        return best_move           

