In [9]:
from typing import Tuple, List, Callable, Optional
from random import randint
# import matplotlib.pyplot as plt

from two_player_games.player import Player
from two_player_games.games.connect_four import ConnectFour, ConnectFourMove, ConnectFourState

In [10]:
ROW_COUNT = 6
COLUMN_COUNT = 7

In [11]:
class MinMaxSolver:

    def __init__(self, game: ConnectFour, heuristic: Callable):
        self.game: ConnectFour = game
        self.heuristic: Callable = heuristic

    def get_best_move(self, depth: int) -> int:
        column, score = self.minimax(depth, float('-inf'), float('inf'), True)
        return column

    def is_valid_move(self, col_index:int)->bool:
        return col_index >= 0 and col_index < COLUMN_COUNT and self.game.state.fields[col_index][-1] is None

    def minimax(self, depth, alpha:float, beta:float, is_maximizing_player:bool)-> Tuple[int, float]:
        """Returns column index and score"""
        if depth == 0 or self.game.state.is_finished():
            return None, self.evaluate_state(self.game.state)

        valid_moves = [column for column in range(COLUMN_COUNT) if self.is_valid_move(column)]
        best_move = None

        if is_maximizing_player:
            max_eval = float('-inf')
            for column in valid_moves:
                # Simulate the move
                saved_state = self.game.state
                self.game.make_move(ConnectFourMove(column))

                _, eval_score = self.minimax(depth - 1, alpha, beta, False)

                self.game.state = saved_state

                if eval_score > max_eval:
                    max_eval = eval_score
                    best_move = column

                alpha = max(alpha, max_eval)
                if beta <= alpha:
                    break

            return best_move, max_eval

        else:
            min_eval = float('inf')
            for column in valid_moves:
                # Simulate the move
                saved_state = self.game.state
                self.game.make_move(ConnectFourMove(column))

                _, eval_score = self.minimax(depth - 1, alpha, beta, True)

                self.game.state = saved_state

                if eval_score < min_eval:
                    min_eval = eval_score
                    best_move = column

                beta = min(beta, min_eval)
                if beta <= alpha:
                    break

            return best_move, min_eval

    def evaluate_state(self, state:ConnectFourState) -> float:
        """
        h(s) = {
                w(s)     , for s ∈ T
                heuristic, for remaining
                }
        """
        state_eval:float

        if state.is_finished():
            state_eval = self.payoff_function(state)
        else:
            state_eval = self.heuristic(state)

        return state_eval

    def payoff_function(self, state:ConnectFourState) -> float:
        """
        w(s) = {
                1000 , MAX wins
                0    , draw
                -1000, MIN wins
               }
        """
        if not self.game.get_winner():
            return 0

        # assuming MAX is labeled as 'a'
        if self.game.get_winner().char == self.game.state.get_current_player():
            return 100_000.0
        else:
            return -100_000.0


In [12]:
def heuristic_control(state: ConnectFourState) -> float:
    return 0

def heuristic_random(state: ConnectFourState) -> float:
    """
    Randomly rate the state of the game
    """
    return randint(0, 30)

def heuristic_center(state: ConnectFourState) -> float:
    """
    Prioritize center column:
        + 10 for each row already taken by the same player in the center column
        + 2  for each empty row in the center column
    """
    center_column = COLUMN_COUNT // 2
    score = 0

    for row in range(ROW_COUNT):
        if state.fields[center_column][row] is state.get_current_player():
            score += 10
        if state.fields[center_column][row] is None:
            score += 2

    return score

def heuristic_rows(state: ConnectFourState) -> float:
    """
    Prioritize longer horizontal lines:
        + 10 for each field in line times line length **
    """
    horizontal_lenghts = measure_horizontal_lines(state)
    score = 0
    for line_length in horizontal_lenghts:
        score += 10*(line_length ** line_length)

    return score

def heuristic_columns(state: ConnectFourState) -> float:
    """
    Prioritize longer vertical lines:
        + 10 for each field in line times line length **
    """
    vertical_lenghts = measure_vertical_lines(state)
    score = 0
    for line_length in vertical_lenghts:
        score += 10*(line_length ** line_length)

    return score

def heuristic_rows_and_columns(state: ConnectFourState) -> float:
    """
    Prioritize longer vertical lines over longer horizontal lines:
        + 10 for each field in vertical line times line length **
        + 10 for each field in horizontal line times line length **
    """
    vertical_value = heuristic_columns(state)
    horizontal_value = heuristic_rows(state)
    return vertical_value + horizontal_value

