# Connect 4

---

Author: S. Menary [sbmenary@gmail.com]

Date  : 2023-01-03, last edit 2023-01-11

Brief : Develop a simple Connect 4 game environment and implement a bot using Monte Carlo Tree Search (MCTS)

---

### Summary

- Connect 4 is a two-player, fully-observable, zero-sum game. 
- The game states may be represented as a tree sturcture, We can therefore implement a bot using tree-search algorithms. We choose Connect 4 because it is simple, and therefore provides a launch-pad for more complex games such as checkers or chess.
- Initially we implement vanilla MCTS with no machine learning. We expect this to be limited by (i) the stochastic rollout of the tree and (ii) the simplicity of the simulation policy.
- To introduce ML, we would perform alternate steps of MCTS evaluation and simulation policy improvement. In this way, the simulated games will _hopefully_ begin to approach "good play", and the final MCTS values will reflect the behaviour of good players.
- MCTS configuration:
    + Tree-traversal policy is:
        1. From the current node, uniformly-randomly select a non-expanded child if one is available
        2. Otherwise select child with highest UCB-1 score, traverse to this node and repeat
    + Resulting node is expanded by adding all possible children and selecting one by performing a uniformly-random action
    + Simulation policy is to select a uniformly-random action
- The UCB-1 score is designed to optimally balance exploration/exploitation for static multi-arm bandits. Strictly speaking, we are applying this in a non-stationary environment because the reward-distribution for each action changes according to the evolution of the down-stream tree. This makes UCB-1 theoretically sub-optimal. However, it is often used nonetheless.
- When playing an actual move (i.e. inference time), greedily select the action with the max average score from its MCTS visits (do not use UCB-1 since we are no longer exploring).

Observations:
- Strength of decision-making depends on how many iterations of MCTS we perform:
    1. When tree is shallow, we effectively assume that future play is random, which means we will choose options with the greatest number of permutations of winning. We therefore may neglect to defend against an imminent loss, favouring a different move with many win permutations (bad behaviour).
    2. When tree is deep and UCB1 score converges towards true means, at least for the best moves, then we effectively assume that future play is optimal. As play-count goes to infinity, our scores become unbiased.
    3. For finite but sufficient run-time, we assume optimal play, but using mean scores that are biased by the fact that our early simulations used random play instead of optimal play.
- This explains why even random simulation MCTS is pretty good - we end up doing most of our simulations with pretty effective play, at least for the next few moves where our tree is sufficiently grown.


## Imports

In [4]:
###
###  Required imports
###  - all imports should be placed here
###


##  Python core libs
import sys, time
from enum import IntEnum
from __future__ import annotations

##  PyPI libs
import numpy as np


In [5]:
###
###  Print version for reproducibility
###

print(f"Python version is {sys.version}")
print(f"Numpy  version is {np.__version__}")

Python version is 3.10.8 | packaged by conda-forge | (main, Nov 22 2022, 08:25:29) [Clang 14.0.6 ]
Numpy  version is 1.23.2


## Prelims

Define some useful preliminary objects to help us later on.

In [6]:
###
###   BinaryPlayer class definition
###   - enumeration for the players in a two-player game
###   - enum names are X, 0, NONE
###   - enum values are +1, -1, 0 respectively
###


class BinaryPlayer(IntEnum):
    """
    An enumeration for the player in a two-player game
    Options are: X=1, 0=-1, NONE=0
    """
    NONE = 0
    X    = 1
    O    = -1
    def label(self) :
        """
        Returns a single character representation of the player.
            > X    = 'X'
            > 0    = '0'
            > NONE = '.'
        """
        if self.value == 0  : return '.'
        if self.value == 1  : return 'X'
        if self.value == -1 : return '0'
        raise NotImplementedError()
    

In [7]:
###
###   GameResult class definition
###   - enumeration for the result of a two-player game
###   - enum names are X, O, DRAW, NONE
###


