In [1]:
from mdp import *
import math, time, random
from collections import defaultdict
import numpy as np
import matplotlib.pyplot as plt

#### Define a base Class for the tree nodes

In [2]:
class Node:

    # static counter for node IDs
    next_node_id = 0

    # static dictionary for recording the number of times each node in the tree has been visited
    visits = defaultdict(lambda: 0)

    def __init__(self, mdp, parent, state, qfunction, bandit, reward=0.0, action=None):
        self.mdp = mdp
        self.parent = parent
        self.state = state
        self.id = Node.next_node_id
        Node.next_node_id += 1

        # initialize Q function
        self.qfunction = qfunction

        # multi-armed bandit for node selection
        self.bandit = bandit

        # immediate reward received for transitioning into this state
        self.reward = reward

        # action that generated this node
        self.action = action

    # select a node that hasn't been fully expanded, i.e. leaf node
    def select(self): abstractmethod

    # expand a node if it is non-terminal state
    def expand(self): abstractmethod

    # backpropagate accumulate reward to the root node
    def backpropagate(self): abstractmethod

    # return V of this node
    def get_value(self):
        _, value = self.qfunction.get_maxQ(self.state, self.mdp.get_actions(self.state))

    




#### Define a class implementing the MCTS algorithm

In [3]:
class MCTS:
    def __init__(self, mdp, qfunction, bandit):
        self.mdp = mdp
        self.qfunction = qfunction
        self.bandit = bandit


    # performs mcts from specified root node (timeout in seconds)
    def mcts(self, timeout=1, root_node=None):
        # create a root node if none provided
        if root_node == None:
            root_node = self.create_root_node()

        # start the timer
        start_time = time.time()
        current_time = time.time()
        num_iterations = 0

        # perform mcts iterations until timeout
        while current_time < start_time + timeout:

            print(f"MCTS iteration# {num_iterations}")

            # select node for expansion
            selected_node = root_node.select()
            
            if not (self.mdp.is_exit(selected_node.state)):

                # expand the selected node to generate a child node (if the node is not a terminal state)
                child = selected_node.expand()

                # run simulation to get a reward
                reward = self.simulate(child)

                # backpropagate the reward to root node
                selected_node.backpropagate(reward, child)

                # update value function and display the table
                self.qfunction.update_V_from_Q()
                self.qfunction.display()    


            current_time = time.time()      
            num_iterations += 1    

        return root_node


    # createa a root node representing the initial state
    def create_root_node(self): abstractmethod       


    # choose a random action for monte carlo simulation (can use a heuristic to choose actions instead of picking at random)
    def choose(self, state):
        actions = self.mdp.get_actions(state)
        return random.choice(actions)
    
    # run simulation until terminal state reached (can be stopped after a fixed number of time steps instead of running until reaching terminal state)
    def simulate(self, node):
        state = node.state
        cumulative_reward = 0.0
        depth = 0

        while not self.mdp.is_exit(state):
            # choose an action to execute
            action = self.choose(state)
            # transition to next state
            (next_state, reward) = self.mdp.execute(state, action)
            # discount the reward
            cumulative_reward += pow(self.mdp.gamma, depth) * reward 
            depth += 1
            state = next_state

        return cumulative_reward    
    



#### Implement derived classes for tree node and MCTS for single-agent MCTS

