# Connect 4

In [1]:
from collections import Counter
import numpy as np
import time

In [2]:
NUM_COLUMNS = 7
COLUMN_HEIGHT = 6
FOUR = 4

# Board can be initiatilized with `board = np.zeros((NUM_COLUMNS, COLUMN_HEIGHT), dtype=np.byte)`
# Notez Bien: Connect 4 "columns" are actually NumPy "rows"

# Basic Functions

In [3]:
def valid_moves(board):
    """Returns columns where a disc may be played"""
    return [n for n in range(NUM_COLUMNS) if board[n, COLUMN_HEIGHT - 1] == 0]

def play(board, column, player):
    """Updates `board` as `player` drops a disc in `column`"""
    (index,) = next((i for i, v in np.ndenumerate(board[column]) if v == 0))
    board[column, index] = player

def take_back(board, column):
    """Updates `board` removing top disc from `column`"""
    (index,) = [i for i, v in np.ndenumerate(board[column]) if v != 0][-1]
    board[column, index] = 0

def four_in_a_row(board, player):
    """Checks if `player` has a 4-piece line"""
    return (
        any(
            all(board[c, r] == player)
            for c in range(NUM_COLUMNS)
            for r in (list(range(n, n + FOUR)) for n in range(COLUMN_HEIGHT - FOUR + 1))
        )
        or any(
            all(board[c, r] == player)
            for r in range(COLUMN_HEIGHT)
            for c in (list(range(n, n + FOUR)) for n in range(NUM_COLUMNS - FOUR + 1))
        )
        or any(
            np.all(board[diag] == player)
            for diag in (
                (range(ro, ro + FOUR), range(co, co + FOUR))
                for ro in range(0, NUM_COLUMNS - FOUR + 1)
                for co in range(0, COLUMN_HEIGHT - FOUR + 1)
            )
        )
        or any(
            np.all(board[diag] == player)
            for diag in (
                (range(ro, ro + FOUR), range(co + FOUR - 1, co - 1, -1))
                for ro in range(0, NUM_COLUMNS - FOUR + 1)
                for co in range(0, COLUMN_HEIGHT - FOUR + 1)
            )
        )
    )

In [4]:
def display(board):
    for i in range(COLUMN_HEIGHT-1, -1, -1):
        for j in range(NUM_COLUMNS):
            o = '.'
            if board[j,i] == 1:
                o = 'X'
            elif board[j,i] == -1:
                o = 'O'
            print(o, end='  ')
        print()
    print('---' * (NUM_COLUMNS - 1), end='-\n')

# MinMax Algorithm + Montecarlo Evaluation

In [5]:
def _mc(board, player):
    p = -player
    while valid_moves(board):
        p = -p
        c = np.random.choice(valid_moves(board))
        play(board, c, p)
        if four_in_a_row(board, p):
            return p
    return 0

def montecarlo(board, player, montecarlo_samples=50):
    cnt = Counter(_mc(np.copy(board), player) for _ in range(montecarlo_samples))
    return (cnt[1] -player * cnt[0] - cnt[-1]) / montecarlo_samples