class GameResult(IntEnum):
    """
    An enumeration for the result of a two-player game
    NONE=0 enforced to allow "is GameResult" statements to return True only if a game has ended
    """
    NONE = 0
    DRAW = 1
    X    = 2
    O    = 3
    
    @classmethod
    def from_player(cls, player:BinaryPlayer, none_player_means_draw:bool=True) -> GameResult :
        """
        Instantiate game result from a BinaryPlayer instance.
        
        Inputs:
        
            >  player, BinaryPlayer
               instance of BinaryPlayer to be converted to GameResult
               
            >  none_player_means_draw, bool, default=True
               if True then interpret BinaryPlayer.NONE as GameResult.DRAW, otherwise GameResult.NONE
        """
        if player == BinaryPlayer.X    : return GameResult.X
        if player == BinaryPlayer.O    : return GameResult.O
        if player == BinaryPlayer.NONE : 
            if none_player_means_draw : return GameResult.DRAW
            return GameResult.NONE
        raise NotImplementedError(f"Could not cast {player} into a GameResult")
        
    @classmethod
    def from_piece_value(cls, value:int, none_player_means_draw:bool=True) -> GameResult :
        """
        Instantiate game result from a BinaryPlayer value. 
        
        Inputs:
        
            >  value, int
               value to be converted to BinaryPlayer instance, and then to GameResult
               
            >  none_player_means_draw, bool, default=True
               if True then interpret BinaryPlayer.NONE as GameResult.DRAW, otherwise GameResult.NONE
        """
        return cls.from_player(BinaryPlayer(value), none_player_means_draw=none_player_means_draw)
    
    def get_game_score_for_player(self, player:BinaryPlayer) -> float :
        """
        Return the game score for the given player.
        Score is 0 for a DRAW or NONE, +1 if the GameResult matches the BinaryPlayer, -1 otherwise
        
        Inputs:
        
            >  player, BinaryPlayer
               enum of the player to whom the score applies
        """
        
        ##  Require player to be resolved
        if player not in [BinaryPlayer.X, BinaryPlayer.O] :
            raise NotImplementedError(f"Cannot resolve score for player {player.name}")
        
        ##  Resolve NONE result
        if self == GameResult.NONE :
            return 0.
        
        ##  Resolve DRAW result
        if self == GameResult.DRAW :
            return 0.
        
        ##  Resolve X WIN result for X PLAYER
        if self == GameResult.X and player == BinaryPlayer.X :
            return 1.
        
        ##  Resolve O WIN result for O PLAYER
        if self == GameResult.O and player == BinaryPlayer.O :
            return 1.
        
        ##  If here then the specifid player must have lost
        return -1.
        
    

In [8]:
###
###   DebugLevel class definition
###   - enumeration for the verbosity level when debugging a function
###   - enum names are MUTE=0, LOW=1, MEDIUM=2, HIGH=3, ALL=4
###   - default is MUTE, which means that no debug information is printed unless requested
###   - deriving from IntEnum allows labelled int operations like "if lvl >= VerboseLevel.MEDIUM :"
###   - used as improvement over print statements for debugging
###  


class DebugLevel(IntEnum):
    """
    An enumeration for the verbosity level of debug statements
    Options are: MUTE=0, LOW=1, MEDIUM=2, HIGH=3, ALL=4
    Use of IntEnum allows labelled int operations like "if lvl >= DebugLevel.MEDIUM :"
    """
    MUTE   = 0
    LOW    = 1
    MEDIUM = 2
    HIGH   = 3
    ALL    = 4
    

In [9]:
###   
###   Provide method to resolve debug statements
###
    

def debug(min_lvl:VerboseLevel, debug_lvl:VerboseLevel, message:str) :
    """
    Print message only if debug_lvl >= min_lvl
    """
    
    ##  If debug_lvl >= min_lvl then print message and return True
    if debug_lvl >= min_lvl :
        print(message)
        return True
    
    ##  Other wise no message printed and return False
    return False
    

##  Connect 4 environment

Define a simple object to manipulate and print a game of Connect 4

In [10]:
###   
###   Define environment for Connect 4 game board.
###

