In [4]:
# Download and extract the datase
!wget https://database.lposchess.or
!unzstd lposchess_db_eval.jsonl.zst

In [61]:
import numpy as np
import chess
import json
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
import sys

In [84]:
number_of_boards = 1000
X = []
y = []

def fen_to_768(fen):
    i = 0
    pos = 0

    board = np.zeros(768)
    print(fen)
    while(fen[i] != " "):
        match fen[i]:
            case "/":
                pos += 1
            case "1":
                pos += 1
            case "2":
                pos += 2
            case "3":
                pos += 3
            case "4":
                pos += 4
            case "5":
                pos += 5
            case "6":
                pos += 6
            case "7":
                pos += 7
            case "8":
                pos += 8
            case "k":
                board[pos] = 1
                pos += 1
            case "q":
                board[pos + 1 * 64] = 1
                pos += 1
            case "r":
                board[pos + 2 * 64] = 1
                pos += 1
            case "b":
                board[pos + 3 * 64] = 1
                pos += 1
            case "n":
                board[i + 4 * 64] = 1
                i += 1
            case "p":
                board[pos + 5 * 64] = 1
                pos += 1
            case "K":
                board[pos + 6 * 64] = 1
                pos += 1
            case "Q":
                board[pos + 7 * 64] = 1
                pos += 1
            case "R":
                board[pos + 8 * 64] = 1
                pos += 1
            case "B":
                board[pos + 9 * 64] = 1
                pos += 1
            case "N":
                board[pos + 10 * 64] = 1
                pos += 1
            case "P":
                board[pos + 11 * 64] = 1
                pos += 1
        i += 1

    return board


with open('lichess_db_eval.jsonl') as f:
    i = 0
    while i < number_of_boards:
        line = next(f)
        data = json.loads(line)

        if "cp" in data["evals"][0]["pvs"][0].keys():
            X.append(fen_to_768(data["fen"]))
            y.append(data["evals"][0]["pvs"][0]["cp"])
            i += 1


7r/1p3k2/p1bPR3/5p2/2B2P1p/8/PP4P1/3K4 b - -
8/4r3/2R2pk1/6pp/3P4/6P1/5K1P/8 b - -
r1b2rk1/1p2bppp/p1nppn2/q7/2P1P3/N1N5/PP2BPPP/R1BQ1RK1 w - -
rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq -
8/8/2N2k2/8/1p2p3/p7/K7/8 b - -
8/1r6/2R2pk1/6pp/3P4/6P1/5K1P/8 w - -
1R4k1/3q1pp1/6n1/b2p2Pp/2pP2b1/p1P5/P1BQrPPB/5NK1 b - -
1k1r1r2/pbp3pp/1p1q1p2/2p2Q2/4P3/1P1PB3/P1P3PP/4RRK1 w - -
8/3B4/8/p4p1k/5P1p/Pb6/1P4P1/6K1 w - -
r2qk2r/3n2p1/1pp1p3/3pPpb1/P2P1nBp/1NB4P/1PP2P2/R3QR1K w kq f6
1R6/3q1ppk/6n1/b2p2Pp/2pP2b1/p1P5/P1B1rPPB/2Q2NK1 b - -
3r4/1p3k2/p1bPR3/5p2/2B2P1p/8/PP4P1/3K4 w - -
1r2kb1r/pBp2ppp/4pn2/5b2/Q1pq4/6P1/PP1NPP1P/R1B2RK1 b k -
3r4/6k1/2bPR3/pp3p2/2B2P1p/P7/1P3KP1/8 w - -
rnbqkbnr/ppp1pppp/8/3p4/3P4/8/PPP1PPPP/RNBQKBNR w KQkq -
r2k2r1/pppb1p1p/2p5/8/3Bn3/8/PPP2PPP/2KR1B1R b - -
8/5p1k/3r2p1/7p/R6P/8/5PPK/8 w - -
rnbqkbnr/pp2pppp/2p5/3P4/3P4/8/PP2PPPP/RNBQKBNR b KQkq -
r1b2rk1/pp3ppp/1q2p3/2npP1N1/8/8/PPQ2PPP/R3RBK1 b - -
rnbqkbnr/pppppppp/8/8/4P3/8/PPPP1PPP/RNBQKBNR b KQkq -
8/

In [85]:
class simple_dataset(Dataset):
    def __init__(self, X, y, train):
        self.X = X.copy()
        self.y = y.copy()

        if len(X) != len(y):
            raise Exception("Not possible !!!")

    def __len__(self):
        return len(self.y)

    def __getitem__(self, idx):
        return torch.tensor(self.X[idx]).float(), torch.tensor(self.y[idx]).int()

def run_training(model, optimizer, loss_function, device, num_epochs, train_dataloader, val_dataloader, early_stopping):
    train_losses = []
    val_losses = []
    
    for epoch in range(num_epochs):
        print("Epoch: ",epoch)
        sys.stdout.flush()

        train_loss = train( train_dataloader, optimizer, model, loss_function, device )
        
        val_loss = validate( val_dataloader, model, loss_function, device )
        
        early_stopping(val_loss, model)

        if early_stopping.early_stop:
            print("Early Stopp !!!")
            break
        
        train_losses.append(train_loss)
        val_losses.append(val_loss)
        
    return train_losses, val_losses

