# Connect 4

---

Author: S. Menary [sbmenary@gmail.com]

Date  : 2023-01-03, last edit 2023-01-14

Brief : Develop a simple Connect 4 game environment and implement a bot using Monte Carlo Tree Search (MCTS)

---

### Summary

- Connect 4 is a two-player, fully-observable, zero-sum game. 
- The game states may be represented as a tree sturcture, We can therefore implement a bot using tree-search algorithms. We choose Connect 4 because it is simple, and therefore provides a launch-pad for more complex games such as checkers or chess.
- Initially we implement vanilla MCTS with no machine learning. We expect this to be limited by (i) the stochastic rollout of the tree and (ii) the simplicity of the simulation policy.
- To introduce ML, we would perform alternate steps of MCTS evaluation and simulation policy improvement. In this way, the simulated games will _hopefully_ begin to approach "good play", and the final MCTS values will reflect the behaviour of good players.
- MCTS configuration:
    + Tree-traversal policy is:
        1. From the current node, uniformly-randomly select a non-expanded child if one is available
        2. Otherwise select child with highest UCB-1 score, traverse to this node and repeat
    + Resulting node is expanded by adding all possible children and selecting one by performing a uniformly-random action
    + Simulation policy is to select a uniformly-random action
- The UCB-1 score is designed to optimally balance exploration/exploitation for static multi-arm bandits. Strictly speaking, we are applying this in a non-stationary environment because the reward-distribution for each action changes according to the evolution of the down-stream tree. This makes UCB-1 theoretically sub-optimal. However, it is often used nonetheless.
- When playing an actual move (i.e. inference time), greedily select the action with the max average score from its MCTS visits (do not use UCB-1 since we are no longer exploring).

Observations:
- Strength of decision-making depends on how many iterations of MCTS we perform:
    1. When tree is shallow, we effectively assume that future play is random, which means we will choose options with the greatest number of permutations of winning. We therefore may neglect to defend against an imminent loss, favouring a different move with many win permutations (bad behaviour).
    2. When tree is deep and UCB1 score converges towards true means, at least for the best moves, then we effectively assume that future play is optimal. As play-count goes to infinity, our scores become unbiased.
    3. For finite but sufficient run-time, we assume optimal play, but using mean scores that are biased by the fact that our early simulations used random play instead of optimal play.
- This explains why even random simulation MCTS is pretty good - we end up doing most of our simulations with pretty effective play, at least for the next few moves where our tree is sufficiently grown.


## Imports

In [1]:
###
###  Required imports
###  - all imports should be placed here
###


##  Python core libs
import sys, time
from enum import IntEnum
from abc  import ABC, abstractmethod, abstractstaticmethod
from __future__ import annotations

##  PyPI libs
import numpy as np

##  Local packages
from connect4.enums     import BinaryPlayer, DebugLevel, GameResult
from connect4.GameBoard import GameBoard
from connect4.MCTS      import Node_VanillaMCTS


In [2]:
###
###  Print version for reproducibility
###

print(f"Python version is {sys.version}")
print(f"Numpy  version is {np.__version__}")

Python version is 3.10.8 | packaged by conda-forge | (main, Nov 22 2022, 08:25:29) [Clang 14.0.6 ]
Numpy  version is 1.23.2


##  MCTS

Implement Node class to handle the tree search.

In [3]:


class BaseBot(ABC) :
    
    def __init__(self) :
        """
        Class BaseBot
        """
        self.root_node = None
    
    
    @abstractmethod
    def create_root_node(self, game_board:GameBoard) :
        """
        Create a root node for a given MCTS node type.
        """
        raise NotImplementedError()
        
        
    def choose_action(self, 
                      game_board:GameBoard = None, 
                      duration:int         = 1, 
                      max_sim_steps:int    = -1, 
                      debug_lvl:DebugLevel = DebugLevel.MUTE) :
        """
        Create a root_node from the current game state, and perform a timed MCTS to choose a move.
        """
        
        ##  If game has ended then cannot generate a new action
        game_result = game_board.get_result() 
        if game_board and game_board.get_result() :
            raise RuntimeError(f"Game is in terminal state {game_result}")

        ##  Create root_node from game_board provided
        ##  -  fall back to stored root_node if game_board is None
        if game_board :
            root_node = self.create_root_node(game_board)
        elif self.root_node : 
            root_node = self.root_node
        else :
            raise RuntimeError("No game_board provided and no previous root node stored")

        ##  Call timed_MCTS to update tree values 
        root_node.timed_MCTS(duration      = duration     , 
                             max_sim_steps = max_sim_steps, 
                             debug_lvl     = debug_lvl    )
        chosen_action = root_node.get_best_action()

        ##  Print debug info
        debug_lvl.message(DebugLevel.HIGH, root_node.tree_summary())
        debug_lvl.message(DebugLevel.LOW, 
              "Action values are:  " + " ".join([f"{x.get_action_value():.2f}".ljust(6) if x else "N/A   " for x in root_node.children]))
        debug_lvl.message(DebugLevel.LOW, 
              "Visit counts are:   " + " ".join([f"{x.num_visits}".ljust(6) if x else "N/A   " for x in root_node.children]))
        debug_lvl.message(DebugLevel.LOW, 
              f"Selecting action {chosen_action}")
        
        ##  Store root node
        self.root_node = root_node

        ##  Return best action from tree evaluation
        return chosen_action
    
    
    def num_itr(self) :
        """
        Return number of MCTS iterations applied to the root node.
        """
        
        ##  If no root node created then answer is 0
        if not self.root_node :
            return 0
            
        ##  Otherwise query root node for num_visits
        return self.root_node.num_visits
    
    
    def take_move(self, 
                  game_board:GameBoard = None, 
                  duration:int         = 1, 
                  max_sim_steps:int    = -1, 
                  debug_lvl:DebugLevel = DebugLevel.MUTE) :
        """
        Use MCTS to find a bot move.
        Apply the bot move.
        return the updated game board.
        """

        ##  Use timed MCTS to obtain a bot action
        action = self.choose_action(game_board, duration, max_sim_steps, debug_lvl)

        ##  Apply the bot move
        game_board.apply_action(action)
        return game_board
        

In [4]:

class Bot_VanillaMCTS(BaseBot) :
    
    def create_root_node(self, game_board) :
        return Node_VanillaMCTS(game_board)
    

##  Test MCTS

In [5]:
###
###  Setup a small game
###  - 4x4 grid
###  - line of 3 needed to win
###

##  Create game board
game_board = GameBoard(4, 4, 3)

##  Show initial game board
print(game_board)


+---+---+---+---+
| . | . | . | . |
| . | . | . | . |
| . | . | . | . |
| . | . | . | . |
+---+---+---+---+
| 0 | 1 | 2 | 3 |
+---+---+---+---+
Game result is: NONE


In [6]:
###
###  Play a few initial moves
###  - transitions into a ciritical state where O player needs to be careful not to 
###    blunder a win for X
###

##  Play moves
game_board.apply_action(1)
game_board.apply_action(2)
game_board.apply_action(1)

##  Show updated game state
print(game_board)