class GameBoard :
    
    def __init__(self, 
                 horizontal_size: int=7, 
                 vertical_size  : int=6,
                 target_length  : int=4) -> None :
        """
        Class GameBoard
        
        Stores the current state of a game. Allows actions to be played, modifying the internal state. Provides
        simple ASCII visualisation of the board. Internal state stored as numpy array of int objects, using
        BinaryPlayer class to define the IntEnum values.
        
        Inputs:
        
            >  horizontal_size, int, default=7
               horizontal size of the game board
               
            >  vertical_size, int, default=6
               vertical size of the game board
               
            >  target_length, int, default=4
               number of connected pieces required to win the game
        """
        
        ##  Make sure inputs are correctly typed
        if type(horizontal_size) is not int : 
            raise TypeError(f"Expected argument horizontal_size of type int but {type(horizontal_size)} provided")
        if type(vertical_size) is not int : 
            raise TypeError(f"Expected argument vertical_size of type int but {type(vertical_size)} provided")
        if type(target_length) is not int : 
            raise TypeError(f"Expected argument vertical_size of type int but {type(target_length)} provided")
            
        ##  Store game configuration
        self.horizontal_size   = horizontal_size
        self.vertical_size     = vertical_size
        self.target_length     = target_length
        
        ##   Initialise game board
        self.board = np.full(shape      = (horizontal_size, vertical_size), 
                             fill_value = BinaryPlayer.NONE.value, 
                             dtype      = np.int8)
        self.X_to_play       = True
        self.applied_actions = []
        
    
    def __eq__(self, other:GameBoard) -> bool :
        """
        Overload comparison operator comparing two GameBoard objects.
        """
        
        ##  Check games are equally configured
        if self.horizontal_size != other.horizontal_size : return False
        if self.vertical_size   != other.vertical_size   : return False
        if self.target_length   != other.target_length   : return False
        
        ##  Check it's the same player's turn in both games
        if self.X_to_play != other.X_to_play : return False
        
        ##  Check game boards are identical
        if (self.board != other.board).any() : return False
        
        ##  If here then games are identical!
        return True
        
        
    def __str__(self) -> str :
        """
        Return a string representation of the game board.
        String is a simple ASCII picture of the game.
        """
        ##  Populate multi-line string with the following steps:
        ##  1. Create empty string to iteratively add lines to
        ##  2. Add upper boundary line
        ##  3. Add graphic for every row in the game board, with the (0,0) located at bottom-left
        ##     - get piece ASCII character representation using label() method for each BinaryPlayer(p) token 
        ##  4. Add middle boundary
        ##  5. Add numerical label for each column
        ##  6. Add lower boundary
        ##  7. Add result label
        ret = ""
        ret += "+-"   +    '-+-'.join(["-" for p in range(self.horizontal_size)]) + "-+"
        ret += "\n| " + ' |\n| '.join([' | '.join([BinaryPlayer(p).label() for p in row]) for row in self.board.T[::-1]]) + " |"
        ret += "\n+-" +    '-+-'.join(["-" for p in range(self.horizontal_size)]) + "-+"
        ret += "\n| " +     '| '.join([f"{p}".ljust(2) for p in range(self.horizontal_size)]) + "|"
        ret += "\n+-" +    '-+-'.join(["-" for p in range(self.horizontal_size)]) + "-+"
        ret += f"\nGame result is: {self.get_result().name}"
        ##  Return complete multi-line str
        return ret
    
    
    def apply_action(self, column_idx: int) -> None :
        """
        Play a new piece at the specified column index. Player is determined by internal state, 
        which keeps track of whose turn it is. If column is full then throw error.
        """
        
        ##  Check that input is correctly typed
        if not np.issubdtype(type(column_idx), int) :
            raise TypeError(f"column_idx of type {type(column_idx)} where int expected")
            
        ##  Check that game has not already finished
        if self.get_result() :
            raise RuntimeError("Cannot play new moves because the game is a terminal state")
            
        ##  Get column from internal numpy array
        column = self.board[column_idx]
        
        ##  Find the smallest unoccupied row index
        row_idx = 0
        while column[row_idx] :
            row_idx += 1
            
        ##  Add piece to the specified (row, column) indices
        column[row_idx] = BinaryPlayer.X.value if self.X_to_play else BinaryPlayer.O.value
        self.applied_actions.append((column_idx, row_idx, column[row_idx]))
        
        ##  Update whose turn it is to play
        self.X_to_play = not self.X_to_play
        
        
    def deep_copy(self) -> GameBoard :
        """
        Create a deep copy of the game board.
        """
        
        ##  Initialise a new GameBoard object with the same configuration
        ##    then perform a deep copy of the internal numpy array, set the X_to_play flag, and return
        new_gameboard                 = GameBoard(self.horizontal_size, self.vertical_size, self.target_length)
        new_gameboard.board           = self.board.copy()
        new_gameboard.X_to_play       = self.X_to_play
        new_gameboard.applied_actions = self.applied_actions
        return new_gameboard
        
        
    @classmethod
    def from_gameboard(cls, gameboard) -> GameBoard :
        """
        Create a new GameBoard object as a deep copy of one provided.
        """
        return gameboard.deep_copy()
            
    
    def get_available_actions(self) -> list[int] :
        """
        Get list of available actions. Action corresponds to a column index. Action is considered 
        available if the game has not ended and column is not already full.
        """
        
        ##  Check whether game has ended
        if self.get_result() :
            return []
        
        ##  Otherwise return list of all unfilled column indices
        return self.get_unfilled_columns()
            
    
    def get_unfilled_columns(self) -> list[int] :
        """
        Get list of unfilled columns
        """
        return [a for a in range(self.horizontal_size) if self.board[a,-1] == 0]
    
    
    def get_result(self) -> GameResult :
        """
        Check whether the game board has reached a terminal state.
        """
        
        ##  Check for vertical win condition
        for column in self.board :
            last_piece, counter, row_idx = 999, 1, 0
            while last_piece and row_idx < self.vertical_size :
                piece = column[row_idx]
                if piece == last_piece :
                    counter += 1
                else :
                    counter = 1
                if counter == self.target_length :
                    return GameResult.from_piece_value(BinaryPlayer(piece))
                row_idx += 1
                last_piece = piece
                
        ##  Check for horizontal win condition
        for row in self.board.T :
            last_piece, counter = 999, 1
            for col_idx in range(self.horizontal_size) :
                piece = row[col_idx]
                if piece == last_piece :
                    counter += 1
                else :
                    counter = 1
                if piece != 0 and counter == self.target_length :
                    return GameResult.from_piece_value(piece)
                last_piece = piece
                
        ##  Check for diagonal win condition
        for col_idx in range(self.horizontal_size) :
            for row_idx in range(self.vertical_size) :
                piece = self.board[col_idx, row_idx]
                if not piece : continue
                for col_dir, row_dir in [[1,1], [1,-1], [-1,1], [-1,-1]] :
                    is_winning_sequence = True
                    for seq_idx in range(self.target_length) :
                        check_col_idx, check_row_idx = col_idx + col_dir*seq_idx, row_idx + row_dir*seq_idx
                        if (check_col_idx < 0 or check_col_idx >= self.horizontal_size or 
                            check_row_idx < 0 or check_row_idx >= self.vertical_size or 
                            self.board[check_col_idx, check_row_idx] != piece) :
                            is_winning_sequence = False
                            break
                    if is_winning_sequence :
                        return GameResult.from_piece_value(piece)
                    
        ##  Check for draw
        if len(self.get_unfilled_columns()) == 0 :
            return GameResult.DRAW
        
        ##  If here then game has not finished
        return GameResult.NONE
    
    def undo_action(self) -> None :
        """
        Undo the most recent action stored in the internal record. 
        A full game record is maintained, allowing many moves to be undone.
        """
        
        ##  Check that an action exists
        if len(self.applied_actions) == 0 :
            raise RuntimeError(f"No actions to undo.")
            
        ##  Remove and return the last item from the record of applied actions
        column, row, _ = self.applied_actions.pop()
        
        ##  Return the specific index to its 0 state, indicating that no piece is present
        self.board[column, row] = 0
        

