In [1]:
from mdp import *
import math, time, random
from collections import defaultdict
import numpy as np
import matplotlib.pyplot as plt

#### Define a base Class for the tree nodes

In [6]:
class Node:

    # static counter for node IDs
    next_node_id = 0

    # static dictionary for recording the number of times each node in the tree has been visited
    visits = defaultdict(lambda: 0)

    def __init__(self, mdp, parent, state, qfunction, bandit, reward=0.0, action=None):
        self.mdp = mdp
        self.parent = parent
        self.state = state
        self.id = Node.next_node_id
        Node.next_node_id += 1

        # initialize Q function
        self.qfunction = qfunction

        # multi-armed bandit for node selection
        self.bandit = bandit

        # immediate reward received for transitioning into this state
        self.reward = reward

        # action that generated this node
        self.action = action

    # select a node that hasn't been fully expanded, i.e. leaf node
    def select(self): abstractmethod

    # expand a node if it is non-terminal state
    def expand(self): abstractmethod

    # backpropagate accumulate reward to the root node
    def backpropagate(self): abstractmethod

    # return V of this node
    def get_value(self):
        _, value = self.qfunction.get_maxQ(self.state, self.mdp.get_actions(self.state))



#### Define a class implementing the MCTS algorithm

In [7]:
class MCTS:
    def __init__(self, mdp, qfunction, bandit):
        self.mdp = mdp
        self.qfunction = qfunction
        self.bandit = bandit


    # performs mcts from specified root node (timeout in seconds)
    def mcts(self, timeout=1, root_node=None):
        # create a root node if none provided
        if root_node == None:
            root_node = self.create_root_node()

        # start the timer
        start_time = time.time()
        current_time = time.time()
        num_iterations = 0

        # perform mcts iterations until timeout
        while current_time < start_time + timeout:

            # select node for expansion
            selected_node = root_node.select()
            
            if not (self.mdp.is_exit(selected_node.state)):

                # expand the selected node to generate a child node (if the node is not a terminal state)
                child = selected_node.expand()

                # run simulation to get a reward
                reward = self.simulate(child)

                # backpropagate the reward to root node
                selected_node.backpropagate(reward, child) 

            current_time = time.time()      
            num_iterations += 1    


        print(f"MCTS iterations: {num_iterations}")

        # update value function and display the table
        self.qfunction.update_V_from_Q()
        self.qfunction.display()  

        return root_node


    # createa a root node representing the initial state
    def create_root_node(self): abstractmethod       


    # choose a random action for monte carlo simulation (can use a heuristic to choose actions instead of picking at random)
    def choose(self, state):
        actions = self.mdp.get_actions(state)
        
        # choose actions randomly
        next_action = random.choice(actions)
        
        # heuristic based:  choose action that can lead to a state with highest immediate reward
        '''
        max_reward = float("-inf")
        max_actions = [] 
        for action in actions:
            transitions = self.mdp.get_transitions(state, action)
            for (next_state, _) in transitions:
                reward = self.mdp.get_rewards(state, action, next_state)
                if reward > max_reward:
                    max_reward = reward
                    max_actions = [action]
                elif reward == max_reward:
                    max_actions.append(action)

        # if multiple actions lead to best reward state,. pick one randomly
        next_action = random.choice(max_actions)                 
        '''
        
        return next_action
    
    
    # run simulation until terminal state reached (can be stopped after a fixed number of time steps instead of running until reaching terminal state)
    def simulate(self, node):
        state = node.state
        cumulative_reward = 0.0
        depth = 0

        while not self.mdp.is_exit(state):
            # choose an action to execute
            action = self.choose(state)
            # transition to next state
            (next_state, reward) = self.mdp.execute(state, action)
            # discount the reward
            cumulative_reward += pow(self.mdp.gamma, depth) * reward 
            depth += 1
            state = next_state

        return cumulative_reward    
    



#### Implement derived classes for tree node and MCTS for single-agent MCTS