def alpha_beta_minmax(board, player, alpha=-np.inf, beta=np.inf, depth=3, montecarlo_samples=50, cache={}, starting_depth=3):
    starting_depth = depth
    valid = valid_moves(board)
    # Reorder moves based on priority to exploit as much as possible alpha-beta pruning
    # Central col --> priority then, create threats in even cols
    possible = [move for move in [3,1,5,2,4,0,6] if move in valid]
    
    # Return heuristic eval of state
    if depth == 0 or not possible:
        return None, montecarlo(board, player, montecarlo_samples)

    if player == 1: # Maximizing player

        # Detect suicide moves
        if depth == starting_depth:
            for ply in possible:
                b = np.copy(board)
                play(b, ply, -1)
                if four_in_a_row(b, -1):
                    return (ply, 1)

        max_eval = -np.inf
        max_move = possible[0]
        for ply in possible:
            b = np.copy(board)
            play(b, ply, 1)

            # Detect insta-win moves
            if four_in_a_row(b, 1):
                return (ply, 2)

            eval = None
            if str(b) in cache.keys():
                eval = cache[str(b)]
            else:
                _, eval = alpha_beta_minmax(b, -1, alpha, beta, depth - 1, montecarlo_samples, cache, starting_depth)
                cache[str(b)] = eval
                cache[str(np.flip(b, 0))] = eval
                
            if eval > max_eval:
                max_eval = eval
                max_move = ply
            alpha = max(alpha, eval)
            if beta <= alpha:
                break
        return (max_move, max_eval)

    else: # Minimizing player
        
        # Detect suicide moves
        if depth == starting_depth:
            for ply in possible:
                b = np.copy(board)
                play(b, ply, 1)
                if four_in_a_row(b, 1):
                    return (ply, -1)

        min_eval = np.inf
        min_move = possible[0]
        for ply in possible:
            b = np.copy(board)
            play(b, ply, -1)

            # Detect insta-win moves
            if four_in_a_row(b, -1):
                return (ply, -2)

            eval = None
            if str(b) in cache.keys():
                eval = cache[str(b)]
            else:
                _, eval = alpha_beta_minmax(b, 1, alpha, beta, depth - 1, montecarlo_samples, cache, starting_depth)
                cache[str(b)] = eval
                cache[str(np.flip(b, 0))] = eval
                
            if eval < min_eval:
                min_eval = eval
                min_move = ply
            beta = min(beta, eval)
            if beta <= alpha:
                break
        return (min_move, min_eval)

In [6]:
# Test of MinMax + Montecarlo Eval.
board = np.zeros((NUM_COLUMNS, COLUMN_HEIGHT), dtype=np.byte)
display(board)
cache = {}

player = -1
while not (four_in_a_row(board, 1) or four_in_a_row(board, -1)) and valid_moves(board):
    tick = time.time()
    player = -player
    col, _ = alpha_beta_minmax(board, player, depth=5, montecarlo_samples=5, cache=cache)
    play(board, col, player)
    display(board)
    print(f'Last move time: {(time.time() - tick):.4f}s')

if four_in_a_row(board, 1): print('P1 (X) Won!') 
elif four_in_a_row(board, -1): print('P2 (O) Won!')
else: print('Tie!')

.  .  .  .  .  .  .  
.  .  .  .  .  .  .  
.  .  .  .  .  .  .  
.  .  .  .  .  .  .  
.  .  .  .  .  .  .  
.  .  .  .  .  .  .  
-------------------
.  .  .  .  .  .  .  
.  .  .  .  .  .  .  
.  .  .  .  .  .  .  
.  .  .  .  .  .  .  
.  .  .  .  .  .  .  
.  .  .  X  .  .  .  
-------------------
Last move time: 11.6323s
.  .  .  .  .  .  .  
.  .  .  .  .  .  .  
.  .  .  .  .  .  .  
.  .  .  .  .  .  .  
.  .  .  O  .  .  .  
.  .  .  X  .  .  .  
-------------------
Last move time: 0.0270s
.  .  .  .  .  .  .  
.  .  .  .  .  .  .  
.  .  .  .  .  .  .  
.  .  .  X  .  .  .  
.  .  .  O  .  .  .  
.  .  .  X  .  .  .  
-------------------
Last move time: 0.0270s
.  .  .  .  .  .  .  
.  .  .  .  .  .  .  
.  .  .  O  .  .  .  
.  .  .  X  .  .  .  
.  .  .  O  .  .  .  
.  .  .  X  .  .  .  
-------------------
Last move time: 0.0195s
.  .  .  .  .  .  .  
.  .  .  X  .  .  .  
.  .  .  O  .  .  .  
.  .  .  X  .  .  .  
.  .  .  O  .  .  .  
.  .  .  X  .  .  .  
-----------

# Montecarlo Tree Search

