In [76]:
import random
from copy import deepcopy
import math
from src import TicTacToe

def player_formatter(player: int) -> str:
    symbol_mapping = {0: " ", 1: "X", -1: "O"}
    return symbol_mapping[player]

def state_formatter(state: tuple[int, ...]) -> str:
    size = int(len(state) ** 0.5)
    formatted_state = "\n"
    for i in range(size):
        formatted_state += "+---+---+---+\n"
        row = state[i*size:(i+1)*size]
        formatted_state += "| " + " | ".join(player_formatter(cell)for cell in row) + " |\n"
    formatted_state += "+---+---+---+"
    return formatted_state

game_cls = TicTacToe(default_state_formatter=state_formatter)

class TreeNode():
    def __init__(self, game, parent):
        self.game = game
        self.state = game.get_state()
        self.actions = game.get_actions()
        self.is_terminated = game.is_terminated()
        self.fully_expand = self.is_terminated
        self.parent = parent
        self.num_visited = 0
        self.total_reward = 0
        self.children = {}
        self.visited = False


class MCTS():
    def __init__(self, game_cls, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.game_cls = game_cls
        self.root = None
        self.iteration = 2000
        self.exploration_constant = 1 / math.sqrt(2)

    def select_child(self, node):
        while not node.is_terminated:
            if node.fully_expand:
                new_node = self.get_best_move(node, self.exploration_constant)
                if new_node is None:
                    raise Exception("No best move found")
                node = new_node
            else:
                return self.expand(node)
        return node

    def expand(self, node):
        actions = node.actions
        for action in actions:
            if action not in node.children:
                new_game = deepcopy(node.game)
                new_game.move(action)
                new_node = TreeNode(new_game, node)
                node.children[action] = new_node
                if len(actions) == len(node.children):
                    node.fully_expand = True
                return new_node
        raise Exception("No actions to expand")

    def simulated_game(self, node):
        node.visited = True
        if math.factorial(len(node.game.get_actions())) <= self.iteration:
            return self.simulate_by_brute_force(node)
        else:
            return self.simulate_by_sampling(node)

    def simulate_by_brute_force(self, node):
        if node.game.is_terminated():
            return self.calculate_reward(node)
        outcomes = []
        for action in node.game.get_actions():
            # new_game = deepcopy(node.game)
            # new_game.move(action)
            node.game.apply_action(action)
            child_node = TreeNode(node.game, node)
            outcome = self.simulated_game(child_node)
            node.game.undo_action()
            outcomes.append(outcome)
        return max(outcomes)

    def simulate_by_sampling(self, node):
        new_game = deepcopy(node.game)
        while not new_game.is_terminated():
            possible_actions = new_game.get_actions()
            action = random.choice(list(possible_actions))
            new_game.move(action)

        return self.calculate_reward(TreeNode(new_game, node))

    def calculate_reward(self, node):
        winner = node.game.get_winner()
        if winner == node.game.player:
            return -10
        elif winner == 0:
            return 0
        else:
            return 10

    def execute_round(self):
        node = self.select_child(self.root)
        reward = self.simulated_game(node)
        self.back_propogate(node, reward)

    def back_propogate(self, node, reward):
        while node is not None:
            node.num_visited += 1
            node.total_reward += reward
            reward = -reward
            node = node.parent

    def get_best_move(self, node, exploration_constant):
        best_value = float("-inf")
        best_nodes = []
        for action, child in node.children.items():
            node_value = (child.total_reward / child.num_visited
                          + exploration_constant * math.sqrt(2 * math.log(node.num_visited) / child.num_visited))
            if node_value > best_value:
                best_value = node_value
                best_nodes = [child]
            elif node_value == best_value:
                best_nodes.append(child)
        return random.choice(best_nodes) if best_nodes else None

    def get_all_Qs(self, state: tuple[int, ...], player: int, action_space: set[int]) -> dict[int, float]:
        self.root = TreeNode(self.game_cls.from_state(state, player), None)
        for _ in range(self.iteration):
            self.execute_round()
    
        action_scores = {}
        for action in action_space:
            child_node = self.root.children.get(action)
            if child_node is not None and child_node.num_visited > 0:
                average_reward = child_node.total_reward / child_node.num_visited
                action_scores[action] = average_reward
            else:
                action_scores[action] = 0
    
        return action_scores



In [77]:
mcts_policy = MCTS(game_cls)
game_cls.reset()

In [78]:
game_cls.move(4)
game_cls.render()


+---+---+---+
|   |   |   |
+---+---+---+
|   | X |   |
+---+---+---+
|   |   |   |
+---+---+---+


In [79]:
result = mcts_policy.get_all_Qs(game_cls.get_state(), game_cls.player, game_cls.get_actions())
result

{0: -7.5,
 1: -7.5,
 2: -0.051440329218107,
 3: -7.5,
 5: -7.5,
 6: -7.5,
 7: -7.5,
 8: -7.5}

In [80]:
game_cls.move(2)
game_cls.render()


+---+---+---+
|   |   | O |
+---+---+---+
|   | X |   |
+---+---+---+
|   |   |   |
+---+---+---+


In [81]:
game_cls.move(8)
game_cls.render()


+---+---+---+
|   |   | O |
+---+---+---+
|   | X |   |
+---+---+---+
|   |   | X |
+---+---+---+


In [82]:
result = mcts_policy.get_all_Qs(game_cls.get_state(), game_cls.player, game_cls.get_actions())
result

{0: 0.02594706798131811,
 1: -6.666666666666667,
 3: -7.142857142857143,
 5: -6.296296296296297,
 6: -6.666666666666667,
 7: -6.296296296296297}

In [83]:
game_cls.move(0)
game_cls.render()


+---+---+---+
| O |   | O |
+---+---+---+
|   | X |   |
+---+---+---+
|   |   | X |
+---+---+---+


In [84]:
game_cls.move(1)
game_cls.render()


+---+---+---+
| O | X | O |
+---+---+---+
|   | X |   |
+---+---+---+
|   |   | X |
+---+---+---+


In [85]:
result = mcts_policy.get_all_Qs(game_cls.get_state(), game_cls.player, game_cls.get_actions())
result

{3: -5.0, 5: -5.0, 6: -5.0, 7: -0.015090543259557344}

In [86]:
game_cls.move(7)
game_cls.render()


+---+---+---+
| O | X | O |
+---+---+---+
|   | X |   |
+---+---+---+
|   | O | X |
+---+---+---+


In [87]:
game_cls.move(3)
game_cls.render()


+---+---+---+
| O | X | O |
+---+---+---+
| X | X |   |
+---+---+---+
|   | O | X |
+---+---+---+


In [88]:
result = mcts_policy.get_all_Qs(game_cls.get_state(), game_cls.player, game_cls.get_actions())
result

{5: 0.0, 6: -3.3333333333333335}

In [89]:
game_cls.move(5)
game_cls.render()


+---+---+---+
| O | X | O |
+---+---+---+
| X | X | O |
+---+---+---+
|   | O | X |
+---+---+---+


In [90]:
game_cls.move(6)
game_cls.render()


+---+---+---+
| O | X | O |
+---+---+---+
| X | X | O |
+---+---+---+
| X | O | X |
+---+---+---+


In [91]:
result = mcts_policy.get_all_Qs(game_cls.get_state(), game_cls.player, game_cls.get_actions())
result

{}