# Monte-carlo Tree Search


In this notebook, we'll implement a MCTS planning and use it to solve a Gym env.

![image.png](https://i.postimg.cc/6QmwnjPS/image.png)

__How it works?__
We just start with an empty tree and expand it. There are several common procedures.

__1) Selection__
Starting from the root, recursively select the node that corresponds to the tree policy.  

There are several options for tree policies, which we saw earlier as exploration strategies: epsilon-greedy, Thomson sampling, UCB-1. It was shown that in MCTS, UCB-1 achieves a good result. Further, we will consider the one, but you can try to use others.

Following the UCB-1 tree policy, we will choose an action that, on one hand, we expect to have the highest return, and on the other hand, we haven't explored much.

$$
\DeclareMathOperator*{\argmax}{arg\,max}
$$

$$
\dot{a} = \argmax_{a} \dot{Q}(s, a)
$$

$$
\dot{Q}(s, a) = Q(s, a) + C_p \sqrt{\frac{2 \log {N}}{n_a}}
$$

where: 
- $N$ - number of times we have visited state $s$,
- $n_a$ - number of times we have taken action $a$,
- $C_p$ - exploration balance parameter, which is performed between exploration and exploitation. 

Using Hoeffding inequality for rewards $R \in [0,1]$ it can be shown that optimal $C_p = 1/\sqrt{2}$. For rewards outside this range, the parameter should be tuned. We'll be using 10, but you can experiment with other values.

__2) Expansion__
After the selection procedure, we can achieve a leaf node or node in which we don't complete actions. In this case, we expand the tree by feasible actions and get new state nodes. 

__3) Simulation__
How we can estimate node Q-values? The idea is to estimate action values for a given _rollout policy_ by averaging the return of many simulated trajectories from the current node. Simply, we can play with random or some special policy or use some model that can estimate it.

__4) Backpropagation__
The reward of the last simulation is backed up through the traversed nodes and propagates Q-value estimations, upwards to the root.

$$
Q({\text{parent}}, a) = r + \gamma \cdot Q({\text{child}}, a)
$$

In [None]:
import sys, os
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

We first need to make a wrapper for Gym environments to allow saving and loading game states to facilitate backtracking.

In [None]:
import gym
from gym.core import Wrapper
from pickle import dumps, loads
from collections import namedtuple


ActionResult = namedtuple(
    "action_result", ("snapshot", "observation", "reward", "is_done", "info"))


class WithSnapshots(Wrapper):
    # Creates a wrapper that supports saving and loading environemnt states.
    # Required for planning algorithms.


    def get_snapshot(self, render=False):
        # returns environment state that can be loaded with load_snapshot.
        
        
        if render:
            self.render()  
            self.close()
            
        if self.unwrapped.viewer is not None:
            self.unwrapped.viewer.close()
            self.unwrapped.viewer = None
        return dumps(self.env)
    

    def load_snapshot(self, snapshot, render=False):
        # loads snapshot as current env state.
        
        assert not hasattr(self, "_monitor") or hasattr(
            self.env, "_monitor"), "can't backtrack while recording"

        if render:
            self.render()  
            self.close()
        self.env = loads(snapshot)
        

    def get_result(self, snapshot, action):
        # Breturns next snapshot and everything that env.step would have returned.
        
        self.load_snapshot(snapshot, render=False)
        next_observation, reward, is_done, info = self.step(action)
        next_snapshot = self.get_snapshot()
        return ActionResult(next_snapshot, next_observation, reward, is_done, info)

### Try out snapshots

In [None]:
# make env
env = WithSnapshots(gym.make("CartPole-v0"))
env.reset()

n_actions = env.action_space.n

In [None]:
print("initial_state:")
plt.imshow(env.render('rgb_array'))
env.close()

# create first snapshot
snap0 = env.get_snapshot()

In [None]:
# play without making snapshots (faster)
while True:
    is_done = env.step(env.action_space.sample())[2]
    if is_done:
        print("Whoops! We died!")
        break

print("final state:")
plt.imshow(env.render('rgb_array'))
env.close()

In [None]:
# reload initial state
env.load_snapshot(snap0)

print("\n\nAfter loading snapshot")
plt.imshow(env.render('rgb_array'))
env.close()

