In [1]:
from mdp import *
import math, time, random
from collections import defaultdict
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

#### Two-player Monte Carlo Tree Search: For turn based games where actions are alternatingly carried out by an agent and it's opponent. We will demonstrate monte carlo search using the Tic Tac Toe game

#### Define a class for a tic tac toe game

In [2]:
EMPTY = '-'
CIRCLE = 'O'
CROSS = 'X'

class TicTacToe:

    def __init__(self, gamma = 1.0):
        self.players = [CIRCLE, CROSS]
        self.CROSS = CROSS
        self.CIRCLE = CIRCLE
        self.gamma = 1.0

    # get list of players
    def get_players(self):
        return self.players
    
    
    def get_opponent(self, player):
        if player == CIRCLE:
            opponent = CROSS
        else:
            opponent = CIRCLE 
        return opponent       
    
    
    # get all valid actions from a state (a valid action is just an empty position on the board which can be marked by a player)
    def get_actions(self, state):
        board = state
        actions = []
        for i in range(len(board)):
            for j in range(len(board[i])):
                if board[i][j] == EMPTY:
                    actions += [(i,j)]
            
        return actions


    # make deep copy of a state
    def copy(self, state):
        next_state = []
        for i in range(len(state)):
            new_row = []
            for j in range(len(state[i])):
                new_row += [state[i][j]]
            next_state += [new_row]    
        
        return next_state
    

    # return the reward and state reulting from playing an action from a state
    def execute(self, state, action): 
        next_state = self.copy(state)
        next_state[action[0]][action[1]] = self.get_player_turn(state)
        
        return (next_state, self.get_reward(next_state))


    # return the state reulting from playing an action from a state
    def get_transition(self, state, action): 
        next_state = self.copy(state)
        next_state[action[0]][action[1]] = self.get_player_turn(state)
        
        return next_state


    # return all possible successors from the state
    def get_successors(self, state):
        if self.is_terminal(state):
            # if game is finished, there's no successor states
            return set()
        else:
            successors = []
            actions = self.get_actions(state)
            for action in actions:
                successors += [self.get_transition(state, action)]
            
            return successors


    # return the reward for transitioning into the state
    def get_reward(self, state):
        winner = self.get_winner(state)
        if winner == None:
            #return {CROSS: 0.5, CIRCLE: 0.5}
            return {CROSS: 0.0, CIRCLE: 0.0}
        
        elif winner == CROSS:
            return {CROSS: 1.0, CIRCLE: -1.0}

        elif winner == CIRCLE:
            return {CROSS: -1.0, CIRCLE: 1.0}


    # count how many empty positions on the board
    def count_empty(self, board):
        empty = 0
        for i in range(len(board)):
            for j in range(len(board[i])):
                if board[i][j] == EMPTY:
                    empty += 1

        return empty
    
    # return true iff state is a terminal state of the game
    def is_terminal(self, state):
        #if self.get_winner(state) != None:
        #    print(f"The winner is : {self.get_winner(state)}")

        return self.count_empty(state)==0 or self.get_winner(state) != None
    
    # return player who gets to select the action at current state, i.e. whose turn it is
    def get_player_turn(self, state):
        board = state
        # cross always starts the game, so crosses turn if there are an odd number of empty cells (9 cells total)
        empty = self.count_empty(board)
        if empty %2 == 0:
            return CIRCLE
        else:
            return CROSS


    # initial state of the game (empty board)
    def initial_state(self):
        board = [[EMPTY, EMPTY, EMPTY], 
                 [EMPTY, EMPTY, EMPTY], 
                 [EMPTY, EMPTY, EMPTY]]

        return board

    def get_winner(self, state):
        board = state

        # check columns
        for i in range(len(board)):
            circles = 0
            crosses = 0
            for j in range(len(board[i])):
                if board[i][j] == CIRCLE:
                    circles += 1
                elif board[i][j] == CROSS:
                    crosses += 1
            if crosses == len(board[i]):
                return CROSS      
            elif circles == len(board[i]):
                return CIRCLE 
       
        # check rows
        for j in range(len(board[0])):
            circles = 0
            crosses = 0
            for i in range(len(board)):
                if board[i][j] == CIRCLE:
                    circles += 1
                elif board[i][j] == CROSS:
                    crosses += 1
            if crosses == len(board):
                return CROSS      
            elif circles == len(board):
                return CIRCLE 

        # check top-left to bottom-right diagonal
        if board[0][0] == CIRCLE and board[1][1] == CIRCLE and board[2][2] == CIRCLE:
            return CIRCLE
        elif board[0][0] == CROSS and board[1][1] == CROSS and board[2][2] == CROSS:
            return CROSS

        # check top-right to bottom-left diagonal
        if board[0][2] == CIRCLE and board[1][1] == CIRCLE and board[2][0] == CIRCLE:
            return CIRCLE
        elif board[0][2] == CROSS and board[1][1] == CROSS and board[2][0] == CROSS:
            return CROSS

        # no winner
        return None
    

    # format the board into a string
    def display_board(self, state):
        for row in state:
            for i in range(len(row)):        
                if i < len(row)-1:
                    print(row[i], end=' | ')
                else:
                    print(row[i], end='')    
            print()


    '''
    def game_tree(self):
        return self.state_to_node(self.get_initial_state())
    

    def state_to_node(self, state):
        if self.is_terminal(state):
            return GameNode(state, None, self.get_reward(state))
        
        player = self.get_player_turn(state)
        children = dict()
        for action in self.get_actions(state):
            next_state = self.get_transition(state, action)
            child_node = self.state_to_node(next_state)
            children[action] = child_node
        
        return GameNode(state, player, None, children)    
    '''