In [8]:
class SingleAgentNode(Node):
    
    def __init__(self, mdp, parent, state, qfunction, bandit, reward=0, action=None):
        super().__init__(mdp, parent, state, qfunction, bandit, reward, action)

        # a dictionary from actions to node-probability pairs
        self.children = {}


    # checks if a node has been fully expanded
    def is_fully_expanded(self):    
        actions = self.mdp.get_actions(self.state)
        if(len(actions) == len(self.children)):
            return True
        else:
            return False

    # recursively traverse the tree and select a node that has not been fully expanded yet
    def select(self):

        if not self.is_fully_expanded() or self.mdp.is_exit(self.state):
            return self
        else:
            actions = list(self.children.keys())
            action =  self.bandit.select(self.state, actions, self.qfunction)
            return self.get_outcome_child(action).select()   


    # expand a node if it's a non terminal-state
    def expand(self):
        if not self.mdp.is_exit(self.state):
            # randomly select an unexpanded action to expand
            unexplored_actions = self.mdp.get_actions(self.state) - self.children.keys()
            action = random.choice(list(unexplored_actions))
            # create a slot for that action in the children dictionary
            self.children[action] = []
            return self.get_outcome_child(action)
        # for terminal state, can't expand further
        return self


    def get_outcome_child(self, action):
        # choose one outcome state based on transition probabilities
        (next_state, reward) = self.mdp.execute(self.state, action)

        # check if this outcome state is already in the child set
        for (child, _) in self.children[action]:
            if next_state == child.state:
                return child

        # if this outcome did not occur before from this state-action pair, then create a new child node
        new_child = SingleAgentNode(self.mdp, self, next_state, self.qfunction, self.bandit, reward, action)     
        probability = 0.0
        for (outcome, probability) in self.mdp.get_transitions(self.state, action):
            if outcome == next_state:
                self.children[action] += [(new_child, probability)]
                return new_child       


    # backpropagate reward back to root node (recursively update all nodes along the path to the root)
    def backpropagate(self, G, child):
        # get the action which generated the child
        action = child.action

        # update number of times visited for both the state (white) node and state-action (black) node
        Node.visits[self.state] = Node.visits[self.state] + 1
        Node.visits[(self.state, action)] = Node.visits[(self.state, action)] + 1

        # get current Q value 
        qvalue = self.qfunction.evaluate(self.state, action)
        # compute update delta
        delta = (G - self.qfunction.evaluate(self.state, action)) / Node.visits[(self.state, action)]
        # update the Q value
        self.qfunction.update(self.state, action, qvalue, delta)

        # recursively backpropagate until root node is reached
        if self.parent != None:
            self.parent.backpropagate(self.reward + G, self)


class SingleAgentMCTS(MCTS):

    def __init__(self, mdp, qfunction, bandit):
        super().__init__(mdp, qfunction, bandit)


    def create_root_node(self, root_state=None):
        if root_state == None:
            return SingleAgentNode(self.mdp, None, self.mdp.get_initial_state(), self.qfunction, self.bandit)
        else:
            return SingleAgentNode(self.mdp, None, root_state, self.qfunction, self.bandit)

#### Let's test the MCTS on the gridworld problem

In [5]:
'''
# instantiate grid world mdp object
gw = GridWorld(discount_factor=0.9)

# instantaiate Q table
qfunction = QTable(gw)

# instantiate a bandit
bandit = UCBBandit()

# instantiate MCTS solver
mcts_solver = SingleAgentMCTS(gw, qfunction, bandit)

# run the mcts solver from the root node
root_node = mcts_solver.mcts(1)
'''

'\n# instantiate grid world mdp object\ngw = GridWorld(discount_factor=0.9)\n\n# instantaiate Q table\nqfunction = QTable(gw)\n\n# instantiate a bandit\nbandit = UCBBandit()\n\n# instantiate MCTS solver\nmcts_solver = SingleAgentMCTS(gw, qfunction, bandit)\n\n# run the mcts solver from the root node\nroot_node = mcts_solver.mcts(1)\n'

#### Note that the best action according to our MCTS policy from the initial state (0,0) is up which is consistent with the optimal policy for this problem. Now lets use MCTS to execute a sequence of actions and check the policy at each step. 

In [9]:
# instantiate grid world mdp object
gw2 = GridWorld(discount_factor=0.9)

# instantaiate Q table
qfunction2 = QTablePartial(gw2)

# instantiate a bandit
bandit2 = UCBBandit()

# instantiate MCTS solver
mcts_solver2 = SingleAgentMCTS(gw2, qfunction2, bandit2)

# set the initial state of the mdp as the root node
node = mcts_solver2.create_root_node()

max_steps = 100
steps = 0

while not gw2.is_exit(node.state) and steps < max_steps:

    # run MCTS 
    print(f"Running MCTS...")
    node = mcts_solver2.mcts(0.5, node)
    best_action, _ = qfunction2.get_maxQ(node.state, gw2.get_actions(node.state))   

    # transition to next state using the best action
    (next_state, _) = gw2.execute(node.state, best_action) 
    print(f"s0 = {node.state}, best_action: {gw2.action_names[best_action]}, next state: {next_state}")

    
    ''' 
    # set the child node corresponding to the next state as new root node
    for (child, probability) in node.children[best_action]:
        print(f"Child state: {child.state}, Probability: {probability}")
        if child.state == next_state:
            node = child
            break
    # disconnect with old parent node    
    node.parent = None
    '''

    # generate a new root node with the next state
    node = mcts_solver2.create_root_node(root_state=next_state)

    steps += 1
    print(f"New root node state: {node.state}")
       