In [None]:
# get outcome (snapshot, observation, reward, is_done, info)
res = env.get_result(snap0, env.action_space.sample())

snap1, observation, reward = res[:3]

# second step
res2 = env.get_result(snap1, env.action_space.sample())

# MCTS: Monte-Carlo Tree Search


Implementing the `Node` class - a simple class that acts like MCTS node and supports some of the MCTS algorithm steps.

This MCTS implementation makes some assumptions about the environment, you can find those _in the notes section at the end of the notebook_.

### Implementation Overview

In the following cell, we'll implement the `Node` class which represents each state in the MCTS tree. Each node stores:

- **snapshot**: Environment state for backtracking
- **observation**: The observation at this state  
- **immediate_reward**: Reward received when reaching this node
- **is_done**: Whether this is a terminal state
- **qvalue_sum**: Sum of all Q-values from rollouts passing through this node
- **times_visited**: Number of times this node has been visited
- **children**: Set of child nodes (one per action)

The key insight of MCTS is that we don't need to expand the entire game tree. Instead, we focus computational resources on the most promising paths using the UCB-1 exploration strategy.


In [None]:
assert isinstance(env, WithSnapshots)

In [None]:
class Node:
    # a tree node for MCTS.
    
    parent = None  # parent Node
    qvalue_sum = 0.  # sum of Q-values from all visits 
    times_visited = 0  # counter of visits 

    
    def __init__(self, parent, action):
        # Creates and empty node with no children.
        
        self.parent = parent
        self.action = action
        self.children = set()  # set of child nodes

        # get action outcome and save it
        res = env.get_result(parent.snapshot, action)
        self.snapshot, self.observation, self.immediate_reward, self.is_done, _ = res
        

        
    def is_leaf(self):
        return len(self.children) == 0
    
    

    def is_root(self):
        return self.parent is None
    
    

    def get_qvalue_estimate(self):
        return self.qvalue_sum / self.times_visited if self.times_visited != 0 else 0
    

    
    def ucb_score(self, scale=10, max_value=1e100):
        # param scale: Multiplies upper bound by that. From Hoeffding inequality, assumes reward range to be [0, scale].
        # param max_value: a value that represents infinity (for unvisited nodes).

        if self.times_visited == 0:
            return max_value

        # ==================================== Your Code (Begin) ==================================
        
        # Calculate UCB-1 score: Q(s,a) + C_p * sqrt(2 * log(N) / n_a)
        # where N is parent's visit count and n_a is this node's visit count
        
        q_value = self.get_qvalue_estimate()
        
        # Calculate exploration bonus
        exploration_bonus = scale * np.sqrt(2 * np.log(self.parent.times_visited) / self.times_visited)
        
        ucb_score = q_value + exploration_bonus
        
        return ucb_score
        
        # ==================================== Your Code (End) ====================================

        
        
        
    
    def select_best_leaf(self):
        
        
        
        # ==================================== Your Code (Begin) ==================================
        
        # Return the leaf with the highest priority to expand.
        # Recursively pick nodes with the best UCB-1 score until it reaches a leaf.
        
        # If this is a leaf node, return itself
        if self.is_leaf():
            return self
        
        # Otherwise, select the child with the best UCB score and recurse
        best_child = max(self.children, key=lambda child: child.ucb_score())
        
        return best_child.select_best_leaf()
        
        # ==================================== Your Code (End) ====================================

        
        
        

    def expand(self):
        # expands the current node by creating all possible child nodes.
        # returns one of those children.
        
        assert not self.is_done, "can't expand from terminal state"

        for action in range(n_actions):
            self.children.add(Node(self, action))

        return self.select_best_leaf()
    
    

    def rollout(self, t_max=10**4):
        
        # set env into the appropriate state
        env.load_snapshot(self.snapshot)
        obs = self.observation
        is_done = self.is_done
        
        
        # ==================================== Your Code (Begin) ==================================
        
        # If node is terminal, just return 0
        if self.is_done:
            return 0
        
        # Play the game from this state to the end (done) or for t_max steps.
        # On each step, pick action at random.
        
        rollout_reward = 0
        
        for _ in range(t_max):
            if is_done:
                break
            
            # Take a random action
            action = env.action_space.sample()
            obs, reward, is_done, _ = env.step(action)
            
            # Accumulate the reward
            rollout_reward += reward
        
        return rollout_reward  

        
        # ==================================== Your Code (End) ====================================


        
        
        
        
    def propagate(self, child_qvalue):
        # Uses child Q-value (sum of rewards) to update parents recursively.
        
        # compute node Q-value
        my_qvalue = self.immediate_reward + child_qvalue

        # update qvalue_sum and times_visited
        self.qvalue_sum += my_qvalue
        self.times_visited += 1

        # propagate upwards
        if not self.is_root():
            self.parent.propagate(my_qvalue)

            
            
    def safe_delete(self):
        # safe delete to prevent memory leak in some python versions 
        del self.parent
        for child in self.children:
            child.safe_delete()
            del child