##  MCTS

Implement Node class to handle the tree search.

In [11]:
class Node :
    
    def __init__(self, game_board:GameBoard, parent:Node=None, UCB_c:float=2., shallow_copy_board:bool=False, 
                 label=None) :
        """
        Class Node
        
        - Used as part of MCTS algorithm. 
        - Stores total score and number of visits
        - Stores a list of children and a reference to the parent node
        - Provides methods for node selection, expansion, simulation, backpropagation
        
        Inputs:
        
            > game_board, GameBoard
              state of the game at this node
              
            > parent, None, default=None
              reference to the parent node, only equals None if this is a root node
              
            > UCB_c, float, default=2.
              hyper-parameter controlling strength of exploration vs exploitation in UCB algorithm
              
            > shallow_copy_board, bool, default=False
              whether to only create a shallow copy of the game board - caution: improves memory efficiency 
              but may lead to undefined behaviour if either one of the referenced objects is updated
              
            > label, str, default=None
              label for the node, used when generating summary strings
        """
                
        self.game_board  = game_board.deep_copy()
        self.actions     = game_board.get_available_actions()
        self.player      = BinaryPlayer.X if game_board.X_to_play else BinaryPlayer.O
        self.is_terminal = True if len(self.actions) == 0 else False
        self.children    = [None for a_idx in range(len(self.actions))]
        self.parent      = parent
        self.total_score = 0
        self.num_visits  = 0
        self.UCB_c       = UCB_c
        self.label       = label
        
        
    def __str__(self) -> str :
        """
        Return a string representation of the current node.
        """
        
        ##  Figure out parent / children info
        is_root              = False if self.parent else True
        num_children         = len(self.children)
        num_visited_children = len([c for c in self.children if c])
        
        ##  Begin str with node label if one provided
        ret  = f"[{self.label}] " if self.label else ""
        
        ##  Add some node information
        ret += f"N={self.num_visits}, T={self.total_score}, is_root={is_root}, is_leaf={self.is_terminal}"
        ret += f", num_children={num_children}, num_visited_children={num_visited_children}"
        
        ##  Return str
        return ret
        
        
    def get_best_action(self) -> int :
        """
        Return the optimal action based on the currently stored values.
        """
        
        ##  If this is a terminal node then no actions available
        if self.is_terminal : 
            return None
        
        ##  Find the index of the best child node, and return the corresponding action
        ##  - if no actions evaluated then argmax will return first action by default
        child_scores = [c.get_score() if c else -np.inf for c in self.children]
        best_a_idx   = np.argmax(child_scores)
        return self.actions[best_a_idx]
        
        
    def get_score(self) -> float :
        """
        Return node score.
        """
        
        ##  If node has not been visited then return -inf
        if self.num_visits == 0 :
            return -np.inf
        
        ##  Otherwise return mean reward per visit
        return self.total_score / self.num_visits
        
        
    def get_UCB_score(self) -> float :
        """
        Returns the UCB score of this node
        """
                
        ##  If node is un-visited then the UCB score is infinite
        if self.UCB_c != 0 and self.num_visits == 0 :
            return np.inf
        
        ##  If node has no parent then no UCB score exists
        if not self.parent :
            return np.nan
        
        ##  Calculate mean score from past games
        mean_score = self.total_score / self.num_visits
        
        ##  Otherwise calculate UCB score
        return mean_score + self.UCB_c * np.sqrt(np.log(self.parent.num_visits) / self.num_visits)
        
        
    def select_and_expand(self, recurse:bool=False, debug_lvl:DebugLevel=DebugLevel.MUTE) -> Node :
        """
        Select from node children according to tree traversal policy. If next state is None then create a 
        new child and return this.
        
        Inputs:
        
            > recurse, bool, default=False
              whether to recursively iterate through tree until a new leaf node is found.
              
            > debug_lvl, DebugLevel, default=MUTE
              level at which to print debug statements to help understand algorithm behaviour.
        """
        
        ##  If leaf node then nothing to expand
        if self.is_terminal :
            debug(DebugLevel.MEDIUM, debug_lvl, f"Leaf node found")
            return self
                
        ##  Uniformly randomly expand from un-visited children
        unvisited_children = [c_idx for c_idx, c in enumerate(self.children) if not c]
        if len(unvisited_children) > 0 :
            a_idx = np.random.choice(unvisited_children)
            new_game_board = self.game_board.deep_copy()
            node_label = f"{'X' if self.game_board.X_to_play else 'O'}:{self.actions[a_idx]}"
            debug(DebugLevel.MEDIUM, debug_lvl, f"Select unvisited action {node_label}")
            new_game_board.apply_action(self.actions[a_idx])
            self.children[a_idx] = Node(new_game_board, parent=self, UCB_c=self.UCB_c, shallow_copy_board=True, label=node_label)
            return self.children[a_idx]
        
        ##  Otherwise best child is that with highest UCB score
        a_idx = np.argmax([c.get_UCB_score() for c in self.children])
        best_child = self.children[a_idx]
        debug(DebugLevel.MEDIUM, debug_lvl, f"Select known action {'X' if self.game_board.X_to_play else 'O'}:{self.actions[a_idx]}")
        
        ##  If recurse then also select_and_expand from the child node
        if recurse :
            debug(DebugLevel.MEDIUM, debug_lvl, "... iterating to next level ...")
            return best_child.select_and_expand(recurse=recurse, debug_lvl=debug_lvl)
        
        ##  Otherwise return selected child
        return best_child
    
    
    def simulate(self, max_turns:int=-1, debug_lvl:DebugLevel=DebugLevel.MUTE) -> GameResult :
        """
        Simulate a game starting from this node.
        Assumes that both players act according to a uniform-random policy.
        
        Inputs:
        
            > max_turns, int, default=-1
              if positive then determines how many moves to play before declaring a drawn game
              
            > debug_lvl, DebugLevel, default=MUTE
              level at which to print debug statements to help understand algorithm behaviour.
              
        Returns:
        
            > float
              the score of the simulation, defined as +1 for a win, -1 for a loss, 0 for a draw
        """
        
        ##  Check if game has already been won
        ##  - if so then return score
        ##  - score is -1 if target player has lost, +1 if they've won, and 0 for a draw
        result = self.game_board.get_result()
        if result :
            debug(DebugLevel.MEDIUM, debug_lvl, f"Leaf node found with result {result.name}")
            return result
                
        ##  Create copy of game board to play simulation
        simulated_game = self.game_board.deep_copy()
        
        ##  Keep playing moves until one of terminating conditions is reached:
        ##  1. game is won by a player
        ##  2. no further moves are possible, game is considered a draw
        ##  3. maximum move limit is reached, game is considered a draw
        turn_idx, is_terminal, result = 0, False, GameResult.NONE
        trajectory = []
        while not is_terminal :
            turn_idx += 1
            action = np.random.choice(simulated_game.get_unfilled_columns())
            trajectory.append(f"{'X' if simulated_game.X_to_play else 'O'}:{action}")
            simulated_game.apply_action(action)
            result = simulated_game.get_result()
            if result :
                is_terminal = True
                  
        ##  Debug trajectory
        debug(DebugLevel.MEDIUM, debug_lvl, f"Simulation ended with result {result.name}")
        debug(DebugLevel.HIGH  , debug_lvl, f"Simulated trajectory was: {' '.join(trajectory)}")
                                
        ##  Return score
        return result
    
    
    def simulate_and_backprop(self, max_turns:int=-1, 
                              debug_lvl:DebugLevel=DebugLevel.MUTE) -> None :
        """
        Simulate a game starting from this node. Backpropagate the resulting score up the whole tree.
        
        Inputs:
        
            > max_turns, int, default=-1
              if positive then determines how many moves to play before declaring a drawn game
              
            > debug_lvl, DebugLevel, default=MUTE
              level at which to print debug statements to help understand algorithm behaviour.
        """
        
        ##  Simulated game and obtain instance of GameResult
        result = self.simulate(max_turns=max_turns, debug_lvl=debug_lvl)
        
        ##  Update this node and backprop up the tree
        self.update_and_backprop(result, debug_lvl=debug_lvl)
        
        
    def tree_summary(self, indent_level:int=0) :
        """
        Return a multi-line str summarising every node in the tree.
        """
        
        ##  Summarise this node
        ret = ("     "*indent_level +
               f"> [{indent_level}{f': {self.label}' if self.label else ''}] N={self.num_visits}, T={self.total_score}, " +
               f"UCB={self.get_UCB_score():.3f}, s={self.get_score():.3f}")
        
        ##  Recursively add the summary of each child node, iterating the indent level to reflect tree depth
        for a, c in zip(self.actions, self.children) :
            if c :
                ret += f"\n{c.tree_summary(indent_level+1)}"
            else :
                ret += "\n" + "     "*(indent_level+1) + "> None"
                
        ##  Return
        return ret
        
        
    def update(self, result:GameResult, debug_lvl:DebugLevel=DebugLevel.MUTE) -> None :
        """
        Update the score and visit counts for this node.
        """
        
        ##  Resolve score for this node given the game result
        ##  - score is from the viewpoint of the parent, since this is the one deciding whether to come here!
        ##  - if no parent exists then this is a ROOT node, and we assign a score of 0. by default
        if self.parent :
            score = result.get_game_score_for_player(self.parent.player)
        else :
            score = 0.
        debug(DebugLevel.MEDIUM, debug_lvl, 
              f"Node {self.label} with parent={self.parent.player.name if self.parent else 'NONE'}, N={self.num_visits}, T={self.total_score:.2f} receiving score {score:.2f} for game ending in result {result.name}")
        
        ##  Update total score and number of visits for this node
        self.total_score += score
        self.num_visits  += 1
        
        
    def update_and_backprop(self, result:GameResult, 
                            debug_lvl:DebugLevel=DebugLevel.MUTE) -> None :
        """
        Update the score and visit counts for this node and backprop to all parents.
        """
        
        ##  Update this node
        self.update(result, debug_lvl=debug_lvl)
        
        ##  Recursively update all parent nodes
        if self.parent :
            self.parent.update_and_backprop(result, debug_lvl=debug_lvl)
        

