In [205]:
from mdp import *
import math, time, random
from collections import defaultdict
import numpy as np
import matplotlib.pyplot as plt

#### Two-player Monte Carlo Tree Search: For turn based games where actions are alternatingly carried out by an agent and it's opponent. We will demonstrate monte carlo search using the Tic Tac Toe game

#### Define a class for a tic tac toe game

In [206]:
EMPTY = '-'
CIRCLE = 'O'
CROSS = 'X'

class TicTacToe:

    def __init__(self, gamma = 1.0):
        self.players = [CIRCLE, CROSS]
        self.CROSS = CROSS
        self.CIRCLE = CIRCLE
        self.gamma = 1.0

    # get list of players
    def get_players(self):
        return self.players
    
    
    def get_opponent(self, player):
        if player == CIRCLE:
            opponent = CROSS
        else:
            opponent = CIRCLE 
        return opponent       
    
    
    # get all valid actions from a state (a valid action is just an empty position on the board which can be marked by a player)
    def get_actions(self, state):
        board = state
        actions = []
        for i in range(len(board)):
            for j in range(len(board[i])):
                if board[i][j] == EMPTY:
                    actions += [(i,j)]
            
        return actions


    # make deep copy of a state
    def copy(self, state):
        next_state = []
        for i in range(len(state)):
            new_row = []
            for j in range(len(state[i])):
                new_row += [state[i][j]]
            next_state += [new_row]    
        
        return next_state
    

    # return the reward and state reulting from playing an action from a state
    def execute(self, state, action): 
        next_state = self.copy(state)
        next_state[action[0]][action[1]] = self.get_player_turn(state)
        
        return (next_state, self.get_reward(next_state))


    # return the state reulting from playing an action from a state
    def get_transition(self, state, action): 
        next_state = self.copy(state)
        next_state[action[0]][action[1]] = self.get_player_turn(state)
        
        return next_state


    # return all possible auccessors from the state
    def get_successors(self, state):
        if self.is_terminal(state):
            # if game is finished, there's no successor states
            return set()
        else:
            successors = []
            actions = self.get_actions(state)
            for action in actions:
                successors += [self.get_transition(state, action)]
            
            return successors


    # return the reward for transitioning into the state
    def get_reward(self, state):
        winner = self.get_winner(state)
        if winner == None:
            return {CROSS: 0, CIRCLE: 0}
        
        elif winner == CROSS:
            return {CROSS: 1, CIRCLE: 0}

        elif winner == CIRCLE:
            return {CROSS: 0, CIRCLE: 1}


    # count how many empty positions on the board
    def count_empty(self, board):
        empty = 0
        for i in range(len(board)):
            for j in range(len(board[i])):
                if board[i][j] == EMPTY:
                    empty += 1

        return empty
    
    # return true iff state is a terminal state of the game
    def is_terminal(self, state):
        #if self.get_winner(state) != None:
        #    print(f"The winner is : {self.get_winner(state)}")

        return self.count_empty(state)==0 or self.get_winner(state) != None
    
    # return player who gets to select the action at current state, i.e. whose turn it is
    def get_player_turn(self, state):
        board = state
        # cross always starts the game, so crosses turn if there are an odd number of empty cells (9 cells total)
        empty = self.count_empty(board)
        if empty %2 == 0:
            return CIRCLE
        else:
            return CROSS


    # initial state of the game (empty board)
    def initial_state(self):
        board = [[EMPTY, EMPTY, EMPTY], 
                 [EMPTY, EMPTY, EMPTY], 
                 [EMPTY, EMPTY, EMPTY]]

        return board

    def get_winner(self, state):
        board = state

        # check columns
        for i in range(len(board)):
            circles = 0
            crosses = 0
            for j in range(len(board[i])):
                if board[i][j] == CIRCLE:
                    circles += 1
                elif board[i][j] == CROSS:
                    crosses += 1
            if crosses == len(board[i]):
                return CROSS      
            elif circles == len(board[i]):
                return CIRCLE 
       
        # check rows
        for j in range(len(board[0])):
            circles = 0
            crosses = 0
            for i in range(len(board)):
                if board[i][j] == CIRCLE:
                    circles += 1
                elif board[i][j] == CROSS:
                    crosses += 1
            if crosses == len(board):
                return CROSS      
            elif circles == len(board):
                return CIRCLE 

        # check top-left to bottom-right diagonal
        if board[0][0] == CIRCLE and board[1][1] == CIRCLE and board[2][2] == CIRCLE:
            return CIRCLE
        elif board[0][0] == CROSS and board[1][1] == CROSS and board[2][2] == CROSS:
            return CROSS

        # check top-right to bottom-left diagonal
        if board[0][2] == CIRCLE and board[1][1] == CIRCLE and board[2][0] == CIRCLE:
            return CIRCLE
        elif board[0][2] == CROSS and board[1][1] == CROSS and board[2][0] == CROSS:
            return CROSS

        # no winner
        return None
    

    # format the board into a string
    def display_board(self, state):
        for row in state:
            for i in range(len(row)):        
                if i < len(row)-1:
                    print(row[i], end=' | ')
                else:
                    print(row[i], end='')    
            print()


    '''
    def game_tree(self):
        return self.state_to_node(self.get_initial_state())
    

    def state_to_node(self, state):
        if self.is_terminal(state):
            return GameNode(state, None, self.get_reward(state))
        
        player = self.get_player_turn(state)
        children = dict()
        for action in self.get_actions(state):
            next_state = self.get_transition(state, action)
            child_node = self.state_to_node(next_state)
            children[action] = child_node
        
        return GameNode(state, player, None, children)    
    '''