def train(dataloader, optimizer, model, loss_fn, device):   
    model.train()
    losses = []
    
    # Loop over each batch of data provided by the dataloader
    for X, y in dataloader:
        X, y = X.to(device), y.type(torch.FloatTensor).to(device)
        pred = model(X)
        loss = loss_fn(pred, y)
        losses.append(loss.item())
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
    return sum(losses) / len(losses),

def validate(dataloader, model, loss_fn, device):
    model.eval()
    losses = []

    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.type(torch.FloatTensor).to(device)
            pred = model(X).squeeze()
            losses.append( loss_fn(pred, y).item() )
    return sum(losses) / len(losses)

class chess_model(torch.nn.Module):

    def __init__(self):
        super(chess_model, self).__init__()

        self.layers = nn.Sequential(
            nn.Linear(768, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, 1),
        )

    def forward(self, x):
        return self.layers(x)

class EarlyStopping:
    def __init__(self, checkpoint_path, patience=5, verbose=False, delta=0):
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.inf
        self.delta = delta
        self.model_checkpoint_path = checkpoint_path

    def __call__(self, val_loss, model):

        score = -val_loss

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
        elif score < self.best_score + self.delta:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
            self.counter = 0
    
    def save_checkpoint(self, val_loss, model):
        if self.verbose:
            print(f'Validation Loss Decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...')
        torch.save(model.state_dict(), self.model_checkpoint_path)
        self.val_loss_min = val_loss

In [86]:
batch_size = 4096

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)
X_val, X_test, y_val, y_test = train_test_split(X_test, y_test, test_size=0.5, random_state=42)

train_dataloader = DataLoader(simple_dataset(X_train, y_train, True), batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(simple_dataset(X_test, y_test, False), batch_size=batch_size, shuffle=True) 
val_dataloader = DataLoader(simple_dataset(X_val, y_val, False), batch_size=batch_size, shuffle=True)

In [88]:

loss_function = torch.nn.MSELoss()
device = "cpu"
num_epochs = 1000
lr = 0.001

model = chess_model()
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
early_stopping = EarlyStopping("model_test.pth", patience=10, verbose=False, delta=0)


run_training(model, optimizer, loss_function, device, num_epochs, train_dataloader, val_dataloader, early_stopping)


Epoch:  0
Validation Loss Decreased (inf --> 701459.000000).  Saving model ...
Epoch:  1
Validation Loss Decreased (701459.000000 --> 701448.937500).  Saving model ...
Epoch:  2
Validation Loss Decreased (701448.937500 --> 701438.250000).  Saving model ...
Epoch:  3
Validation Loss Decreased (701438.250000 --> 701425.687500).  Saving model ...
Epoch:  4
Validation Loss Decreased (701425.687500 --> 701409.500000).  Saving model ...
Epoch:  5
Validation Loss Decreased (701409.500000 --> 701387.937500).  Saving model ...
Epoch:  6
Validation Loss Decreased (701387.937500 --> 701359.312500).  Saving model ...
Epoch:  7
Validation Loss Decreased (701359.312500 --> 701321.812500).  Saving model ...
Epoch:  8
Validation Loss Decreased (701321.812500 --> 701273.625000).  Saving model ...
Epoch:  9
Validation Loss Decreased (701273.625000 --> 701211.812500).  Saving model ...
Epoch:  10
Validation Loss Decreased (701211.812500 --> 701133.375000).  Saving model ...
Epoch:  11
Validation Loss Dec

([(2446129.0,),
  (2446106.25,),
  (2446084.0,),
  (2446059.25,),
  (2446025.5,),
  (2445977.5,),
  (2445910.0,),
  (2445817.25,),
  (2445692.25,),
  (2445526.75,),
  (2445311.25,),
  (2445034.25,),
  (2444683.5,),
  (2444245.0,),
  (2443702.25,),
  (2443039.0,),
  (2442235.25,),
  (2441271.25,),
  (2440125.5,),
  (2438774.75,),
  (2437198.5,),
  (2435373.75,),
  (2433281.0,),
  (2430904.5,),
  (2428232.0,),
  (2425261.0,),
  (2422000.5,),
  (2418474.75,),
  (2414729.25,),
  (2410837.75,),
  (2406909.75,),
  (2403096.25,),
  (2399599.75,),
  (2396673.0,),
  (2394602.25,),
  (2393655.0,)],
 [701459.0,
  701448.9375,
  701438.25,
  701425.6875,
  701409.5,
  701387.9375,
  701359.3125,
  701321.8125,
  701273.625,
  701211.8125,
  701133.375,
  701035.375,
  700914.875,
  700768.0625,
  700591.4375,
  700382.75,
  700139.1875,
  699858.9375,
  699542.75,
  699192.375,
  698813.3125,
  698416.3125,
  698016.5625,
  697638.125,
  697316.6875,
  697100.3125,
  697056.1875,
  697273.625,
  6