In [12]:
###
###  Methods for MCTS
###  - Implement methods which interact with the Node class to perform a number of MCTS iterations
###


def one_step_MCTS(root_node, max_turns=-1, debug_lvl=DebugLevel.MUTE) :
    """
    Perform a single MCTS iteration on the root_node provided.
    """
    
    ##  Select and expand from the root node
    leaf_node = root_node.select_and_expand(recurse=True, debug_lvl=debug_lvl)
    
    ##  Simulate and backprop from the selected child
    leaf_node.simulate_and_backprop(max_turns=max_turns, debug_lvl=debug_lvl)
    
    ##  Print updated tree if debug level is HIGH
    debug(DebugLevel.HIGH, debug_lvl, f"Updated tree is:\n{root_node.tree_summary()}")
    
    
def multi_step_MCTS(root_node, num_steps, max_turns=-1, debug_lvl=DebugLevel.MUTE) :
    """
    Perform a many MCTS iterations on the root_node provided.
    """
    
    ##  Call one_step_MCTS a number of times equal to num_steps
    for idx in range(num_steps) :
        debug(DebugLevel.MEDIUM, debug_lvl, f"Running MCTS step {idx}")
        one_step_MCTS(root_node, max_turns=max_turns, debug_lvl=debug_lvl)
        debug(DebugLevel.MEDIUM, debug_lvl, f"")
        
        
