Dataset

In [14]:
import os
import time
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from chess import pgn
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor

In [15]:
def load_pgn(file_path):
    games = []
    with open(file_path) as file:
        while True:
            game = pgn.read_game(file)
            if game is None:
                break
            games.append(game)
    return games


def load_all_pgns(files):
    games = []
    with ThreadPoolExecutor() as executor:
        results = list(tqdm(executor.map(load_pgn, files), total=len(files)))
        for result in results:
            games.extend(result)
    return games

files = [file for file in os.listdir("data") if file.endswith(".pgn")]
LIMIT_OF_FILES = min(26, len(files))

games = load_all_pgns([f"data/{file}" for file in files[:LIMIT_OF_FILES]])
del files
print(f"Loaded {len(games)} games")

100%|██████████| 26/26 [02:33<00:00,  5.91s/it]

Loaded 19330 games





In [17]:
from dataset import ChessDataset
from model import ChessModel
from auxiliary_func import create_nn_input

_, _, move_to_int = create_nn_input(games)
num_classes = len(move_to_int)
print(f"Number of classes (unique moves): {num_classes}")

# Create Dataset and DataLoader
dataset = ChessDataset(games)
del games
dataloader = DataLoader(dataset, batch_size=64, shuffle=True) 

# Check for GPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f'Using device: {device}')

# Model Initialization
num_classes = dataset.__getitem__
print(f"Number of classes: {num_classes}")
model = ChessModel(num_classes=num_classes).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.0001)

ValueError: not enough values to unpack (expected 3, got 2)