In [4]:
class SingleAgentNode(Node):
    
    def __init__(self, mdp, parent, state, qfunction, bandit, reward=0, action=None):
        super().__init__(mdp, parent, state, qfunction, bandit, reward, action)

        # a dictionary from actions to node-probability pairs
        self.children = {}


    # checks if a node has been fully expanded
    def is_fully_expanded(self):    
        actions = self.mdp.get_actions(self.state)
        if(len(actions) == len(self.children)):
            return True
        else:
            return False

    # recursively traverse the tree and select a node that has not been fully expanded yet
    def select(self):

        if not self.is_fully_expanded() or self.mdp.is_exit(self.state):
            return self
        else:
            actions = list(self.children.keys())
            action =  self.bandit.select(self.state, actions, self.qfunction)
            return self.get_outcome_child(action).select()   


    # expand a node if it's a non terminal-state
    def expand(self):
        if not self.mdp.is_exit(self.state):
            #randomly select an unexpanded action to expand
            unexplored_actions = self.mdp.get_actions(self.state) - self.children.keys()
            action = random.choice(list(unexplored_actions))
            # create a slot for that action in the children dictionary
            self.children[action] = []
            return self.get_outcome_child(action)
        # for terminal state, can't expand further
        return self


    # backpropagate reward back to root node (recursively update all nodes along the path to the root)
    def backpropagate(self, G, child):
        # get the action which generated the child
        action = child.action

        # update number of times visited for both the state (white) node and state-action (black) node
        Node.visits[self.state] = Node.visits[self.state] + 1
        Node.visits[(self.state, action)] = Node.visits[(self.state, action)] + 1

        # get current Q value 
        qvalue = self.qfunction.evaluate(self.state, action)
        # compute update delta
        delta = (G - self.qfunction.evaluate(self.state, action)) / Node.visits[(self.state, action)]
        # update the Q value
        self.qfunction.update(self.state, action, qvalue, delta)

        # recursively backpropagate until root node is reached
        if self.parent != None:
            self.parent.backpropagate(self.reward + G, self)


    def get_outcome_child(self, action):
        # choose one outcome state based on transition probabilities
        (next_state, reward) = self.mdp.execute(self.state, action)

        # check if this outcome state is already in the child set
        for (child, _) in self.children[action]:
            if next_state == child.state:
                return child

        # if this outcome did not occur before from this state-action pair, then create a new child node
        new_child = SingleAgentNode(self.mdp, self, next_state, self.qfunction, self.bandit, reward, action)     
        probability = 0.0
        for (outcome, probability) in self.mdp.get_transitions(self.state, action):
            if outcome == next_state:
                self.children[action] += [(new_child, probability)]
                return new_child       



class SingleAgentMCTS(MCTS):

    def __init__(self, mdp, qfunction, bandit):
        super().__init__(mdp, qfunction, bandit)


    def create_root_node(self):
        return SingleAgentNode(self.mdp, None, self.mdp.get_initial_state(), self.qfunction, self.bandit)

#### Let's test the MCTS on the gridworld problem

In [5]:
# instantiate grid world mdp object
gw = GridWorld(discount_factor=0.9)

# instantaiate Q table
qfunction = QTable(gw)

# instantiate a bandit
bandit = UCBBandit()

# instantiate MCTS solver
mcts_solver = SingleAgentMCTS(gw, qfunction, bandit)

In [6]:
# run the mcts solver
root_node = mcts_solver.mcts(7)

MCTS iteration# 0
-----------------------
 0.00  0.00  0.00  0.00 
 0.00  0.00  0.00  0.00 
 0.00  0.00  0.00  0.00 
-----------------------
-----------------------
up     up     up     end    
up     None   up     end    
up     up     up     up     
-----------------------
MCTS iteration# 1
-----------------------
 0.00  0.00  0.00  0.00 
 0.00  0.00  0.00  0.00 
 0.00  0.00  0.00  0.00 
-----------------------
-----------------------
up     up     up     end    
up     None   up     end    
up     up     up     up     
-----------------------
MCTS iteration# 2
-----------------------
 0.00  0.00  0.00  0.00 
 0.00  0.00  0.00  0.00 
 0.21  0.00  0.00  0.00 
-----------------------
-----------------------
up     up     up     end    
up     None   up     end    
up     up     up     up     
-----------------------
MCTS iteration# 3
-----------------------
 0.00  0.00  0.00  0.00 
 0.00  0.00  0.00  0.00 
 0.21  0.00  0.00  0.00 
-----------------------
-----------------------
up     