In [7]:
class Node:
    total_visits = 0

    def __init__(self, state, action, value=0, parent=None):
        self.state = state
        self.action = action
        self.value = value
        self.parent = parent
        self.num_visits = 0
        self.children = None
    
    def expand(self, player):
        self.children = []
        for ply in valid_moves(self.state):
            b = np.copy(self.state)
            play(b, ply, player)
            self.children.append(Node(b, ply, parent=self))

    def simulate(self, moving_player, my_player, num_simulations=100):
        for _ in range(num_simulations):
            board = np.copy(self.state)
            Node.total_visits += 1
            self.num_visits += 1
            p = -moving_player
            while valid_moves(board):
                p = -p
                c = np.random.choice(valid_moves(board))
                play(board, c, p)
                if four_in_a_row(board, p):
                    if p == my_player: self.value += 1
                    break

    def select(self, C=2):
        max_UCB = 0
        selected_child = None
        for child in self.children:
            UCB = (child.value / child.num_visits) + C*np.sqrt(np.log(Node.total_visits) / child.num_visits)
            if UCB > max_UCB or selected_child is None:
                max_UCB = UCB
                selected_child = child
        return selected_child

    def backprop(self, node, value, num_visits):
        if node.parent is not None:
            node.parent.value += value
            node.parent.num_visits += num_visits
            node.backprop(node.parent, value, num_visits)

def MCTS(board, player, depth=10, num_simulations=5, C=2):
    root = None
    Node.total_visits = 0
    root = Node(np.copy(board), None)

    # Detect insta-win moves
    for ply in valid_moves(board):
        b = np.copy(board)
        play(b, ply, player)
        if four_in_a_row(b, player):
            return ply
    
    # Detect suicide moves
    for ply in valid_moves(board):
        b = np.copy(board)
        play(b, ply, -player)
        if four_in_a_row(b, -player):
            return ply
    
    if root.children is None:
        root.expand(player)
        for child in root.children:
            child.simulate(-player, player, num_simulations=num_simulations)
            child.backprop(child, child.value, child.num_visits)

    p = player
    while depth > 0:
        p = -p
        depth -= 1

        selected_child = root.select(C)

        while selected_child.children is not None:
            selected_child = selected_child.select(C)
            p = -p

        selected_child.expand(p)

        for child in selected_child.children:
            child.simulate(-p, player, num_simulations=num_simulations)
            child.backprop(child, child.value, child.num_visits)

    return root.select(C).action

In [8]:
board = np.zeros((NUM_COLUMNS, COLUMN_HEIGHT), dtype=np.byte)
display(board)

player = -1
while not (four_in_a_row(board, 1) or four_in_a_row(board, -1)) and valid_moves(board):
    tick = time.time()
    player = -player
    col = MCTS(board, player, depth=120, num_simulations=1, C=2)
    play(board, col, player)
    display(board)
    print(f'Last move time: {(time.time() - tick):.4f}s')

if four_in_a_row(board, 1): print('P1 (X) Won!') 
elif four_in_a_row(board, -1): print('P2 (O) Won!')
else: print('Tie!')

.  .  .  .  .  .  .  
.  .  .  .  .  .  .  
.  .  .  .  .  .  .  
.  .  .  .  .  .  .  
.  .  .  .  .  .  .  
.  .  .  .  .  .  .  
-------------------
.  .  .  .  .  .  .  
.  .  .  .  .  .  .  
.  .  .  .  .  .  .  
.  .  .  .  .  .  .  
.  .  .  .  .  .  .  
.  .  X  .  .  .  .  
-------------------
Last move time: 14.5843s
.  .  .  .  .  .  .  
.  .  .  .  .  .  .  
.  .  .  .  .  .  .  
.  .  .  .  .  .  .  
.  .  .  .  .  .  .  
.  .  X  .  O  .  .  
-------------------
Last move time: 14.9505s
.  .  .  .  .  .  .  
.  .  .  .  .  .  .  
.  .  .  .  .  .  .  
.  .  .  .  .  .  .  
.  .  X  .  .  .  .  
.  .  X  .  O  .  .  
-------------------
Last move time: 14.6360s
.  .  .  .  .  .  .  
.  .  .  .  .  .  .  
.  .  .  .  .  .  .  
.  .  .  .  .  .  .  
.  .  X  .  .  .  .  
.  .  X  O  O  .  .  
-------------------
Last move time: 12.4400s
.  .  .  .  .  .  .  
.  .  .  .  .  .  .  
.  .  .  .  .  .  .  
.  .  .  .  .  .  .  
.  .  X  .  X  .  .  
.  .  X  O  O  .  .  
--------