### Understanding the Node Class Implementation

The `Node` class implements the core MCTS data structure with four key methods:

**1. `ucb_score()`**: Computes the Upper Confidence Bound (UCB-1) score to balance exploration and exploitation. The formula is:
- Q-value (exploitation): Average reward from past visits
- Exploration bonus: `C_p * sqrt(2 * log(N) / n_a)` encourages visiting less-explored nodes

**2. `select_best_leaf()`**: Recursively traverses the tree by selecting children with the highest UCB scores until reaching a leaf node (no children yet).

**3. `rollout()`**: Simulates a random playthrough from the current node to estimate its value. This uses a simple uniform random policy to play until the episode ends or t_max steps.

**4. `propagate()`**: Backpropagates the rollout reward up the tree, updating Q-values and visit counts for all ancestor nodes.

These four methods correspond to the four phases of MCTS: Selection, Expansion, Simulation, and Backpropagation.


In [None]:
class Root(Node):
    def __init__(self, snapshot, observation):
        # creates special node that acts like tree root
        
        self.parent = self.action = None
        self.children = set()  # set of child nodes

        # root: load snapshot and observation
        self.snapshot = snapshot
        self.observation = observation
        self.immediate_reward = 0
        self.is_done = False

    @staticmethod
    def from_node(node):
        # initializes node as root
        root = Root(node.snapshot, node.observation)
        # copy data
        copied_fields = ["qvalue_sum", "times_visited", "children", "is_done"]
        for field in copied_fields:
            setattr(root, field, getattr(node, field))
        return root

## Main MCTS Loop


In [None]:
def plan_mcts(root, n_iters=10):
    
    # build tree with monte-carlo tree search for n_iters iterations
    
    
    # ==================================== Your Code (Begin) ==================================
    
    # For n_iters iterations, perform the MCTS algorithm:
    # 1. Selection: Select best leaf using UCB-1
    # 2. Expansion: Expand the leaf (if not done)
    # 3. Simulation: Perform rollout from the new node
    # 4. Backpropagation: Propagate the results upwards
    
    for _ in range(n_iters):
        # 1. Selection: Select the best leaf node to expand
        leaf = root.select_best_leaf()
        
        # 2. Expansion: If the leaf is not a terminal state, expand it
        if not leaf.is_done:
            leaf = leaf.expand()
        
        # 3. Simulation: Perform a rollout from the selected/expanded node
        rollout_reward = leaf.rollout()
        
        # 4. Backpropagation: Propagate the rollout reward back up the tree
        leaf.propagate(rollout_reward)
    
    # ==================================== Your Code (End) ====================================

### Understanding the MCTS Planning Loop

The `plan_mcts()` function implements the main MCTS algorithm that builds a search tree over multiple iterations:

**Each iteration performs 4 steps:**

1. **Selection**: Start from root and use UCB-1 to select the most promising path down to a leaf node
2. **Expansion**: If the leaf is not terminal, create child nodes for all possible actions
3. **Simulation (Rollout)**: Play out the game randomly from the expanded node to estimate its value
4. **Backpropagation**: Update Q-values and visit counts along the path from leaf to root

**Why this works:**
- UCB-1 ensures we explore promising but uncertain areas of the tree
- More iterations = better estimates of action values
- The tree grows asymmetrically toward high-reward regions
- After planning, we can select actions by choosing the child with the highest average reward (pure exploitation)