#### Define a base class for game tree node

In [207]:
class GameNode:

    # static counter for node IDs
    next_node_id = 0

    # static dictionary for recording the number of times each node in the tree has been visited
    visits = defaultdict(lambda: 0)

    def __init__(self, game, state, parent, player_turn, bandit, value=0.0, reward=0.0, is_best_action=False, children=dict(), action=None):
        self.game = game
        self.state = state
        self.parent = parent            # pointer to parent node
        self.player_turn = player_turn  # marks which player's turn is on that state
        self.value = value
        self.reward = reward
        self.bandit = bandit
        self.is_best_action = is_best_action
        self.children = children
        self.action = action # action which generated the child
        self.accumulated_rewards = 0.0

        self.id = GameNode.next_node_id
        GameNode.next_node_id += 1
        

    # recursively traverse the tree and select a node that has not been fully expanded yet
    def select(self):

        if not self.is_fully_expanded() or self.game.is_terminal(self.state):
            # stop recursion when we've found node which hasn't been fully expanded or is terminal state
            return self
        else:
            # use the bandit to select the next child
            actions = list(self.children.keys())
            values = {action:self.children[action].get_value() for action in actions}
            parent_visits = GameNode.visits[self.id]
            children_visits = {action:GameNode.visits[self.children[action].id] for action in actions}
            action =  self.bandit.select(actions, values, parent_visits, children_visits)
            return self.children[action].select()  


    # checks if a node has been fully expanded
    def is_fully_expanded(self):    
        actions = self.game.get_actions(self.state)
        if(len(actions) == len(self.children)):
            return True
        else:
            return False


    # expand a node (only generate one child, not all children) if it's a non terminal-state
    def expand(self):
        if not self.game.is_terminal(self.state):

            # randomly select an unexpanded action to expand
            valid_actions = self.game.get_actions(self.state)
            #print(f"Available actions: {valid_actions}, already expanded children actions: {list(self.children.keys())}")

            unexplored_actions = list(self.game.get_actions(self.state) - self.children.keys())
            #print(f"Unexplored actions: {unexplored_actions}")

            if len(unexplored_actions) > 0:
                action = random.choice(unexplored_actions)
            else:
                action = random.choice(list(self.children.keys()))

            #print(f"randomly chosen action = {action}")
            # create a slot for that action in the children dictionary
            #self.children[action] = []

            # create a new child node and add to children dictionary
            (child_state, reward) = self.game.execute(self.state, action)
            new_child = GameNode(self.game, child_state, self, self.game.get_opponent(self.player_turn), self.bandit, reward=reward, action=action, children=dict())    
            self.children[action] = new_child  # each action leads to only one child 
            #print(f"generated child id: {new_child.id}, expanded children actions: {list(new_child.children.keys())}")
            return new_child                   
        
        # for terminal state, can't expand further
        return self
    

    # backpropagate reward back to root node (recursively update all nodes along the path to the root)
    def backpropagate(self, G):

        # update number of times visited for the state
        GameNode.visits[self.id] = GameNode.visits[self.id] + 1

        # apply discount to the backpropagated reward
        discounted_G = {}
        for player in G:
            discounted_G[player] = G[player] * self.game.gamma   

        # update the accumulated reward
        simulation_reward = discounted_G[self.player_turn]
        self.accumulated_rewards += simulation_reward 

        print(f"Backpropagation. Node id: {self.id}, visited: {GameNode.visits[self.id]}, accumulated reward: {self.accumulated_rewards}")

        # recursively backpropagate until root node is reached
        if self.parent != None:
            self.parent.backpropagate(discounted_G)


    # return value of this node
    def get_value(self):
        return self.accumulated_rewards / GameNode.visits[self.id] 