# Proposed solution

As we see, MCTS is not really good at detecting traps even though it requires less evaluations than minimax. \
We could crank up depth and num_simulations in minimax but that would make it really really slow (and python doesn't help either!) \
So we can check a book containing all openings (I'll do it via API since in this way I don't have to save any additional file, however internet is required!) \
and then use minimax to compute the less cpu-intensive moves. \
\
Note: I'm testing my implementation of minimax against an optimal player who knows how to win from any position (but since it's player 2 it can't win if \
player 1 plays perfectly).

In [9]:
# Optimal CHEAT LIKE A PRO - Enjoy ;-))
import requests
import json

def get_optimal_move(prev_moves):
    pos = ''.join(map(lambda x: str(x+1), prev_moves))
    url = 'https://connect4.gamesolver.org/solve?pos=' + pos
    headers = {'user-agent': 'my-app/0.0.1'}
    response = requests.get(url, headers=headers)
    argmax_val = -np.inf
    argmax = 0
    for i, val in enumerate(json.loads(response.text)['score']):
        if val > argmax_val and val != 100:
            argmax = i
            argmax_val = val
    return argmax
    
board = np.zeros((NUM_COLUMNS, COLUMN_HEIGHT), dtype=np.byte)
display(board)
prev_moves = []
cache_minimax = {}

lookup_num_moves = 12

player = -1
while not (four_in_a_row(board, 1) or four_in_a_row(board, -1)) and valid_moves(board):
    tick = time.time()
    player = -player

    col = None
    if player == 1:
        if lookup_num_moves > 0:
            lookup_num_moves -= 1
            col = get_optimal_move(prev_moves)
        else:
            col, _ = alpha_beta_minmax(board, player, depth=24, montecarlo_samples=1, cache=cache_minimax)
    else:
        col = get_optimal_move(prev_moves)

    play(board, col, player)
    prev_moves.append(col)
    display(board)
    print(f'Last move time: {(time.time() - tick):.4f}s')

if four_in_a_row(board, 1): print('P1 (X) Won!') 
elif four_in_a_row(board, -1): print('P2 (O) Won!')
else: print('Tie!')

.  .  .  .  .  .  .  
.  .  .  .  .  .  .  
.  .  .  .  .  .  .  
.  .  .  .  .  .  .  
.  .  .  .  .  .  .  
.  .  .  .  .  .  .  
-------------------
.  .  .  .  .  .  .  
.  .  .  .  .  .  .  
.  .  .  .  .  .  .  
.  .  .  .  .  .  .  
.  .  .  .  .  .  .  
.  .  .  X  .  .  .  
-------------------
Last move time: 0.6519s
.  .  .  .  .  .  .  
.  .  .  .  .  .  .  
.  .  .  .  .  .  .  
.  .  .  .  .  .  .  
.  .  .  O  .  .  .  
.  .  .  X  .  .  .  
-------------------
Last move time: 0.6434s
.  .  .  .  .  .  .  
.  .  .  .  .  .  .  
.  .  .  .  .  .  .  
.  .  .  X  .  .  .  
.  .  .  O  .  .  .  
.  .  .  X  .  .  .  
-------------------
Last move time: 0.6407s
.  .  .  .  .  .  .  
.  .  .  .  .  .  .  
.  .  .  O  .  .  .  
.  .  .  X  .  .  .  
.  .  .  O  .  .  .  
.  .  .  X  .  .  .  
-------------------
Last move time: 0.5580s
.  .  .  .  .  .  .  
.  .  .  X  .  .  .  
.  .  .  O  .  .  .  
.  .  .  X  .  .  .  
.  .  .  O  .  .  .  
.  .  .  X  .  .  .  
------------