#### Define a base class for game tree node

In [19]:
class GameNode:

    # static counter for node IDs
    next_node_id = 0

    # static dictionary for recording the number of times each node in the tree has been visited
    visits = defaultdict(lambda: 0)

    def __init__(self, game, state, parent, player_turn, bandit, value=0.0, reward=0.0, is_best_action=False, children=dict(), action=None):
        self.game = game
        self.state = state
        self.parent = parent            # pointer to parent node
        self.player_turn = player_turn  # marks which player's turn is on that state
        self.value = value
        self.reward = reward
        self.bandit = bandit
        self.is_best_action = is_best_action
        self.children = dict()
        self.action = action # action which generated the child
        self.accumulated_rewards = 0.0

        self.id = GameNode.next_node_id
        GameNode.next_node_id += 1
        

    # recursively traverse the tree and select a node that has not been fully expanded yet
    def select(self):

        if not self.is_fully_expanded() or self.game.is_terminal(self.state):
            # stop recursion when we've found node which hasn't been fully expanded or is terminal state
            return self
        else:
            # use the bandit to select the next child
            actions = list(self.children.keys())
            values = {action:self.children[action].get_value() for action in actions}
            parent_visits = GameNode.visits[self.id]
            children_visits = {action:GameNode.visits[self.children[action].id] for action in actions}
            action =  self.bandit.select(actions, values, parent_visits, children_visits)
            return self.children[action].select()  


    # checks if a node has been fully expanded
    def is_fully_expanded(self):    
        valid_actions = self.game.get_actions(self.state)
        if(len(valid_actions) == len(self.children)):
            return True
        else:
            return False


    # expand a node (only generate one child, not all children) if it's a non terminal-state
    def expand(self):
        if not self.game.is_terminal(self.state):

            #print(f"Expanding node id: {self.id}")

            # randomly select an unexpanded action to expand
            valid_actions = self.game.get_actions(self.state)
            #print(f"Available actions: {valid_actions}, already expanded children actions: {list(self.children.keys())}")

            unexplored_actions = list(valid_actions - self.children.keys())
            #print(f"Unexplored actions: {unexplored_actions}")

            #if len(unexplored_actions) > 0:
            action = random.choice(unexplored_actions)
            #else:
            #    action = random.choice(list(self.children.keys()))

            #print(f"randomly chosen action = {action}")
            # create a slot for that action in the children dictionary
            #self.children[action] = []

            # create a new child node and add to children dictionary
            (child_state, reward) = self.game.execute(self.state, action)            
            player_turn = self.game.get_player_turn(child_state)
            new_child = GameNode(self.game, child_state, self, player_turn, self.bandit, reward=reward, action=action)    
            
            self.children[action] = new_child  # each action leads to only one child 
            #print(f"generated child id: {new_child.id}, expanded children actions: {list(new_child.children.keys())}")
            
            #print(f"Genereated child node id: {self.children[action].id}, action: {action}, board state:")
            #self.game.display_board(self.children[action].state)
            
            return new_child                   
        
        # for terminal state, can't expand further
        return self
    

    # backpropagate reward back to root node (recursively update all nodes along the path to the root)
    def backpropagate(self, G):

        # update number of times visited for the state
        GameNode.visits[self.id] = GameNode.visits[self.id] + 1

        # apply discount to the backpropagated reward
        discounted_G = {}
        for player in G:
            discounted_G[player] = G[player] * self.game.gamma   

        # update the accumulated reward
        #simulation_reward = discounted_G[self.player_turn]  
        simulation_reward = discounted_G[self.game.get_opponent(self.player_turn)]  # this seems to make it work, even though it's the opposite of what it should be ... 
        self.accumulated_rewards += simulation_reward 

        #print(f"Backpropagation. Node id: {self.id}, visited: {GameNode.visits[self.id]}, accumulated reward: {self.accumulated_rewards}")

        # recursively backpropagate until root node is reached
        if self.parent != None:
            self.parent.backpropagate(discounted_G)


    # return value of this node
    def get_value(self):
        return self.accumulated_rewards / GameNode.visits[self.id] 