def heuristic_diagonals(state: ConnectFourState) -> float:
    """
    Prioritize longer diagonal lines:
        + 10 for each field in line times line length **
    """
    diagonal_lengths = measure_diagonal_lines(state)
    score = 0
    for line_length in diagonal_lengths:
        score += 10 * (line_length ** line_length)

    return score

def heuristic_all_lines(state: ConnectFourState) -> float:
    """
    Prioritize longer horizontal, vertical, and diagonal lines:
        + 10 for each field in line times line length squared
    """
    horizontal_value = heuristic_rows_and_columns(state)
    vertical_value = heuristic_columns(state)
    diagonal_value = heuristic_diagonals(state)
    return horizontal_value + vertical_value + diagonal_value


In [None]:
def measure_horizontal_lines(state: ConnectFourState) -> List[int]:
    """
    Length of each line
    Line is at least 2 adjacent fields
    """
    line_lengths = []
    current_player = state.get_current_player()
    for row in range(ROW_COUNT):
        previous = None
        current_length = 0
        for column in range(COLUMN_COUNT):
            current_field = state.fields[column][row]
            if previous is current_player and current_field is current_player:
                current_length += 1
                if current_length == 7:
                    line_lengths.append(current_length)
            elif current_length > 1:
                line_lengths.append(current_length)
                current_length = 0
            previous = state.fields[column][row]

    return line_lengths

def measure_vertical_lines(state: ConnectFourState) -> List[int]:
    """
    Length of each line
    Line is at least 2 adjacent fields
    """
    line_lengths = []
    current_player = state.get_current_player()

    for column in range(COLUMN_COUNT):
        previous = None
        current_length = 0
        for row in range(ROW_COUNT):
            current_field = state.fields[column][row]
            if previous is current_player and current_field is current_player:
                current_length += 1
                if current_length == 7:
                    line_lengths.append(current_length)
            elif current_length != 0:
                line_lengths.append(current_length)
                current_length = 0
            if current_field is None:
                continue
            previous = state.fields[column][row]

    return line_lengths

def measure_diagonal_lines(state: ConnectFourState) -> List[int]:
    """
    Length of each diagonal line.
    Line is at least 2 adjacent fields.
    """
    line_lengths = []
    current_player = state.get_current_player()

    starting_fields_1: List[Tuple[int, int]] = [(0, i) for i in range(COLUMN_COUNT)] + [(i, 0) for i in range(ROW_COUNT)]
                                            #  ^                                            ^
                                            #  column 0                                     row 0
    starting_fields_1.pop(0)
    # remove duplicated (0, 0)

    starting_fields_2: List[Tuple[int, int]] = [(0, i) for i in range(COLUMN_COUNT)] + [(i, ROW_COUNT-1) for i in range(ROW_COUNT)]
                                            #    ^                                          ^
                                            #    column 0                                   top row
    starting_fields_2.pop(COLUMN_COUNT-1)


    # Bottom-left to top-right diagonals
    for starting_field in starting_fields_1:
        previous = None
        current_length = 0
        col, row = starting_field
        while col < COLUMN_COUNT and row < ROW_COUNT:
            current_field = state.fields[col][row]
            if previous is current_player and current_field is current_player:
                current_length += 1
            elif current_length != 0:
                line_lengths.append(current_length)
                current_length = 0
            previous = current_field
            col += 1
            row += 1
        if current_length != 0:
            line_lengths.append(current_length)

    # Top-left to bottom-right diagonals
    for starting_field in starting_fields_2:
        previous = None
        current_length = 0
        col, row = starting_field
        while col < COLUMN_COUNT and row >= 0:
            current_field = state.fields[col][row]
            if previous is current_player and current_field is current_player:
                current_length += 1
            elif current_length != 0:
                line_lengths.append(current_length)
                current_length = 0
            previous = current_field
            col += 1
            row -= 1
        if current_length != 0:
            line_lengths.append(current_length)

    return line_lengths

In [14]:
# def make_graph(graph_title: str, x_coords: list[int], y_coords: list[int]) -> None:
#     plt.clf()
#     plt.plot(x_coords, y_coords, marker='o')
#     plt.xlabel('Depth')
#     plt.ylabel('Win Ratio')
#     plt.title(graph_title)
#     plt.grid(True)
#     plt.savefig(str(randint(0, 10000000)))
#     # plt.show()

