In [1]:
import sys
from pathlib import Path

# Dynamically add the 'src' directory to sys.path
project_root = Path().resolve().parent  # Get the parent directory of 'notebooks'
src_path = project_root / "src"
sys.path.append(str(src_path))

In [3]:
import torch
import torch.nn as nn
from transformers import GPT2Tokenizer, GPT2Model, GPT2Config, GPT2LMHeadModel

from torch.utils.data import DataLoader
import torch.optim as optim
from tokenizer import ChessTokenizer

# from dataa import ChessDataset
import gc
import chess
import numpy as np
from model import ChessGPTModel
from chessdata import ChessDataset

In [2]:
tok = ChessTokenizer()
mps_device = torch.device("mps")
cpu_device = torch.device("cpu")

In [9]:
hidden_size = 768
seq_len = 256
n_layer = 16
n_head = 8
# vocab_size = tok.vocabulary_size()
vocab_size = tok.vocabulary_size() #1000
config = GPT2Config(vocab_size=vocab_size, n_positions=seq_len, n_ctx=seq_len, n_embd=hidden_size, n_layer=n_layer, n_head=n_head)

In [10]:
model = GPT2LMHeadModel(config)
model.to(mps_device)

num_of_parameters = sum(map(torch.numel, model.parameters()))
num_of_parameters / 1e6

115.117824

In [5]:
def accuracy(logits, labels):
    preds = torch.argmax(logits, dim=1)
    correct = (preds == labels).sum().item()
    accuracy = correct / labels.size(0)
    return accuracy

def top_k_accuracy(logits, labels, k=5):
    topk_values, topk_indices = torch.topk(logits, k, dim=-1)  # [batch_size, k]
    correct_in_topk = (topk_indices == labels.unsqueeze(-1)).any(dim=-1)
    topk_accuracy = correct_in_topk.float().mean().item()
    return topk_accuracy

def compute_topk_accuracy(logits: torch.Tensor, labels: torch.Tensor, attention_mask: torch.Tensor, pad_token_id: int = 0, k: int = 1) -> float:
    batch_size, seq_len, vocab_size = logits.shape

    # 1) Flatten all except vocab
    logits_flat = logits.view(-1, vocab_size)         # shape: [batch_size*seq_len, vocab_size]
    labels_flat = labels.view(-1)                     # shape: [batch_size*seq_len]
    mask_flat = attention_mask.view(-1).bool()        # shape: [batch_size*seq_len], True=real token, False=pad

    # 2) Further exclude positions where label == pad_token_id
    #    We'll combine both conditions into a single boolean mask
    valid_positions = mask_flat & (labels_flat != pad_token_id)

    if valid_positions.sum() == 0:
        # If there are no valid positions, return 0 or NaN
        return 0.0

    valid_logits = logits_flat[valid_positions]       # shape: [#valid, vocab_size]
    valid_labels = labels_flat[valid_positions]       # shape: [#valid]

    # 3) Get top-k indices along vocab dimension
    topk_values, topk_indices = torch.topk(valid_logits, k, dim=-1)  # [#valid, k]

    # 4) Check if the correct label is in top-k predictions
    correct_in_topk = (topk_indices == valid_labels.unsqueeze(-1)).any(dim=-1)  # [#valid] bool
    accuracy = correct_in_topk.float().mean().item()

    return accuracy

criterion = nn.CrossEntropyLoss()  # Example for classification
loss_fn = nn.CrossEntropyLoss(ignore_index=0)
optimizer = optim.Adam(model.parameters(), lr=1e-3)

In [8]:
num_epochs = 100
model.train()

print_every = 1



for epoch in range(num_epochs):
    
    running_loss = 0.0
    top_1_accuracy = 0.0
    top_5_accuracy = 0.0

    dataset = ChessDataset(file_name="./data/trainingmedium.pgn", tokenizer=tok, max_seq_len=seq_len)
    dataloader = DataLoader(dataset, batch_size=16)

    for i, (cur_token_ids, cur_attn_mask, cur_legal_mask, cur_labels) in enumerate(dataloader):
        optimizer.zero_grad()
        
        cur_attn_mask = cur_attn_mask.to(mps_device)
        cur_token_ids = cur_token_ids.to(mps_device)
        cur_legal_mask = cur_legal_mask.to(mps_device)
        cur_labels = cur_labels.to(mps_device)

        # print(cur_token_ids.shape)
        # print(cur_attn_mask.shape)
        # print(cur_legal_mask.shape)
        # print(cur_labels.shape)

        outputs = model(input_ids = cur_token_ids, attention_mask = cur_attn_mask).logits
        masked_logits = outputs.masked_fill(~cur_legal_mask, float('-1e10'))
        vocab_size = masked_logits.size(-1)
        
        top_1_accuracy += compute_topk_accuracy(masked_logits, cur_labels, cur_attn_mask, k=1)
        top_5_accuracy += compute_topk_accuracy(masked_logits, cur_labels, cur_attn_mask, k=5)

        masked_logits = masked_logits.view(-1, vocab_size)
        labels_flat = cur_labels.view(-1)
        loss = loss_fn(masked_logits, labels_flat)
        running_loss += loss.item()

        del outputs, cur_token_ids, cur_attn_mask, cur_legal_mask, cur_labels

        gc.collect()
        
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        # running_loss += loss.item()
        if (epoch + 1) % 1 == 0 and (i + 1) % print_every == 0:
            print(f'Epoch {epoch+1}, Batch {i+1}, Loss: {loss / print_every:.4f}, Top-1 Accuracy: {top_1_accuracy / print_every:.4f}, Top-5 Accuracy: {top_5_accuracy / print_every:.4f}')
            running_loss = 0.0
            top_1_accuracy = 0.0
            top_5_accuracy = 0.0

