In [34]:
import gymnasium as gym
from gymnasium import spaces
import numpy as np

class ConnectFourGym(gym.Env):
    def __init__(self):
        super(ConnectFourGym, self).__init__()
        self.rows = 6
        self.columns = 7
        self.action_space = spaces.Discrete(self.columns)
        self.observation_space = spaces.Box(low=0, high=2, shape=(self.rows, self.columns), dtype=int)
        self.board = np.zeros((self.rows, self.columns), dtype=int)
        self.current_player = 1

    def reset(self, seed=None, options=None):
        super().reset(seed=seed)
        self.board = np.zeros((self.rows, self.columns), dtype=int)
        self.current_player = 1
        info = {}
        return self.board, info

    def step(self, action):
        if self.board[0, action] != 0:
            return self.board, -10, True, False, {}

        row = -1
        for r in range(self.rows):
            if self.board[r, action] == 0:
                row = r
                break

        self.board[row, action] = self.current_player

        reward, terminated = self.check_winner(row, action)
        truncated = False  # implement truncation logic if needed
        self.current_player = 3 - self.current_player
        return self.board, reward, terminated, truncated, {}

    def check_winner(self, row, col):
        player = self.board[row, col]
        directions = [(1, 0), (0, 1), (1, 1), (1, -1)]
        for dr, dc in directions:
            count = 1
            for step in range(1, 4):
                r, c = row + dr * step, col + dc * step
                if 0 <= r < self.rows and 0 <= c < self.columns and self.board[r, c] == player:
                    count += 1
                else:
                    break

            for step in range(1, 4):
                r, c = row - dr * step, col - dc * step
                if 0 <= r < self.rows and 0 <= c < self.columns and self.board[r, c] == player:
                    count += 1
                else:
                    break

            if count >= 4:
                return 1 if player == 1 else -1, True

        if np.all(self.board != 0):
            return 0, True

        return 0, False

    # def render(self, mode='human'):
    #     print(self.board)

    def render(self, mode='human'):
      print("\n".join([" ".join(map(str, row)) for row in self.board]))

    def change_reward(self):
        pass


In [35]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class ConnectFourCNN(nn.Module):
    def __init__(self):
        super(ConnectFourCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
        self.fc1 = nn.Linear(128 * 6 * 7, 512)
        self.fc2 = nn.Linear(512, 7)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = x.view(-1, 128 * 6 * 7)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x


In [40]:
from stable_baselines3 import PPO
from stable_baselines3.common.env_checker import check_env
from stable_baselines3.common.vec_env import DummyVecEnv

env = ConnectFourGym()
check_env(env)

env = DummyVecEnv([lambda: env])

model = PPO('MlpPolicy', env, verbose=1)
model.learn(total_timesteps=100000)

model.save("ppo_connectfour")


In [42]:
model = PPO.load("ppo_connectfour")

obs = env.reset()
done = False
while not done:
    action, _states = model.predict(obs)
    obs, rewards, dones, infos = env.step(action)
    done = dones[0]
    env.render()
