In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import pandas as pd
from torch.utils.data import random_split

In [2]:
class Connect4Dataset(Dataset):
    def __init__(self, csv_file):
        self.data = pd.read_csv(csv_file)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        row = self.data.iloc[idx].values
        
        # First 84 fields are the board state
        board = row[:84]
        
        # Reshape into 2x6x7
        board_tensor = torch.tensor(board, dtype=torch.float32).view(2, 6, 7)
        
        # 85th field is the target column
        target = torch.tensor(row[84], dtype=torch.long)  # Target should be a long tensor
        
        return board_tensor, target
csv_file = "data/samples.csv"
dataset = Connect4Dataset(csv_file)
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

In [3]:
# Define CNN model
class Connect4CNN(nn.Module):
    def __init__(self):
        super(Connect4CNN, self).__init__()
        # Convolutional layers
        self.conv1 = nn.Conv2d(2, 32, kernel_size=3, padding=1)  # Input: 2x6x7, Output: 32x6x7
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)  # Output: 64x6x7
        
        # Fully connected layers
        self.fc1 = nn.Linear(64 * 6 * 7, 128)  # Flattened output of conv2
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 7)  # Output layer (7 possible columns)

    def forward(self, x):
        x = F.relu(self.conv1(x))  # Apply first conv + ReLU
        x = F.relu(self.conv2(x))  # Apply second conv + ReLU
        x = x.view(x.size(0), -1)  # Flatten the feature map
        x = F.relu(self.fc1(x))  # Fully connected layer 1
        x = F.relu(self.fc2(x))  # Fully connected layer 2
        x = self.fc3(x)  # Output layer
        return x

In [4]:
# Initialize the CNN model, loss function, and optimizer
model = Connect4CNN()

# Check for GPU availability
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)  # Move the model to GPU if available

criterion = nn.CrossEntropyLoss()  # Suitable for multi-class classification
optimizer = optim.Adam(model.parameters(), lr=1e-5)

In [None]:
if torch.cuda.is_available():
    print("CUDA is available")

best_val_loss = float('inf')  # To track the best validation loss

num_epochs = 200
for epoch in range(num_epochs):
    # Training
    model.train()
    running_loss = 0.0
    running_accuracy = 0.0
    for inputs, targets in train_loader:
        if torch.cuda.is_available():
            inputs, targets = inputs.cuda(), targets.cuda()
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        _, predicted = torch.max(outputs, 1)
        running_accuracy += (predicted == targets).float().mean().item()

    train_loss = running_loss / len(train_loader)
    train_accuracy = running_accuracy / len(train_loader)

    # Validation
    model.eval()
    val_loss = 0.0
    val_accuracy = 0.0
    with torch.no_grad():
        for inputs, targets in val_loader:
            if torch.cuda.is_available():
                inputs, targets = inputs.cuda(), targets.cuda()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            val_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            val_accuracy += (predicted == targets).float().mean().item()

    val_loss /= len(val_loader)
    val_accuracy /= len(val_loader)

    print(f"Epoch [{epoch + 1}/{num_epochs}], "
          f"Train Loss: {train_loss:.4f}, Train Acc: {train_accuracy:.4f}, "
          f"Val Loss: {val_loss:.4f}, Val Acc: {val_accuracy:.4f}")

    # Save the best model based on validation loss
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), "model/connect4cnn.pth")
        print("  Model saved!")