# Connect 4

---

Author: S. Menary [sbmenary@gmail.com]

Date  : 2023-01-03, last edit 2023-01-14

Brief : Develop a simple Connect 4 game environment and implement a bot using Monte Carlo Tree Search (MCTS)

---

### Summary

- Connect 4 is a two-player, fully-observable, zero-sum game. 
- The game states may be represented as a tree sturcture, We can therefore implement a bot using tree-search algorithms. We choose Connect 4 because it is simple, and therefore provides a launch-pad for more complex games such as checkers or chess.
- Initially we implement vanilla MCTS with no machine learning. We expect this to be limited by (i) the stochastic rollout of the tree and (ii) the simplicity of the simulation policy.
- To introduce ML, we would perform alternate steps of MCTS evaluation and simulation policy improvement. In this way, the simulated games will _hopefully_ begin to approach "good play", and the final MCTS values will reflect the behaviour of good players.
- MCTS configuration:
    + Tree-traversal policy is:
        1. From the current node, uniformly-randomly select a non-expanded child if one is available
        2. Otherwise select child with highest UCB-1 score, traverse to this node and repeat
    + Resulting node is expanded by adding all possible children and selecting one by performing a uniformly-random action
    + Simulation policy is to select a uniformly-random action
- The UCB-1 score is designed to optimally balance exploration/exploitation for static multi-arm bandits. Strictly speaking, we are applying this in a non-stationary environment because the reward-distribution for each action changes according to the evolution of the down-stream tree. This makes UCB-1 theoretically sub-optimal. However, it is often used nonetheless.
- When playing an actual move (i.e. inference time), greedily select the action with the max average score from its MCTS visits (do not use UCB-1 since we are no longer exploring).

Observations:
- Strength of decision-making depends on how many iterations of MCTS we perform:
    1. When tree is shallow, we effectively assume that future play is random, which means we will choose options with the greatest number of permutations of winning. We therefore may neglect to defend against an imminent loss, favouring a different move with many win permutations (bad behaviour).
    2. When tree is deep and UCB1 score converges towards true means, at least for the best moves, then we effectively assume that future play is optimal. As play-count goes to infinity, our scores become unbiased.
    3. For finite but sufficient run-time, we assume optimal play, but using mean scores that are biased by the fact that our early simulations used random play instead of optimal play.
- This explains why even random simulation MCTS is pretty good - we end up doing most of our simulations with pretty effective play, at least for the next few moves where our tree is sufficiently grown.


## Imports

In [1]:
###
###  Required imports
###  - all imports should be placed here
###


##  Python core libs
import sys, time
from enum import IntEnum
from abc  import ABC, abstractmethod
from __future__ import annotations

##  PyPI libs
import numpy as np

##  Local packages
from connect4.enums     import BinaryPlayer, DebugLevel, GameResult
from connect4.GameBoard import GameBoard
from connect4.MCTS      import Node


In [2]:
###
###  Print version for reproducibility
###

print(f"Python version is {sys.version}")
print(f"Numpy  version is {np.__version__}")

Python version is 3.10.8 | packaged by conda-forge | (main, Nov 22 2022, 08:25:29) [Clang 14.0.6 ]
Numpy  version is 1.23.2


##  MCTS

Implement Node class to handle the tree search.

In [3]:
###
###  Methods for MCTS
###  - Implement methods which interact with the Node class to perform a number of MCTS iterations
###


def one_step_MCTS(root_node, max_turns=-1, debug_lvl=DebugLevel.MUTE) :
    """
    Perform a single MCTS iteration on the root_node provided.
    """
    
    ##  Select and expand from the root node
    leaf_node = root_node.select_and_expand(recurse=True, debug_lvl=debug_lvl)
    
    ##  Simulate and backprop from the selected child
    leaf_node.simulate_and_backprop(max_turns=max_turns, debug_lvl=debug_lvl)
    
    ##  Print updated tree if debug level is HIGH
    debug_lvl.message(DebugLevel.HIGH, f"Updated tree is:\n{root_node.tree_summary()}")
    
    
def multi_step_MCTS(root_node, num_steps, max_turns=-1, debug_lvl=DebugLevel.MUTE) :
    """
    Perform a many MCTS iterations on the root_node provided.
    """
    
    ##  Call one_step_MCTS a number of times equal to num_steps
    for idx in range(num_steps) :
        debug_lvl.message(DebugLevel.MEDIUM, f"Running MCTS step {idx}")
        one_step_MCTS(root_node, max_turns=max_turns, debug_lvl=debug_lvl)
        debug_lvl.message(DebugLevel.MEDIUM, f"")
        
        
def timed_MCTS(root_node, duration, max_turns=-1, debug_lvl=DebugLevel.MUTE) :
    """
    Perform a MCTS iterations on the root_node until duration (in seconds) has elapsed.
    After this time, MCTS will finish its current iteration, so total execution time is > duration.
    """
    
    ##  Keep calling one_step_MCTS until required duration has elapsed
    start_time   = time.time()
    current_time = start_time
    num_itr = 0
    while current_time - start_time < duration :
        one_step_MCTS(root_node, max_turns=max_turns, debug_lvl=debug_lvl)
        current_time = time.time()
        num_itr += 1
    return num_itr


