In [None]:
import gym
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

## Seminar: Monte-carlo tree search

In [None]:
from gym.core import Wrapper
from pickle import dumps,loads
from collections import namedtuple

#a container for get_result function. Works just like tuple, but prettier
ActionResult = namedtuple("action_result",("snapshot","observation","reward","is_done","info"))


class WithSnapshots(Wrapper):
    
    def __init__(self,env):      
        """
        Creates a wrapper that supports saving and loading environemnt states.
        Required for planning algorithms.
        
        This class will have access to the core environment as self.env, e.g.:
        - self.env.reset()           #reset original env
        - self.env.ale.cloneState()  #make snapshot for atari. load with .restoreState()
        - ...
        
        You can also use reset, step and render directly for convenience.
        - s, r, _, _ = self.step(action)   #step, same as self.env.step(action)
        - self.render(close=True)          #close window, same as self.env.render(close=True)
        
        """

        Wrapper.__init__(self,env)
        
    def get_snapshot(self):
        """
        :returns: environment state that can be loaded with load_snapshot 
        Snapshots guarantee same env behaviour each time they are loaded.
        
        Warning! Snapshots can be arbitrary things (strings, integers, json, tuples)
        Don't count on them being pickle strings when implementing MCTS.
        
        Developer Note: Make sure the object you return will not be affected by 
        anything that happens to the environment after it's saved.
        You shouldn't, for example, return self.env. 
        In case of doubt, use pickle.dumps or deepcopy.
        
        """
        self.render(close=True) #close popup windows since we can't pickle them
        return dumps(self.env)
    
    def load_snapshot(self,snapshot):
        """
        Loads snapshot as current env state.
        Should not change snapshot inplace (in case of doubt, deepcopy).
        """
        
        assert not hasattr(self,"_monitor") or hasattr(self.env,"_monitor"), "can't backtrack while recording"

        self.render(close=True) #close popup windows since we can't load into them
        self.env = loads(snapshot)
    
    def get_result(self,snapshot,action):
        """
        A convenience function that 
        - loads snapshot, 
        - commits action via self.step,
        - and takes snapshot again :)
        
        :returns: next snapshot, next_observation, reward, is_done, info
        
        Basically it returns next snapshot and everything that env.step would have returned.
        """
        
        <your code here>
        
        return ActionResult(new_snapshot,obs,r,done,info)
    


### try out snapshots:


In [None]:
#make env
env = WithSnapshots(gym.make("CartPole-v0"))
env.reset()

n_actions = env.action_space.n

In [None]:
print("initial_state:")

plt.imshow(env.render('rgb_array'))

#create first snapshot
snap0 = env.get_snapshot()

In [None]:
#play without making snapshots (faster)
while True:
    is_done = env.step(env.action_space.sample())[2]
    if is_done: 
        print("Whoops! We died!")
        break
        
print("final state:")
plt.imshow(env.render('rgb_array'))
plt.show()

#reload initial state
env.load_snapshot(snap0)
print("\n\nAfter snapshot")
plt.imshow(env.render('rgb_array'))
plt.show()

In [None]:
#get outcome (snapshot, observation, reward, is_done, info)
res = env.get_result(snap0,env.action_space.sample())

#like tuple:
snap1, observation, reward = res[:3]

#like class:
is_done = res.is_done

#second step:
#res2 = get_result(snap1,env.action_space.sample())

# MCTS: Monte-Carlo tree search

In this section, we'll implement the vanilla MCTS algorithm with UCB1-based node selection.

We will start by implementing the `Node` class - a simple class that acts like MCTS node and supports some of the MCTS algorithm steps.

__assumptions__:
* Deterministic state transition
* Deterministic rewards

In [None]:
assert isinstance(env,WithSnapshots)

