In [4]:
from mdp import *
import math, time, random
from collections import defaultdict
import numpy as np
import matplotlib.pyplot as plt

#### Define a base Class for the tree nodes

In [6]:
class Node:

    # static counter for node IDs
    next_node_id = 0

    # static dictionary for recording the number of times the node has been visited
    visits = defaultdict(lambda: 0)

    def __init__(self, mdp, parent, state, qfunction, bandit, reward=0.0, action=None):
        self.mdp = mdp
        self.parent = parent
        self.state = state
        self.id = Node.next_node_id
        Node.next_node_id += 1

        # initialize Q function
        self.qfunction = qfunction

        # multi-armed bandit for node selection
        self.bandit = bandit

        # immediate reward received for transitioning into this state
        self.reward = reward

        # action that generated this node
        self.action = action

    # select a node that hasn't been fully expanded, i.e. leaf node
    def select(self): abstractmethod

    # expand a node if it is non-terminal state
    def expand(self): abstractmethod

    # backpropagate accumulate reward to the root node
    def backpropagate(self): abstractmethod

    # return V of this node
    def get_value(self):
        _, value = self.qfunction.get_maxQ(self.state, self.mdp.get_actions(self.state))

    




#### Define a class implementing the MCTS algorithm

In [7]:
class MCTS:
    
    def __init__(self, mdp, qfunction, bandit):
        self.mdp = mdp
        self.qfunction = qfunction
        self.bandit = bandit


    def mcts(self, timeout=1, root_node=None):
        # create a root node if none provided
        if root_node == None:
            root_node = self.create_root_node()

        # start the timer
        start_time = time.time()
        current_time = time.time()

        # perform mcts iterations until timeout
        while current_time < start_time + timeout:

            # select node for expansion
            selected_node = root_node.expand()
            
            if not (self.mdp.is_exit(selected_node.state)):

                # expand the selected node to generate a child node (if the node is not a terminal state)
                child = selected_node.expand()

                # run simulation to get a reward
                reward = self.simulate(child)

                # backpropagate the reward to root node
                selected_node.backpropagate(reward, child)

            current_time = time.time()          

        return root_node


    # createa a root node representing the initial state
    def create_root_node(self): abstractmethod       


    # choose a random action for monte carlo simulation (can use a heuristic to choose actions instead of picking at random)
    def choose_action(self, state):
        actions = self.mdp.get_actions(state)
        return random.choice(actions)
    
    # run simulation ubntil terminal state reached (can be stopped after a fixed number of time steps instead of running until reaching terminal state)
    def simulate(self, node):
        state = node.state
        cumulative_reward = 0.0
        depth = 0

        while not self.mdp.is_exit(state):
            # choose an action to execute
            action = self.choose(state)
            # transition to next state
            (next_state, reward) = self.mdp.execute(state, action)
            # discount the reward
            cumulative_reward += pow(self.mdp.gamma, depth) * reward 
            depth += 1
            state = next_state

        return cumulative_reward    



#### Implement a derived class for tree node for single agent MCTS

In [None]:
class SingleAgentNode(Node):
    
        