In [2]:
from mdp import *
import math, time, random
from collections import defaultdict
import numpy as np
import matplotlib.pyplot as plt

#### Two-player Monte Carlo Tree Search: For turn based games where actions are alternatingly carried out by an agent and it's opponent. We will demonstrate monte carlo search using the Tic Tac Toe game

#### Define a base class for game tree node

In [8]:
class GameNode:

    # static counter for node IDs
    next_node_id = 0

    # static dictionary for recording the number of times each node in the tree has been visited
    visits = defaultdict(lambda: 0)

    def __init__(self, state, player_turn, value, is_best_action=False, children=dict()):
        self.state = state
        self.player_turn = player_turn  # marks which player's turn is on that state
        self.value = value
        self.is_best_action = is_best_action
        self.children = children

        self.id = GameNode.next_node_id
        GameNode.next_node_id += 1
        
    # select a node that hasn't been fully expanded, i.e. leaf node
    def select(self): abstractmethod

    # expand a node if it is non-terminal state
    def expand(self): abstractmethod

    # backpropagate accumulate reward to the root node
    def backpropagate(self): abstractmethod

    # return value of this node
    def get_value(self):
        return self.value



#### Define a class for the tic tac toe game

In [27]:
EMPTY = ' '
CIRCLE = 'O'
CROSS = 'X'

class TicTacToe:

    def __init__(self):
        self.players = [CIRCLE, CROSS]

    # get list of players
    def get_players(self):
        return self.players
    
    # get all valid actions from a state (a valid action is just an empty position on the board which can be marked by a player)
    def get_actions(self, state):
        board = state
        actions = []
        for i in range(len(board)):
            for j in range(len(board[i])):
                if board[i][j] == EMPTY:
                    actions += [(i,j)]
            
        return actions


    # make deep copy of a state
    def copy(self, state):
        next_state = []
        for i in range(len(state)):
            new_row = []
            for j in range(len(state[i])):
                new_row += [state[i][j]]
            next_state += [new_row]    
        
        return next_state
    

    # return the state reulting from playing an action from a state
    def get_transition(self, state, action): 
        next_state = self.copy(state)
        next_state[action[0]][action[1]] = self.get_player_turn(state)
        
        return next_state


    # return all possible auccessors from the state
    def get_successors(self, state):
        if self.is_terminal(state):
            # if game is finished, there's no successor states
            return set()
        else:
            successors = []
            actions = self.get_actions(state)
            for action in actions:
                successors += [self.get_transition(state, action)]
            
            return successors


    # return the reward for transitioning into the state
    def get_reward(self, state):
        winner = self.get_winner(state)
        if winner == None:
            return {CROSS: 0, CIRCLE: 0}
        
        elif winner == CROSS:
            return {CROSS: 1, CIRCLE: 0}

        elif winner == CIRCLE:
            return {CROSS: 0, CIRCLE: 1}


    # count how many empty positions on the board
    def count_empty(self, board):
        empty = 0
        for i in range(len(board)):
            for j in range(len(board[i])):
                if board[i][j] == EMPTY:
                    empty += 1

        return empty
    
    # return true iff state is a terminal state of the game
    def is_terminal(self, state):
        return self.count_empty(state)==0 or self.get_winner(state) != None
    
    # return player who gets to select the action at current state, i.e. whose turn it is
    def get_player_turn(self, state):
        board = state
        # cross always starts the game, so crosses turn if there are an odd number of empty cells (9 cells total)
        empty = self.count_empty(board)
        if empty %2 == 0:
            return CIRCLE
        else:
            return CROSS


    # initial state of the game (empty board)
    def initial_state(self):
        board = [[EMPTY, EMPTY, EMPTY], 
                 [EMPTY, EMPTY, EMPTY], 
                 [EMPTY, EMPTY, EMPTY]]

        return board

    def get_winner(self, state):
        board = state

        # check rows
        for i in range(len(board)):
            circles = 0
            crosses = 0
            for j in range(len(board[i])):
                if board[i][j] == CIRCLE:
                    circles += 1
                elif board[i][j] == CROSS:
                    crosses += 1
            if crosses == len(board[i]):
                return CROSS      
            elif circles == len(board[i]):
                return CIRCLE 
       
        # check columns
        for j in range(len(board[0])):
            circles = 0
            crosses = 0
            for i in range(len(board)):
                if board[i][j] == CIRCLE:
                    circles += 1
                elif board[i][j] == CROSS:
                    crosses += 1
            if crosses == len(board):
                return CROSS      
            elif circles == len(board):
                return CIRCLE 

        # check top-left to bottom-right diagonal
        if board[0][0] == CIRCLE and board[1][1] == CIRCLE and board[2][2] == CIRCLE:
            return CIRCLE
        elif board[0][0] == CROSS and board[1][1] == CROSS and board[2][2] == CROSS:
            return CROSS

        # check top-right to bottom-left diagonal
        if board[0][2] == CIRCLE and board[1][1] == CIRCLE and board[2][0] == CIRCLE:
            return CIRCLE
        elif board[0][2] == CROSS and board[1][1] == CROSS and board[2][0] == CROSS:
            return CROSS

        # no winner
        return None
    

    # format the board into a string
    def display_board(self, state):
        for row in state:
            for i in range(len(row)):        
                if i < len(row)-1:
                    print(row[i], end=' | ')
                else:
                    print(row[i], end='')    
            print()

    def game_tree(self):
        return self.state_to_node(self.get_initial_state())
    

    def state_to_node(self, state):
        if self.is_terminal(state):
            return GameNode(state, None, self.get_reward(state))
        
        player = self.get_player_turn(state)
        children = dict()
        for action in self.get_actions(state):
            next_state = self.get_transition(state, action)
            child_node = self.state_to_node(next_state)
            children[action] = child_node
        
        return GameNode(state, player, None, children)    