# upper confidence bounds (UCB) bandit
class UCBBandit:
    def __init__(self):
        self.total = 0
        # dictionary for recording number of times each action has been chosen
        self.times_selected = {}


    def select(self, actions, values, parent_visits, children_visits):
        
        # first, make sure each action has been executed once
        for action in actions:
            if action not in self.times_selected.keys():
                self.times_selected[action] = 1
                self.total += 1
                return action

        max_actions = []
        max_value = float("-inf")
        for action in actions:
            value = values[action] + math.sqrt(2.0*math.log(parent_visits)/children_visits[action])
            if value > max_value:
                max_value = value
                max_actions = [action]
            elif value == max_value:
                max_actions.append(action)

        # if multiple actions with max value, pick one randomly
        selected_action = random.choice(max_actions)
        self.times_selected[selected_action] = self.times_selected[selected_action] + 1
        self.total += 1
        
        return selected_action    

#### Class for MCTS

In [208]:
class MCTS:
    def __init__(self, game, bandit):
        self.game = game
        self.bandit = bandit


    # performs mcts from specified root node (timeout in seconds)
    def mcts(self, timeout=1, root_node=None):
        # create a root node if none provided
        if root_node == None:
            root_node = self.create_root_node()

        print(f"root node id: {root_node.id}")

        # start the timer
        start_time = time.time()
        current_time = time.time()
        num_iterations = 0

        # perform mcts iterations until timeout
        while current_time < start_time + timeout and num_iterations<600:

            # select node for expansion
            selected_node = root_node.select()
            print(f"selected node id: {selected_node.id}")

            if not (self.game.is_terminal(selected_node.state)):

                # expand the selected node to generate a child node (if the node is not a terminal state)
                child = selected_node.expand()
                print(f"Child node id: {child.id}")

                # run simulation to get a reward
                reward = self.simulate(child)
                print(f"Simulation rewards: {reward}")

                # backpropagate the reward to root node
                child.backpropagate(reward) 

       

            current_time = time.time()      
            num_iterations += 1    


        print(f"MCTS iterations: {num_iterations}")


        return root_node


    # create a root node representing the initial state
    def create_root_node(self, root_state=None, player_turn=None):
        if root_state == None:
            return GameNode(self.game, self.game.initial_state(), None, self.game.CROSS, self.bandit) # cross gets first turn
        else:
            return GameNode(self.game, root_state, None, player_turn, self.bandit)  


    # choose a random action for monte carlo simulation (can use a heuristic to choose actions instead of picking at random)
    def choose(self, state):
        actions = self.game.get_actions(state)
        
        # choose actions randomly
        next_action = random.choice(actions)
                
        return next_action
    
    
    # run simulation until terminal state reached (can be stopped after a fixed number of time steps instead of running until reaching terminal state)
    def simulate(self, node):
        state = node.state
        # the second entry in the reward vector is for opponent agent
        player = node.player_turn
        opponent = self.game.get_opponent(node.player_turn)
        cumulative_reward = {player:0.0, opponent:0.0}
        depth = 0

        print(f"simulation initial state: ")
        self.game.display_board(state)
        reward = None
        while not self.game.is_terminal(state):
            # choose an action to execute
            action = self.choose(state)
            # transition to next state
            (next_state, reward) = self.game.execute(state, action)
            # discount the rewards for my agent and opponent agent
            cumulative_reward[player] = cumulative_reward[player] + pow(self.game.gamma, depth) * reward[player] 
            cumulative_reward[opponent] = cumulative_reward[opponent] +pow(self.game.gamma, depth) * reward[opponent] 
            depth += 1
            state = next_state

            print(f"Next state: ")
            self.game.display_board(state)

        print(f"Terminal state reached! Simulation completed. Rewards: {reward}")

        return cumulative_reward    
    