def get_bot_action(game_board, duration=1, max_turns=-1, debug_lvl=DebugLevel.MUTE) :
    """
    Create a root_node from the current game state, and perform a timed MCTS to choose a move.
    """
    
    ##  Create root node from current game board
    root_node = Node(game_board)
    
    ##  Call timed_MCTS to update tree values 
    num_itr       = timed_MCTS(root_node, duration=duration, max_turns=max_turns, debug_lvl=debug_lvl)
    chosen_action = root_node.get_best_action()
    
    ##  Print debug info
    debug_lvl.message(DebugLevel.HIGH, 
          root_node.tree_summary())
    debug_lvl.message(DebugLevel.LOW, 
          "Action values are:  " + " ".join([f"{x.get_action_value():.2f}".ljust(6) if x else "N/A   " for x in root_node.children]))
    debug_lvl.message(DebugLevel.LOW, 
          "Visit counts are:   " + " ".join([f"{x.num_visits}".ljust(6) if x else "N/A   " for x in root_node.children]))
    debug_lvl.message(DebugLevel.LOW, 
          f"Selecting action {chosen_action}")
        
    ##  Return best action from tree evaluation, and the number of MCTS iterations executed
    return chosen_action, root_node, num_itr
    
    
def take_move(game_board, my_action, duration=1, max_turns=-1, debug_lvl=DebugLevel.MUTE) :
    """
    Apply a human move.
    Print the game board.
    Use MCTS to find a responding bot move.
    Apply the bot move.
    Print the game board.
    """
    
    ##  Apply the human move.
    print(f"Human takes move {my_action}")
    game_board.apply_action(my_action)
    print(game_board)
    print()
    
    ##  If game has ended then return
    if game_board.get_result() :
        return
    
    ##  Use timed MCTS to obtain a bot action
    bot_action, _, num_itr = get_bot_action(game_board, 
                                            duration=duration, 
                                            debug_lvl=debug_lvl)
    
    ##  Apply the bot move
    print(f"Bot takes move {bot_action} ({num_itr} iterations)")
    game_board.apply_action(bot_action)
    print(game_board)


##  Test MCTS

In [4]:
###
###  Setup a small game
###  - 4x4 grid
###  - line of 3 needed to win
###

##  Create game board
game_board = GameBoard(4, 4, 3)

##  Show initial game board
print(game_board)


+---+---+---+---+
| . | . | . | . |
| . | . | . | . |
| . | . | . | . |
| . | . | . | . |
+---+---+---+---+
| 0 | 1 | 2 | 3 |
+---+---+---+---+
Game result is: NONE


In [5]:
###
###  Play a few initial moves
###  - transitions into a ciritical state where O player needs to be careful not to 
###    blunder a win for X
###

##  Play moves
game_board.apply_action(1)
game_board.apply_action(2)
game_board.apply_action(1)

##  Show updated game state
print(game_board)


