# Connect 4 with AI trained with Reinforcement Learning

### Import

In [263]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import random

### Build the Connect4 environment

In [563]:
class InvalidMoveException(Exception):
    pass


class InvalidPlayerException(Exception):
    pass


class InvalidColumnException(Exception):
    pass


class InvalidTurnException(Exception):
    pass


class GameOverException(Exception):
    pass


class Connect4Board:
    players = [1, -1]

    def __init__(self):
        self.reset_board()

    def print_board(self):
        b = np.flip(self.board, 0).astype(str)
        b = np.char.replace(b, "-1.0", "X")
        b = np.char.replace(b, "1.0", "O")
        b = np.char.replace(b, "0.0", "-")
        print(b)

    def get_board(self):
        return self.board

    def reset_board(self):
        self.board = np.zeros((6, 7))
        self.next_player = (random.randint(0, 1) - 0.5) * 2

    def get_next_player(self):
        return self.next_player

    def perform_move(self, player, column):
        if player not in self.players:
            raise InvalidPlayerException(
                "Invalid player. Player should be either 1 or -1."
            )
        if column >= 7 or column < 0:
            raise InvalidColumnException(
                "Invalid column. Column should be either 0, 1, 2, 3, 4, 5, 6."
            )
        if player != self.next_player:
            raise InvalidTurnException(
                "Invalid move. Player %d is not the next player." % player
            )
        if self.check_win() != None:
            raise GameOverException("Game is over.")

        c = self.board[:, column]
        for i, p in enumerate(c):
            if p == 0:
                self.board[i][column] = player
                self.next_player *= -1
                return self.board

        raise InvalidMoveException("Invalid move")

    def check_win(self):
        # Check vertical
        for c in self.board.T:
            if (np.sum(np.abs(c))) >= 4:  # only check lines with at least 4 occupied
                for i in range(len(c) - 3):
                    for player in self.players:
                        if np.all(c[i : i + 4] == player):
                            return player

        # Check horizontal
        for c in self.board:
            if (np.sum(np.abs(c))) >= 4:  # only check lines with at least 4 occupied
                for i in range(len(c) - 3):
                    for player in self.players:
                        if np.all(c[i : i + 4] == player):
                            return player

        # Check diagonal
        for i in range(self.board.shape[0] - 3):
            for j in range(self.board.shape[1] - 3):
                for player in self.players:
                    if np.all(
                        (self.board[i : i + 4, j : j + 4]).diagonal() == player
                    ) or np.all(
                        (np.flip(self.board[i : i + 4, j : j + 4], axis=0)).diagonal()
                        == player
                    ):
                        return player

        # Check draw
        if np.sum(np.abs(self.board)) == self.board.shape[0] * self.board.shape[1]:
            return 0

        # Otherwise
        return None
    
    def get_game_over(self):
        if self.check_win() != None:
            return True
        return False

In [598]:
while True:
    connect4 = Connect4Board()

    while True:
        try:
            connect4.perform_move(connect4.get_next_player(), random.randint(0, 6))
            # connect4.print_board()
            if connect4.get_game_over():
                break
        except Exception as e:
            print(e)

    # if connect4.check_win()/3 == 1:
    #     break
    break


# connect4 = Connect4Board()

# connect4.perform_move(connect4.get_next_player(), 1)
# connect4.perform_move(connect4.get_next_player(), 1)
# connect4.perform_move(connect4.get_next_player(), 2)
# connect4.perform_move(connect4.get_next_player(), 2)


print(connect4.check_win())

connect4.print_board()

print(connect4.board.shape)

1
[['-' '-' '-' '-' '-' '-' '-']
 ['-' '-' '-' '-' '-' '-' '-']
 ['X' '-' '-' '-' '-' '-' '-']
 ['X' '-' '-' '-' 'X' '-' '-']
 ['O' '-' '-' '-' 'X' '-' '-']
 ['O' 'O' 'O' 'O' 'O' 'X' '-']]
(6, 7)