print("Reached termimnal state")


Running MCTS...
MCTS iterations: 2252
-----------------------
 0.08  0.13  0.28  0.00 
 0.07  0.00  0.02  0.00 
 0.04  0.00 -0.27  0.00 
-----------------------
-----------------------
right  up     right  end    
up     None   up     end    
up     left   left   up     
-----------------------
s0 = (0, 0), best_action: up, next state: (0, 1)
New root node state: (0, 1)
Running MCTS...
MCTS iterations: 4206
-----------------------
 0.10  0.14  0.28  0.00 
 0.09  0.00  0.09  0.00 
 0.04  0.00 -0.27  0.00 
-----------------------
-----------------------
right  right  right  end    
up     None   up     end    
up     left   left   up     
-----------------------
s0 = (0, 1), best_action: up, next state: (0, 2)
New root node state: (0, 2)
Running MCTS...
MCTS iterations: 3187
-----------------------
 0.10  0.14  0.28  0.00 
 0.09  0.00  0.09  0.00 
 0.04  0.00 -0.27  0.00 
-----------------------
-----------------------
right  down   right  end    
up     None   up     end    
up     left

#### Now testing on CliffWorld problem 

In [7]:
'''
# instantiate grid world mdp object
cw = CliffWorld(discount_factor=0.9, noise=0.0)

# instantaiate Q table
qfunction = QTable(cw)

# instantiate a bandit
bandit = UCBBandit()

# instantiate MCTS solver
mcts_solver = SingleAgentMCTS(cw, qfunction, bandit)

# run the mcts solver
root_node = mcts_solver.mcts(1)
'''

'\n# instantiate grid world mdp object\ncw = CliffWorld(discount_factor=0.9, noise=0.0)\n\n# instantaiate Q table\nqfunction = QTable(cw)\n\n# instantiate a bandit\nbandit = UCBBandit()\n\n# instantiate MCTS solver\nmcts_solver = SingleAgentMCTS(cw, qfunction, bandit)\n\n# run the mcts solver\nroot_node = mcts_solver.mcts(1)\n'

In [11]:
# instantiate grid world mdp object
cw = CliffWorld(discount_factor=0.9, noise=0.1)

# instantaiate Q table
qfunction3 = QTablePartial(cw)

# instantiate a bandit
bandit3 = UCBBandit()

# instantiate MCTS solver
mcts_solver3 = SingleAgentMCTS(cw, qfunction3, bandit3)

# set the initial state of the mdp as the root node
node = mcts_solver3.create_root_node()

max_steps = 100
steps = 0

while not cw.is_exit(node.state) and steps < max_steps:

    # run MCTS for 1 second
    print(f"Running MCTS...")
    node = mcts_solver3.mcts(1.0, node)
    best_action, _ = qfunction3.get_maxQ(node.state, cw.get_actions(node.state))   

    # transition to next state using the best action
    (next_state, _) = cw.execute(node.state, best_action) 
    print(f"s0 = {node.state}, best_action: {cw.action_names[best_action]}, next state: {next_state}")

    # set the child node corresponding to the next state as new root node
    '''
    for (child, probability) in node.children[best_action]:
        print(f"Child state: {child.state}, Probability: {probability}")
        if child.state == next_state:
            node = child
            break
    # disconnect with old parent node    
    node.parent = None    
    '''

    # generate a new root node with the next state
    node = mcts_solver3.create_root_node(root_state=next_state)
    # reset Q function to zeros
    #qfunction3.reset()

    steps += 1
       
print("Reached termimnal state")


Running MCTS...
MCTS iterations: 7500
-----------------------
-2.43 -2.54 -2.37 -2.04  0.07  0.00 
-1.92 -0.57 -0.32 -2.05  0.37  0.00 
-2.83 -4.34 -1.24 -5.34  0.26  0.00 
-4.22  0.00  0.00  0.00  0.00  0.00 
-----------------------
-----------------------
up     up     up     up     down   up     
up     up     up     up     down   up     
up     up     up     down   left   up     
up     end    end    end    end    end    
-----------------------
s0 = (0, 0), best_action: up, next state: (0, 1)
Running MCTS...
MCTS iterations: 7212
-----------------------
-2.43 -2.54 -2.37 -2.04  0.07  0.00 
-1.92 -0.57 -0.32 -2.05  0.37  0.00 
-2.83 -4.34 -1.24 -5.34  0.26  0.00 
-4.22  0.00  0.00  0.00  0.00  0.00 
-----------------------
-----------------------
up     up     up     up     up     up     
right  up     up     up     right  up     
up     down   up     down   right  up     
up     end    end    end    end    end    
-----------------------
s0 = (0, 1), best_action: up, next state: (