In [1]:
import math
import time
import random
from collections import defaultdict

## Offline Planning & Online Planning for MDPs

* Value iteration is an offline planning method since it solves the problem offline for all possiable states and then uses the solution (a policy) online to act.

* In online planning (like AlphaZero), planning is undertaken immediately before executing an action.
    * For each state $s$ visited, the set of all available actions $A(s)$ partially evaluated.
    * The quality of each action $a$ is approximated by averaging the expected reward of trajectories over $S$ obtained by repeated simulations, giving as an approximation for $Q(s,a)$.
    * The chosen action is $\text{argmax}_{a'} Q(s,a)$

* In online planning, we need access to a simulator that approximates the transitions function $P_a(s' |s)$ and reward function $r$ of our MDP.

## MCTS

#### Framework

The basic framkwork is to build up a tree using simulation. The states that have been evaluated are stored in a search tree. The set of evaluated states is **incementally** built be interation over the following four steps:

-   **Select**: Select a single node in the tree that is **not fully expanded**. By this, we mean at least one of its children is not yet explored.
    
-   **Expand**: Expand this node by applying one available action (as defined by the MDP) from the node.
    
-   **Simulate**: From one of the outcomes of the expanded, perform a complete random simulation of the MDP to a terminating state. This therefore assumes that the simulation is finite, but versions of MCTS exist in which we just execute for some time and then estimate the outcome.
    
-   **Backpropagate**: Finally, the value of the node is backpropagated to the root node, updating the value of each ancestor node on the way using expected value.

#### Algorithm

1. **Selection**: The first loop progressively selects a branch in the tree using a multi-armed bandit algorithm using $Q(s,a)$. The outcome that occurs from an action is chosen according to $P(s' \mid s)$ defined in the MDP.

2. **Expansion**: Select an action $a$ to apply in  state $s$, either randomly or using an heuristic. Get an outcome state $s'$ from applying action $a$ in state $s$ according to the probability distribution $P(s' \mid s)$. Expand a new environment node and a new state node for that outcome.

3. **Simulation**: Perform a randomised simulation of the MDP until we reach a terminating state. That is, at each choice point, randomly select an possible action from the MDP, and use transition probabilities $P_a(s' \mid s)$ to choose an outcome for each action. Heuristics can be used to improve the random simulation by guiding it towards more promising states. $G$ is the cumulative discounted reward received from the simulation starting at $s'$ until the simulation terminates. 

4. **Backpropagation**: The reward from the simulation is backpropagated from the selected node to its ancestors recursively. We must not forget the discount factor! For each state $s$ and action $a$ selected in the Select step, update the cumulative reward of that state.

Once we have run out of computational time, we select the action that maximises are expected return, which is simply the one with the highest Q-value from our simulations: 

$$\text{argmax}_{a \in A(s)} Q(s_0, a)$$ 

We execute that action and wait to see which outcome occurs for the action.

Once we see the outcome state, which we will call $s'$, we start the process all over again, except with $s_0 \leftarrow s'$.

## Implementation

In [2]:
class Node:
    
    # Record a unique node id to distinguish duplicated states
    next_node_id = 0
    
    # Records the number of times states have been visited
    visits = defaultdict(lambda: 0)
    
    def __init__(self, mdp, parent, state, qfunction, bandit, reward=0.0, action=None):
        self.mdp = mdp
        self.parent = parent
        self.state = state
        self.id = Node.next_node_id
        Node.next_node_id += 1
        
        # The Q function used to store state-action values
        self.qfunction = qfunction
        
        # A multi-armed bandit for this node
        self.bandit = bandit
        
        # The immediate reward reveived for reaching this state, used for backpropagation
        self.reward = reward
        
        # The action that generated this node
        self.action = action
        
    """ Select a node that is not fully expanded """
    
    def select(self):
        pass
    
    """ Expand a node if it is not a termianl node """
    
    def expand(self):
        pass
    
    """ Backpropagate the reward back to the parent node """
    
    def back_propagate(self, reward, child):
        pass
    
    """ Return the value of this node """
    
    def get_value(self):
        max_q_value = self.qfunction.get_max_q(
            self.state, self.mdp.get_actions(self.state)
        )
        return max_q_value
    
    """ Get the number of visits to this state """
    
    def get_visits(self):
        return Node.visits[self.state]


class SingleAgentNode(Node):
    def __init__(
        self,
        mdp,
        parent,
        state,
        qfunction,
        bandit,
        reward=0.0,
        action=None,
    ):
        super().__init__(mdp, parent, state, qfunction, bandit, reward, action)
        
        # A dictionary from actions to a set of node-probability pairs
        self.children = {}
        
    """ Return true if and only if all child actions have been expanded """
    
    def is_fully_expanded(self):
        valid_actions = self.mdp.get_actions(self.state)
        if len(valid_actions) == len(self.children):
            return True
        else:
            return False
    
    """ Select a node that is not fully expanded """
    
    def select(self):
        if not self.is_fully_expanded() or self.mdp.is_terminal(self.state):
            return self
        else:
            actions = list(self.children.keys())
            action = self.bandit.select(self.state, actions, self.qfunction)
            return self.get_outcome_child(action).select()
    
    """ Expand a node if it is not a terminal node """
    
    def expand(self):
        if not self.mdp.is_terminal(self.state):
            # Randomly select an unexpanded action to expand
            actions = self.mdp.get_actions(self.state) - self.children.keys()
            action = random.choice(list(actions))
            
            self.children[action] = []
            return self.get_outcome_child(actions)
        return self
    
    """ Backpropagate the reward back to the parent node """
    
    def back_propagate(self, reward, child):
        action = child.action
        
        Node.visits[self.state] = Node.visits[self.state] + 1
        Node.visits[(self.state, action)] = Node.visits[(self.state, action)] + 1
        
        q_value = self.qfunction.get_q_value(self.state, action)
        delta = (1 / (Node.visits[(self.state, action)])) * (
            reward - self.qfunction.get_q_value(self.state, action)
        )
        self.qfunction.update(self.state, action, delta)
        
        if self.parent != None:
            self.parent.back_propagate(self.reward + reward, self)
    
    """ Simulate the outcome of an action, and return the child node """
    
    def get_outcome_child(self, action):
        # Choose one outcome based on transition probabilities
        (next_state, reward, done) = self.mdp.execute(self.state, action)
        
        # Find the corresponding state and return if this already exists
        for (child, _) in self.children[action]:
            if next_state == child.state:
                return child
        
        # This outcome has not occured from this state-action pair previously
        new_child = SingleAgentNode(
            self.mdp, self, next_state, self.qfunction, self.bandit, reward, action
        )
        
        # Find the probability of this outcome (only possible for model-based) for visualising tree
        probability = 0.0
        for (outcome, probability) in self.mdp.get_transitions(self.state, action):
            if outcome == next_state:
                self.children[action] += [(new_child, probability)]
                return new_child

In [None]:
class MCTS:
    def __init__(self, mdp, qfunction, bandit):
        self.mdp = mdp
        self.qfunction = qfunction
        self.bandit = bandit

    """
    Execute the MCTS algorithm from ths initial state given, with timeout in seconds
    """
    
    def mcts(self, timeout=1, root_node=None):
        if root_node is None:
            root_node = self.create_root_node()
            
        start_time = time.time()
        current_time = time.time()
        while current_time < start_time + timeout:
            
            # Find a state node to expand
            selected_node = root_node.select()
            if not self.mdp.is_terminal(selected_node):
                
                