In [209]:
# instantiate ganme object
game = TicTacToe()

# instantiate bandit
bandit = UCBBandit()

# instantiate MCTS solver
mcts_solver = MCTS(game, bandit)

In [210]:
# run mcts solver
root_node = mcts_solver.create_root_node()
node = mcts_solver.mcts(1, root_node)

root node id: 0
selected node id: 0
Child node id: 1
simulation initial state: 
X | - | -
- | - | -
- | - | -
Next state: 
X | - | -
O | - | -
- | - | -
Next state: 
X | - | -
O | - | -
- | X | -
Next state: 
X | - | -
O | - | O
- | X | -
Next state: 
X | - | -
O | X | O
- | X | -
Next state: 
X | - | -
O | X | O
- | X | O
Next state: 
X | - | -
O | X | O
X | X | O
Next state: 
X | - | O
O | X | O
X | X | O
Terminal state reached! Simulation completed. Rewards: {'X': 0, 'O': 1}
Simulation rewards: {'O': 1.0, 'X': 0.0}
Backpropagation. Node id: 1, visited: 1, accumulated reward: 1.0
Backpropagation. Node id: 0, visited: 1, accumulated reward: 0.0
selected node id: 0
Child node id: 2
simulation initial state: 
- | - | -
X | - | -
- | - | -
Next state: 
O | - | -
X | - | -
- | - | -
Next state: 
O | - | -
X | - | -
X | - | -
Next state: 
O | - | -
X | - | O
X | - | -
Next state: 
O | - | -
X | - | O
X | X | -
Next state: 
O | O | -
X | - | O
X | X | -
Next state: 
O | O | -
X | X | O
X | 

#### Now define an MCTS class for the tic tac toe game 

