In [1]:
import numpy as np # type: ignore
import time
import torch
import torch.nn as nn # type: ignore
import torch.optim as optim # type: ignore
from torch.utils.data import random_split, DataLoader # type: ignore
from torch.optim.lr_scheduler import MultiStepLR
from chess import pgn # type: ignore
from tqdm import tqdm # type: ignore
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime
from auxiliary_func import check_memory, load_dataset, encode_moves
from dataset import ChessDataset
from model import ChessModel
from model2 import ChessModel2
from model3 import ChessModel3
from model4 import ChessModel4
from model5 import ChessModel5
import pickle

2025-11-05 12:57:28.053986: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:479] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-11-05 12:57:28.075474: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:10575] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-11-05 12:57:28.075509: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1442] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-11-05 12:57:28.088178: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
total_mem = check_memory()
print(total_mem)


X, y, games_parsed, files_parsed = load_dataset(data_folder="../../data/Lichess_Elite_Database", pgn_memory_mark=1.0, file_limit=2)


X, y = np.array(X, dtype=np.float32), np.array(y)

y, move_to_int = encode_moves(y)
num_classes = len(move_to_int)

8.866966247558594


  0%|          | 0/79 [00:00<?, ?it/s]

Completed sampling limit of files with 8.860027313232422 remaining


  1%|▏         | 1/79 [00:00<00:04, 18.41it/s]


In [3]:

X = torch.tensor(X, dtype=torch.float32)
y = torch.tensor(y, dtype=torch.long)


In [None]:
# Create Dataset
dataset = ChessDataset(X, y)
dataloader = DataLoader(dataset, batch_size=64, shuffle=True)

# Compute split sizes
train_size = int(0.9 * len(dataset))
val_size = len(dataset) - train_size

train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

# Then create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)

# Check for GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'Using device: {device}', flush=True)

# Model Initialization
model = ChessModel(num_classes=num_classes).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.0005)

scheduler = MultiStepLR(optimizer, milestones=[50000, 250000, 400000], gamma=0.2)

Using device: cuda


In [7]:
model.train()
running_loss = 0.0
for inputs, labels in tqdm(train_loader):
    inputs, labels = inputs.to(device), labels.to(device)  # Move data to GPU
    optimizer.zero_grad()

    outputs = model(inputs)  # Raw logits

    print(outputs)
    # Compute loss
    loss = criterion(outputs, labels)
    loss.backward()
    
    # Gradient clipping
    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
    
    optimizer.step()
    scheduler.step()
    print(loss.item())
    print(loss.item())
    running_loss += loss.item()

100%|██████████| 1/1 [00:00<00:00,  5.97it/s]

tensor([[ 0.1393,  1.6369, -3.3265,  ...,  0.8556, -1.0527, -1.0251],
        [-1.2652, -0.0954, -1.9578,  ..., -0.5304,  1.1013, -0.2666],
        [ 1.1467,  0.2864, -1.2602,  ...,  0.0129, -2.3730, -0.0494],
        ...,
        [-0.0803, -0.5543, -0.1809,  ..., -1.3569, -0.3243, -2.2262],
        [ 5.9492, -1.0146, -1.1220,  ..., -1.1028,  0.1902, -0.0175],
        [-1.3070,  1.1778, -1.7877,  ..., -0.8435, -2.9171,  1.0776]],
       device='cuda:0', grad_fn=<AddmmBackward0>)
0.5079324841499329
0.5079324841499329



