In [1]:
import random, math, time, copy
from tqdm import tqdm


#### Below is an implementation of Multi Agent Monte Carlo Tree Search (MCTS) for the game of tic tac toe. The game is turn based and the search tree levels alternate between player and opponent nodes, i.e. all sucessors of a player node are opponent nodes and vice versa. The goal of the algorithm is to iteratively build up a search tree and at the end of the iterations, choose the best possible actions based on the values of the successor nodes of the root node. There are 4 main steps in the algorithm, always starting from the root node:

1) `Selection`: This step involves traversing down the tree until a leaf node is found (a leaf node is a node which has not been expanded before, i.e. it has no children). The traversal is done using an exploration-exploitation strategy (UCT) which balances nodes with higher value with random exploration. (The value of a node is the accumulated reward for the player at that node in proportion to the number of times that node has been visited/traversed.)

2) `Expansion`: The selcted leaf node is expanded (i.e. all of its children are generated), and the one of its children is picked at random and it's action is executed.

3) `Simulation`: Then a simulation is run from the game state resulting from executing the action of the random child. 

4) `Backpropagation`: Rewards from the simulation results are accumulated on every node along the path from the random child to the root. In this case the reward from the simuation is 1 for the winning player and 0 for the loser or 0 for both if it's a draw.  

After many iterations, we can then extract the best action by choosing the action of the child node of the root which has the best value, i.e. best accuumulated reward to number of visitc ratio. This type of strategy is called a "self-play" tree policy because the player and opponent both share the same state-action space, so the opponentn can be thought of as the player playing against itself.


In [74]:
class TicTacToe:
    def __init__(self, first_move = 'X') -> None:
        self.board = [' '] * 9 # initially empty board
        self.current_player = first_move # player X gets first turn

    # returns list of empty positions on the board
    def get_legal_moves(self):    
        return [i for i, cell in enumerate(self.board) if cell == ' ']
    

    # player move, marks a position on the board, then switches turn
    def make_move(self, move):
        self.board[move] = self.current_player
        # switch turn
        self.current_player = 'O' if self.current_player == 'X' else 'X'


    # checks if game is over
    def is_terminal(self):
        winning_combinations = [(0,1,2), (3,4,5), (6,7,8),  # rows
                                (0,3,6), (1,4,7), (2,5,8),  # cols
                                (0,4,8), (2,4,6)            # diagonals 
                                ]                  
        # check for winning combination        
        for combo in winning_combinations:
            if self.board[combo[0]] == self.board[combo[1]] == self.board[combo[2]] != ' ':
                return True 

        # check if board has no empty position
        if len(self.get_legal_moves()) == 0:
            return True
        
        return False


    # returns winner is there is one
    def get_winner(self):
        winning_combinations = [(0,1,2), (3,4,5), (6,7,8),  # rows
                                (0,3,6), (1,4,7), (2,5,8),  # cols
                                (0,4,8), (2,4,6)            # diagonals 
                                ]                  
        
        for combo in winning_combinations:
            if self.board[combo[0]] == self.board[combo[1]] == self.board[combo[2]] != ' ':
                return self.board[combo[0]] 
            
        return ' '    
 
    
    def get_opponent(self, player):
        if player == 'X':
            return 'O'
        elif player == 'O':
            return 'X'
 

    # display the game board
    def print_board(self):
        print('---------')
        for i in range(0, 9, 3):
            print(self.board[i], '|', self.board[i+1], '|', self.board[i+2])
        print('---------')



class Node:

    next_node_id = 0

    def __init__(self, player=None, move=None, parent=None) -> None:
        self.player = player
        self.move = move
        self.parent = parent
        self.children = []
        self.wins = 0
        self.losses = 0
        self.visits = 0
        self.id = Node.next_node_id
        Node.next_node_id += 1

    def add_child(self, child):
        self.children.append(child)    

    def update(self, win):
        # update visit stats
        self.visits += 1
        # accumulate rewards
        self.wins += win
        self.losses += (1-win) 