def timed_MCTS(root_node, duration, max_turns=-1, debug_lvl=DebugLevel.MUTE) :
    """
    Perform a MCTS iterations on the root_node until duration (in seconds) has elapsed.
    After this time, MCTS will finish its current iteration, so total execution time is > duration.
    """
    
    ##  Keep calling one_step_MCTS until required duration has elapsed
    start_time   = time.time()
    current_time = start_time
    num_itr = 0
    while current_time - start_time < duration :
        one_step_MCTS(root_node, max_turns=max_turns, debug_lvl=debug_lvl)
        current_time = time.time()
        num_itr += 1
    return num_itr


def get_bot_action(game_board, duration=1, max_turns=-1, debug_lvl=DebugLevel.MUTE) :
    """
    Create a root_node from the current game state, and perform a timed MCTS to choose a move.
    """
    
    ##  Create root node from current game board
    root_node = Node(game_board)
    
    ##  Call timed_MCTS to update tree values 
    num_itr       = timed_MCTS(root_node, duration=duration, max_turns=max_turns, debug_lvl=debug_lvl)
    chosen_action = root_node.get_best_action()
    
    ##  Print debug info
    debug(DebugLevel.HIGH, debug_lvl, 
          root_node.tree_summary())
    debug(DebugLevel.LOW, debug_lvl, 
          "Action values are:  " + " ".join([f"{x.get_score():.2f}".ljust(6) if x else "N/A   " for x in root_node.children]))
    debug(DebugLevel.LOW, debug_lvl, 
          "Visit counts are:   " + " ".join([f"{x.num_visits}".ljust(6) if x else "N/A   " for x in root_node.children]))
    debug(DebugLevel.LOW, debug_lvl, 
          f"Selecting action {chosen_action}")
        
    ##  Return best action from tree evaluation, and the number of MCTS iterations executed
    return chosen_action, root_node, num_itr
    
    
