### Connect Four implementation

In [1]:
class ConnectFour:

    EMPTY_CELL = 0
    PLAYER_ONE = 1
    PLAYER_TWO = 2

    def __init__(self, rows: int, cols: int) -> None:
        self.rows = rows
        self.cols = cols
        self.board = [[ConnectFour.EMPTY_CELL for _ in range(cols)] for _ in range(rows)]
        self.player_to_move = ConnectFour.PLAYER_ONE
        self.is_terminal = False
        self.winner = None

    def clone(self):
        result = ConnectFour(self.rows, self.cols)
        result.board = [row[:] for row in self.board]
        return result

    # places the move in the correct column and updates the board,
    # if it cant be played then it returns an error, otherwise it
    # returns the boolean value of whether or not sequence of four
    # same colored disks for the input player is in the board
    def move(self, col: int) -> bool:
        
        if not (0 <= col < self.cols):
            raise ValueError(f'Poor column value: {col}')

        for row in range(self.rows):
            if self.board[row][col] == ConnectFour.EMPTY_CELL:
                self.board[row][col] = self.player_to_move
                self.player_to_move = 3 - self.player_to_move
                if len(self.get_valid_moves()) == 0:
                    self.is_terminal = True
                    return False
                if self.check_connect_four((row, col), 3 - self.player_to_move):
                    self.winner = 3 - self.player_to_move
                    self.is_terminal = True
                    return True
                else:
                    return False
                
        
        raise ValueError(f'Column {col} is full')
    
    # this checks if we have created a connect four by expanding from the current
    # position in each opposing direction, so we can check verticals, horizontals,
    # and both diagonals.
    # TODO: refactor for less code tbh
    def check_connect_four(self, position: tuple, player: int) -> bool:

        north = self._expand(position, (1, 0), player)
        south = self._expand(position, (-1, 0), player)

        vertical = north + south - 1
        if vertical >= 4:
            return True

        west = self._expand(position, (0, -1), player)
        east = self._expand(position, (0, 1), player)

        horizontal = west + east - 1
        if horizontal >= 4:
            return True

        northeast = self._expand(position, (1, 1), player)
        southwest = self._expand(position, (-1, -1), player)

        diag_nesw = northeast + southwest - 1
        if diag_nesw >= 4:
            return True

        northwest = self._expand(position, (1, -1), player)
        southeast = self._expand(position, (-1, 1), player)

        diag_nwse = northwest + southeast - 1
        if diag_nwse >= 4:
            return True

        return False


    # position is passed in as a tuple of (row, col)
    # player is either PLAYER_ONE or PLAYER_TWO
    def _expand(self, position: tuple, delta: tuple, player: int) -> int:
        row, col = position
        row_delta, col_delta = delta
        count: int = 0

        # i think we cna remove the try catch finally now that were checking
        # for the bounds (this was taken from the other code)
        try:
            while (0 <= row < self.rows) and (0 <= col < self.cols) \
                and self.board[row][col] == player:

                count += 1
                row += row_delta
                col += col_delta
        
        finally:
            return count


    # if you think about it we only have to check the last value of each row
    # to know what columns are available to be played on
    def get_valid_moves(self):
        return [i for i, v in enumerate(self.board[self.rows - 1]) if v == ConnectFour.EMPTY_CELL]

Code to run a simulation between two players

In [2]:
def run_simulation(iterations, players):
    wins = {}
    wins[players[0]] = 0
    wins[players[1]] = 0
    
    for i in range(iterations):
        if i % 100 == 0:
            print(str(i) + " games done")
        game = ConnectFour(6, 7)
        if i % 2:
            game.player_to_move = 2
        while not game.is_terminal:
            player = players[game.player_to_move - 1]
            action = player.move(game, game.player_to_move)
            if game.move(action):
                wins[player] += 1
                break
    return wins

Abstract Class that different players will inherit

In [3]:
from abc import ABC, abstractmethod

class BasePlayer:
    def __init__(self):
        pass
    @abstractmethod
    def move(self, game, player):
        pass


Random player which randomly chooses a move from the valid moves

In [4]:
import random

class RandomPlayer(BasePlayer):
    def __init__(self):
        pass

    def move(self, game, player):
        return random.choice(game.get_valid_moves())

Quick test to make sure that the game works correctly - the wins for each player should be roughly half and there should be very little draws

In [5]:
p1 = RandomPlayer()
p2 = RandomPlayer()
wins = run_simulation(1000, [p1, p2])
print("Random Player 1 wins: " + str(wins[p1]))
print("Random Player 2 wins: " + str(wins[p2]))
print("Draws: " + str(1000 - wins[p1] - wins[p2]))

