Copyright **`(c)`** 2021 Giovanni Squillero `<squillero@polito.it>`  
[`https://github.com/squillero/computational-intelligence`](https://github.com/squillero/computational-intelligence)  
Free for personal or classroom use; see 'LICENCE.md' for details.

# Connect 4

In [1]:
from collections import Counter
import numpy as np
import time

In [2]:
NUM_COLUMNS = 7
COLUMN_HEIGHT = 6
FOUR = 4

# Board can be initiatilized with `board = np.zeros((NUM_COLUMNS, COLUMN_HEIGHT), dtype=np.byte)`
# Notez Bien: Connect 4 "columns" are actually NumPy "rows"

## Basic Functions

In [3]:
def valid_moves(board):
    """Returns columns where a disc may be played"""
    return [n for n in range(NUM_COLUMNS) if board[n, COLUMN_HEIGHT - 1] == 0]


def play(board, column, player):
    """Updates `board` as `player` drops a disc in `column`"""
    (index,) = next((i for i, v in np.ndenumerate(board[column]) if v == 0))
    board[column, index] = player


def take_back(board, column):
    """Updates `board` removing top disc from `column`"""
    (index,) = [i for i, v in np.ndenumerate(board[column]) if v != 0][-1]
    board[column, index] = 0


def four_in_a_row(board, player):
    """Checks if `player` has a 4-piece line"""
    return (
        any(
            all(board[c, r] == player)
            for c in range(NUM_COLUMNS)
            for r in (list(range(n, n + FOUR)) for n in range(COLUMN_HEIGHT - FOUR + 1))
        )
        or any(
            all(board[c, r] == player)
            for r in range(COLUMN_HEIGHT)
            for c in (list(range(n, n + FOUR)) for n in range(NUM_COLUMNS - FOUR + 1))
        )
        or any(
            np.all(board[diag] == player)
            for diag in (
                (range(ro, ro + FOUR), range(co, co + FOUR))
                for ro in range(0, NUM_COLUMNS - FOUR + 1)
                for co in range(0, COLUMN_HEIGHT - FOUR + 1)
            )
        )
        or any(
            np.all(board[diag] == player)
            for diag in (
                (range(ro, ro + FOUR), range(co + FOUR - 1, co - 1, -1))
                for ro in range(0, NUM_COLUMNS - FOUR + 1)
                for co in range(0, COLUMN_HEIGHT - FOUR + 1)
            )
        )
    )

## Montecarlo Evaluation

In [4]:
def _mc(board, player):
    p = -player
    while valid_moves(board):
        p = -p
        c = np.random.choice(valid_moves(board))
        play(board, c, p)
        if four_in_a_row(board, p):
            return p
    return 0


def montecarlo(board, player):
    montecarlo_samples = 50
    cnt = Counter(_mc(np.copy(board), player) for _ in range(montecarlo_samples))
    return (cnt[1] - cnt[-1]) / montecarlo_samples

def eval_terminal(board):
    if four_in_a_row(board, 1):
        # Alice won
        return 1
    elif four_in_a_row(board, -1):
        # Bob won
        return -1
    else:
        return 0

def eval_board(board, player):
    if four_in_a_row(board, 1):
        # Alice won
        return 1
    elif four_in_a_row(board, -1):
        # Bob won
        return -1
    else:
        # Not terminal, let's simulate...
        return montecarlo(board, player)

## Montecarlo Tree Search

In [5]:
class Node:
    def __init__(self, move=None, parent=None, state=None, player=None):
        self.state = state.copy()
        self.parent = parent
        self.move = move
        self.untriedMoves = valid_moves(state)
        self.childNodes = []
        self.wins = 0
        self.visits = 0
        self.player = player 
        
    def selection(self):
        # return child with largest UCT value
        foo = lambda x: x.wins/x.visits + np.sqrt(2*np.log(self.visits)/x.visits)
        return sorted(self.childNodes, key=foo)[-1]
        
    def expand(self, move, state):
        # return child when move is taken
            # remove move from current node
        child = Node(move=move, parent=self, state=state, player=-(self.player))
        self.untriedMoves.remove(move)
        self.childNodes.append(child)
        return child

    def update(self, result):
        if self.player == result:
            self.wins += 1
        self.visits += 1

def MCTS(currentState, player, itermax):
    rootnode = Node(state=currentState, player=player)
    
    for i in range(itermax):
        node = rootnode
        state = currentState.copy()
        
        # selection
            # keep going down the tree based on best UCT values until terminal or unexpanded node
        while node.untriedMoves == [] and node.childNodes != []:
            node = node.selection()
            play(state, node.move, node.player)

        # expand
        if node.untriedMoves != []:
            m = np.random.choice(node.untriedMoves)
            play(state, m, node.player)            
            node = node.expand(m, state)
        
        # rollout
        res = _mc(state, -player)
            
        # backpropagate
        while node is not None:
            node.update(res)
            node = node.parent
    foo = lambda x: x.wins/x.visits
    sortedChildNodes = sorted(rootnode.childNodes, key=foo)[-1]
    return sortedChildNodes.move

## MinMax with Alpha Beta pruning and Montecarlo simulation

In [6]:
MAX_DEPTH = 2

def minmax(board, player, alpha, beta, depth=0):
    possible = valid_moves(board)
    if depth == MAX_DEPTH or not possible:
        return None, eval_board(board, player)
    val = eval_terminal(board)
    if val == -1 or val == 1:
        return None, val
    best_ply = -1
    if player == 1:
        val = -2
        for ply in possible:
            play(board, ply, player)
            _, val_prova = minmax(board, -player, alpha, beta, depth+1)
            if val_prova > val:
                val = val_prova
                best_ply = ply
            take_back(board, ply)
            if val >= beta:
                break
            alpha = max(val, alpha)
        return best_ply, val
    else:
        val = 2
        for ply in possible:
            play(board, ply, player)
            _, val_prova = minmax(board, -player, alpha, beta, depth+1)
            if val_prova < val:
                val = val_prova
                best_ply = ply
            take_back(board, ply)
            if val <= alpha:
                break
            beta = min(val, beta)
        return best_ply, val

## Example

In [7]:
board = np.zeros((NUM_COLUMNS, COLUMN_HEIGHT), dtype=np.byte)
play(board, 3, 1)
play(board, 1, -1)
play(board, 4, 1)
play(board, 1, -1)
print(board)
best_play, eval = minmax(board, 1, -2, 2)
print(best_play, eval)
play(board, best_play, 1)

#best_play = MCTS(board, 1, 3000)
#print(best_play)
#play(board, best_play, 1)

print(board)

[[ 0  0  0  0  0  0]
 [-1 -1  0  0  0  0]
 [ 0  0  0  0  0  0]
 [ 1  0  0  0  0  0]
 [ 1  0  0  0  0  0]
 [ 0  0  0  0  0  0]
 [ 0  0  0  0  0  0]]
5 0.52
[[ 0  0  0  0  0  0]
 [-1 -1  0  0  0  0]
 [ 0  0  0  0  0  0]
 [ 1  0  0  0  0  0]
 [ 1  0  0  0  0  0]
 [ 1  0  0  0  0  0]
 [ 0  0  0  0  0  0]]