# upper confidence bounds (UCB) bandit
class UCBBandit:
    def __init__(self, exploration_param = 1.0):
        self.total = 0
        # dictionary for recording number of times each action has been chosen
        self.times_selected = {}
        self.exploration_param = exploration_param


    def select(self, actions, values, parent_visits, children_visits):

        max_actions = []
        max_value = float("-inf")
        for action in actions:
            # avoid division by zero
            if children_visits[action] == 0 or parent_visits == 0:
                value = float("inf")
            else: 
                value = values[action] + self.exploration_param * math.sqrt(2.0*math.log(parent_visits)/children_visits[action])
                
            if value > max_value:
                max_value = value
                max_actions = [action]
            elif value == max_value:
                max_actions.append(action)

        # if multiple actions with max value, pick one randomly
        selected_action = random.choice(max_actions)
        
        return selected_action    
    

class MCTS:
    def __init__(self, game, bandit):
        self.game = game
        self.bandit = bandit

    # performs mcts from specified root node (timeout in seconds)
    def mcts(self, timeout=1, initial_root_node=None, root_state=None, player_turn=None):
        # create a root node if none provided
        #if root_node == None:
        root_node = self.create_root_node(root_state=root_state, player_turn=player_turn)
        #root_node = GameNode(self.game, root_state, None, player_turn, self.bandit)  
        #root_node.children = dict()
        #print(f"root node id: {root_node.id}, children: {list(root_node.children.keys())}")
        #print(f"Initial game state:")
        #self.game.display_board(root_state)

        # start the timer
        start_time = time.time()
        current_time = time.time()
        num_iterations = 0

        # perform mcts iterations until timeout
        while current_time < start_time + timeout and num_iterations<400:

            #print(f"MCTS iteration # {num_iterations}")
            # select node for expansion
            selected_node = root_node.select()
            #print(f"selected node id: {selected_node.id}")

            if not (self.game.is_terminal(selected_node.state)):

                # expand the selected node to generate a child node (if the node is not a terminal state)
                child = selected_node.expand()
                #print(f"Child node id: {child.id}")

                # run simulation to get a reward
                reward = self.simulate(child)
                #print(f"Simulation rewards: {reward}")

                # backpropagate the reward to root node
                child.backpropagate(reward) 

            current_time = time.time()      
            num_iterations += 1    

        #print(f"MCTS iterations: {num_iterations}")
        # after finishing mcts iterations, find the best action
        best_action = self.get_best_action(root_node)

        return best_action

    # get the best action from the root node
    def get_best_action(self, root_node):

        #print(f"choosing best action.. Root node id: {root_node.id}, Root node state:")
        #self.game.display_board(root_node.state)
        #print(f"Available children/actions: {list(root_node.children.keys())}")

        best_action = None
        best_value = float("-inf")
        values = []
        actions = []
        for action in root_node.children:

            actions.append(action)
            #print(f"Child node id: {root_node.children[action].id}, action: {action}, board state:")
            #self.game.display_board(root_node.children[action].state)

            value = root_node.children[action].get_value()
            values.append(value)
            if value > best_value:
                best_value = value
                best_action = [action]
            elif value == best_value:
                best_action.append(action)

        ''' 
        # softmax
        temperature = 0.1 #0.9
        # exponentiate the values
        values_exp = []
        for value in values:
            if value > float("-inf"):
                values_exp.append(math.exp(value/temperature))
            else:
                values_exp.append(0.0)

        sum_exp = sum(values_exp)
        # if all scores are zero, then just pick a child at random
        if sum_exp == 0.0:
            return random.choice(root_node.children.keys())
        # compute softmax distribution    
        probabilities = [value_exp/sum_exp for value_exp in values_exp]   
        # add a tiny bit of random noise to these probabilities
        noise_magnitude = 0.00001 * (min(probabilities) + max (probabilities))
        probabilities = [(p + noise_magnitude*random.random()) for p in probabilities]
        # sample a child node according to this probability distribution
        action_index = random.choices(range(len(probabilities)), probabilities)[0]
        #print(f"probabilities: {probabilities}, child_index = {child_index}")
        child_move = actions[action_index]
        return child_move    
        '''

        # if multiple best actions, use random sampling
        return random.choice(best_action)    


    # create a root node representing the initial state
    def create_root_node(self, root_state=None, player_turn=None):
        if root_state == None:
            return GameNode(self.game, self.game.initial_state(), None, self.game.CROSS, self.bandit) # cross gets first turn
        else:
            return GameNode(self.game, root_state, None, player_turn, self.bandit)  


    # choose a random action for monte carlo simulation (can use a heuristic to choose actions instead of picking at random)
    def choose(self, state):
        actions = self.game.get_actions(state)
        
        # choose actions randomly
        next_action = random.choice(actions)
                
        return next_action
    
    
    # run simulation until terminal state reached (can be stopped after a fixed number of time steps instead of running until reaching terminal state)
    def simulate(self, node):
        state = node.state
        # the second entry in the reward vector is for opponent agent
        player = node.player_turn
        opponent = self.game.get_opponent(node.player_turn)
        cumulative_reward = {player:0.0, opponent:0.0}
        depth = 0

        #print(f"simulation initial state: ")
        #self.game.display_board(state)
        reward = None
        while not self.game.is_terminal(state):
            # choose an action to execute
            action = self.choose(state)
            # transition to next state
            (next_state, reward) = self.game.execute(state, action)
            # discount the rewards for my agent and opponent agent
            cumulative_reward[player] = cumulative_reward[player] + pow(self.game.gamma, depth) * reward[player] 
            cumulative_reward[opponent] = cumulative_reward[opponent] +pow(self.game.gamma, depth) * reward[opponent] 
            depth += 1
            state = next_state

            #print(f"Next state: ")
            #self.game.display_board(state)
        #print(f"Terminal state reached! Simulation completed. Rewards: {reward}")
        return cumulative_reward    
    

