In [1]:
import random 
import math

In [None]:
class TicTacToe:
    def __init__(self) -> None:
        self.board = [' '] * 9 # initially empty board
        self.current_player = 'X' # player X gets first turn

    # returns list of empty positions on the board
    def get_legal_moves(self):    
        return [i for i in range(self.board) if self.board[i] == ' ']
    

    # player move, marks a position on the board, then switches turn
    def make_move(self, move):
        self.board[move] = self.current_player
        # switch turn
        self.current_player = 'O' if self.current_player == 'X' else 'X'

    # checks if game is over
    def is_terminal(self):
        winning_combinations = [(0,1,2), (3,4,5), (6,7,8),  # rows
                                (0,3,6), (1,4,7), (2,5,8),  # cols
                                (0,4,8), (2,4,6)            # diagonals 
                                ]                  
        
        for combo in winning_combinations:
            if self.board[combo[0]] == self.board[combo[1]] == self.board[combo[2]] != ' ':
                return True 

    # returns winner is there is one
    def gets_winner(self):
        winning_combinations = [(0,1,2), (3,4,5), (6,7,8),  # rows
                                (0,3,6), (1,4,7), (2,5,8),  # cols
                                (0,4,8), (2,4,6)            # diagonals 
                                ]                  
        
        for combo in winning_combinations:
            if self.board[combo[0]] == self.board[combo[1]] == self.board[combo[2]] != ' ':
                return self.board[combo[0]] 
            
        return ' '    



class Node:

    next_node_id = 0

    def __init__(self, move=None, parent=None) -> None:
        self.move = move
        self.parent = parent
        self.children = []
        self.wins = 0
        self.visits = 0
        self.id = Node.next_node_id
        Node.next_node_id += 1

    def add_child(self, child):
        self.children.append(child)    


class MultiAgentMCTS:
    def __init__(self, exploration_constant=1.4, iterations=10) -> None:
        self.exploration_constant = exploration_constant
        self.iterations = iterations


    # UCB selection of successor node
    def select_child(self, node):
        total_visits = sum(child.visits for child in node.children)
        log_total_visits = math.log(total_visits)

        best_score = float("-inf")
        best_child = None

        # find child node with highest score
        for child in node.children:
            exploit_term = child.wins/child.visits
            explore_term = self.exploration_constant * math.sqrt(log_total_visits/child.visits) 
            score = exploit_term + explore_term
            if score > best_score:
                best_score = score
                best_child = child

        return best_child    

    # traverses down the tree and selects an unvisitied (i.e. never explanded) child node
    def select(self, node, state):
        while node.children:
            child = self.select_child(node)
            state.make_move(child.move)
            node = child
        
        return node, state    
    
    def expand(self, selected_node, state):
        legal_moves = state.get_legal_moves()
        unexplored_moves = [move for move in legal_moves if not any(child.move == move for child in selected_node.children)]
        if unexplored_moves:
            # randomly pick one of the unexplored actions available to selected node and generate a child/successor node from it
            move = random.choice(unexplored_moves) 
            state.make_move(move)
            new_node = Node(move, node)
            node = node.add_child(new_node)

        return new_node, state


    # random/monte carlo simulation to terminal state
    def simulate(self, state):
        while not state.is_terminal():
            legal_moves = state.get_legal_moves()
            move = random.choice(legal_moves)
            state.make_move(move)

        return state    


    # backpropagate the simulation rewards up to root node
    def backpropagate(self, node, state):
        while node is not None: 
            # update visit stats
            node.visits += 1
            # accumulate reward
            if state.get_winner() == node.move:
                node.wins += 1
            node = node.parent   



    def search(self, game_state):
        # create a root node
        root = Node()
        # run MCTS iterations
        for _ in range(self.iterations):
            node = root
            state = game_state

            # selection 
            selected_node, state = self.select(node, state)
        
            # expansion
            new_child_node, state = self.expand(selected_node, state)

            # simulation 
            state = self.simulate(state)

            # backpropagation
            self.backpropagate(new_child_node, state)

        # after iterations are done, return the best move
        best_move = max(root.children, key = lambda child: child.visits).move 
        
        return best_move           