def take_move(game_board, my_action, duration=1, max_turns=-1, debug_lvl=DebugLevel.MUTE) :
    """
    Apply a human move.
    Print the game board.
    Use MCTS to find a responding bot move.
    Apply the bot move.
    Print the game board.
    """
    
    ##  Apply the human move.
    print(f"Human takes move {my_action}")
    game_board.apply_action(my_action)
    print(game_board)
    print()
    
    ##  If game has ended then return
    if game_board.get_result() :
        return
    
    ##  Use timed MCTS to obtain a bot action
    bot_action, _, num_itr = get_bot_action(game_board, 
                                            duration=duration, 
                                            debug_lvl=debug_lvl)
    
    ##  Apply the bot move
    print(f"Bot takes move {bot_action} ({num_itr} iterations)")
    game_board.apply_action(bot_action)
    print(game_board)


##  Test MCTS

In [13]:
###
###  Setup a small game
###  - 4x4 grid
###  - line of 3 needed to win
###

##  Create game board
game_board = GameBoard(4, 4, 3)

##  Show initial game board
print(game_board)


+---+---+---+---+
| . | . | . | . |
| . | . | . | . |
| . | . | . | . |
| . | . | . | . |
+---+---+---+---+
| 0 | 1 | 2 | 3 |
+---+---+---+---+
Game result is: NONE


In [14]:
###
###  Play a few initial moves
###  - transitions into a ciritical state where O player needs to be careful not to 
###    blunder a win for X
###

##  Play moves
game_board.apply_action(1)
game_board.apply_action(2)
game_board.apply_action(1)

##  Show updated game state
print(game_board)


+---+---+---+---+
| . | . | . | . |
| . | . | . | . |
| . | X | . | . |
| . | X | 0 | . |
+---+---+---+---+
| 0 | 1 | 2 | 3 |
+---+---+---+---+
Game result is: NONE


In [15]:
###
###  Perform a few MCTS steps
###  - transitions into a ciritical state where O player needs to be careful not to 
###    blunder a win for X
###

##  Create a root node at the current game state
root_node = Node(game_board, label="ROOT")

##  Print the initial value tree (should be a ROOT node with no children)
print("Initial tree:")
print(root_node.tree_summary())
print()

##  Perform several MCTS steps with a HIGH debug level
multi_step_MCTS(root_node, num_steps=10, max_turns=-1, debug_lvl=DebugLevel.HIGH)

##  Print the updated value tree 
print("Updated tree:")
print(root_node.tree_summary())
print()


Initial tree:
> [0: ROOT] N=0, T=0, UCB=inf, s=-inf
     > None
     > None
     > None
     > None

Running MCTS step 0
Select unvisited action O:0
Simulation ended with result X
Simulated trajectory was: X:1
Node O:0 with parent=O, N=0, T=0.00 receiving score -1.00 for game ending in result X
Node ROOT with parent=NONE, N=0, T=0.00 receiving score 0.00 for game ending in result X
Updated tree is:
> [0: ROOT] N=1, T=0.0, UCB=nan, s=0.000
     > [1: O:0] N=1, T=-1.0, UCB=-1.000, s=-1.000
          > None
          > None
          > None
          > None
     > None
     > None
     > None