class MultiAgentMCTS:
    def __init__(self, exploration_constant=0.5, iterations=10) -> None:
        self.exploration_constant = exploration_constant
        self.iterations = iterations


    # UCB selection of successor node
    def select_child_uct(self, node):

        total_visits = node.visits    # sum(child.visits for child in node.children)
        best_score = float("-inf")
        best_child = None

        # find child node with highest score
        for child in node.children:
            # avoid division by zero
            if child.visits == 0 or total_visits == 0:
                score = float("inf")
            else:    
                log_total_visits = math.log(total_visits)
                exploit_term = child.wins/child.visits
                explore_term = self.exploration_constant * math.sqrt(2.0*log_total_visits/child.visits)         
                score = exploit_term + explore_term
                
            if score > best_score:
                best_score = score
                best_child = [child]
            elif score == best_score:
                 best_child += [child]   

            # if there are multiple best children, pick one randomly
        return random.choice(best_child)
        

    # traverses down the tree and selects an unexplored/unexpanded child node which is not a termninal state
    def select(self, node, state):
        while node.children and not state.is_terminal():
            # select best child according to UCB bandit
            child = self.select_child_uct(node)
            # execute it's move
            state.make_move(child.move)
            node = child
        
        return node, state    
    

    '''
    def expand_partial(self, selected_node, state):
        legal_moves = state.get_legal_moves()
        unexplored_moves = [move for move in legal_moves if not any(child.move == move for child in selected_node.children)]
        if unexplored_moves:
            # randomly pick one of the unexplored actions available to selected node and generate a child/successor node from it
            move = random.choice(unexplored_moves) 
            state.make_move(move)
            new_node = Node(move, node)
            node = node.add_child(new_node)

        return new_node, state
    '''


    # generates all children of a node given the game state represented by that node
    def expand(self, node, game_state, init=False):
        # get all available actions for this node
        #if init:
        #    print(f"initial expansion of root node: Initial game state board:")
        #    game_state.print_board()
        
        legal_moves = game_state.get_legal_moves()

        if len(legal_moves) == 0:
            #print(f"ERROR!!!!")
            game_state.print_board()
            raise RuntimeError("Error! No legal moves found from this state!")

        # generate all successors
        for move in legal_moves:
            new_child = Node(player=game_state.get_opponent(node.player), move=move, parent=node)
            node.add_child(new_child)    

        return node


    # random/monte carlo simulation to terminal state
    def simulate(self, state):
        while not state.is_terminal():
            legal_moves = state.get_legal_moves()
            #state.print_board()
            #print(f"Available moves for simulation")
            move = random.choice(legal_moves)
            state.make_move(move)

        #print(f"Simulation completed:")
        #state.print_board()
        
        return state    


    # backpropagate the simulation rewards up to root node
    def backpropagate(self, node, state):
        winner =  state.get_winner()
        #print(f"Winner: {winner if winner != ' ' else 'draw'}")
        while node is not None: 
            if winner == node.player:
                win = 1 
            elif winner == ' ': 
                win = 0.5
            else:
                win = 0

            #win = 0 if winner == node.move else 1
            #print(f"Backpropagation, node_id = {node.id}, move = {node.move}, win = {win}, parent_id = {node.parent.id if node.parent != None else None}")
            # update node stats
            node.update(win)
            node = node.parent   
    
    def get_best_move(self, root, greedy=True):
        best_child = []
        best_score = float("-inf")
        scores = []


        for child in root.children:
            if child.visits > 0:
                score = child.wins/child.visits
            else: 
                score = float("-inf")

            scores.append(score)
            if score > best_score:
                best_score = score
                best_child = [child]
            elif score == best_score:
                best_child.append(child)


        # greedy always picks the child node with highest value
        if greedy:
            # if multiple best child, pick one at random
            return random.choice(best_child).move    

        # softmax/temperature chooses according to a probability distribution
        else:

            temperature = 0.5 #0.9
            # exponentiate the scores
            scores_exp = []
            for score in scores:
                if score > float("-inf"):
                    scores_exp.append(math.exp(score/temperature))
                else:
                    scores_exp.append(0.0)

            sum_exp = sum(scores_exp)
            # if all scores are zero, then just pick a child at random
            if sum_exp == 0.0:
                return random.choice(root.children).move
                
            probabilities = [score_exp/sum_exp for score_exp in scores_exp]   
            # add a tiny bit of random noise to these probabilities
            noise_magnitude = 0.00001 * (min(probabilities) + max (probabilities))
            probabilities = [(p + noise_magnitude*random.random()) for p in probabilities]

            # sample a child node according to thsi probability distribution
            child_index = random.choices(list(enumerate(root.children)), weights=probabilities, k=1)[0][0]
            #print(f"probabilities: {probabilities}, child_index = {child_index}")
            child_move = root.children[child_index].move
            return child_move



    # performs monte carlo tree search iterations (we create a new tree every time, could maybe reuse parts of the same tree)
    def search(self, game_state, timeout=1.0):
        # create a root node 
        root = Node(player=game_state.current_player, move=None, parent=None)
        # expand the root node
        root = self.expand(root, game_state, init=True)

        # start the timer
        start_time = time.time()
        current_time = time.time()
        num_iterations = 0

        #print(f"Game state before MCTS, player turn: {game_state.current_player}, Board: ")
        #game_state.print_board()


        # run MCTS iteration until time out
        while current_time < start_time + timeout and num_iterations < 200:
            node = root
            # make a copy of the initial game state
            state = TicTacToe()
            state.board = copy.deepcopy(game_state.board)
            state.current_player = game_state.current_player

            #print(f"Game state board at beginning of iteration:")
            #state.print_board()

            # select leaf node
            selected_node, state = self.select(node, state)
        
            #print(f"Game state board after node selection:")
            #state.print_board()

            # carry out steps 2-4 if the selected node is not a terminal state
            if not state.is_terminal():
                # expand leaf node and pick one of its successors at random and execute it's move
                expanded_node = self.expand(selected_node, state)
                random_child = random.choice(expanded_node.children)
                state.make_move(random_child.move)

                # simulation 
                state = self.simulate(state)

                # backpropagation
                self.backpropagate(random_child, state)

            #print(f"Iterations# {num_iterations}, Root node_id :{root.id}, num vists = {root.visits}, wins = {root.wins}, value = {root.wins/root.visits}")
            
            current_time = time.time()  
            num_iterations += 1


        # after iterations are done, return the best move (i.e. the child with highest value)
        best_move = self.get_best_move(root) 
        #print(f"# of Iterations {num_iterations}, Root node_id :{root.id}, Player: {root.move}, num vists = {root.visits}, wins = {root.wins}, value = {(root.wins/root.visits) if root.visits > 0 else None}")
        #print(f"Num iterations: {num_iterations}, Best move: {best_move}")
        
        
        #print(f"Game state after MCTS, player turn: {game_state.current_player}, Board: ")
        #game_state.print_board()

        return best_move           



