In [1]:
import numpy as np # type: ignore
import time
import torch
import torch.nn as nn # type: ignore
import torch.optim as optim # type: ignore
from torch.utils.data import random_split, DataLoader # type: ignore
from torch.optim.lr_scheduler import MultiStepLR
from chess import pgn # type: ignore
from tqdm import tqdm # type: ignore
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime
from auxiliary_func import check_memory, load_dataset, encode_moves
from dataset import ChessDataset
from model import ChessModel
from model2 import ChessModel2
from model3 import ChessModel3
from model4 import ChessModel4
from model5 import ChessModel5
from model7 import ChessModel7
import pickle

2025-11-06 14:20:33.383718: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:479] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-11-06 14:20:33.403353: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:10575] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-11-06 14:20:33.403393: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1442] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-11-06 14:20:33.415385: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
total_mem = check_memory()
print(total_mem)


X, y, games_parsed, files_parsed = load_dataset(data_folder="../../data/Lichess_Elite_Database", pgn_memory_mark=1.0, file_limit=2)


X, y = np.array(X, dtype=np.float32), np.array(y)

y, move_to_int = encode_moves(y)
num_classes = len(move_to_int)

7.693897247314453


  0%|          | 0/79 [00:00<?, ?it/s]

Completed sampling limit of files with 7.693897247314453 remaining


  1%|▏         | 1/79 [00:00<00:05, 13.29it/s]


In [3]:

X = torch.tensor(X, dtype=torch.float32)
y = torch.tensor(y, dtype=torch.long)


In [4]:
X.shape
# y.shape

torch.Size([171, 16, 8, 8])

In [5]:
# Create Dataset
dataset = ChessDataset(X, y)
dataloader = DataLoader(dataset, batch_size=64, shuffle=True)

# Compute split sizes
train_size = int(0.9 * len(dataset))
val_size = len(dataset) - train_size

train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

# Then create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)

# Check for GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'Using device: {device}', flush=True)

# Model Initialization
model = ChessModel7(num_classes=num_classes).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.0005)

scheduler = MultiStepLR(optimizer, milestones=[50000, 250000, 400000], gamma=0.2)

Using device: cuda


In [6]:
model.train()
running_loss = 0.0
for inputs, labels in tqdm(train_loader):
    inputs, labels = inputs.to(device), labels.to(device)  # Move data to GPU
    optimizer.zero_grad()

    outputs = model(inputs)  # Raw logits

    print(outputs)
    # Compute loss
    loss = criterion(outputs, labels)
    loss.backward()
    
    # Gradient clipping
    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
    
    optimizer.step()
    scheduler.step()
    print(loss.item())
    print(loss.item())
    running_loss += loss.item()

  0%|          | 0/3 [00:00<?, ?it/s]

torch.Size([64, 64, 8, 8])
torch.Size([64, 64, 1, 1])
torch.Size([64, 64])
torch.Size([64, 8])
torch.Size([64, 64])
torch.Size([64, 64, 1, 1])
tensor([[ 0.1359,  0.1158, -0.4285,  ...,  0.2883,  0.4095, -0.3727],
        [-0.8300,  0.1510,  0.5317,  ...,  0.8046,  1.0785,  1.6197],
        [ 0.1077, -1.0788,  0.4180,  ..., -0.2307,  0.3737, -0.2505],
        ...,
        [-0.4980, -0.5779,  0.5010,  ...,  0.3198,  0.6781,  0.6157],
        [-2.4249, -0.1408,  0.5513,  ..., -0.7549,  1.4830,  1.7898],
        [-0.9673,  0.7497,  0.2564,  ...,  2.6634,  1.3394, -0.3859]],
       device='cuda:0', grad_fn=<AddmmBackward0>)


100%|██████████| 3/3 [00:00<00:00,  4.16it/s]

5.417665004730225
5.417665004730225
torch.Size([64, 64, 8, 8])
torch.Size([64, 64, 1, 1])
torch.Size([64, 64])
torch.Size([64, 8])
torch.Size([64, 64])
torch.Size([64, 64, 1, 1])
tensor([[ 0.9467,  0.7286, -0.5864,  ...,  0.7267, -0.4543, -0.7226],
        [-0.8535, -0.0693, -0.2242,  ...,  0.5035, -0.2450, -0.0755],
        [-0.3670,  0.3299,  0.6909,  ...,  0.2754,  0.3421, -0.5678],
        ...,
        [-1.0111, -0.2984, -0.3384,  ...,  0.2664, -0.5644,  0.3098],
        [-0.7990, -0.2408,  0.1900,  ..., -0.1993,  0.1834,  0.4669],
        [ 0.0268, -1.6397,  1.2044,  ...,  1.1058,  0.0556,  1.0107]],
       device='cuda:0', grad_fn=<AddmmBackward0>)
5.471574783325195
5.471574783325195
torch.Size([25, 64, 8, 8])
torch.Size([25, 64, 1, 1])
torch.Size([25, 64])
torch.Size([25, 8])
torch.Size([25, 64])
torch.Size([25, 64, 1, 1])
tensor([[-0.1775,  1.0130,  0.5060,  ..., -0.3446,  0.9891, -1.1812],
        [-0.6625, -0.7017,  0.0918,  ..., -0.3525,  0.5182,  0.0241],
        [-1.1861, 