Running MCTS step 1
Select unvisited action O:1
Simulation ended with result X
Simulated trajectory was: X:2 O:3 X:3
Node O:1 with parent=O, N=0, T=0.00 receiving score -1.00 for game ending in result X
Node ROOT with parent=NONE, N=1, T=0.00 receiving score 0.00 for game ending in result X
Updated tree is:
> [0: ROOT] N=2, T=0.0, UCB=nan, s=0.000
     > [1: O:0] N=1, T=-1.0, UCB=0.665, s=-1.000
   

In [16]:
###
###  Use MCTS to play a move
###

##  Use MCTS to search for an optimal action
bot_action, _, num_itr = get_bot_action(game_board, 
                                        duration=1, 
                                        debug_lvl=DebugLevel.LOW)
print(f"Bot chooses action {bot_action} after {num_itr} MCTS iterations")

##  Play bot move
game_board.apply_action(bot_action)

##  Show updated game state
print(game_board)


Action values are:  -0.82  -0.63  -0.80  -0.80 
Visit counts are:   147    489    163    167   
Selecting action 1
Bot chooses action 1 after 966 MCTS iterations
+---+---+---+---+
| . | . | . | . |
| . | 0 | . | . |
| . | X | . | . |
| . | X | 0 | . |
+---+---+---+---+
| 0 | 1 | 2 | 3 |
+---+---+---+---+
Game result is: NONE


## Connect 4

Play a game of connect 4 against our bot!

Just add new calls to `take_move(game_board, column_index, duration)` to play a move in column `column_index`. Turning up the `duration` parameter will improve the bot by allowing it to search for longer.

In [17]:
##  Create a new game

game_board = GameBoard()
print(game_board)


+---+---+---+---+---+---+---+
| . | . | . | . | . | . | . |
| . | . | . | . | . | . | . |
| . | . | . | . | . | . | . |
| . | . | . | . | . | . | . |
| . | . | . | . | . | . | . |
| . | . | . | . | . | . | . |
+---+---+---+---+---+---+---+
| 0 | 1 | 2 | 3 | 4 | 5 | 6 |
+---+---+---+---+---+---+---+
Game result is: NONE


In [18]:
##  Play a move in column index 3

take_move(game_board, 3, duration=5, max_turns=30)


Human takes move 3
+---+---+---+---+---+---+---+
| . | . | . | . | . | . | . |
| . | . | . | . | . | . | . |
| . | . | . | . | . | . | . |
| . | . | . | . | . | . | . |
| . | . | . | . | . | . | . |
| . | . | . | X | . | . | . |
+---+---+---+---+---+---+---+
| 0 | 1 | 2 | 3 | 4 | 5 | 6 |
+---+---+---+---+---+---+---+
Game result is: NONE

Bot takes move 1 (972 iterations)
+---+---+---+---+---+---+---+
| . | . | . | . | . | . | . |
| . | . | . | . | . | . | . |
| . | . | . | . | . | . | . |
| . | . | . | . | . | . | . |
| . | . | . | . | . | . | . |
| . | 0 | . | X | . | . | . |
+---+---+---+---+---+---+---+
| 0 | 1 | 2 | 3 | 4 | 5 | 6 |
+---+---+---+---+---+---+---+
Game result is: NONE


---

... and so on, we keep calling `take_move` until the game is complete!

---

In [19]:
##  Play a move in column index 3

take_move(game_board, 2, duration=5, max_turns=30, debug_lvl=DebugLevel.LOW)


Human takes move 2
+---+---+---+---+---+---+---+
| . | . | . | . | . | . | . |
| . | . | . | . | . | . | . |
| . | . | . | . | . | . | . |
| . | . | . | . | . | . | . |
| . | . | . | . | . | . | . |
| . | 0 | X | X | . | . | . |
+---+---+---+---+---+---+---+
| 0 | 1 | 2 | 3 | 4 | 5 | 6 |
+---+---+---+---+---+---+---+
Game result is: NONE

Action values are:  -0.36  -0.31  -0.12  -0.14  -0.16  -0.11  -0.48 
Visit counts are:   78     90     214    192    164    208    54    
Selecting action 5
Bot takes move 5 (1000 iterations)
+---+---+---+---+---+---+---+
| . | . | . | . | . | . | . |
| . | . | . | . | . | . | . |
| . | . | . | . | . | . | . |
| . | . | . | . | . | . | . |
| . | . | . | . | . | . | . |
| . | 0 | X | X | . | 0 | . |
+---+---+---+---+---+---+---+
| 0 | 1 | 2 | 3 | 4 | 5 | 6 |
+---+---+---+---+---+---+---+
Game result is: NONE