In [5]:
num_epochs = 100
for epoch in range(num_epochs):
    start_time = time.time()
    model.train()
    running_loss = 0.0
    for inputs, labels in tqdm(dataloader):
        inputs, labels = inputs.to(device), labels.to(device)  # Move data to GPU
        optimizer.zero_grad()

        outputs = model(inputs)  # Raw logits

        # Compute loss
        loss = criterion(outputs, labels)
        loss.backward()
        
        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        
        optimizer.step()
        running_loss += loss.item()
    end_time = time.time()
    epoch_time = end_time - start_time
    minutes: int = int(epoch_time // 60)
    seconds: int = int(epoch_time) - minutes * 60
    print(f'Epoch {epoch + 1}/{num_epochs}, Loss: {running_loss / len(dataloader):.4f}, Time: {minutes}m{seconds}s')

  0%|          | 0/24642 [00:00<?, ?it/s]

100%|██████████| 24642/24642 [07:18<00:00, 56.25it/s] 


Epoch 1/100, Loss: 3.9568, Time: 7m18s


100%|██████████| 24642/24642 [07:07<00:00, 57.71it/s] 


Epoch 2/100, Loss: 2.9318, Time: 7m7s


100%|██████████| 24642/24642 [04:17<00:00, 95.79it/s] 


Epoch 3/100, Loss: 2.6409, Time: 4m17s


100%|██████████| 24642/24642 [04:12<00:00, 97.66it/s] 


Epoch 4/100, Loss: 2.4787, Time: 4m12s


100%|██████████| 24642/24642 [04:46<00:00, 85.91it/s] 


Epoch 5/100, Loss: 2.3668, Time: 4m46s


100%|██████████| 24642/24642 [04:34<00:00, 89.65it/s] 


Epoch 6/100, Loss: 2.2798, Time: 4m34s


100%|██████████| 24642/24642 [04:05<00:00, 100.30it/s]


Epoch 7/100, Loss: 2.2079, Time: 4m5s


100%|██████████| 24642/24642 [04:41<00:00, 87.64it/s] 


Epoch 8/100, Loss: 2.1458, Time: 4m41s


100%|██████████| 24642/24642 [04:36<00:00, 88.99it/s] 


Epoch 9/100, Loss: 2.0906, Time: 4m36s


100%|██████████| 24642/24642 [04:11<00:00, 97.99it/s] 


Epoch 10/100, Loss: 2.0413, Time: 4m11s


100%|██████████| 24642/24642 [03:50<00:00, 106.94it/s]


Epoch 11/100, Loss: 1.9968, Time: 3m50s


100%|██████████| 24642/24642 [04:12<00:00, 97.70it/s] 


Epoch 12/100, Loss: 1.9558, Time: 4m12s


100%|██████████| 24642/24642 [03:58<00:00, 103.11it/s]


Epoch 13/100, Loss: 1.9184, Time: 3m59s


100%|██████████| 24642/24642 [04:19<00:00, 95.09it/s] 


Epoch 14/100, Loss: 1.8835, Time: 4m19s


100%|██████████| 24642/24642 [05:29<00:00, 74.80it/s]


Epoch 15/100, Loss: 1.8511, Time: 5m29s


100%|██████████| 24642/24642 [06:01<00:00, 68.21it/s]


Epoch 16/100, Loss: 1.8207, Time: 6m1s


100%|██████████| 24642/24642 [05:58<00:00, 68.70it/s]


Epoch 17/100, Loss: 1.7921, Time: 5m58s


100%|██████████| 24642/24642 [06:10<00:00, 66.56it/s]


Epoch 18/100, Loss: 1.7655, Time: 6m10s


100%|██████████| 24642/24642 [07:11<00:00, 57.17it/s]


Epoch 19/100, Loss: 1.7405, Time: 7m11s


100%|██████████| 24642/24642 [07:21<00:00, 55.86it/s]


Epoch 20/100, Loss: 1.7166, Time: 7m21s


100%|██████████| 24642/24642 [06:14<00:00, 65.88it/s]


Epoch 21/100, Loss: 1.6939, Time: 6m14s


100%|██████████| 24642/24642 [06:32<00:00, 62.83it/s] 


Epoch 22/100, Loss: 1.6725, Time: 6m32s


100%|██████████| 24642/24642 [03:22<00:00, 121.81it/s]


Epoch 23/100, Loss: 1.6520, Time: 3m22s


100%|██████████| 24642/24642 [06:03<00:00, 67.81it/s]


Epoch 24/100, Loss: 1.6321, Time: 6m3s


100%|██████████| 24642/24642 [08:02<00:00, 51.12it/s]


Epoch 25/100, Loss: 1.6135, Time: 8m2s


100%|██████████| 24642/24642 [08:12<00:00, 50.07it/s]


Epoch 26/100, Loss: 1.5953, Time: 8m12s


100%|██████████| 24642/24642 [08:23<00:00, 48.91it/s]


Epoch 27/100, Loss: 1.5775, Time: 8m23s


100%|██████████| 24642/24642 [08:11<00:00, 50.11it/s]


Epoch 28/100, Loss: 1.5612, Time: 8m11s


100%|██████████| 24642/24642 [08:26<00:00, 48.65it/s]


Epoch 29/100, Loss: 1.5446, Time: 8m26s


100%|██████████| 24642/24642 [08:22<00:00, 49.07it/s]


Epoch 30/100, Loss: 1.5289, Time: 8m22s


100%|██████████| 24642/24642 [08:08<00:00, 50.41it/s]


Epoch 31/100, Loss: 1.5143, Time: 8m8s


100%|██████████| 24642/24642 [04:42<00:00, 87.32it/s] 


Epoch 32/100, Loss: 1.4993, Time: 4m42s


100%|██████████| 24642/24642 [06:09<00:00, 66.78it/s]


Epoch 33/100, Loss: 1.4851, Time: 6m9s


100%|██████████| 24642/24642 [05:40<00:00, 72.47it/s] 


Epoch 34/100, Loss: 1.4720, Time: 5m40s


100%|██████████| 24642/24642 [05:52<00:00, 69.82it/s]


Epoch 35/100, Loss: 1.4586, Time: 5m52s


100%|██████████| 24642/24642 [05:41<00:00, 72.10it/s]


Epoch 36/100, Loss: 1.4461, Time: 5m41s


100%|██████████| 24642/24642 [04:50<00:00, 84.93it/s] 


Epoch 37/100, Loss: 1.4341, Time: 4m50s


100%|██████████| 24642/24642 [02:02<00:00, 201.22it/s]


Epoch 38/100, Loss: 1.4222, Time: 2m2s


100%|██████████| 24642/24642 [02:04<00:00, 197.68it/s]


Epoch 39/100, Loss: 1.4109, Time: 2m4s


100%|██████████| 24642/24642 [02:03<00:00, 198.81it/s]


Epoch 40/100, Loss: 1.3997, Time: 2m3s


100%|██████████| 24642/24642 [02:06<00:00, 194.34it/s]


Epoch 41/100, Loss: 1.3888, Time: 2m6s


100%|██████████| 24642/24642 [02:03<00:00, 200.00it/s]


Epoch 42/100, Loss: 1.3785, Time: 2m3s


100%|██████████| 24642/24642 [02:03<00:00, 199.23it/s]


Epoch 43/100, Loss: 1.3682, Time: 2m3s


100%|██████████| 24642/24642 [02:03<00:00, 199.53it/s]


Epoch 44/100, Loss: 1.3585, Time: 2m3s


100%|██████████| 24642/24642 [02:07<00:00, 192.85it/s]


Epoch 45/100, Loss: 1.3492, Time: 2m7s


100%|██████████| 24642/24642 [02:07<00:00, 192.70it/s]


Epoch 46/100, Loss: 1.3397, Time: 2m7s


100%|██████████| 24642/24642 [02:01<00:00, 203.10it/s]


Epoch 47/100, Loss: 1.3309, Time: 2m1s


100%|██████████| 24642/24642 [02:08<00:00, 192.38it/s]


Epoch 48/100, Loss: 1.3223, Time: 2m8s


100%|██████████| 24642/24642 [02:10<00:00, 188.95it/s]


Epoch 49/100, Loss: 1.3138, Time: 2m10s


100%|██████████| 24642/24642 [02:08<00:00, 192.32it/s]


Epoch 50/100, Loss: 1.3054, Time: 2m8s


100%|██████████| 24642/24642 [02:06<00:00, 194.06it/s]


Epoch 51/100, Loss: 1.2970, Time: 2m6s


100%|██████████| 24642/24642 [02:12<00:00, 186.42it/s]


Epoch 52/100, Loss: 1.2891, Time: 2m12s


100%|██████████| 24642/24642 [02:25<00:00, 169.35it/s]


Epoch 53/100, Loss: 1.2812, Time: 2m25s


100%|██████████| 24642/24642 [02:32<00:00, 161.94it/s]


Epoch 54/100, Loss: 1.2736, Time: 2m32s


100%|██████████| 24642/24642 [02:50<00:00, 144.42it/s]


Epoch 55/100, Loss: 1.2665, Time: 2m50s


100%|██████████| 24642/24642 [02:13<00:00, 184.35it/s]


Epoch 56/100, Loss: 1.2593, Time: 2m13s


100%|██████████| 24642/24642 [02:25<00:00, 169.73it/s]


Epoch 57/100, Loss: 1.2522, Time: 2m25s


100%|██████████| 24642/24642 [02:20<00:00, 175.10it/s]


Epoch 58/100, Loss: 1.2454, Time: 2m20s


100%|██████████| 24642/24642 [02:39<00:00, 154.82it/s]


Epoch 59/100, Loss: 1.2382, Time: 2m39s


100%|██████████| 24642/24642 [02:24<00:00, 170.35it/s]


Epoch 60/100, Loss: 1.2317, Time: 2m24s


100%|██████████| 24642/24642 [02:27<00:00, 166.57it/s]


Epoch 61/100, Loss: 1.2255, Time: 2m27s


100%|██████████| 24642/24642 [02:20<00:00, 174.84it/s]


Epoch 62/100, Loss: 1.2196, Time: 2m20s


100%|██████████| 24642/24642 [02:25<00:00, 169.41it/s]


Epoch 63/100, Loss: 1.2132, Time: 2m25s


100%|██████████| 24642/24642 [02:25<00:00, 169.70it/s]


Epoch 64/100, Loss: 1.2074, Time: 2m25s


100%|██████████| 24642/24642 [02:35<00:00, 158.65it/s]


Epoch 65/100, Loss: 1.2014, Time: 2m35s


100%|██████████| 24642/24642 [02:30<00:00, 163.37it/s]


Epoch 66/100, Loss: 1.1957, Time: 2m30s


100%|██████████| 24642/24642 [02:44<00:00, 149.93it/s]


Epoch 67/100, Loss: 1.1900, Time: 2m44s


100%|██████████| 24642/24642 [02:27<00:00, 167.07it/s]


Epoch 68/100, Loss: 1.1849, Time: 2m27s


100%|██████████| 24642/24642 [02:26<00:00, 168.71it/s]


Epoch 69/100, Loss: 1.1789, Time: 2m26s


100%|██████████| 24642/24642 [02:19<00:00, 176.02it/s]


Epoch 70/100, Loss: 1.1736, Time: 2m20s


100%|██████████| 24642/24642 [02:30<00:00, 164.04it/s]


Epoch 71/100, Loss: 1.1692, Time: 2m30s


100%|██████████| 24642/24642 [02:26<00:00, 168.49it/s]


Epoch 72/100, Loss: 1.1637, Time: 2m26s


100%|██████████| 24642/24642 [02:21<00:00, 173.89it/s]


Epoch 73/100, Loss: 1.1585, Time: 2m21s


100%|██████████| 24642/24642 [02:25<00:00, 169.67it/s]


Epoch 74/100, Loss: 1.1539, Time: 2m25s


100%|██████████| 24642/24642 [02:28<00:00, 166.21it/s]


Epoch 75/100, Loss: 1.1485, Time: 2m28s


100%|██████████| 24642/24642 [02:28<00:00, 166.40it/s]


Epoch 76/100, Loss: 1.1441, Time: 2m28s


100%|██████████| 24642/24642 [02:23<00:00, 172.03it/s]


Epoch 77/100, Loss: 1.1394, Time: 2m23s


100%|██████████| 24642/24642 [02:25<00:00, 169.94it/s]


Epoch 78/100, Loss: 1.1353, Time: 2m25s


100%|██████████| 24642/24642 [02:27<00:00, 166.54it/s]


Epoch 79/100, Loss: 1.1303, Time: 2m27s


100%|██████████| 24642/24642 [02:29<00:00, 165.17it/s]


Epoch 80/100, Loss: 1.1261, Time: 2m29s


100%|██████████| 24642/24642 [02:32<00:00, 161.78it/s]


Epoch 81/100, Loss: 1.1220, Time: 2m32s


100%|██████████| 24642/24642 [02:27<00:00, 167.01it/s]


Epoch 82/100, Loss: 1.1173, Time: 2m27s


100%|██████████| 24642/24642 [02:31<00:00, 162.65it/s]


Epoch 83/100, Loss: 1.1129, Time: 2m31s


100%|██████████| 24642/24642 [02:28<00:00, 165.45it/s]


Epoch 84/100, Loss: 1.1090, Time: 2m28s


100%|██████████| 24642/24642 [02:30<00:00, 163.35it/s]


Epoch 85/100, Loss: 1.1053, Time: 2m30s


100%|██████████| 24642/24642 [02:24<00:00, 170.25it/s]


Epoch 86/100, Loss: 1.1010, Time: 2m24s


100%|██████████| 24642/24642 [02:19<00:00, 177.20it/s]


Epoch 87/100, Loss: 1.0968, Time: 2m19s


100%|██████████| 24642/24642 [02:45<00:00, 149.30it/s]


Epoch 88/100, Loss: 1.0934, Time: 2m45s


100%|██████████| 24642/24642 [02:16<00:00, 180.29it/s]


Epoch 89/100, Loss: 1.0895, Time: 2m16s


100%|██████████| 24642/24642 [02:15<00:00, 181.39it/s]


Epoch 90/100, Loss: 1.0855, Time: 2m15s


100%|██████████| 24642/24642 [02:13<00:00, 184.67it/s]


Epoch 91/100, Loss: 1.0816, Time: 2m13s


100%|██████████| 24642/24642 [02:24<00:00, 170.86it/s]


Epoch 92/100, Loss: 1.0778, Time: 2m24s


100%|██████████| 24642/24642 [02:21<00:00, 173.62it/s]


Epoch 93/100, Loss: 1.0743, Time: 2m21s


100%|██████████| 24642/24642 [02:19<00:00, 176.97it/s]


Epoch 94/100, Loss: 1.0709, Time: 2m19s


100%|██████████| 24642/24642 [02:16<00:00, 180.25it/s]


Epoch 95/100, Loss: 1.0676, Time: 2m16s


100%|██████████| 24642/24642 [02:25<00:00, 169.86it/s]


Epoch 96/100, Loss: 1.0643, Time: 2m25s


100%|██████████| 24642/24642 [02:13<00:00, 184.88it/s]


Epoch 97/100, Loss: 1.0604, Time: 2m13s


100%|██████████| 24642/24642 [02:19<00:00, 177.11it/s]


Epoch 98/100, Loss: 1.0573, Time: 2m19s


100%|██████████| 24642/24642 [02:17<00:00, 178.74it/s]


Epoch 99/100, Loss: 1.0540, Time: 2m17s


100%|██████████| 24642/24642 [02:16<00:00, 180.62it/s]

Epoch 100/100, Loss: 1.0507, Time: 2m16s





In [6]:
# Save the model
torch.save(model.state_dict(), "models/TORCH_100EPOCHS.pth")
import pickle

with open("models/100_move_to_int", "wb") as file:
    pickle.dump(move_to_int, file)