#### Now define an MCTS class for the tic tac toe game 

In [None]:
class MCTS:
    def __init__(self, game, exploration_weight=1.0):
        self.game = game

        # dictionary for storing value/total reward at each node
        self.Q = defaultdict(int)
        # dictionary for storing visit counts for each node
        self.visited = defaultdict(int)
        # dictionary for storing children of expanded nodes
        self.children = dict()
        self.exploration_weight = exploration_weight

                    
    # choose a move in the game 
    def choose_move(self, node):
        # make sure the state is not terminal
        if self.game.is_terminal(node.state):
            raise RuntimeError(f"choose called on terminal state: {node.state}")

        # get all valid actions
        actions = self.game.get_actions(node.state)
        
        if node not in self.children:
            return random.choice(actions)
        
        def score(n):
            if self.N[n] == 0:
                return float("-inf") # unseen moves have lowest possible value
            else:
                return self.Q[n] / self.N[n] # average reward

        # return child with best value
        return max(self.children[node], key=score)


    # traverse down from a node and select a descendent node which has not been fully expanded yet
    def select(self, node):
        path = []
        while True:
            path.append(node)
            if node not in self.children or not self.children[node]:
                # node is either unexplored/unexpanded or terminal (i.e. node has no children)
                return path
            # get unexpanded nodes 
            unexplored = self.children[node] - self.children.keys()
            if unexplored:
                n = unexplored.pop()
                path.append(n)
                return path
            
            # if all children expanded, need to go another level deeper, select child using UCB bandit
            node = self.ucb_select(node)
            

    # select successor node using ucb bandit
    def ucb_select(self, node):
        log_N = math.log(self.N[node])
        def ucb(n):
            return self.Q[n]/self.N[n] + self.exploration_weight * math.sqrt(log_N/self.N[n])
        
        return max(self.children[node], key=ucb)


    # update the children dictionary with the children of node
    def expand(self, node):
        # node already expanded
        if node in self.children:
            return
        self.children[node] = node.get_transitions()
    

    # run simulation until terminal state reached (can be stopped after a fixed number of time steps instead of running until reaching terminal state)
    def simulate(self, node):
        state = node.state
        cumulative_reward = 0.0
        depth = 0
        invert_reward = True

        while True:

            if self.game.is_terminal(node.state):
                reward = self.game.get_reward(state)
                return 1- reward if invert_reward else reward 
            
            node = random.choice(self.game.get_actions(node.state))

            # choose an action to execute
            action = self.choose(state)
            # transition to next state
            (next_state, reward) = self.mdp.execute(state, action)
            # discount the reward
            cumulative_reward += pow(self.mdp.gamma, depth) * reward 
            depth += 1
            state = next_state

        return cumulative_reward    



    # performs mcts from specified root node (timeout in seconds)
    def mcts(self, timeout=1, root_node=None):
        # create a root node if none provided
        if root_node == None:
            root_node = self.create_root_node()

        # start the timer
        start_time = time.time()
        current_time = time.time()
        num_iterations = 0

        # perform mcts iterations until timeout
        while current_time < start_time + timeout:

            # select node for expansion
            selected_node = root_node.select()
            
            if not (self.mdp.is_exit(selected_node.state)):

                # expand the selected node to generate a child node (if the node is not a terminal state)
                child = selected_node.expand()

                # run simulation to get a reward
                reward = self.simulate(child)

                # backpropagate the reward to root node
                selected_node.backpropagate(reward, child) 

            current_time = time.time()      
            num_iterations += 1    


        print(f"MCTS iterations: {num_iterations}")

        # update value function and display the table
        self.qfunction.update_V_from_Q()
        self.qfunction.display()  

        return root_node
 


    # backpropagate reward back to root node (recursively update all nodes along the path to the root)
    def backpropagate(self, G, child):
        # get the action which generated the child
        action = child.action

        # update number of times visited for both the state (white) node and state-action (black) node
        Node.visits[self.state] = Node.visits[self.state] + 1
        Node.visits[(self.state, action)] = Node.visits[(self.state, action)] + 1

        # get current Q value 
        qvalue = self.qfunction.evaluate(self.state, action)
        # compute update delta
        delta = (G - self.qfunction.evaluate(self.state, action)) / Node.visits[(self.state, action)]
        # update the Q value
        self.qfunction.update(self.state, action, qvalue, delta)

        # recursively backpropagate until root node is reached
        if self.parent != None:
            self.parent.backpropagate(self.reward + G, self)