0 games done
100 games done
200 games done
300 games done
400 games done
500 games done
600 games done
700 games done
800 games done
900 games done
Random Player 1 wins: 485
Random Player 2 wins: 508
Draws: 7


# MCTS Implementation

In [6]:
import math
import time
import copy

class Node:
    def __init__(self, game, game_copy, predecessor):
        self.game = game 
        self.game_copy = game_copy
        self.predecessor = predecessor
        self.actions = game.get_valid_moves()
        self.successors = []
        self.payoff = 0
        self.num_paths = 0
    
class MCTSPlayer(BasePlayer):
    def __init__(self, time_limit):
        self.time_limit = time_limit

    def move(self, game, player):
        game_cpy = game.clone()
        root = Node(game_cpy, None, None)
        time_end = time.time() + self.time_limit
        while time.time() < time_end:
            curr_node = root
            while not curr_node.game.is_terminal:
                if len(curr_node.actions) != 0:
                    break
                vals = []
                for successor in curr_node.successors:
                    if successor.predecessor.game.player_to_move == 1:
                        vals.append(successor.payoff / successor.num_paths - 2 * math.sqrt(2) * math.sqrt(math.log(successor.predecessor.num_paths) / successor.num_paths))
                    else:
                        vals.append(successor.payoff / successor.num_paths + 2 * math.sqrt(2) * math.sqrt(math.log(successor.predecessor.num_paths) / successor.num_paths))
                if(curr_node.game.player_to_move == 1):
                    i = vals.index(min(vals))
                else:
                    i = vals.index(max(vals))
                curr_node = curr_node.successors[i]
            if not curr_node.game.is_terminal:
                action = curr_node.actions.pop()
                game_copy = curr_node.game
                game_copy.move(action)
                next_node = Node(game_copy, action, curr_node)
                curr_node.successors.append(next_node)
                curr_state = next_node
                while not curr_state.game.is_terminal:
                    action = random.choice(curr_state.actions)
                    game_copy = curr_node.game
                    game_copy.move(action)
                    new_state = Node(game_copy, action, curr_state)
                    curr_state = new_state
                if curr_state.game.winner is None:
                    payoff = 0
                elif curr_state.game.winner == player:
                    payoff = 1
                else:
                    payoff = -1
                tmp = next_node
                while(tmp is not None):
                    tmp.num_paths += 1
                    tmp.payoff += payoff
                    tmp = tmp.predecessor
            else:
                if curr_node.game.winner is None:
                    payoff = 0
                elif curr_node.game.winner == player:
                    payoff = 1
                else:
                    payoff = -1
                tmp = curr_node
                while(tmp is not None):
                    tmp.num_paths += 1
                    tmp.payoff += payoff
                    tmp = tmp.predecessor
        exploitations = []
        for successor in root.successors:
            exploitations.append(successor.payoff / successor.num_paths)
        if(root.game.player_to_move == 1):
            i = exploitations.index(min(exploitations))
        else:
            i = exploitations.index(max(exploitations))
        return root.successors[i].game_copy          

Results of MCTS vs random player

In [7]:
p1 = MCTSPlayer(0.01)
p2 = RandomPlayer()
wins = run_simulation(1000, [p1, p2])
print("MCTS Player wins: " + str(wins[p1]))
print("Random Player wins: " + str(wins[p2]))
print("Draws: " + str(1000 - wins[p1] - wins[p2]))

0 games done
100 games done
200 games done
300 games done
400 games done
500 games done
600 games done
700 games done
800 games done
900 games done
MCTS Player wins: 785
Random Player wins: 215
Draws: 0


In [None]:
import keras
from keras.models import Sequential, load_model
from keras.layers import Dense, Dropout, Activation, Flatten
from keras.layers import Conv2D, MaxPooling2D
import os

class NeuralNetwork:
    def __init__(self):
        self.model_file = 'model.h5'
        model = Sequential()
        model.add(Conv2D(64, (4,4), input_shape=(6, 7, 1)))
        model.add(Activation('relu'))
        model.add(Conv2D(64, (2, 2)))
        model.add(Activation('relu'))
        model.add(Conv2D(64, (2, 2)))
        model.add(Activation('relu'))
        model.add(Flatten())
        model.add(Dense(64))
        model.add(Activation('relu'))
        model.add(Dense(1))
        model.compile(loss = 'mean_squared_error', optimizer = keras.optimizers.Adagrad(), metrics=['accuracy'])
        self.model = model
        if os.path.isfile(model_file):
            self.model = load_model(model_file)
        