In [1]:
import copy
import numpy as np

In [2]:
class TreeNode:
    """ 
    A tree node in the MCTS. Each node keeps track of its own value Q, prior probability P, 
    and its visited-count-adjusted prior score u.
    TODO: gamma is 1.0 in this case.
    """

    def __init__(self, parent, prior_p):
        self._parent = parent
        self._children = {}
        self._n_visits = 0 
        self._Q = 0
        self._P = prior_p 
        self._u = 0

    def expand(self, action_priors):
        """
        Expand tree by creating new children.
        action_priors: a list of tuples of actions and their prior probability according to the policy function.
        """
        for action, prob in action_priors:
            if action not in self._children:
                self._children[action] = TreeNode(self, prob)

    def select(self, c_puct):
        """
        Selection action among children that gives maximum action value Q plus bonus u(P).
        """
        return max(self._children.items(), key=lambda act_node:act_node[1].get_value(c_puct))
    
    def update(self, G):
        """
        Update node values form Monte-Carlo evaluation with return G.
        TODO: alpha is 1.0 in this case.
        """
        self._n_visits += 1
        self._Q += 1.0 * (G - self._Q) / self._n_visits

    def update_recursive(self, G):
        """
        Update recursively for all ancestors
        """
        if self._parent:
            self._parent.update_resursive(-G)
        self.update(G)

    def get_value(self, c_puct):
        """
        Calculate and return the value for this node.
        It is a combination of leaf evaluation Q, and this node's prior adjusted for its visit count u.
        c_punt: a number in (0, inf) controlling the relative impact of value Q, and prior probability P, 
        on this node's score.
        """
        self._u = (c_puct * self._P * np.sqrt(self._parent._n_visits)) / (1 + self._n_visits)
        return self._Q + self._u 

    def is_leaf(self):
        """
        Check if leaf node.
        """
        return self._chilfren == {}
    
    def is_root(self):
        """
        Check if root node.
        """
        return self._parent is None 

In [3]:
class MCTS:
    """A simple implementation of Monte Carlo Tree Search."""

    def __init__(self, policy_value_fn, c_puct=5, n_playout=10000):
        """
        policy_value_fn: a function that takes in a board state, outputs a list of (action, probability, q) tuples.
        c_puct: a number in (0, inf) that controls how quickly exploration converages to the maximum-value policy.
        """
        self._root = TreeNode(parent=None, prior_p=1.0)
        self._policy = policy_value_fn
        self._c_puct = c_puct 
        self._n_playout = n_playout 

    def _playout(self, state):
        """
        Run a single playout from the root to the leaf, getting a value at the leaf and propagating it back through
        its partents. State is modified in-place, so a copy must be provided.
        """
        node = self._root
        action, node = node.select(self._c_puct)
        # state.do_mode(action)

        action_probs = self._policy(state)
        # end, winner = state.game_end()
        # if not end:
            # node.expend(action_probs)
        leaf_value = self._evaluate_rollout(state)
        node.update_recursive(-leaf_value)

    def _evaluate_rollout(self, state, limit=1000):
        """
        Use the rollout policy to play until the end of the game, returning 1 if the current player wins, -1 if 
        the oppnent wins and 0 if it is a tie.
        """
        pass

    def get_move(self, state):
        """
        Runs all playout sequentially and returns the most visited action.
        state: the current game state.
        """
        for n in range(self._n_playout):
            state_copy = copy.deepcopy(state)
            self._playoput(state_copy)
            self._playout(state_copy)
            return max(self._root._children.items(), key=lambda act_node: act_node[1]._n_visits)[0]
        
    def update_with_move(self, last_move):
        """
        Step forward in the tree, keeping everything we already know about the subtree.
        """
        if last_move in self._root._children:
            self._root = self._root._children[last_move]
            self._root._parent = None 
        else:
            self._root = TreeNode(None, 1.0)

In [4]:
class Agent:
    """
    AI player based on MCTS
    """
    def __init__(self, policy_value_fn, c_puct=5, n_playout=2000):
        self.mcts = MCTS(policy_value_fn, c_puct, n_playout)

    def set_player_ind(self, p):
        self.player = p 
    
    def reset_player(self):
        self.mcts.update_with_move(-1)

    def get_action(self, board):
        sensible_move = board.availables
        if len(sensible_move) > 0:
            move = self.mcts.get_move(board)
            self.mcts.update_with_move(-1)
            return move 
        else:
            print("WARNING: the board is full")