+---+---+---+---+
| . | . | . | . |
| . | . | . | . |
| . | [31mX[0m | . | . |
| . | [31mX[0m | [34mO[0m | . |
+---+---+---+---+
| 0 | 1 | 2 | 3 |
+---+---+---+---+
Game result is: NONE


In [6]:
###
###  Perform a few MCTS steps
###  - transitions into a ciritical state where O player needs to be careful not to 
###    blunder a win for X
###

##  Create a root node at the current game state
root_node = Node(game_board, label="ROOT")

##  Print the initial value tree (should be a ROOT node with no children)
print("Initial tree:")
print(root_node.tree_summary())
print()

##  Perform several MCTS steps with a HIGH debug level
multi_step_MCTS(root_node, num_steps=10, max_turns=-1, debug_lvl=DebugLevel.HIGH)

##  Print the updated value tree 
print("Updated tree:")
print(root_node.tree_summary())
print()


Initial tree:
> [0: ROOT] N=0, T=0, E=inf, Q=-inf
     > None
     > None
     > None
     > None

Running MCTS step 0
Select unvisited action -1:0
Simulation ended with result X
Simulated trajectory was: 1:3 -1:0 1:3 -1:2 1:3
Node -1:0 with parent=O, N=0, T=0.00 receiving score -1.00 for game ending in result X
Node ROOT with parent=NONE, N=0, T=0.00 receiving score 0.00 for game ending in result X
Updated tree is:
> [0: ROOT] N=1, T=0.0, E=nan, Q=0.000
     > [1: -1:0] N=1, T=-1.0, E=-1.000, Q=-1.000
          > None
          > None
          > None
          > None
     > None
     > None
     > None

Running MCTS step 1
Select unvisited action -1:1
Simulation ended with result X
Simulated trajectory was: 1:3 -1:0 1:2 -1:1 1:2 -1:2 1:0
Node -1:1 with parent=O, N=0, T=0.00 receiving score -1.00 for game ending in result X
Node ROOT with parent=NONE, N=1, T=0.00 receiving score 0.00 for game ending in result X
Updated tree is:
> [0: ROOT] N=2, T=0.0, E=nan, Q=0.000
     > [1: -1:0] N

In [7]:
###
###  Use MCTS to play a move
###

##  Use MCTS to search for an optimal action
bot_action, _, num_itr = get_bot_action(game_board, 
                                        duration=1, 
                                        debug_lvl=DebugLevel.LOW)
print(f"Bot chooses action {bot_action} after {num_itr} MCTS iterations")

##  Play bot move
game_board.apply_action(bot_action)

##  Show updated game state
print(game_board)


Action values are:  -0.81  -0.59  -0.80  -0.84 
Visit counts are:   145    543    149    125   
Selecting action 1
Bot chooses action 1 after 962 MCTS iterations
+---+---+---+---+
| . | . | . | . |
| . | [34mO[0m | . | . |
| . | [31mX[0m | . | . |
| . | [31mX[0m | [34mO[0m | . |
+---+---+---+---+
| 0 | 1 | 2 | 3 |
+---+---+---+---+
Game result is: NONE


## Connect 4

Play a game of connect 4 against our bot!

Just add new calls to `take_move(game_board, column_index, duration)` to play a move in column `column_index`. Turning up the `duration` parameter will improve the bot by allowing it to search for longer.

In [8]:
##  Create a new game

game_board = GameBoard()
print(game_board)


+---+---+---+---+---+---+---+
| . | . | . | . | . | . | . |
| . | . | . | . | . | . | . |
| . | . | . | . | . | . | . |
| . | . | . | . | . | . | . |
| . | . | . | . | . | . | . |
| . | . | . | . | . | . | . |
+---+---+---+---+---+---+---+
| 0 | 1 | 2 | 3 | 4 | 5 | 6 |
+---+---+---+---+---+---+---+
Game result is: NONE


In [9]:
##  Play a move in column index 3

take_move(game_board, 3, duration=5, max_turns=30)


Human takes move 3
+---+---+---+---+---+---+---+
| . | . | . | . | . | . | . |
| . | . | . | . | . | . | . |
| . | . | . | . | . | . | . |
| . | . | . | . | . | . | . |
| . | . | . | . | . | . | . |
| . | . | . | [31mX[0m | . | . | . |
+---+---+---+---+---+---+---+
| 0 | 1 | 2 | 3 | 4 | 5 | 6 |
+---+---+---+---+---+---+---+
Game result is: NONE

Bot takes move 3 (1025 iterations)
+---+---+---+---+---+---+---+
| . | . | . | . | . | . | . |
| . | . | . | . | . | . | . |
| . | . | . | . | . | . | . |
| . | . | . | . | . | . | . |
| . | . | . | [34mO[0m | . | . | . |
| . | . | . | [31mX[0m | . | . | . |
+---+---+---+---+---+---+---+
| 0 | 1 | 2 | 3 | 4 | 5 | 6 |
+---+---+---+---+---+---+---+
Game result is: NONE


---

... and so on, we keep calling `take_move` until the game is complete!

---

## Bot-only game

In [10]:
#  Play a bot game!

game_board = GameBoard()
result     = game_board.get_result()
print(game_board)

while not result :
    chosen_action, _, _ = get_bot_action(game_board)
    game_board.apply_action(chosen_action)
    result = game_board.get_result()
    print(game_board)


+---+---+---+---+---+---+---+
| . | . | . | . | . | . | . |
| . | . | . | . | . | . | . |
| . | . | . | . | . | . | . |
| . | . | . | . | . | . | . |
| . | . | . | . | . | . | . |
| . | . | . | . | . | . | . |
+---+---+---+---+---+---+---+
| 0 | 1 | 2 | 3 | 4 | 5 | 6 |
+---+---+---+---+---+---+---+
Game result is: NONE
+---+---+---+---+---+---+---+
| . | . | . | . | . | . | . |
| . | . | . | . | . | . | . |
| . | . | . | . | . | . | . |
| . | . | . | . | . | . | . |
| . | . | . | . | . | . | . |
| . | . | . | . | [31mX[0m | . | . |
+---+---+---+---+---+---+---+
| 0 | 1 | 2 | 3 | 4 | 5 | 6 |
+---+---+---+---+---+---+---+
Game result is: NONE
+---+---+---+---+---+---+---+
| . | . | . | . | . | . | . |
| . | . | . | . | . | . | . |
| . | . | . | . | . | . | . |
| . | . | . | . | . | . | . |
| . | . | . | . | . | . | . |
| . | . | . | [34mO[0m | [31mX[0m | . | . |
+---+---+---+---+---+---+---+
| 0 | 1 | 2 | 3 | 4 | 5 | 6 |
+---+---+---+---+---+---+---+
Game result is: NONE
+---+---+-