# NNUE training

Great source on NNUE: https://official-stockfish.github.io/docs/nnue-pytorch-wiki/docs/nnue.html

## Input data

Stockfish has a lot of data available for NNUE training in the .binpack format. They have a repo for training NNUEs (nnue-pytorch) that enables efficient dataloading with this format. I don't want to use nnue-pytorch, i want to make my own NNUE training setup.

The nnue-pytorch repo also has information on training datasets for NNUEs: https://github.com/official-stockfish/nnue-pytorch/wiki/Training-datasets. They explain how to make your own dataset and link some of the datasets they generated. I will use some of this data, because generating the data myself would be too time-consuming on my hardware.

Currently using training data: test80-2024-01-jan-2tb7p.min-v2.v6.binpack.zst from https://huggingface.co/datasets/linrock/test80-2024/tree/main

This file contains billions of positions with evaluations in the .binpack format. The stockfish tools branch has a tool to covert the .binpack data into .plain data (https://github.com/official-stockfish/Stockfish/blob/tools/docs/convert.md). I used this tool and stored the first 200M evaluated positions.

### Load input data

In [1]:
import pandas as pd
import numpy as np
import torch

### Turn FEN into input layer

In [2]:
piece_dict_w = {'P': 0, 'N': 1, 'B': 2, 'R': 3, 'Q': 4, 'K': 5, 'p': 6, 'n': 7, 'b': 8, 'r': 9, 'q': 10, 'k': 11}
piece_dict_b = {'P': 6, 'N': 7, 'B': 8, 'R': 9, 'Q': 10, 'K':11, 'p': 0, 'n': 1, 'b': 2, 'r': 3, 'q': 4, 'k': 5}
stm_dict = {'w': 0, 'b': 1}

def FEN_to_inputs(fen):
    """
    Convert a FEN string to an NNUE input vector.
    """
    # Split the FEN string into its components
    sub_FEN = fen.split(' ')
    board = sub_FEN[0]
    ranks = board.split('/')
    stm = stm_dict[sub_FEN[1]]

    # Convert the board to a 1D boolean array
    # in the chess engine, position 0 corresponds to a1, so the ranks in the FEN string will need to be reversed
    input_layer_w = np.zeros(768, dtype = np.float32)
    input_layer_b = np.zeros(768, dtype = np.float32)
    position = 0
    for rank in ranks[::-1]:
        for char in rank:
            if char.isdigit():
                position += int(char)
            else:
                alt_pos = 63 - (position ^ 7)
                input_layer_w[position + piece_dict_w[char]*64] = 1
                input_layer_b[alt_pos + piece_dict_b[char]*64] = 1
                position += 1

    return torch.tensor(input_layer_w, dtype=torch.float32), torch.tensor(input_layer_b, dtype=torch.float32), torch.tensor(stm, dtype=torch.float32)

In [3]:
# testing encoding
fen1 = 'rnbqkbnr/pppppppp/8/8/8/5P2/PPPPP1PP/RNBQKBNR b KQkq - 0 1'

w_features, b_features, stm = FEN_to_inputs(fen1)
print("White Features:", sum(w_features))
print(np.nonzero(np.array(w_features)))
print("Black Features:", sum(b_features))
print(np.nonzero(np.array(b_features)))
print("Side to Move:", stm)

White Features: tensor(32.)
(array([  8,   9,  10,  11,  12,  14,  15,  21,  65,  70, 130, 133, 192,
       199, 259, 324, 432, 433, 434, 435, 436, 437, 438, 439, 505, 510,
       570, 573, 632, 639, 699, 764]),)
Black Features: tensor(32.)
(array([  8,   9,  10,  11,  12,  13,  14,  15,  65,  70, 130, 133, 192,
       199, 259, 324, 429, 432, 433, 434, 435, 436, 438, 439, 505, 510,
       570, 573, 632, 639, 699, 764]),)
Side to Move: tensor(1.)


  print(np.nonzero(np.array(w_features)))
  print(np.nonzero(np.array(b_features)))


In [5]:
test1 = [192, 65, 130, 259, 324, 133, 70, 199, 8, 9, 10, 11, 12, 14, 15, 21, 432, 433, 434, 435, 436, 437, 438, 439, 632, 505, 570, 699, 764, 573, 510, 639]
test2 = [ 8, 9,  10,  11,  12,  14,  15,  21,  65,  70, 130, 133, 192, 199, 259, 324, 432, 433, 434, 435, 436, 437, 438, 439, 505, 510, 570, 573, 632, 639, 699, 764]

np.sort(test1)
np.sort(test2)
print(np.array_equal(np.sort(test1), np.sort(test2)))

test3 = [  8,   9,  10,  11,  12,  13,  14,  15,  65,  70, 130, 133, 192, 199, 259, 324, 429, 432, 433, 434, 435, 436, 438, 439, 505, 510,  570, 573, 632, 639, 699, 764]
test4 = [632, 505, 570, 699, 764, 573, 510, 639, 432, 433, 434, 435, 436, 438, 439, 429, 8, 9, 10, 11, 12, 13, 14, 15, 192, 65, 130, 259, 324, 133, 70, 199]
print(np.array_equal(np.sort(test3), np.sort(test4)))

True
True


## Model architecture

Input: a sparse, binary array of length 768. Each element of the array represents a possible combination of piece type (6), piece_color (2) and position (64) (6*2*64 = 768).

This is a very simple input feature (P feature set) set that will be improved upon later (HalfKP).

The fully connected feedfoward network has 4 hidden layers: 768 -> 1024, 1024 -> 8, 8 -> 32 and 32 -> 1.

The output is a single scalar.

In [6]:
import torch
import torch.nn as nn

class Split_NNUE(nn.Module):
    def __init__(self):
        super(Split_NNUE, self).__init__()
        self.fc1 = nn.Linear(768, 128)
        self.fc2 = nn.Linear(256, 32)
        self.fc3 = nn.Linear(32, 1)

    def forward(self, white_features, black_features, stm):
        w = self.fc1(white_features)
        b = self.fc1(black_features)
        cat_wb = torch.cat([w, b], dim=1)
        cat_bw = torch.cat([b, w], dim=1)

        stm = stm.to(dtype=cat_wb.dtype).view(-1, 1)

        accumulator = (1 - stm) * cat_wb + stm * cat_bw

        x = torch.clamp(accumulator, min = 0, max = 1)
        x = torch.clamp(self.fc2(x), min = 0, max = 1)
        x = self.fc3(x)
        return x


In [6]:
import csv
import torch
from torch.utils.data import IterableDataset, DataLoader

class Custom_Split_Dataset(IterableDataset):
    def __init__(self, csv_path, shuffle_buffer=0):
        """
        csv_path: path to CSV file with two columns: fen, score
        fen_to_tensor: function(str) -> torch.Tensor
        shuffle_buffer: size of in-memory shuffle buffer; 0 = no shuffle
        """
        super().__init__()
        self.csv_path = csv_path
        self.shuffle_buffer = shuffle_buffer

    def _row_stream(self):
        """
        Generator that yields (fen, score) tuples from the CSV file.
        """
        with open(self.csv_path, newline='') as csvfile:
            reader = csv.reader(csvfile)
            for row in reader:
                if not row or row[0].startswith('#'):
                    continue
                w_in, b_in, stm = FEN_to_inputs(row[0].strip())
                score, result = float(row[1].strip()), float(row[2].strip())
                if score == 32002:
                    score = 0
                if result == -1:
                    result = 0
                elif result == 0:
                    result = 0.5
                yield w_in, b_in, stm, torch.tensor(score, dtype=torch.float32), torch.tensor(result, dtype=torch.float32)

    def __iter__(self):
        stream = self._row_stream()
        if self.shuffle_buffer > 1:

            # reservoir-style shuffle buffer
            buf = []
            for w_in, b_in, stm, score, result in stream:
                buf.append((w_in, b_in, stm, score, result))
                if len(buf) >= self.shuffle_buffer:
                    idx = torch.randint(len(buf), (1,)).item()
                    yield buf.pop(idx)
                    
            # drain remaining buffer
            while buf:
                idx = torch.randint(len(buf), (1,)).item()
                yield buf.pop(idx)
        else:
            for w_in, b_in, stm, score, result in stream:
                yield w_in, b_in, stm, score, result

In [7]:
from torch.utils.data import DataLoader

# Load the dataset
csv_path = '/home/yvlaere/projects/yvl-chess/NNUE_training/training_data/sf_training_data_full_10M.csv'
dataset = Custom_Split_Dataset(csv_path, shuffle_buffer = 100000)
loader = DataLoader(dataset, batch_size = 1024, num_workers = 4, pin_memory = True)

In [8]:
def init_weights(m):
    if isinstance(m, nn.Linear):
        # Kaiming uniform for piecewise-linear (ReLU-like) activations:
        nn.init.kaiming_uniform_(m.weight, a=0.0, nonlinearity='relu')
        nn.init.zeros_(m.bias)

In [9]:
# tensorboard
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter(log_dir="runs/nnue_training_split_model_10M")

# hyperparameters
nr_epochs = 10000
learning_rate = 1e-3
weight_decay = 1e-5
scaling_factor = 400
ground_truth_scaling_factor = 400
lambda_ = 0.2
log_interval = 100
save_interval = 100000
step = 0
running_loss = 0.0
epsilon = 1e-10

# initialize model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = Split_NNUE().to(device)
#model.apply(init_weights)
optimizer = torch.optim.AdamW(model.parameters(), lr = learning_rate, weight_decay = weight_decay)
#scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=100, min_lr=1e-6)
#scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1000000, gamma=0.5)
criterion = nn.MSELoss()
#criterion = nn.BCEWithLogitsLoss()

for epoch in range(nr_epochs):
    for batch in loader:

        #for _ in range(1000000):
    
        # get data from the dataloader
        batch_x_w, batch_x_b, stm, batch_y, result = batch
        batch_x_w = batch_x_w.to(device, non_blocking = True)
        batch_x_b = batch_x_b.to(device, non_blocking = True)
        stm = stm.to(device, non_blocking = True)
        batch_y = batch_y.to(device, non_blocking = True)
        result = result.to(device, non_blocking = True)
        pred = model(batch_x_w, batch_x_b, stm).squeeze(1)

        # Transform the CP scores to the WDL space
        wdl_batch_y = lambda_*result + (1 - lambda_) * torch.sigmoid(batch_y / ground_truth_scaling_factor)
        wdl_pred = torch.sigmoid(pred / scaling_factor)

        #loss = (wdl_batch_y * torch.log(wdl_batch_y + epsilon) + (1 - wdl_batch_y) * torch.log(1 - wdl_batch_y + epsilon)) -(wdl_batch_y * torch.log(wdl_pred   + epsilon) + (1 - wdl_batch_y) * torch.log(1 - wdl_pred   + epsilon))
        #loss = loss.mean()

        # calculate the loss
        loss = criterion(wdl_pred, wdl_batch_y)
        running_loss += loss.item()

        # make a step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        #scheduler.step()
        step += 1

        # calculate the gradient norm
        total_norm_sq = 0.0
        for p in model.parameters():
            if p.grad is not None:
                param_norm = p.grad.data.norm(2)  # L2 norm of this parameter's gradient
                total_norm_sq += param_norm.item() ** 2
        total_grad_norm = total_norm_sq ** 0.5
        # Now total_grad_norm is the L2 norm of all gradients combined.
        #print(f"Step {step}  Grad Norm = {total_grad_norm:.8f}")

        # Log every `log_interval` steps
        if step % log_interval == 0 and step != 0:
            avg_loss = running_loss / log_interval
            print(f"Epoch {epoch+1} | Step {step} | Avg Loss: {avg_loss:.4f} | Grad Norm: {total_grad_norm:.8f}")
            running_loss = 0.0
            current_lr = optimizer.param_groups[0]['lr']
            writer.add_scalar("Loss/train", avg_loss, step)
            writer.add_scalar("LR", current_lr, step)
            writer.add_scalar("Grad Norm", total_grad_norm, step)
            writer.add_scalar("WDL Pred", torch.mean(wdl_pred).item(), step)
            writer.add_scalar("WDL BatchY", torch.mean(wdl_batch_y).item(), step)
            writer.add_scalar("Pred", torch.median(pred).item(), step)
            writer.add_scalar("BatchY", torch.median(batch_y).item(), step)


            # log separate grad norms
            for name, param in model.named_parameters():
                if param.grad is not None:
                    grad_norm = param.grad.data.norm(2).item()
                    writer.add_scalar(f'GradNorm/{name}', grad_norm, step)

        # Save the model every `save_interval` steps
        if step % save_interval == 0:
            model_name = 'saved_models/split_model_10M_' + str(step) + ".pth"
            print("Saving model at step" + str(step))
            torch.save({'epoch': epoch, 'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss,}, model_name)
            
    #scheduler.step()
    print(f"Epoch {epoch+1}, Loss: {loss.item():.4f}")

Epoch 1 | Step 100 | Avg Loss: 0.0390 | Grad Norm: 0.00022215
Epoch 1 | Step 200 | Avg Loss: 0.0396 | Grad Norm: 0.00029383
Epoch 1 | Step 300 | Avg Loss: 0.0389 | Grad Norm: 0.00034422
Epoch 1 | Step 400 | Avg Loss: 0.0397 | Grad Norm: 0.00046375
Epoch 1 | Step 500 | Avg Loss: 0.0403 | Grad Norm: 0.00055304
Epoch 1 | Step 600 | Avg Loss: 0.0399 | Grad Norm: 0.00066189
Epoch 1 | Step 700 | Avg Loss: 0.0408 | Grad Norm: 0.00074175
Epoch 1 | Step 800 | Avg Loss: 0.0415 | Grad Norm: 0.00083143
Epoch 1 | Step 900 | Avg Loss: 0.0416 | Grad Norm: 0.00104168
Epoch 1 | Step 1000 | Avg Loss: 0.0418 | Grad Norm: 0.00116530
Epoch 1 | Step 1100 | Avg Loss: 0.0405 | Grad Norm: 0.00117744
Epoch 1 | Step 1200 | Avg Loss: 0.0387 | Grad Norm: 0.00125537
Epoch 1 | Step 1300 | Avg Loss: 0.0400 | Grad Norm: 0.00132723
Epoch 1 | Step 1400 | Avg Loss: 0.0399 | Grad Norm: 0.00138729
Epoch 1 | Step 1500 | Avg Loss: 0.0397 | Grad Norm: 0.00158825
Epoch 1 | Step 1600 | Avg Loss: 0.0391 | Grad Norm: 0.00146230
E

KeyboardInterrupt: 

### Postprocessing of model

In [10]:
# export the model
model = Split_NNUE()
checkpoint = torch.load('saved_models/split_model_10M_100000.pth')
model.load_state_dict(checkpoint['model_state_dict'])

def save_layer(layer, name):
    w = layer.weight.detach().numpy()
    b = layer.bias.detach().numpy()
    with open(f"{name}_weights.txt", "w") as f:
        for row in w:
            f.write(" ".join(map(str, row)) + "\n")
    with open(f"{name}_biases.txt", "w") as f:
        f.write(" ".join(map(str, b)))

save_layer(model.fc1, "model/layer1")
save_layer(model.fc2, "model/layer2")
save_layer(model.fc3, "model/layer3")


In [14]:
model = Split_NNUE()
checkpoint = torch.load('saved_models/split_model_10M_100000.pth')
model.load_state_dict(checkpoint['model_state_dict'])

model.eval()
fen1 = 'rnbqkbnr/pppppppp/8/8/8/5P2/PPPPP1PP/RNBQKBNR b KQkq - 0 1'
fen2 = 'rnbqkbnr/pppppppp/8/8/8/7N/PPPPPPPP/RNBQKB1R b KQkq - 1 1'
fen3 = 'rnbqkbnr/pppppppp/8/8/8/5N2/PPPPPPPP/RNBQKB1R b KQkq - 1 1'

start_fen = 'rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1'

#torch.tensor(FEN_to_input(fen1))

with torch.no_grad():
    w1, b1, stm1 = FEN_to_inputs(fen1)
    w1 = w1.unsqueeze(0)
    b1 = b1.unsqueeze(0)
    stm1 = stm1.unsqueeze(0)
    print(w1.shape, b1.shape, stm1.shape)
    #in_1 = np.argwhere(input1.numpy() == 1)
    #print(sum(input1.numpy()))
    w2, b2, stm2 = FEN_to_inputs(fen2)
    w2 = w2.unsqueeze(0)
    b2 = b2.unsqueeze(0)
    stm2 = stm2.unsqueeze(0)
    #print(np.argwhere(input2.numpy() == 1))
    w3, b3, stm3 = FEN_to_inputs(fen3)
    w3 = w3.unsqueeze(0)
    b3 = b3.unsqueeze(0)
    stm3 = stm3.unsqueeze(0)
    #print(np.argwhere(input3.numpy() == 1))

    #in_start = np.argwhere(FEN_to_inputs(start_fen).numpy() == 1)

    pred1 = model(w1, b1, stm1)
    pred2 = model(w2, b2, stm2)
    pred3 = model(w3, b3, stm3)

    print(pred1.item())
    print(pred2.item())
    print(pred3.item())

    #accumulator = model.fc1(input1)
    #ws, bs, stms = FEN_to_inputs(start_fen)
    #ws = ws.unsqueeze(0)
    #bs = bs.unsqueeze(0)
    #stms = stms.unsqueeze(0)
    w_accumulator = model.fc1(w1)
    b_accumulator = model.fc1(b1)
    #print(w_accumulator)
    #print(b_accumulator)


    cat_wb = torch.cat([w_accumulator, b_accumulator], dim=1)
    cat_bw = torch.cat([b_accumulator, w_accumulator], dim=1)

    #stm1 = stm1.to(dtype=cat_wb.dtype).view(-1, 1)
    #print(cat_bw)

    accumulator = (1 - stm1) * cat_wb + stm1 * cat_bw
    #print(accumulator)

    x = torch.clamp(accumulator, min = 0, max = 1)
    x = model.fc2(x)

    #print(x)

    print(model.fc2.bias.detach().numpy())
    #print(model.fc2.weight[0][:10])

    #print("weights[0][0]")
    #print(model.fc1.weight[0][0])
    #print("weights[1][0]")
    #print(model.fc1.weight[1][0])
    #print(model.fc1.bias[0])

torch.Size([1, 768]) torch.Size([1, 768]) torch.Size([1])
75.44014739990234
60.845481872558594
-1.3742446899414062
[-0.10270691 -0.7508942  -0.73212284 -0.6711883   0.4690156  -0.8976217
 -0.97113603 -0.7924493  -0.5964821  -0.75407517 -0.8800155  -0.2378027
  0.70388985 -0.6717296  -0.9132408  -0.8965203  -0.89196694 -0.35501885
  0.00581718  1.0301418  -0.73850054 -0.30603865 -0.21337932 -1.2286891
 -0.00251864 -1.0047029  -0.3110618  -0.640365   -0.5807997  -0.14240061
  0.07644044 -0.7504747 ]


In [None]:
import matplotlib.pyplot as plt
import numpy as np

# Set number of bins
num_bins = 64
bin_edges = np.linspace(-32003, 32003, num_bins + 1)
counts = np.zeros(num_bins, dtype=int)

filename = '/home/yvlaere/projects/yvl-chess/NNUE_training/training_data/scores.txt'

# Re-read the file and bin values
with open(filename, 'r') as f:
    for line in f:
        try:
            val = float(line.strip())
            # Determine bin index
            bin_idx = np.searchsorted(bin_edges, val, side='right') - 1
            if 0 <= bin_idx < num_bins:
                counts[bin_idx] += 1
        except ValueError:
            continue

# Find empty bins
empty_bins = []
for i, count in enumerate(counts):
    if count == 0:
        left_edge = bin_edges[i]
        right_edge = bin_edges[i + 1]
        empty_bins.append((i, left_edge, right_edge))

# Print empty bin ranges
print("Empty bins:")
for i, left, right in empty_bins:
    print(f"Bin {i}: [{left}, {right})")

# Plot histogram
plt.bar(bin_edges[:-1], counts, width=np.diff(bin_edges), edgecolor='black', align='edge')
plt.title("Histogram (streamed)")
plt.xlabel("Scores")
plt.ylabel("Frequency")
plt.grid(True)
plt.show()

In [None]:
# CP to WDL conversion
scaling_factor = 400
score = torch.tensor(32000, dtype=torch.float32)
print(torch.sigmoid(score / scaling_factor))
score = torch.tensor(1000, dtype=torch.float32)
print(torch.sigmoid(score / scaling_factor))
score = torch.tensor(-1000, dtype=torch.float32)
print(torch.sigmoid(score / scaling_factor))
score = torch.tensor(0, dtype=torch.float32)
print(torch.sigmoid(score / scaling_factor))

In [None]:
in_2 = [192, 65, 130, 259, 324, 133, 70, 199, 8, 9, 10, 11, 12, 13, 14, 15, 432, 433, 434, 435, 436, 437, 438, 439, 632, 505, 570, 699, 764, 573, 510, 639]

print(np.sort(in_start.reshape(1, 32)))
print(np.sort(in_2))


### HalfKP

In [None]:
piece_dict = {'P': 0, 'N': 1, 'B': 2, 'R': 3, 'Q': 4, 'K': 5, 'p': 6, 'n': 7, 'b': 8, 'r': 9, 'q': 10, 'k': 11}
stm_dict = {'w': 0, 'b': 1}


def FEN_to_HalfKP(fen):
    """
    Convert a FEN string to an NNUE input vector.
    """
    # Split the FEN string into its components
    sub_FEN = fen.split(' ')
    board = sub_FEN[0]
    stm = stm_dict[sub_FEN[1]]
    ranks = board.split('/')

    # Convert the board to a 1D boolean array
    # in the chess engine, position 0 corresponds to a1, so the ranks in the FEN string will need to be reversed
    input_layer = np.zeros(40960, dtype = np.float32)
    position = 0
    white_king_position = 0
    black_king_position = 0
    for rank in ranks[::-1]:
        for char in rank:
            if char.isdigit():
                position += int(char)
            elif char == 'K':
                white_king_position = position
                position += 1
            elif char == 'k':
                black_king_position = position
                position += 1
            else:
                position += 1

    white_input_layer = np.zeros(40960, dtype = np.float32)
    black_input_layer = np.zeros(40960, dtype = np.float32)

    position = 0
    for rank in ranks[::-1]:
        for char in rank:
            if char.isdigit():
                position += int(char)
            else:
                if (char != 'K') & (char != 'k'):
                    piece_index = (piece_dict[char] % 6) * 2 + (piece_dict[char] > 5)
                    white_input_layer[position + (piece_index + white_king_position*10)*64] = 1
                    black_input_layer[position + (piece_index + black_king_position*10)*64] = 1
                    position += 1
                else:
                    position += 1

    return torch.tensor(white_input_layer, dtype=torch.float32), torch.tensor(black_input_layer, dtype=torch.float32), torch.tensor(stm, dtype=torch.float32)

In [None]:
import csv
import torch
from torch.utils.data import IterableDataset, DataLoader

class HalfKP_Dataset(IterableDataset):
    def __init__(self, csv_path, shuffle_buffer=0):
        """
        csv_path: path to CSV file with two columns: fen, score
        fen_to_tensor: function(str) -> torch.Tensor
        shuffle_buffer: size of in-memory shuffle buffer; 0 = no shuffle
        """
        super().__init__()
        self.csv_path = csv_path
        self.shuffle_buffer = shuffle_buffer

    def _row_stream(self):
        """
        Generator that yields (fen, score) tuples from the CSV file.
        """
        with open(self.csv_path, newline='') as csvfile:
            reader = csv.reader(csvfile)
            for row in reader:
                if not row or row[0].startswith('#'):
                    continue
                w_in, b_in, stm = FEN_to_HalfKP(row[0].strip())
                score = float(row[1].strip())
                if score == 32002:
                    score = 0
                yield w_in, b_in, stm, torch.tensor(score, dtype=torch.float32)

    def __iter__(self):
        stream = self._row_stream()
        if self.shuffle_buffer > 1:

            # reservoir-style shuffle buffer
            buf = []
            for w_in, b_in, stm, score in stream:
                buf.append((w_in, b_in, stm, score))
                if len(buf) >= self.shuffle_buffer:
                    idx = torch.randint(len(buf), (1,)).item()
                    yield buf.pop(idx)
                    
            # drain remaining buffer
            while buf:
                idx = torch.randint(len(buf), (1,)).item()
                yield buf.pop(idx)
        else:
            for w_in, b_in, stm, score in stream:
                yield w_in, b_in, stm, score

In [None]:
import torch
import torch.nn as nn

NUM_FEATURES = 40960
M = 1024
N = 32
K = 1

class HalfKPNNUE(nn.Module):
    def __init__(self):
        super(HalfKPNNUE, self).__init__()
        # three fully connected layers
        self.fc1 = nn.Linear(NUM_FEATURES, M)
        self.fc2 = nn.Linear(2*M, N)
        self.fc3 = nn.Linear(N, K)

    def forward(self, white_features, black_features, stm):
        w = self.fc1(white_features)
        b = self.fc1(black_features)
        cat_wb = torch.cat([w, b], dim=1)  # [B, 2*M]
        cat_bw = torch.cat([b, w], dim=1)  # [B, 2*M]

        stm = stm.to(dtype=cat_wb.dtype).view(-1, 1)

        accumulator = stm * cat_wb + (1 - stm) * cat_bw

        x = torch.clamp(accumulator, min = 0.0, max = 1.0)
        x = torch.clamp(self.fc2(x), min = 0, max = 1)
        return self.fc3(x)

In [None]:
from torch.utils.data import DataLoader

# Load the dataset
csv_path = '/home/yvlaere/projects/yvl-chess/NNUE_training/training_data/sf_training_data.csv'
dataset = HalfKP_Dataset(csv_path, shuffle_buffer=1000)
loader = DataLoader(dataset, batch_size = 128, num_workers = 4, pin_memory = True)

In [None]:
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter(log_dir="runs/halfKP")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
nr_epochs = 500
model = HalfKPNNUE().to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=0.1, weight_decay=1e-4)
#optimizer = torch.optim.Adadelta(model.parameters(), lr = 0.05)
total_size = 200000000
batch_size = 128
steps_per_epoch = total_size // batch_size
#scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=100, min_lr=1e-6)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 100000, gamma=0.5)
criterion = nn.MSELoss()
MAE_loss = nn.L1Loss()
lowest_MAE = 10000

# Transform the CP scores to the WDL space
scaling_factor = 400

running_loss = 0.0
running_mae = 0.0
log_interval = 100
step = 0

for epoch in range(nr_epochs):
    for batch in loader:
        #for _ in range(100000):

        # get data from the dataloader
        batch_x_w, batch_x_b, stm, batch_y = batch

        # move data to GPU
        batch_x_w = batch_x_w.to(device, non_blocking = True)
        batch_x_b = batch_x_b.to(device, non_blocking = True)
        batch_y = batch_y.to(device, non_blocking = True)
        stm = stm.to(device, non_blocking = True)
        pred = model(batch_x_w, batch_x_b, stm).squeeze(1)  # remove the last dimension

        # Transform the CP scores to the WDL space
        wdl_batch_y = torch.sigmoid(batch_y / scaling_factor)
        wdl_pred = torch.sigmoid(pred / scaling_factor)

        # calculate the MSE loss
        loss = criterion(wdl_batch_y, wdl_pred)
        MAE = MAE_loss(wdl_batch_y, wdl_pred)
        running_loss += loss.item()
        running_mae += MAE
        step += 1

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        scheduler.step()

        

        # Log every `log_interval` steps
        if step % log_interval == 0 and step != 0:
            avg_loss = running_loss / log_interval
            avg_mae = running_mae / log_interval
            print(f"Epoch {epoch+1} | Step {step}/{steps_per_epoch} | Avg Loss: {avg_loss:.4f}")
            running_loss = 0.0
            running_mae = 0
            writer.add_scalar("Loss/train", avg_loss, step)
            writer.add_scalar("MAE/train", avg_mae, step)
            current_lr = optimizer.param_groups[0]['lr']
            writer.add_scalar("LR", current_lr, step)

        # calculate MAE
        if MAE < 0.0002:
            lowest_MAE = MAE
            torch.save(model.state_dict(), 'best_model.pth')
            print(f"New best model saved with MAE: {lowest_MAE.item():.4f}, loss: {loss.item():.4f}")
    
    #scheduler.step()
    print(f"Epoch {epoch+1}, Loss: {loss.item():.4f}")
    print(f"Epoch {epoch+1}, MAE: {MAE.item():.4f}, lowest MAE: {lowest_MAE:.4f}")

In [None]:
piece_dict = {'P': 0, 'N': 1, 'B': 2, 'R': 3, 'Q': 4, 'K': 5, 'p': 6, 'n': 7, 'b': 8, 'r': 9, 'q': 10, 'k': 11}

def FEN_to_input(fen):
    """
    Convert a FEN string to an NNUE input vector.
    """
    # Split the FEN string into its components
    sub_FEN = fen.split(' ')
    board = sub_FEN[0]
    ranks = board.split('/')

    # Convert the board to a 1D boolean array
    # in the chess engine, position 0 corresponds to a1, so the ranks in the FEN string will need to be reversed
    input_layer = np.zeros(768, dtype = np.float32)
    position = 0
    for rank in ranks[::-1]:
        for char in rank:
            if char.isdigit():
                position += int(char)
            else:
                input_layer[position + piece_dict[char]*64] = 1
                position += 1

    return torch.tensor(input_layer, dtype=torch.float32)

In [None]:
import torch
import torch.nn as nn

class SimpleNNUE(nn.Module):
    def __init__(self):
        super(SimpleNNUE, self).__init__()
        # three fully connected layers
        self.fc1 = nn.Linear(768, 256)
        self.fc2 = nn.Linear(256, 32)
        #self.fc3 = nn.Linear(128, 32)
        self.fc4 = nn.Linear(32, 1)
        self.relu = nn.ReLU()

    def forward(self, x):
        #x = torch.clamp(self.fc1(x), min = 0, max = 1)
        #x = torch.clamp(self.fc2(x), min = 0, max = 1)
        #x = torch.clamp(self.fc3(x), min = 0, max = 1)
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        #x = self.relu(self.fc3(x))
        x = self.fc4(x)
        return x


In [None]:
import csv
import torch
from torch.utils.data import IterableDataset, DataLoader

class Custom_Dataset(IterableDataset):
    def __init__(self, csv_path, shuffle_buffer=0):
        """
        csv_path: path to CSV file with two columns: fen, score
        fen_to_tensor: function(str) -> torch.Tensor
        shuffle_buffer: size of in-memory shuffle buffer; 0 = no shuffle
        """
        super().__init__()
        self.csv_path = csv_path
        self.shuffle_buffer = shuffle_buffer

    def _row_stream(self):
        """
        Generator that yields (fen, score) tuples from the CSV file.
        """
        with open(self.csv_path, newline='') as csvfile:
            reader = csv.reader(csvfile)
            for row in reader:
                if not row or row[0].startswith('#'):
                    continue
                fen, score, result = FEN_to_input(row[0].strip()), float(row[1].strip()), float(row[2].strip())
                if score == 32002:
                    score = 0
                if result == -1:
                    result = 0
                elif result == 0:
                    result = 0.5
                yield fen, torch.tensor(score, dtype=torch.float32), torch.tensor(result, dtype=torch.float32)

    def __iter__(self):
        stream = self._row_stream()
        if self.shuffle_buffer > 1:

            # reservoir-style shuffle buffer
            buf = []
            for fen, score, result in stream:
                buf.append((fen, score, result))
                if len(buf) >= self.shuffle_buffer:
                    idx = torch.randint(len(buf), (1,)).item()
                    yield buf.pop(idx)
                    
            # drain remaining buffer
            while buf:
                idx = torch.randint(len(buf), (1,)).item()
                yield buf.pop(idx)
        else:
            for fen, score, result in stream:
                yield fen, score, result