# self play function
def self_play(num_games, random_X=False, random_O=False):

    # instantiate ganme object
    game = TicTacToe()
    # instantiate bandit
    bandit = UCBBandit()
    # instantiate MCTS solver
    mcts_solver1 = MCTS(game, bandit)
    mcts_solver2 = MCTS(game, bandit)

    wins = {'X':0, 'O':0, 'Draw':0}
    
    with tqdm(total=num_games, ncols=80, bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt}') as pbar:

        # first turn goes to X
        player_turn = 'X'
        for i in range(num_games):
            steps = 0
            game_state = game.initial_state()
            #print(f"Starting game # {i}")
            while not game.is_terminal(game_state):
                #print(f"\nPlayer {game.get_player_turn(game_state)} making a move:\n")
                # choose best action

                if player_turn == 'X':
                    if random_X:
                        best_action = random.choice(game.get_actions(game_state))
                    else:
                        best_action = mcts_solver1.mcts(root_state=game_state, player_turn=player_turn)
                
                elif player_turn == 'O': 
                    if random_O:
                        best_action = random.choice(game.get_actions(game_state))
                    else:    
                        best_action = mcts_solver2.mcts(root_state=game_state, player_turn=player_turn)
                

                #print(f"best action: {best_action}")
                # execute action 
                (game_state, reward) = game.execute(game_state, best_action)
                # player for next turn
                player_turn = game.get_player_turn(game_state) 
                #game.display_board(game_state)
                steps += 1
            
            winner = game.get_winner(game_state)
            winner = 'Draw' if winner == None else winner
            wins[winner] = wins[winner] + 1
            pbar.update(1)

    print(f'Game# {i} completed. Winner: {winner}, Player X win rate: {wins["X"]/num_games}, Player O win rate: {wins["O"]/num_games}, Draw rate: {wins["Draw"]/num_games}')