In [None]:
class Node:
    """ a tree node for MCTS """
    
    def __init__(self,parent,action,):
        """
        Creates and empty node with no children.
        Does so by commiting an action and recording outcome.
        
        :param parent: Node or None. if None, it's a root.
        :param action: action to commit from parent state
        
        """
        
        self.parent = parent
        self.action = action
        
        #metadata:
        self.value_sum = 0.         #sum of state values from all visits (numerator)
        self.times_visited = 0      #counter of visits (denominator)
        
        self.children = set()       #set of child nodes
        self.visited_children = 0   #counter of explored child nodes

        #get action outcome and save it
        res = env.get_result(parent.snapshot,action)
        self.snapshot,self.observation,self.immediate_reward,self.is_done,_ = res
        
        
    def is_leaf(self):
        return len(self.children)==0
    
    def is_root(self):
        return self.parent is None
    
    def get_mean_value(self):
        return self.value_sum / self.times_visited if self.times_visited !=0 else 0
    
    def ucb_score(self,max_value=1e100):
        """
        computes ucb1 upper bound using current value and visit counts for node and it's parent
        """
        
        if self.times_visited == 0:
            return max_value
        
        #compute ucb-1
        U = <your code here>
        
        return self.get_mean_value() + U
    
    
    #MCTS steps
    
    def select_best_leaf(self):
        """
        Picks the leaf with highest priority to expand
        Does so by recursively picking nodes with best UCB-1 score until it reaches the leaf.
        
        """
        if self.is_leaf():
            return self
        
        best_node = <select child with best ucb>
        
        #update visit counts
        if best_node.times_visited == 0:
            self.visited_children+=1
        
        return best_node.select_best_leaf()
    
    def expand(self):
        """spawn children of the node"""
        
        assert not self.is_done, "can't expand from terminal state"

        for action in range(n_actions):
            self.children.add(Node(self,action))
        
        return self.select_best_leaf()
    
    def simulate(self,t_max=10**4):
        """
        rollout with random policy and compute reward
        """
        
        #set env into appropriate state
        env.load_snapshot(self.snapshot)
        obs = self.observation
        is_done = self.is_done
        
        rollout_reward = 0   #sum of all rewards from rollout
        
        #note: use env.action_space.sample() for random action
        <your code here - rollout and compute reward>

        return rollout_reward
    
    def propagate(self,child_value):
        """propagate reward updard"""
        
        my_value = self.immediate_reward + child_value
        
        self.value_sum+=my_value
        self.times_visited+=1
        
        if not self.is_root():
            self.parent.propagate(my_value)
        
    def safe_delete(self):
        """safe delete to prevent memory leak in somy python versions"""
        del self.parent
        for child in self.children:
            del child

In [None]:
class Root(Node):
    def __init__(self,snapshot,observation):
        """
        creates special node that acts like root
        :snapshot: snapshot (from env.get_snapshot) to start planning from
        :observation: last environment observation
        """
        
        self.parent = self.action = None

        #metadata: copy-pasted from Node
        self.value_sum = 0.         #sum of state values from all visits (numerator)
        self.times_visited = 0      #counter of visits (denominator)
        
        self.children = set()       #set of child nodes
        self.visited_children = 0   #counter of explored child nodes
        
        #root: load snapshot and observation
        self.snapshot = snapshot
        self.observation = observation
        self.immediate_reward = 0
        self.is_done=False
    
    @staticmethod
    def from_node(node):
        """initializes node from root"""
        root = Root(node.snapshot,node.observation)
        
        copied_fields = ["value_sum","times_visited","children","visited_children","is_done"]
        
        for field in copied_fields:
            setattr(root,field,getattr(node,field))
        
        return root

## Main MCTS loop

With all we implemented, MCTS boils down to a trivial piece of code.

In [None]:
def plan_mcts(root,n_iters=10):
    """
    builds tree with monte-carlo tree search for n_iters iterations
    :param root: tree node to plan from
    :param n_iters: how many select-expand-simulate-propagete loops to make
    """
    for _ in range(n_iters):

        node = <select best leaf>

        if node.is_done:
            node.propagate(0)

        else: #node is not terminal
            <expand-simulate-propagate loop>
    
    env.reset()

## Plan and execute

In [None]:
root_observation = env.reset()
root_snapshot = env.get_snapshot()
root = Root(root_snapshot,root_observation)

In [None]:
#plan from root:
<your code>

In [None]:
from IPython.display import clear_output
from itertools import count

env.load_snapshot(root_snapshot)
total_reward = 0 #sum of rewards

for step in count():
    assert not root.is_leaf(), "We ran out of tree! Need more planning!"

    #get best child
    best_child = <select child with highest get_mean_value>
    
    #take action
    s,r,done,_ = env.step(best_child.action)
    
    #show image
    clear_output(True)
    <visualize state>

    total_reward += r
    if done:
        print("Finished with reward = ",total_reward)
        break
    
    #discard unrealized part of the tree [because not every child matters :(]
    for child in root.children:
        if child != best_child:
            child.safe_delete()

    #declare best child a new root
    root = Root.from_node(best_child)
    

#### This is but a seminar...

Bonus assignments will come shortly.

__If you're eager,__
* Try updating tree on the fly.
* Try re-writing WithSnapshots wrapper to work with atari
 * Atari has a special interface for snapshots:
   ```   
   snapshot = self.env.ale.cloneState()
   ...
   self.env.ale.restoreState(snapshot)
   ```
 * Debug on ```gym.make("MsPacman-ramDeterministic-v0")```
 
 
 
Inspired by [this gist](https://gist.github.com/blole/dfebbec182e6b72ec16b66cc7e331110)