In [211]:
'''class MCTS:
    def __init__(self, game, exploration_weight=1.0):
        self.game = game

        # dictionary for storing value/total reward at each node
        self.Q = defaultdict(int)
        # dictionary for storing visit counts for each node
        self.visited = defaultdict(int)
        # dictionary for storing children of expanded nodes
        self.children = dict()
        self.exploration_weight = exploration_weight

                    
    # choose a move in the game 
    def choose_move(self, node):
        # make sure the state is not terminal
        if self.game.is_terminal(node.state):
            raise RuntimeError(f"choose called on terminal state: {node.state}")
        
        if node not in self.children:
            # randomly pick a successor state
            return random.choice(self.game.get_successors(node.state))
        
        def score(n):
            if self.N[n] == 0:
                return float("-inf") # unseen moves have lowest possible value
            else:
                return self.Q[n] / self.N[n] # average reward

        # return child with best value
        return max(self.children[node], key=score)


    # traverse down from a node and select a descendent node which has not been fully expanded yet
    def select(self, node):
        path = []
        while True:
            path.append(node)
            if node not in self.children or not self.children[node]:
                # node is either unexplored/unexpanded or terminal (i.e. node has no children)
                return path
            # get unexpanded nodes 
            unexplored = self.children[node] - self.children.keys()
            if unexplored:
                n = unexplored.pop()
                path.append(n)
                return path
            
            # if all children expanded, need to go another level deeper, select child using UCB bandit
            node = self.ucb_select(node)
            

    # select successor node using ucb bandit
    def ucb_select(self, node):
        log_N = math.log(self.N[node])
        def ucb(n):
            return self.Q[n]/self.N[n] + self.exploration_weight * math.sqrt(log_N/self.N[n])
        
        return max(self.children[node], key=ucb)


    # update the children dictionary with the children of node
    def expand(self, node):
        # node already expanded
        if node in self.children:
            return
        self.children[node] = self.game.get_successors(node.state)
    

    # run simulation until terminal state reached (actions are chosen randomly)
    def simulate(self, node):
        state = node.state
        invert_reward = True

        while True:

            if self.game.is_terminal(node.state):
                reward = self.game.get_reward(state)
                return 1- reward if invert_reward else reward 
            
            # randomly pick a successor
            node = random.choice(self.game.get_successors(node.state))
            invert_reward = not invert_reward



    # performs mcts from specified root node (timeout in seconds)
    def mcts(self, timeout=1, root_node=None):
        # create a root node if none provided
        if root_node == None:
            root_node = self.create_root_node()

        # start the timer
        start_time = time.time()
        current_time = time.time()
        num_iterations = 0

        # perform mcts iterations until timeout
        while current_time < start_time + timeout:

            # select node for expansion
            selected_node = root_node.select()
            
            if not (self.mdp.is_exit(selected_node.state)):

                # expand the selected node to generate a child node (if the node is not a terminal state)
                child = selected_node.expand()

                # run simulation to get a reward
                reward = self.simulate(child)

                # backpropagate the reward to root node
                selected_node.backpropagate(reward, child) 

            current_time = time.time()      
            num_iterations += 1    


        print(f"MCTS iterations: {num_iterations}")

        # update value function and display the table
        self.qfunction.update_V_from_Q()
        self.qfunction.display()  

        return root_node
 


    # backpropagate reward back to root node (recursively update all nodes along the path to the root)
    def backpropagate(self, G, child):
        # get the action which generated the child
        action = child.action

        # update number of times visited for both the state (white) node and state-action (black) node
        Node.visits[self.state] = Node.visits[self.state] + 1
        Node.visits[(self.state, action)] = Node.visits[(self.state, action)] + 1

        # get current Q value 
        qvalue = self.qfunction.evaluate(self.state, action)
        # compute update delta
        delta = (G - self.qfunction.evaluate(self.state, action)) / Node.visits[(self.state, action)]
        # update the Q value
        self.qfunction.update(self.state, action, qvalue, delta)

        # recursively backpropagate until root node is reached
        if self.parent != None:
            self.parent.backpropagate(self.reward + G, self)



'''

'class MCTS:\n    def __init__(self, game, exploration_weight=1.0):\n        self.game = game\n\n        # dictionary for storing value/total reward at each node\n        self.Q = defaultdict(int)\n        # dictionary for storing visit counts for each node\n        self.visited = defaultdict(int)\n        # dictionary for storing children of expanded nodes\n        self.children = dict()\n        self.exploration_weight = exploration_weight\n\n                    \n    # choose a move in the game \n    def choose_move(self, node):\n        # make sure the state is not terminal\n        if self.game.is_terminal(node.state):\n            raise RuntimeError(f"choose called on terminal state: {node.state}")\n        \n        if node not in self.children:\n            # randomly pick a successor state\n            return random.choice(self.game.get_successors(node.state))\n        \n        def score(n):\n            if self.N[n] == 0:\n                return float("-inf") # unseen moves 