In [20]:
self_play(500, random_X=True, random_O=True)

100%|██████████████████████████████████████████████████████████████████| 500/500

Game# 499 completed. Winner: X, Player X win rate: 0.562, Player O win rate: 0.296, Draw rate: 0.142





In [21]:
self_play(300, random_X=False, random_O=False)

100%|██████████████████████████████████████████████████████████████████| 300/300

Game# 299 completed. Winner: X, Player X win rate: 0.27, Player O win rate: 0.3466666666666667, Draw rate: 0.38333333333333336





In [22]:
self_play(300, random_X=True, random_O=False)

100%|██████████████████████████████████████████████████████████████████| 300/300

Game# 299 completed. Winner: Draw, Player X win rate: 0.25, Player O win rate: 0.58, Draw rate: 0.17





In [23]:
self_play(300, random_X=False, random_O=True)

100%|██████████████████████████████████████████████████████████████████| 300/300

Game# 299 completed. Winner: Draw, Player X win rate: 0.7233333333333334, Player O win rate: 0.14666666666666667, Draw rate: 0.13





In [24]:
self_play(3000, random_X=False, random_O=True)

100%|████████████████████████████████████████████████████████████████| 3000/3000

Game# 2999 completed. Winner: X, Player X win rate: 0.693, Player O win rate: 0.13666666666666666, Draw rate: 0.17033333333333334





In [231]:
'''# instantiate ganme object
game = TicTacToe()

# instantiate bandit
bandit = UCBBandit()

# instantiate MCTS solver
mcts_solver = MCTS(game, bandit)'''

'# instantiate ganme object\ngame = TicTacToe()\n\n# instantiate bandit\nbandit = UCBBandit()\n\n# instantiate MCTS solver\nmcts_solver = MCTS(game, bandit)'

In [225]:
'''# run mcts solver
root_node = mcts_solver.create_root_node()
best_action = mcts_solver.mcts(1, root_node)'''

'# run mcts solver\nroot_node = mcts_solver.create_root_node()\nbest_action = mcts_solver.mcts(1, root_node)'