+---+---+---+---+
| . | . | . | . |
| . | . | . | . |
| . | [31mX[0m | . | . |
| . | [31mX[0m | [34mO[0m | . |
+---+---+---+---+
| 0 | 1 | 2 | 3 |
+---+---+---+---+
Game result is: NONE


In [7]:
###
###  Perform a few MCTS steps
###  - transitions into a ciritical state where O player needs to be careful not to 
###    blunder a win for X
###

##  Create a root node at the current game state
root_node = Node_VanillaMCTS(game_board, label="ROOT")

##  Print the initial value tree (should be a ROOT node with no children)
print("Initial tree:")
print(root_node.tree_summary())
print()

##  Perform several MCTS steps with a HIGH debug level
root_node.multi_step_MCTS(num_steps=10, max_sim_steps=-1, debug_lvl=DebugLevel.HIGH)

##  Print the updated value tree 
print("Updated tree:")
print(root_node.tree_summary())
print()


Initial tree:
> [0: ROOT] N=0, T=0, E=inf, Q=-inf
     > None
     > None
     > None
     > None

Running MCTS step 0
Select unvisited action -1:2
Simulation ended with result X
Simulated trajectory was: 1:2 -1:3 1:3 -1:2 1:3 -1:0 1:0 -1:3 1:1
Node -1:2 with parent=O, N=0, T=0.00 receiving score -1.00 for game ending in result X
Node ROOT with parent=NONE, N=0, T=0.00 receiving score 0.00 for game ending in result X
Updated tree is:
> [0: ROOT] N=1, T=0.0, E=nan, Q=0.000
     > None
     > None
     > [1: -1:2] N=1, T=-1.0, E=-1.000, Q=-1.000
          > None
          > None
          > None
          > None
     > None

Running MCTS step 1
Select unvisited action -1:3
Simulation ended with result X
Simulated trajectory was: 1:2 -1:1 1:0 -1:1 1:0
Node -1:3 with parent=O, N=0, T=0.00 receiving score -1.00 for game ending in result X
Node ROOT with parent=NONE, N=1, T=0.00 receiving score 0.00 for game ending in result X
Updated tree is:
> [0: ROOT] N=2, T=0.0, E=nan, Q=0.000
     > No

In [8]:
###
###  Use MCTS to play a move
###

##  Use MCTS to search for an optimal action
bot    = Bot_VanillaMCTS()
action = bot.choose_action(game_board, duration=1, debug_lvl=DebugLevel.LOW)
print(f"Bot chooses action {action} after {bot.num_itr()} MCTS iterations")

##  Play bot move
game_board.apply_action(action)

##  Show updated game state
print(game_board)


Action values are:  -0.80  -0.55  -0.77  -0.82 
Visit counts are:   120    535    138    111   
Selecting action 1
Bot chooses action 1 after 904 MCTS iterations
+---+---+---+---+
| . | . | . | . |
| . | [34mO[0m | . | . |
| . | [31mX[0m | . | . |
| . | [31mX[0m | [34mO[0m | . |
+---+---+---+---+
| 0 | 1 | 2 | 3 |
+---+---+---+---+
Game result is: NONE


## Connect 4

Play a game of connect 4 against our bot!

Just add new calls to `take_move(game_board, column_index, duration)` to play a move in column `column_index`. Turning up the `duration` parameter will improve the bot by allowing it to search for longer.

In [9]:
##  Create a new game

game_board = GameBoard()
bot        = Bot_VanillaMCTS()
print(game_board)


+---+---+---+---+---+---+---+
| . | . | . | . | . | . | . |
| . | . | . | . | . | . | . |
| . | . | . | . | . | . | . |
| . | . | . | . | . | . | . |
| . | . | . | . | . | . | . |
| . | . | . | . | . | . | . |
+---+---+---+---+---+---+---+
| 0 | 1 | 2 | 3 | 4 | 5 | 6 |
+---+---+---+---+---+---+---+
Game result is: NONE


In [10]:
##  Play a move in column index 3

game_board.apply_action(3)
print(game_board)

if not game_board.get_result() :
    bot.take_move(game_board, duration=5, debug_lvl=DebugLevel.LOW)
    print(game_board)


+---+---+---+---+---+---+---+
| . | . | . | . | . | . | . |
| . | . | . | . | . | . | . |
| . | . | . | . | . | . | . |
| . | . | . | . | . | . | . |
| . | . | . | . | . | . | . |
| . | . | . | [31mX[0m | . | . | . |
+---+---+---+---+---+---+---+
| 0 | 1 | 2 | 3 | 4 | 5 | 6 |
+---+---+---+---+---+---+---+
Game result is: NONE
Action values are:  -0.48  -0.41  -0.28  -0.20  -0.42  -0.28  -0.38 
Visit counts are:   77     98     169    271    95     172    107   
Selecting action 3
+---+---+---+---+---+---+---+
| . | . | . | . | . | . | . |
| . | . | . | . | . | . | . |
| . | . | . | . | . | . | . |
| . | . | . | . | . | . | . |
| . | . | . | [34mO[0m | . | . | . |
| . | . | . | [31mX[0m | . | . | . |
+---+---+---+---+---+---+---+
| 0 | 1 | 2 | 3 | 4 | 5 | 6 |
+---+---+---+---+---+---+---+
Game result is: NONE


---

... and so on, we keep calling `take_move` until the game is complete!

---

## Bot-only game

In [11]:
#  Play a bot game!

game_board = GameBoard()
bot        = Bot_VanillaMCTS()
print(game_board)

result = game_board.get_result()
while not result :
    bot.take_move(game_board, duration=1, debug_lvl=DebugLevel.LOW)
    result = game_board.get_result()
    print(game_board)


+---+---+---+---+---+---+---+
| . | . | . | . | . | . | . |
| . | . | . | . | . | . | . |
| . | . | . | . | . | . | . |
| . | . | . | . | . | . | . |
| . | . | . | . | . | . | . |
| . | . | . | . | . | . | . |
+---+---+---+---+---+---+---+
| 0 | 1 | 2 | 3 | 4 | 5 | 6 |
+---+---+---+---+---+---+---+
Game result is: NONE
Action values are:  0.21   0.03   -0.25  0.43   0.00   -0.14  0.21  
Visit counts are:   38     29     16     88     26     21     43    
Selecting action 3
+---+---+---+---+---+---+---+
| . | . | . | . | . | . | . |
| . | . | . | . | . | . | . |
| . | . | . | . | . | . | . |
| . | . | . | . | . | . | . |
| . | . | . | . | . | . | . |
| . | . | . | [31mX[0m | . | . | . |
+---+---+---+---+---+---+---+
| 0 | 1 | 2 | 3 | 4 | 5 | 6 |
+---+---+---+---+---+---+---+
Game result is: NONE
Action values are:  -0.32  -0.62  -0.28  -0.17  -0.34  -0.50  -0.50 
Visit counts are:   41     21     47     63     38     28     28    
Selecting action 3
+---+---+---+---+---+---+---+
| . |

Action values are:  -0.27  0.14   -0.15  -0.36  -0.02  0.05  
Visit counts are:   30     95     40     25     59     76    
Selecting action 1
+---+---+---+---+---+---+---+
| . | . | . | [34mO[0m | . | . | . |
| . | . | . | [31mX[0m | . | . | . |
| . | [34mO[0m | . | [31mX[0m | [31mX[0m | . | . |
| . | [31mX[0m | . | [31mX[0m | [34mO[0m | . | [31mX[0m |
| . | [34mO[0m | . | [34mO[0m | [31mX[0m | . | [34mO[0m |
| . | [34mO[0m | . | [31mX[0m | [34mO[0m | [34mO[0m | [31mX[0m |
+---+---+---+---+---+---+---+
| 0 | 1 | 2 | 3 | 4 | 5 | 6 |
+---+---+---+---+---+---+---+
Game result is: NONE
Action values are:  -0.17  0.02   0.07   -0.17  -0.08  0.37  
Visit counts are:   30     48     56     29     36     174   
Selecting action 6
+---+---+---+---+---+---+---+
| . | . | . | [34mO[0m | . | . | . |
| . | . | . | [31mX[0m | . | . | . |
| . | [34mO[0m | . | [31mX[0m | [31mX[0m | . | [31mX[0m |
| . | [31mX[0m | . | [31mX[0m | [34mO[0m | . | [31