# Connect 4 with AI trained with Reinforcement Learning

### Import

In [2]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import random


import os
# Keep using keras-2 (tf-keras) rather than keras-3 (keras).
os.environ['TF_USE_LEGACY_KERAS'] = '1'

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import abc
import tensorflow as tf
import numpy as np

from tf_agents.environments import py_environment
from tf_agents.environments import tf_environment
from tf_agents.environments import tf_py_environment
from tf_agents.environments import utils
from tf_agents.specs import array_spec
from tf_agents.environments import wrappers
from tf_agents.environments import suite_gym
from tf_agents.trajectories import time_step as ts

2025-06-16 00:00:37.875043: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-06-16 00:00:37.886973: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2025-06-16 00:00:37.971206: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2025-06-16 00:00:38.042388: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1750028438.113451    1032 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1750028438.13

### Build the Connect4 environment

In [None]:
class InvalidMoveException(Exception):
    pass


class InvalidPlayerException(Exception):
    pass


class InvalidColumnException(Exception):
    pass


class InvalidTurnException(Exception):
    pass


class GameOverException(Exception):
    pass


class Connect4Board:
    players = [1, -1]

    def __init__(self):
        self.reset_board()

    def print_board(self):
        b = np.flip(self.board, 0).astype(str)
        b = np.char.replace(b, "-1", "X")
        b = np.char.replace(b, "1", "O")
        b = np.char.replace(b, "0", "-")
        print(b)

    def get_board(self):
        return self.board

    def reset_board(self):
        self.board = np.zeros((6, 7), np.int8)
        self.next_player = int((random.randint(0, 1) - 0.5) * 2)
        self.is_game_over = False

    def get_next_player(self):
        return self.next_player

    def perform_move(self, player, column):
        if player not in self.players:
            raise InvalidPlayerException(
                "Invalid player. Player should be either 1 or -1."
            )
        if column >= 7 or column < 0:
            raise InvalidColumnException(
                "Invalid column. Column should be either 0, 1, 2, 3, 4, 5, 6."
            )
        if player != self.next_player:
            raise InvalidTurnException(
                "Invalid move. Player %d is not the next player." % player
            )
        if self.get_game_over():
            raise GameOverException("Game is over.")

        c = self.board[:, column]
        for i, p in enumerate(c):
            if p == 0:
                self.board[i][column] = player
                winner = self.check_win()
                self.next_player *= -1
                return winner

        raise InvalidMoveException("Invalid move")

    def get_valid_moves(self):
        c = []
        for i in range(self.board.shape[1]):
            if np.sum(np.abs(self.board[:, i])) == self.board.shape[0]:  # if full
                pass
            else:
                c.append(i)
        return np.array(c)

    def check_win(self):
        # Check vertical
        for c in self.board.T:
            if (np.sum(np.abs(c))) >= 4:  # only check lines with at least 4 occupied
                for i in range(len(c) - 3):
                    for player in self.players:
                        if np.all(c[i : i + 4] == player):
                            self.is_game_over = True
                            return player

        # Check horizontal
        for c in self.board:
            if (np.sum(np.abs(c))) >= 4:  # only check lines with at least 4 occupied
                for i in range(len(c) - 3):
                    for player in self.players:
                        if np.all(c[i : i + 4] == player):
                            self.is_game_over = True
                            return player

        # Check diagonal
        for i in range(self.board.shape[0] - 3):
            for j in range(self.board.shape[1] - 3):
                for player in self.players:
                    if np.all(
                        (self.board[i : i + 4, j : j + 4]).diagonal() == player
                    ) or np.all(
                        (np.flip(self.board[i : i + 4, j : j + 4], axis=0)).diagonal()
                        == player
                    ):
                        self.is_game_over = True
                        return player

        # Check draw
        if np.sum(np.abs(self.board)) == self.board.shape[0] * self.board.shape[1]:
            self.is_game_over = True
            return 0

        # Otherwise
        return None

    def get_game_over(self):
        return self.is_game_over

In [62]:
while True:
    connect4 = Connect4Board()

    while True:
        try:
            connect4.perform_move(connect4.get_next_player(), random.randint(0, 6))
            # connect4.print_board()
            if connect4.get_game_over():
                break
        except Exception as e:
            print(e)
    # if connect4.check_win() == 0:
    break


# connect4 = Connect4Board()

# connect4.perform_move(connect4.get_next_player(), 1)
# connect4.perform_move(connect4.get_next_player(), 1)
# connect4.perform_move(connect4.get_next_player(), 2)
# connect4.perform_move(connect4.get_next_player(), 2)


print(connect4.check_win())

connect4.print_board()

print(connect4.board.shape)

connect4.get_valid_moves()

connect4.get_game_over()

Invalid move
Invalid move
-1
[['-' 'X' '-' '-' 'O' '-' '-']
 ['X' 'O' '-' '-' 'X' '-' '-']
 ['X' 'O' '-' '-' 'O' '-' '-']
 ['O' 'X' '-' 'O' 'X' 'O' '-']
 ['O' 'O' 'X' 'X' 'O' 'O' 'X']
 ['X' 'X' 'O' 'X' 'X' 'O' 'X']]
(6, 7)


True

In [None]:
class Connect4Env(py_environment.PyEnvironment):
    def __init__(self):
        self._action_spec = array_spec.BoundedArraySpec(
            shape=(), dtype=np.int8, minimum=0, maximum=6, name="action"
        )
        self._observation_spec = array_spec.BoundedArraySpec(
            shape=(6, 7), dtype=np.int8, minimum=-1, maximum=1, name="observation"
        )
        self._board = Connect4Board()

    def action_spec(self):
        return self._action_spec

    def observation_spec(self):
        return self._observation_spec

    def _reset(self):
        self._board.reset_board()
        return ts.restart(np.array(self._board.get_board(), dtype=np.int8))
    
    def _rewards(self, current_player, winner):
        if winner == None:  # on going game
            return ts.transition(
                observation=np.array(self._board.get_board(), dtype=np.int8),
                reward=0,
                discount=0.99,
            )
        elif winner == current_player:
            return ts.termination(
                observation=np.array(self._board.get_board(), dtype=np.int8),
                reward=1,
            )
        elif winner == -current_player:
            return ts.termination(
                observation=np.array(self._board.get_board(), dtype=np.int8),
                reward=-1,
            )
        elif winner == 0:  # draw
            return ts.termination(
                observation=np.array(self._board.get_board(), dtype=np.int8),
                reward=0,
            )
        
    def _step(self, action):
        action = np.int8(action)
        current_player = self._board.get_next_player()

        if self._board.get_game_over():
            winner = self._board.check_win()
            return self._rewards(current_player, winner)

        if action not in self._board.get_valid_moves():
            return ts.transition(
                observation=np.array(self._board.get_board(), dtype=np.int8),
                reward=-0.1,
                discount=0.99,
            )
            
        winner = self._board.perform_move(current_player, action)
        return self._rewards(current_player, winner)
            