Epoch 1, Batch 1, Loss: 1.0164, Top-1 Accuracy: 0.8203, Top-5 Accuracy: 0.9335
Epoch 1, Batch 2, Loss: 7.2724, Top-1 Accuracy: 0.1123, Top-5 Accuracy: 0.3510
Epoch 1, Batch 3, Loss: 6.0269, Top-1 Accuracy: 0.1050, Top-5 Accuracy: 0.3446
Epoch 1, Batch 4, Loss: 4.5995, Top-1 Accuracy: 0.0990, Top-5 Accuracy: 0.3216
Epoch 1, Batch 5, Loss: 3.9492, Top-1 Accuracy: 0.0986, Top-5 Accuracy: 0.3243
Epoch 1, Batch 6, Loss: 3.5850, Top-1 Accuracy: 0.0875, Top-5 Accuracy: 0.3143
Epoch 1, Batch 7, Loss: 3.3668, Top-1 Accuracy: 0.1008, Top-5 Accuracy: 0.3423
Epoch 1, Batch 8, Loss: 3.1903, Top-1 Accuracy: 0.1109, Top-5 Accuracy: 0.3641
Epoch 1, Batch 9, Loss: 3.0954, Top-1 Accuracy: 0.1352, Top-5 Accuracy: 0.4013
Epoch 1, Batch 10, Loss: 3.0108, Top-1 Accuracy: 0.1426, Top-5 Accuracy: 0.4198
Epoch 1, Batch 11, Loss: 3.2845, Top-1 Accuracy: 0.1168, Top-5 Accuracy: 0.3719
Epoch 1, Batch 12, Loss: 3.0997, Top-1 Accuracy: 0.1273, Top-5 Accuracy: 0.4177
Epoch 1, Batch 13, Loss: 3.1335, Top-1 Accuracy: 

KeyboardInterrupt: 

In [7]:
torch.save(model.state_dict(), "./modelv0.pth")

In [None]:
def get_model_output(model, move_list, tokenizer: ChessTokenizer, seq_len=256):
    board = chess.Board()

    token_ids = [0] * seq_len
    attn_mask = [0] * seq_len

    for i in range(len(move_list)):
        move = move_list[i]
        new_move = chess.Move.from_uci(move)
        board.push(new_move)
        token_ids[i] = tokenizer.tokens_to_ids_single(move)
        attn_mask[i] = 1


    legal_moves = list(board.legal_moves)
    legal_moves = [str(move) for move in legal_moves]
    legal_mask = [0] * tokenizer.vocabulary_size()

    legal_moves_ids = tokenizer.tokens_to_ids_vect(legal_moves)

    for id in legal_moves_ids:
        legal_mask[id] = 1
    
    token_ids = torch.tensor(token_ids)
    attn_mask = torch.tensor(attn_mask)
    legal_mask = torch.tensor(legal_mask)

    token_ids = torch.reshape(token_ids, (1, -1))
    attn_mask = torch.reshape(attn_mask, (1, -1))
    legal_mask = torch.reshape(legal_mask, (1, -1))

    token_ids = token_ids.to(mps_device)
    attn_mask = attn_mask.to(mps_device)
    legal_mask = legal_mask.to(mps_device)

    outputs = model(token_ids, attn_mask, legal_mask)

    softmax = nn.Softmax(dim=-1)
    outputs = softmax(outputs)

    outputs = outputs.cpu().detach().numpy()
    outputs = outputs[0]

    to_sort = [(outputs[i], i) for i in range(len(outputs))]
    to_sort.sort(reverse=True)

    print(to_sort)

    for i in range(10):
        print(tokenizer.ids_to_tokens_single(to_sort[i][1]), to_sort[i][0])

    return tokenizer.ids_to_tokens_single(to_sort[0][1])



get_model_output(model, ["e2e4", "e7e5", "f1c4", "b8c6", "b1c3", "g7g6", "g1f3", "d7d5", "d2d3"], tok, seq_len=256)

[(0.63404244, 762), (0.30915236, 1032), (0.00054985494, 1302), (0.00013388555, 29), (0.00010157549, 1301), (9.4130475e-05, 1031), (6.9529495e-05, 495), (6.3696054e-05, 223), (5.403922e-05, 761), (5.3789186e-05, 1567), (5.3702457e-05, 1808), (4.0073075e-05, 30), (3.728253e-05, 1530), (3.5871017e-05, 249), (3.153475e-05, 496), (3.1181306e-05, 248), (2.8857305e-05, 213), (2.8393255e-05, 1970), (2.8393255e-05, 1969), (2.8393255e-05, 1968), (2.8393255e-05, 1967), (2.8393255e-05, 1966), (2.8393255e-05, 1965), (2.8393255e-05, 1964), (2.8393255e-05, 1963), (2.8393255e-05, 1962), (2.8393255e-05, 1961), (2.8393255e-05, 1960), (2.8393255e-05, 1959), (2.8393255e-05, 1958), (2.8393255e-05, 1957), (2.8393255e-05, 1956), (2.8393255e-05, 1955), (2.8393255e-05, 1954), (2.8393255e-05, 1953), (2.8393255e-05, 1952), (2.8393255e-05, 1951), (2.8393255e-05, 1950), (2.8393255e-05, 1949), (2.8393255e-05, 1948), (2.8393255e-05, 1947), (2.8393255e-05, 1946), (2.8393255e-05, 1945), (2.8393255e-05, 1944), (2.83932

'd2d4'