## Imports

In [1]:
from __future__ import annotations
from dataclasses import dataclass, field
import numpy as np
from numpy.typing import NDArray
import math

## Psuedo Code
Pseudo code found in this [article](https://int8.io/monte-carlo-tree-search-beginners-guide/) that was linked to in the assingment.
```python
def monte_carlo_tree_search(root):
    while resources_left(time, computational_power):
        leaf = traverse(root)  # leaf = unvisited node
        simulation_result = rollout(leaf)
        backpropagate(leaf, simulation_result)
    return best_child(root)

def traverse(node):
    while fully_expanded(node):
        node = best_uct(node)
    return pick_unvisited(node.children) or node  # in case no children are present or node is terminal

def rollout(node):
    while not is_terminal(node):
        node = rollout_policy(node)
    return result(node)

def rollout_policy(node):
    return pick_random(node.children)

def backpropagate(node, result):
    if node is None:
        return
    node.visits += 1
    node.total_reward += result
    backpropagate(node.parent, result)

def best_child(node):
    # pick child with highest number of visits
    return max(node.children, key=lambda c: c.visits)
```

## Python implementation
Here we make our own implementation of monte carlo tree search based on the pseudocode supplied above, we wrap it in a class that we call BetaTicTacToe which is intended as the actual program that you play against.

In [2]:
@dataclass
class Node:
    state: NDArray[np.uint8]
    parent: Node | None
    team: int
    children: list = None
    q: float = 0.0
    visits: int = 0


class MonteCarloTreeSearch:

    def __init__(self, exploration_scale=1.0, search_iterations=10000):
        self.rng = np.random.default_rng()
        self.exploration_scale = exploration_scale
        self.search_iterations = search_iterations

    def search(self, root):
        for _ in range(self.search_iterations):
            leaf = self.traverse(root)
            simulation_result = self.rollout(leaf)
            # Fix the persepctive.
            if leaf.parent:
                simulation_result *= leaf.parent.team
            self.backpropagate(leaf, simulation_result)

    def fully_expanded(self, node):
        if node.children:
            return all(c.visits > 0 for c in node.children)
        return False

    def traverse(self, node):
        # Go down the tree until a node that is terminal or not fully
        # expanded is found
        while self.fully_expanded(node) and not self.is_terminal(node.state):
            node = self.best_uct(node)
        # I the node is terminal return it, otherwise if its expanded
        # just return one of its children, if not expand it and return one
        # of its children
        if not self.is_terminal(node.state):
            self.expand(node)
            unexplored = [c for c in node.children if c.visits == 0]
            return self.rng.choice(unexplored)
        return node

    def backpropagate(self, node, result):
        # Add the results from a node to its parents. This includes:
        # visits: it's required to visit a parent node to reach any terminal
        # node.
        # q score: If the node has a state that is a win, draw, or loss for
        # the current player it's added to the parent node.
        if node is None:
            return
        node.visits += 1
        node.q += result
        # Flip the sign of the result, a win for me is a loss for my parent
        self.backpropagate(node.parent, -result)

    def rollout(self, node):
        # Perform a rollout, which means moving from a node to the end
        # by sampling random actions (moves)
        state = node.state.copy()
        team = node.team
        while not self.is_terminal(state):
            # Query the policy for the next action.
            action = self.rollout_policy(state)
            # Make next move.
            state[action] = team
            # Switch to next player.
            team = -team
        return self.result(state)

    def rollout_policy(self, state):
        """
        Selects the next action given current state.
        """
        # Get all legal actions
        legal_actions = np.argwhere(state == 0)
        idx = self.rng.integers(len(legal_actions))
        return tuple(legal_actions[idx])

    def is_terminal(self, state):
        # If the results is a win
        if self.result(state) != 0:
            return True
        # Or if all spaces are filled
        return not np.any(state == 0)

    def result(self, state):
        side_size = state.shape[0]
        diag = np.diag(state)
        anti_diag = np.diag(np.fliplr(state))

        # Check rows and columns
        for i in range(side_size):
            if np.all(state[i, :] == 1) or np.all(state[:, i] == 1):
                return 1
            if np.all(state[i, :] == -1) or np.all(state[:, i] == -1):
                return -1

        # Check diagonals
        if np.all(diag == 1) or np.all(anti_diag == 1):
            return 1
        if np.all(diag == -1) or np.all(anti_diag == -1):
            return -1

        return 0

    def best_uct(self, node):
        if node.children:
            return max(node.children, key=lambda c: self.uct(node, c))
        return node

    def uct(self, parent, child):
        exploitation = child.q / child.visits
        exploration = self.exploration_scale * math.sqrt(
            math.log(parent.visits) / child.visits
        )
        return exploitation + exploration

    def expand(self, node):
        if node.children:
            return

        node.children = []
        legal_actions = np.argwhere(node.state == 0)
        action = node.team
        for l in legal_actions:
            new_board = node.state.copy()
            new_board[tuple(l)] = action
            new_team = -node.team
            child = Node(state=new_board, parent=node, team=new_team)
            node.children.append(child)


class BetaTicTacToe:

    def __init__(self, exploration_scale=1.0, search_iterations=10000):
        self.mcts = MonteCarloTreeSearch(exploration_scale, search_iterations)

    def play(self, root):
        """
        Returns the next game state (indlucing AI move) given a current game state.
        """
        # Conduct a search
        self.mcts.search(root)
        # find the next best move/state
        if root.children:
            return max(root.children, key=lambda c: c.visits)
        return root


# Small test example
bttt = BetaTicTacToe()
initial_board = np.zeros((3, 3), dtype=np.int8)
root = Node(initial_board, None, 1)
next = bttt.play(root)
display("First move played by BetaTicTacToe", next.state)

'First move played by BetaTicTacToe'

array([[0, 0, 0],
       [0, 1, 0],
       [0, 0, 0]], dtype=int8)

# Interactive game loop

In [3]:
play_game = None
while play_game not in ["yes", "no"]:
    play_game = input(
        "Do you wish to play Tic-Tac-Toe against BetaTicTacToe? (yes/no): "
    ).lower()

if play_game == "yes":
    bttt = BetaTicTacToe()

    # 1. Select side (1 starts, -1 goes second)
    user_side = 0
    while user_side not in [1, -1]:
        try:
            user_side = int(
                input(
                    "Choose your side. Enter 1 to go first (X), or -1 to go second (O): "
                )
            )
        except ValueError:
            continue

    # 2. Create start state
    current_node = Node(np.zeros((3, 3), dtype=np.int8), team=1, parent=None)

    # 3. Run game loop
    while not bttt.mcts.is_terminal(current_node.state):
        print("Board state:")
        print(current_node.state)

        # Determine if it is the Human's turn
        # Human is 'True' if they chose 1, 'False' if they chose -1
        is_human_turn = current_node.team == (user_side == 1)

        if is_human_turn:
            print("--- Your Turn ---")
            move_made = False
            while not move_made:
                try:
                    row = int(input("Enter row (0-2): "))
                    col = int(input("Enter col (0-2): "))
                    if current_node.state[row, col] == 0:
                        # Manually expand to find the child node corresponding to this move
                        if not current_node.children:
                            bttt.mcts.expand(current_node)

                        for child in current_node.children:
                            if child.state[row, col] != 0:
                                current_node = child
                                move_made = True
                                break
                    else:
                        print("Spot already taken!")
                except (ValueError, IndexError):
                    print("Invalid input. Use 0, 1, or 2.")
        else:
            print("--- BetaTicTacToe Turn ---")
            current_node = bttt.play(current_node)
            # Reset parent to save memory and keep current_node as the new root
            current_node.parent = None

    # Final Result
    print("\nFinal Board:")
    print(current_node.state)
    res = bttt.mcts.result(current_node.state)

    if res == 0:
        print("It's a draw!")
    elif (res == 1 and user_side == 1) or (res == -1 and user_side == -1):
        print("Human wins!")
    else:
        print("BetaTicTacToe wins!")