In [15]:
def heuristic_1v1(heuristic1: Callable, heuristic2: Callable, repetitions: int, depth1: int, depth2: int) -> Tuple[int, int, int]:
    """
    Test efficieny of one heuristic against another.
    The first move is always random*
    """

    max_wins = 0
    min_wins = 0
    draws = 0
    # 1 for MAX win, 0 for draw, -1 for MIN win

    p1 = Player("a")
    p2 = Player("b")
    game = ConnectFour(size=(COLUMN_COUNT, ROW_COUNT), first_player=p1, second_player=p2)
    solver1 = MinMaxSolver(game=game, heuristic=heuristic1)
    solver2 = MinMaxSolver(game=game, heuristic=heuristic2)
    initial_state: ConnectFourState = game.state

    for _ in range(repetitions):
        game.state = initial_state

        first_move = randint(0, COLUMN_COUNT - 1)
        game.make_move(ConnectFourMove(first_move))

        while not game.state.is_finished():
            move1 = solver1.get_best_move(depth=depth1)
            game.make_move(ConnectFourMove(move1))
            if game.state.is_finished():
                break
            move2 = solver2.get_best_move(depth=depth2)
            game.make_move(ConnectFourMove(move2))

        if not game.get_winner():
            draws += 1
        elif game.get_winner() == p1:
            max_wins += 1
        else:
            min_wins += 1

    return max_wins, draws, min_wins

In [19]:
def main():
    """ Main V1 """
    # p1 = Player("a")
    # p2 = Player("b")
    # game = ConnectFour(size=(COLUMN_COUNT, ROW_COUNT), first_player=p1, second_player=p2)
    # solver0 = MinMaxSolver(game = game, heuristic= heuristic_control)
    # solver1 = MinMaxSolver(game = game, heuristic= heuristic_center)
    # solver2 = MinMaxSolver(game = game, heuristic= heuristic_rows)
    # solver3 = MinMaxSolver(game = game, heuristic= heuristic_columns)
    # solver4 = MinMaxSolver(game = game, heuristic= heuristic_rows_and_columns)
    # solver5 = MinMaxSolver(game = game, heuristic= heuristic_diagonals)
    # solver6 = MinMaxSolver(game = game, heuristic= heuristic_all_lines)

    # game.make_move(ConnectFourMove(randint(0, COLUMN_COUNT-1)))
    # print(game)


    # while not game.state.is_finished():
    #     move = solver0.get_best_move(depth=2)
    #     game.make_move(ConnectFourMove(move))
    #     print(game)

    #     if game.state.is_finished():
    #         break
    #     move = solver4.get_best_move(depth=2)
    #     game.make_move(ConnectFourMove(move))
    #     print(game)

    # if game.get_winner():
    #     print(f"game winner: {game.get_winner().char}")
    # else:
    #     print("draw")


    """ Main V2 """
    REPETITIONS = 50
    #DEPTHS = range(1, 6)

    # duels = [(heuristic_control, heuristic_center, "max: control, min: center"),
    #         (heuristic_all_lines, heuristic_all_lines, "max: control, min: all lines"),
    #         (heuristic_center, heuristic_all_lines, "max: center, min: all lines"),
    #         (heuristic_diagonals, heuristic_all_lines, "max: diagonals, min: all lines"),
    #         (heuristic_rows_and_columns, heuristic_diagonals, "max: rows and columns, min: diagonals")]

    # for duel in duels:
    #    heuristic1, heuristic2, graph_title = duel
    #
    #    win_ratios = []
    #    for depth in DEPTHS:
    #       max_wins, draws, min_wins = heuristic_1v1(heuristic1, heuristic2, REPETITIONS, depth1=depth, depth2=depth)
    #       win_ratio = max_wins / REPETITIONS
    #       win_ratios.append(win_ratio)
    #
    #    make_graph(graph_title, list(DEPTHS), win_ratios)

    h = heuristic_center
    depths = [(1, 5), (2, 5), (3, 5), (4, 5), (5, 5)]
    win_ratios = []

    for d_pair in depths:
        d1, d2 = d_pair
        max_wins, draws, min_wins = heuristic_1v1(h, h, REPETITIONS, d1, d2)
        win_ratio = max_wins / REPETITIONS
        win_ratios.append(win_ratio)

    # make_graph("max: all lines (d∈[1,5]), min: all lines (d=5)", [1,2,3,4,5], win_ratios)
    print(win_ratios)

if __name__ == "__main__":
    main()

[0.8, 0.0, 0.54, 0.3, 0.0]