In [75]:
'''# instantiate a game oject
game = TicTacToe()

# instantiate mcts solver
mcts_solver = MultiAgentMCTS(iterations=300)

mcts_solver.search(game)'''

'# instantiate a game oject\ngame = TicTacToe()\n\n# instantiate mcts solver\nmcts_solver = MultiAgentMCTS(iterations=300)\n\nmcts_solver.search(game)'

In [76]:
# a game of tictac toe played by MCTS agent against itself
def self_play(mcts_timeout=1.0, random_X=False, random_O=False):
    # instantiate a game oject
    game_state = TicTacToe()
    # instantiate mcts solver for player 1
    mcts_solver1 = MultiAgentMCTS()
    # instantiate mcts solver for player 2
    mcts_solver2 = MultiAgentMCTS()

    #game_state.print_board()

    while not game_state.is_terminal():
        # alternating player-ooponent move
        #print(f"Player {game_state.current_player} makes a move:")

        if game_state.current_player == 'X':
            if random_X:
                best_move = random.choice(game_state.get_legal_moves())
            else:
                best_move = mcts_solver1.search(game_state, mcts_timeout)
            

        else:
            if random_O:
                best_move = random.choice(game_state.get_legal_moves())
            else:
                best_move = mcts_solver2.search(game_state, mcts_timeout)

            
        game_state.make_move(best_move)
        
        #game_state.print_board()
 

    winner =  game_state.get_winner()
    #print(f"Winner: {winner if winner != ' ' else 'draw'}")
    
    return winner

In [77]:
#self_play()

In [78]:
# perform many self-plays and compute win rates for each player
def win_rate(num_games=200, random_X=False, random_O=True):
    playerX = 'X'
    playerO = 'O'
    draw = ' '
    wins = {playerX: 0, playerO: 0, draw: 0}
    total_games = 0

    with tqdm(total=num_games, ncols=80, bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt}') as pbar:
        for i in range(num_games):
            winner = self_play(random_X=random_X, random_O=random_O)
            wins[winner] =  wins[winner] + 1
            total_games += 1
            #first_move = 'O' if first_move == 'X' else 'X'
            pbar.update(1)

    print(f"Games played = {total_games}, Player {playerX} win rate = {wins[playerX]/total_games}, Player {playerO} win rate = {wins[playerO]/total_games}, draw rate = {wins[draw]/total_games}")    

In [80]:
win_rate(num_games=300,random_X=False, random_O=True)

  0%|                                                                    | 0/300

100%|██████████████████████████████████████████████████████████████████| 300/300

Games played = 300, Player X win rate = 0.17333333333333334, Player O win rate = 0.72, draw rate = 0.10666666666666667





#### During self play, the player who gets the first move seems to always have a disproportionately higher win rate!

