In [1]:
import copy
from torch_geometric.data import Data
from torch_geometric import nn
from torch.nn.functional import one_hot
from torch_geometric.data import OnDiskDataset, Dataset
from torch_geometric.loader import DenseDataLoader
import numpy as np
import torch
import pandas as pd
import chess
import chess.pgn as PGN
import io
from random import shuffle

In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print('Using gpu: %s ' % torch.cuda.is_available())

Using gpu: False 


# Data loading

positions.csv can be found at https://www.kaggle.com/datasets/nikitricky/chess-positions

In [3]:
dataset = pd.read_csv("positions.csv")
dataset = dataset[dataset['score'].notnull()]

  dataset = pd.read_csv("positions.csv")


# Preprocessing functions

We use some of Saleh Awer's code that we slightly modified

In [4]:
def encode_piece_node(board,square):
    """
    Returns one-hot encoding of a chess piece
    """
    piece = board.piece_at(square)
    color = 0
    space = np.zeros(25)
    space[13] = 1 if board.turn else -1
    if board.turn == chess.WHITE:
        space[14] = int(square//8)
        space[15] = int(square%8)
    else:
        space[14] = int((63-square)//8)
        space[15] = int((63-square)%8)
    space[16] = 1 if board.is_repetition(2) else 0
    space[17] = 1 if board.is_repetition(3) else 0
    space[18] = board.fullmove_number
    space[19] = board.halfmove_clock
    space[20] = -1 if board.ep_square is None else board.ep_square % 8
    if not board.has_kingside_castling_rights(chess.WHITE):
        space[21] = 1
    if not board.has_queenside_castling_rights(chess.WHITE):
        space[22] = 1
    if not board.has_kingside_castling_rights(chess.BLACK):
        space[23] = 1
    if not board.has_queenside_castling_rights(chess.BLACK):
        space[24] = 1
    if piece == None:
        space[0] = 1
        return torch.tensor(space,dtype=torch.float).view(-1, 25)
    if piece.color != chess.WHITE:
        color = 6
    if piece.piece_type == chess.PAWN:
        idx = 1
    elif piece.piece_type == chess.BISHOP:
        idx = 2
    elif piece.piece_type == chess.KNIGHT:
        idx = 3
    elif piece.piece_type == chess.ROOK:
        idx = 4
    elif piece.piece_type == chess.QUEEN:
        idx = 5
    elif piece.piece_type == chess.KING:
        idx = 6
    space[idx+color] = 1
    return torch.tensor(space,dtype=torch.long).view(-1, 25)


def encode_move_edge(move):
    """
    Returns the edge of move. The edge connection is defined by the squares
    from and to which a move is legal.
    """
    return [move.from_square, move.to_square]


def board2graph(board: chess.Board, score):
    """
    Encodes a chess board into a graph with the needed structure. Each square is a node. An edge implies a legal move from one square to the other.

    Node features : 0 if square is empty, 1 for pawn, 2 for bishop, 3 for knight, 4 for rook, 5 for queen, 6 for king. One hot encoded.

    Edge List: [from_square, to_square] for each legal move for both players.

    Edge features: [1,0] for the current players moves, [0,0] for the opponent's legal moves and [1,n] or [0,n] for the previous n moves.
    """
    # encode pieces
    nodes = [i for i in range(64)]
    node_features = [encode_piece_node(board, node) for node in nodes]

    edge_list = []
    edge_features = []

    # encode side to move moves
    for move in board.legal_moves:
        edge_list.append(encode_move_edge(move))
        edge_features.append([1,0,0])

    # encode side not to move moves
    opp_turn = copy.deepcopy(board)
    opp_turn.turn = not opp_turn.turn
    for move in opp_turn.legal_moves:
        edge_list.append(encode_move_edge(move))
        edge_features.append([0,0,0])

    # encode side to move moves
    moves_list = [encode_move_edge(move) for move in board.legal_moves]

    # encode edge features
    edge_features = [[1,0,0] if edge in moves_list else [0,0,0] for edge in edge_list]

    # dictionary to keep track of move count
    count_dict = {str(move)[:4]: 0 for move in board.legal_moves}

    # handle promotions
    for move in board.legal_moves:
        if move.promotion is not None:
            idxs = np.where(np.array(edge_list) == encode_move_edge(move))[0]
            idx = idxs[count_dict[str(move)[:4]]]
            edge_features[idx][2] = int(move.promotion)
            count_dict[str(move)[:4]] += 1

    count_dict = {str(move)[:4]: 0 for move in opp_turn.legal_moves}
    for move in opp_turn.legal_moves:
        if move.promotion is not None:
            idxs = np.where(np.array(edge_list) == encode_move_edge(move))[0]
            idx = idxs[count_dict[str(move)[:4]]]
            edge_features[idx][2] = int(move.promotion)
            count_dict[str(move)[:4]] += 1

    y = torch.tensor(score).reshape(-1) #What the model should predict

    return Data(x=torch.stack(node_features,dim=0).view(len(nodes), 25), edge_index=torch.tensor(edge_list, dtype=torch.int64).t().view(2, -1), edge_attr=torch.tensor(edge_features, dtype=torch.float), y=y.type(torch.float))
        # return node_features, edge_list,edge_features

In [5]:
class Model(torch.nn.Module):
  def __init__(self):
    super().__init__()
    self.gat = nn.GATv2Conv(in_channels=-1, out_channels=1, heads=3, edge_dim=3)
    self.lstm = torch.nn.LSTM(64, 64)
    self.linear = torch.nn.Linear(64, 1)

  def forward(self, g):
    y = self.gat(x=g.x, edge_index=g.edge_index, edge_attr=g.edge_attr).reshape((3, 64))
    _ , (_,y) = self.lstm(y)
    y = self.linear(y).reshape(-1)

    return y

In [6]:
model = Model()

In [7]:
criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

In [8]:
def train_model(model, boards, num_epochs, batch_size=100):
  model.to(device)
  model.train()

  for epoch in range(num_epochs):
    targets_batch, outputs_batch = [], []
    running_loss = 0.
    n_data = 0
    for i in boards.index:
      sample = board2graph(chess.Board(boards['fen'][i]), boards['score'][i]) #We don't store preprocessed inputs since they take a lot of memoy

      inputs, targets = sample.to(device), sample.y.to(device)
      outputs = model(inputs)
      targets_batch.append(targets)
      outputs_batch.append(outputs)

      if i%batch_size==0 and i>0: #Compute loss and backpropagate every batch_size prediction
        targets = torch.stack(targets_batch).to(device)
        outputs = torch.stack(outputs_batch).to(device)
        loss = criterion(outputs, targets)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        print("Epoch: "+str(epoch)+", currently at i = "+str(i))
        print("    Loss: "+str(running_loss))
        running_loss = 0.
        targets_batch, outputs_batch = [], []

In [9]:
def test_model(model, boards):
  model.eval()
  std = np.std(boards['score'])
  accuracy = 0
  with torch.no_grad():
    for i in boards.index:
      sample = board2graph(chess.Board(boards['fen'][i]), boards['score'][i])
      inputs, targets = sample.to(device), sample.y.to(device)
      outputs = model(inputs)
      if outputs[0]*targets[0] >= 0:
        accuracy += 1

  return accuracy/len(boards)

In [None]:
train_model(model, dataset.loc[:1001], 100)

In [19]:
print(test_model(model, dataset.loc[:1000]))
test_model(model, dataset.loc[1000:11000])

0.9063157894736842


0.5910430839002268