## Imports

In [64]:
from __future__ import annotations
from dataclasses import dataclass, field
import numpy as np
from numpy.typing import NDArray
import math

## Psuedo Code
Pseudo code found in this [article](https://int8.io/monte-carlo-tree-search-beginners-guide/) that was linked to in the assingment.
```python
def monte_carlo_tree_search(root):
    while resources_left(time, computational_power):
        leaf = traverse(root)  # leaf = unvisited node
        simulation_result = rollout(leaf)
        backpropagate(leaf, simulation_result)
    return best_child(root)

def traverse(node):
    while fully_expanded(node):
        node = best_uct(node)
    return pick_unvisited(node.children) or node  # in case no children are present or node is terminal

def rollout(node):
    while not is_terminal(node):
        node = rollout_policy(node)
    return result(node)

def rollout_policy(node):
    return pick_random(node.children)

def backpropagate(node, result):
    if node is None:
        return
    node.visits += 1
    node.total_reward += result
    backpropagate(node.parent, result)

def best_child(node):
    # pick child with highest number of visits
    return max(node.children, key=lambda c: c.visits)
```

## Python implementation
Here we make our own implementation of monte carlo tree search based on the pseudocode supplied above, we wrap it in a class that we call BetaTicTacToe which is intended as the actual program that you play against.

In [65]:
class BetaTicTacToe:
    def __init__(self, exploration_scale=1.0):
        self.rng = np.random.default_rng()
        self.exploration_scale = exploration_scale

    @dataclass
    class State:
        s: NDArray[np.uint8]
        player: bool

    @dataclass
    class Node:
        s: BetaTicTacToe.State
        parent: BetaTicTacToe.Node | None
        children: list = None
        q: float = 0.0
        visits: int = 0

    def search(self, root, max_iter=1000):
        for _ in range(max_iter):
            # Check children
            leaf = self.traverse(root)
            # If the child is not terminal expand it
            if not self.is_terminal(leaf):
                self.expand(leaf)
                leaf = self.pick_unvisited(leaf.children)
            simulation_result = self.rollout(leaf)
            self.backpropagate(leaf, simulation_result)
        return self.best_child(root)

    def fully_expanded(self, node):
        # Checks if a node is fully expanded,
        # which is if all its children are visited
        return all(c.visits > 0 for c in node.children)

    def pick_unvisited(self, children):
        # Go to a child that has never been visited,
        # which means it has 0 visits
        unexplored = [c for c in children if c.visits == 0]
        return self.rng.choice(unexplored)

    def traverse(self, node):
        # First traverese to all children at least once,
        # and run a rollout then prioritize high uct nodes
        while node.children:
            if not self.fully_expanded(node):
                return self.pick_unvisited(node.children)
            node = self.best_uct(node)
        return node

    def backpropagate(self, node, result):
        # Add the results from a node to its parents. This includes
        # visits: it's required to visit a parent node to reach a child node
        # q score: If the node has a state that is a win, draw, or loss for
        # the current player it's added to the parent node.
        if node is None:
            return
        node.visits += 1
        node.q += result
        self.backpropagate(node.parent, result)

    def rollout(self, node):
        # Perform a rollout, which means moving from a node to the end
        # by sampling random actions (moves)
        while not self.is_terminal(node):
            node = self.rollout_policy(node)
        return self.result(node)

    def rollout_policy(self, node):
        # Get all legal actions
        legal_actions = np.argwhere(node.s.s == 0)
        # Select uniformly between all legal actions
        action = self.rng.choice(legal_actions)
        # Create new player
        new_player = not node.s.player
        # Create the new state resulting from the move
        new_state = self.State(
            s=node.s.s.copy(), 
            player=new_player)
        new_state.s[action] = 1 if new_player else -1
        # Return the new Node, dont add it to the tree
        return self.Node(new_state, None)

    def is_terminal(self, node):
        return self.result(node) != 0 or not np.any(node.s.s == 0)  # win/draw

    def result(self, node):
        s = node.s.s
        player = node.s.player
        side_size = s.shape[0]
        return_val = 1 if player else -1
        # Check rows
        for i in range(side_size):
            if sum(s[i, :side_size]) == side_size:
                return return_val
            elif sum(s[i, :side_size]) == -side_size:
                return return_val
        # Check columns
        for j in range(side_size):
            if sum(s[:side_size, j]) == side_size:
                return return_val
            elif sum(s[:side_size, j]) == -side_size:
                return return_val
        # Check diagonals
        diag = np.diag(s)
        anti_diag = np.diag(np.fliplr(s))
        if sum(diag) == side_size:
            return return_val
        elif sum(diag) == -side_size:
            return return_val
        elif sum(anti_diag) == side_size:
            return return_val
        elif sum(anti_diag) == -side_size:
            return return_val
        return 0

    def best_child(self, node):
        return max(node.children, key=lambda c: c.visits)

    def best_uct(self, node):
        return max(node.children, key=lambda c: self.uct(node, c))

    def uct(self, parent, child):
        perspective = 1 if parent.s.player else -1
        exploitation = perspective * (child.q / child.visits)
        exploration = self.exploration_scale * math.sqrt(math.log(parent.visits) / child.visits)
        return exploitation + exploration

    def expand(self, node):
        if node.children:
            return

        node.children = []
        s = node.s.s
        player = node.s.player
        legal_actions = np.argwhere(s == 0)
        action = 1 if player else -1
        for l in legal_actions:
            new_board = s.copy()
            new_board[tuple(l)] = action
            new_state = self.State(s=new_board, player=not player) # Flip player for next state
            child = self.Node(s=new_state, parent=node)
            node.children.append(child)

In [66]:
MCTS = BetaTicTacToe()
root = BetaTicTacToe.Node(
    BetaTicTacToe.State(np.zeros((3,3)), True), 
    None
)
result = MCTS.search(root)
result2 = MCTS.search(result)
display(result.s.s)
display(result2.s.s)

array([[0., 0., 0.],
       [0., 0., 0.],
       [0., 1., 0.]])

array([[ 0.,  0.,  0.],
       [ 0.,  0., -1.],
       [ 0.,  1.,  0.]])