In [35]:
import random
import math
import copy

class Node:
    def __init__(self, state, parent=None):
        self.state = state
        self.parent = parent
        self.children = []
        self.wins = 0
        self.visits = 0

    def calculate_uct(self, exploration_constant):
        if self.visits == 0:
            return float("inf")
        return self.wins / self.visits + exploration_constant * math.sqrt(math.log(self.parent.visits) / self.visits)

class MCTSTicTacToe:
    def __init__(self, exploration_constant=1.0):
        self.exploration_constant = exploration_constant

    def find_best_move(self, state, num_simulations):
        root = Node(state)
        for _ in range(num_simulations):
            node = self.selection(root)
            if not node.state.is_game_over():
                node = self.expansion(node)
            result = self.simulation(node)
            self.backpropagation(node, result)
        
        best_child = max(root.children, key=lambda child: child.visits)
        return best_child.state.last_move

    def selection(self, node):
        while node.children:
            node = max(node.children, key=lambda child: child.calculate_uct(self.exploration_constant))
        return node

    def expansion(self, node):
        unexplored_moves = node.state.get_valid_moves()
        if unexplored_moves:
            move = random.choice(unexplored_moves)
            new_state = node.state.get_updated_state(move)
            new_node = Node(new_state, parent=node)
            node.children.append(new_node)
            return new_node
        return node

    def simulation(self, node):
        current_state = node.state
        while not current_state.is_game_over():
            move = random.choice(current_state.get_valid_moves())
            current_state = current_state.get_updated_state(move)
        return current_state.get_game_result()

    def backpropagation(self, node, result):
        while node:
            node.visits += 1
            if result == node.state.get_player():
                node.wins += 1
            node = node.parent

class TicTacToeState:
    def __init__(self):
        self.board = [[None, None, None] for _ in range(3)]
        self.last_move = None
        self.current_player = "X"

    def get_valid_moves(self):
        moves = []
        for i in range(3):
            for j in range(3):
                if self.board[i][j] is None:
                    moves.append((i, j))
        return moves

    def get_updated_state(self, move):
        i, j = move
        new_state = TicTacToeState()
        new_state.board = [row.copy() for row in self.board]
        new_state.board[i][j] = self.current_player
        new_state.last_move = move
        new_state.current_player = "O" if self.current_player == "X" else "X"
        return new_state

    def is_game_over(self):
        return self.get_game_result() is not None or len(self.get_valid_moves()) == 0

    def get_game_result(self):
        
        for row in self.board:
            if row[0] == row[1] == row[2] != None:
                return row[0]
        
        for j in range(3):
            if self.board[0][j] == self.board[1][j] == self.board[2][j] != None:
                return self.board[0][j]
            
        if self.board[0][0] == self.board[1][1] == self.board[2][2] != None:
            return self.board[0][0]
        if self.board[0][2] == self.board[1][1] == self.board[2][0] != None:
            return self.board[0][2]
        
        return None

    def get_player(self):
        return self.current_player

    # display the game board
    def print_board(self):
        new_board = copy.deepcopy(self.board)
        for i in range(3):
            for j in range(3):
                if self.board[i][j] == None:
                    new_board[i][j] = '-'

        print('---------')
        for i in range(3):
            print(new_board[i][0], '|', new_board[i][1], '|', new_board[i][2])
        print('---------')    


In [36]:
'''state = TicTacToeState()
mcts = MCTSTicTacToe()
best_move = mcts.find_best_move(state, num_simulations=1000)
'''

'state = TicTacToeState()\nmcts = MCTSTicTacToe()\nbest_move = mcts.find_best_move(state, num_simulations=1000)\n'

In [61]:
def play_game(random_X=False, random_O=False):
    state = TicTacToeState()
    mcts_X = MCTSTicTacToe()
    mcts_O = MCTSTicTacToe()

    #state.print_board()
    #print(f"Starting game...")

    while not state.is_game_over():
        if state.current_player == "X":
            if not random_X:
                best_move = mcts_X.find_best_move(state, num_simulations=5000)
            else:
                best_move = random.choice(state.get_valid_moves())
        else:
            if not random_O:
                best_move = mcts_O.find_best_move(state, num_simulations=5000)
            else:
                best_move = random.choice(state.get_valid_moves())
    
        #print(f"Current player: {state.get_player()}")
        state = state.get_updated_state(best_move)
        #state.print_board()

    #print(f"Game over.")
    return state.get_game_result()


