In [1]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.io import read_image
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from torch.optim import Adam
import matplotlib.pyplot as plt

In [2]:
fen_index = {
    "q": 1,
    "k": 2,
    "p": 3,
    "n": 4,
    "b": 5,
    "r": 6,
    "Q": 7,
    "K": 8,
    "P": 9,
    "N": 10,
    "B": 11,
    "R": 12
}

TARGET_DIM = len(fen_index.keys()) + 1

def fen2matrix(y):
    output = list()
    _y = y.split('-')
    for row in _y:
        _row = list()
        for sym in row:
            if sym.isdigit():
                for _i in range(int(sym)):
                    _row.append(0)
            else:
                _row.append(fen_index[sym])
        output.append(_row)
    return torch.tensor(output)

In [3]:
class ChessPositionDataset(Dataset):
    def __init__(self, img_dir, transform=None, target_transform=None):
        self.img_dir = img_dir
        self.images = os.listdir(img_dir)
        self.transform = transform
        self.target_transform = target_transform
        
    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.images[idx])
        image = read_image(img_path)
        patches = image.unfold(1,50,50).unfold(2,50,50).permute(1,2,0,3,4)
        label = os.path.splitext(os.path.basename(img_path))[0]
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)
        return patches/255, label

In [13]:
# Transformations
transform = transforms.Compose([
    transforms.ToPILImage(),
#     transforms.RandomCrop((250,250)),
    transforms.ToTensor()
    # transforms.Normalize((0.4, 0.4, 0.4), (0.1, 0.1, 0.1))
])

# Datasets
train_data = ChessPositionDataset(img_dir='dataset/train', transform=transform, target_transform=fen2matrix)
test_data = ChessPositionDataset(img_dir='dataset/train', transform=transform, target_transform=fen2matrix)

# Dataloaders
train_dataloader = DataLoader(train_data, shuffle=True)
test_dataloader = DataLoader(test_data, shuffle=True)

In [5]:
# Print the first row
# for i in range(8):
    # plt.figure()
    # plt.imshow(board[0][i].permute(1, 2, 0))

### Relevant transofrmations to consider
- Grayscale
- RandomRotation
- RandomCrop
- RandomInvert
- RandomAutocontrast
- RandomHorizontalFlip

In [32]:
CHANNELS = 3
BOARD_SIZE = 400
SQUARE_SIZE = int(BOARD_SIZE / 8)
NUM_OF_FILTERS = 12
FILTER_SIZE = 4
STRIDE = 1
FLAT_LAYER_SIZE = NUM_OF_FILTERS * (50-2*(FILTER_SIZE-1))**2
HIDDEN_DIM = 50
DROPOUT_RATE = 0.1

class ChessPositionNet(nn.Module):
    def __init__(self, target_dim):
        super().__init__()     
        self.conv1 = nn.Conv2d(CHANNELS, NUM_OF_FILTERS, kernel_size=FILTER_SIZE, stride=STRIDE)
        self.conv2 = nn.Conv2d(NUM_OF_FILTERS, NUM_OF_FILTERS, kernel_size=FILTER_SIZE, stride=STRIDE)
        self.fc1 = nn.Linear(FLAT_LAYER_SIZE, HIDDEN_DIM)
        self.fc2 = nn.Linear(HIDDEN_DIM, target_dim)
        self.log_softmax = nn.LogSoftmax(dim=1)

    # def forward(self, x):
    #     x = F.relu(F.max_pool2d(self.conv1(x), 2))
    #     x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
    #     x = x.view(-1, 320)
    #     x = F.relu(self.fc1(x))
    #     x = F.dropout(x, training=self.training)
    #     x = self.fc2(x)
    #     return F.log_softmax(x)

    # def forward(self, board):
    #     """
    #     TODO:
    #     - Add dropout
    #     """

    #     targets = torch.zeros((8,8,13))
    #     assert board.shape[0] == 8
    #     assert board.shape[1] == 8
        
    #     for i in range(8):
    #         for j in range(8):
    #             x = board[i][j]
    #             x = x.reshape(1,3,50,50)
    #             x = F.relu(self.conv1(x))
    #             x = F.relu(self.conv2(x))
    #             x = torch.flatten(x, 1) # flatten all dimensions except batch
    #             x = F.relu(self.fc1(x))
    #             x = self.fc2(x)
    #             x = self.log_softmax(x)
    #             targets[i][j] = x
    #     return targets

    # For a single piece in the board

    def forward(self, x):
        """
        TODO:
        - Add dropout
        """

        x = x.reshape(1,CHANNELS,SQUARE_SIZE,SQUARE_SIZE)
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = torch.flatten(x, 1) # flatten all dimensions except batch
        x = F.relu(self.fc1(x))
        # return self.log_softmax(self.fc2(x)).reshape(TARGET_DIM)
        return self.fc2(x)

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using {} device".format(device))

model = ChessPositionNet(target_dim=TARGET_DIM).to(device)

Using cpu device


In [30]:
def run_validation(model, data, ids):
    correct = 0
    for i in ids:
        board, labels = data[i]
        for _i in range(8):
            for _j in range(8):
                _x,_y = board[_i][_j], labels[_i][_j]
                predicted = F.log_softmax(model(_x), dim=0).argmax()
                correct += int(_y == predicted)
    return correct / (len(ids) * 8 * 8)

In [31]:
import random
import math

total_samples = len(train_data)
total_samples = 20
train_size = math.floor(0.7 * total_samples)
print(f'> Train size: {train_size}')

validation_size = total_samples - train_size
print(f'> Validation size: {validation_size}')

validation_ids = random.sample(range(total_samples), validation_size)
train_ids = [i for i in range(total_samples) if i not in validation_ids]

> Train size: 14
> Validation size: 6


In [36]:
epochs = 10
optimizer = Adam(model.parameters(), lr=0.01)
loss = nn.CrossEntropyLoss()

for _e in range(epochs):
    epoch_loss = 0
    print(f'> Epoch: {_e+1}/{epochs}')
    for k in train_ids:
        optimizer.zero_grad()
        board, labels = train_data[k]
        output = 0
        for i in range(8):
            for j in range(8):
                y = model(board[i][j]).to(device)
                output += loss(y, labels[i][j].reshape(1))
        output.backward()
        epoch_loss += int(output)
    
    print(f'Epoch loss: {epoch_loss}')
    
    validation_acc = run_validation(model, train_data, validation_ids)
    print(f'> Validation accuracy: {validation_acc}')

> Epoch: 1/10
Epoch loss: 2370
> Validation accuracy: 0.8645833333333334
> Epoch: 2/10
Epoch loss: 2370
> Validation accuracy: 0.8645833333333334
> Epoch: 3/10
Epoch loss: 2370
> Validation accuracy: 0.8645833333333334
> Epoch: 4/10
Epoch loss: 2370
> Validation accuracy: 0.8645833333333334
> Epoch: 5/10
Epoch loss: 2370
> Validation accuracy: 0.8645833333333334
> Epoch: 6/10
Epoch loss: 2370
> Validation accuracy: 0.8645833333333334
> Epoch: 7/10
Epoch loss: 2370
> Validation accuracy: 0.8645833333333334
> Epoch: 8/10
Epoch loss: 2370
> Validation accuracy: 0.8645833333333334
> Epoch: 9/10
Epoch loss: 2370
> Validation accuracy: 0.8645833333333334
> Epoch: 10/10
Epoch loss: 2370
> Validation accuracy: 0.8645833333333334