### Test the Implementation

Now let's test our MCTS implementation on the CartPole environment. We'll:
1. Initialize the environment and create a root node
2. Run MCTS planning to build the search tree
3. Execute actions by selecting children with the highest Q-values
4. Continue planning at each step to refine our estimates

The visualization will show how the agent balances the pole using MCTS planning. The total reward indicates how long the agent kept the pole balanced.


## Plan and execute

Here we use our MCTS implementation to find the optimal policy.

In [None]:
env = WithSnapshots(gym.make("CartPole-v0"))
root_observation = env.reset()
root_snapshot = env.get_snapshot()
root = Root(root_snapshot, root_observation)

In [None]:
# plan from root:
plan_mcts(root, n_iters=1000)

### Execution Strategy: Re-planning at Each Step

In this execution loop, we use MCTS in a **receding horizon** approach:

1. **Plan**: Build a search tree from the current state using MCTS
2. **Execute**: Select and execute the action with the highest average Q-value (exploitation only)
3. **Observe**: Receive the next state and reward from the environment
4. **Re-root**: Make the selected child the new root of the tree
5. **Re-plan**: Build the tree further from the new root

**Why re-plan at each step?**
- The environment may be stochastic or partially observable
- We get more accurate estimates as we get closer to a state
- Allows the agent to adapt to new information
- Pruning unrealized branches keeps memory manageable

**Note:** In the code, we grow the tree with 100 additional iterations at each step to maintain good action estimates as we progress through the episode.


In [None]:
from IPython.display import clear_output
from itertools import count
from gym.wrappers import Monitor

total_reward = 0  # sum of rewards
test_env = loads(root_snapshot)  # env used to show progress

for i in count():

    # ==================================== Your Code (Begin) ==================================
    
    # Select child with the highest mean Q-value (exploitation only, no exploration)
    best_child = max(root.children, key=lambda child: child.get_qvalue_estimate())
        
    # ==================================== Your Code (End) ====================================

    # take action
    s, r, done, _ = test_env.step(best_child.action)

    # show image
    clear_output(True)
    plt.title("step %i" % i)
    plt.imshow(test_env.render('rgb_array'))
    plt.show()

    total_reward += r
    if done:
        print("Finished with reward = ", total_reward)
        break

    # discard unrealized part of the tree (because not every child matters :()
    for child in root.children:
        if child != best_child:
            child.safe_delete()

    # declare best child a new root
    root = Root.from_node(best_child)

    assert not root.is_leaf(), \
        "We ran out of tree! Need more planning! Try growing the tree right inside the loop."
    plan_mcts(root,n_iters=100)

## Notes


#### Assumptions

The full list of assumptions is:

* __Finite number of actions__: we enumerate all actions in `expand`.
* __Episodic (finite) MDP__: while technically it works for infinite MDPs, we perform a rollout for $10^4$ steps. If you are knowingly infinite, please adjust `t_max` to something more reasonable.
* __Deterministic MDP__: `Node` represents the single outcome of taking `self.action` in `self.parent`, and does not support the situation where taking an action in a state may lead to different rewards and next states.
* __No discounted rewards__: we assume $\gamma=1$. If that isn't the case, you only need to change two lines in `rollout()` and use `my_qvalue = self.immediate_reward + gamma * child_qvalue` for `propagate()`.
* __pickleable env__: won't work if e.g. your env is connected to a web-browser surfing the internet. For custom envs, you may need to modify get_snapshot/load_snapshot from `WithSnapshots`.

#### On `get_best_leaf` and `expand` functions

This MCTS implementation only selects leaf nodes for expansion.
This doesn't break things down because `expand` adds all possible actions. Hence, all non-leaf nodes are by design fully expanded and shouldn't be selected.

If you want to only add a few random action on each expand, you will also have to modify `get_best_leaf` to consider returning non-leafs.

#### Rollout policy

We use a simple uniform policy for rollouts. This introduces a negative bias to good situations that can be messed up completely with random bad action. As a simple example, if you tend to rollout with uniform policy, you better don't use sharp knives and walk near cliffs.

You can improve that by integrating a reinforcement _learning_ algorithm with a computationally light agent. You can even train this agent on optimal policy found by the tree search.