def self_play(num_games):
    results = {"X": 0, "O": 0, "Draw": 0}

    for _ in range(num_games):
        result = play_game(random_X=True, random_O=True)
        #print(f"Game {_}, result = {result}")
        if result == None:
            result = "Draw"
        results[result] += 1

    print(f"Results--> X win rate: {results['X']/num_games}, O win rate: {results['O']/num_games}, Draw rate: {results['Draw']/num_games}")



In [56]:
self_play(num_games=1000)

Results--> X win rate: 0.591, O win rate: 0.273, Draw rate: 0.136


In [60]:
self_play(num_games=1000)

Results--> X win rate: 0.564, O win rate: 0.306, Draw rate: 0.13


In [62]:
self_play(num_games=1000)

Results--> X win rate: 0.595, O win rate: 0.283, Draw rate: 0.122


In [17]:
import random
import math

class TicTacToeState:
    def __init__(self):
        self.board = [" " for _ in range(9)]
        self.current_player = "X"

    def get_valid_moves(self):
        return [i for i, val in enumerate(self.board) if val == " "]

    def make_move(self, move):
        self.board[move] = self.current_player
        self.current_player = "O" if self.current_player == "X" else "X"

    def is_game_over(self):
        winning_combinations = [
            [0, 1, 2], [3, 4, 5], [6, 7, 8],  # rows
            [0, 3, 6], [1, 4, 7], [2, 5, 8],  # columns
            [0, 4, 8], [2, 4, 6]  # diagonals
        ]

        for combination in winning_combinations:
            if self.board[combination[0]] == self.board[combination[1]] == self.board[combination[2]] != " ":
                return True

        return " " not in self.board

    def get_game_result(self):
        if self.is_game_over():
            if self.current_player == "X":
                return "O"
            elif self.current_player == "O":
                return "X"
            else:
                return "Draw"

        return None

class MCTSTicTacToe:
    def __init__(self, exploration_constant=1.4):
        self.exploration_constant = exploration_constant

    def find_best_move(self, state, num_simulations):
        root_node = Node(state)

        for _ in range(num_simulations):
            node = root_node
            current_state = state

            # Selection
            while not node.untried_moves and node.child_nodes:
                node = self.select_child(node, current_state)
                current_state.make_move(node.move)

            # Expansion
            if node.untried_moves:
                move = random.choice(node.untried_moves)
                current_state.make_move(move)
                node = node.add_child(move, current_state)

            # Simulation
            while not current_state.is_game_over():
                moves = current_state.get_valid_moves()
                random_move = random.choice(moves)
                current_state.make_move(random_move)

            # Backpropagation
            while node:
                result = current_state.get_game_result()
                node.update(result)
                node = node.parent

        best_child = root_node.get_best_child()  # Choose the best move based on the child node visits
        return best_child.move

    def select_child(self, node, state):
        total_visits = sum(child.visits for child in node.child_nodes)
        log_total_visits = math.log(total_visits)

        best_score = float("-inf")
        best_child = None

        for child in node.child_nodes:
            exploration_term = math.sqrt(log_total_visits / child.visits)
            score = child.wins / child.visits + self.exploration_constant * exploration_term

            if score > best_score:
                best_score = score
                best_child = child

        state.make_move(best_child.move)
        return best_child

class Node:
    def __init__(self, state, move=None, parent=None):
        self.state = state
        self.move = move
        self.parent = parent
        self.child_nodes = []
        self.untried_moves = state.get_valid_moves()
        self.visits = 0
        self.wins = 0

    def add_child(self, move, state):
        child = Node(state, move, self)
        self.untried_moves.remove(move)
        self.child_nodes.append(child)
        return child

    def update(self, result):
        self.visits += 1
        if result == self.state.current_player:
            self.wins += 1
        elif result == "Draw":
            self.wins += 0.5

    def get_best_child(self):
        return max(self.child_nodes, key=lambda child: child.visits)

def play_game():
    state = TicTacToeState()
    mcts_X = MCTSTicTacToe()
    mcts_O = MCTSTicTacToe()

    while not state.is_game_over():
        if state.current_player == "X":
            best_move = mcts_X.find_best_move(state, num_simulations=25000)
        else:
            best_move = mcts_O.find_best_move(state, num_simulations=25000)

        state.make_move(best_move)

    return state.get_game_result()

def self_play(num_games):
    results = {"X": 0, "O": 0, "Draw": 0}

    for _ in range(num_games):
        result = play_game()
        results[result] += 1

    print("Results:")
    print(f"X wins: {results['X']}")
    print(f"O wins: {results['O']}")
    print(f"Draws: {results['Draw']}")



In [18]:
self_play(num_games=500)

Results:
X wins: 161
O wins: 339
Draws: 0
