In [1]:
import torch
from torch.utils.data import DataLoader, TensorDataset
import json

with open('move_to_id.json', 'r') as f:
    move_to_id = json.load(f)
PAD_ID = move_to_id["<PAD>"]

# Load the single tensor
encoded_tensor = torch.load('encoded_games_test.pt')

# Shift for X and Y
X_test = encoded_tensor[:, :-1]
Y_test = encoded_tensor[:, 1:]

test_dataset = TensorDataset(X_test, Y_test)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)


In [2]:
import torch.nn as nn

class ChessDecoder(nn.Module):
    def __init__(self, vocab_size, d_model=256, nhead=8, num_layers=4, max_len=200):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, d_model)
        self.pos_embed = nn.Embedding(max_len, d_model)
        decoder_layer = nn.TransformerDecoderLayer(d_model=d_model, nhead=nhead, dim_feedforward=1024)
        self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_layers)
        self.fc_out = nn.Linear(d_model, vocab_size)

    def forward(self, x):
        # x shape: [batch, seq_len] → transformer expects [seq_len, batch]
        x = x.transpose(0, 1)
        seq_len, batch_size = x.size()

        # Add embeddings
        positions = torch.arange(seq_len, device=x.device).unsqueeze(1)
        x = self.embed(x) + self.pos_embed(positions)

        # Decoder masking: prevent attention to future tokens
        mask = torch.triu(torch.ones(seq_len, seq_len, device=x.device), diagonal=1).bool()

        x = self.decoder(x, x, tgt_mask=mask)
        logits = self.fc_out(x)  # [seq_len, batch, vocab_size]
        return logits.transpose(0, 1)  # [batch, seq_len, vocab_size]

In [3]:
vocab_size = len(move_to_id)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'On device {device}')

with open("best_hparams.json", "r") as f:
    best_hparams = json.load(f)

params = best_hparams["best_params"]

# Rebuild model (same as in training)
model = ChessDecoder(
    vocab_size=vocab_size,
    d_model=params["d_model"],
    nhead=params["nhead"],
    num_layers=params["num_layers"],
    max_len=200 # X_test.size(1)
).to(device)

# Load weights if you saved them
model.load_state_dict(torch.load("500k_model.pt", map_location=device))

model.eval()


On device cuda


ChessDecoder(
  (embed): Embedding(11017, 512)
  (pos_embed): Embedding(200, 512)
  (decoder): TransformerDecoder(
    (layers): ModuleList(
      (0-2): 3 x TransformerDecoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
        )
        (multihead_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
        )
        (linear1): Linear(in_features=512, out_features=1024, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (linear2): Linear(in_features=1024, out_features=512, bias=True)
        (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (norm3): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.1, inplace=False)
        (dropout2): Dropout(p=0.1, inplace=False)
  

In [4]:
import torch.nn.functional as F

all_preds = []
all_labels = []
total_loss = 0
total_tokens = 0

with torch.no_grad():
    for x, y in test_loader:
        x = x.to(device)
        y = y.to(device)

        logits = model(x)

        loss = F.cross_entropy(
            logits.reshape(-1, vocab_size),
            y.reshape(-1),
            ignore_index=PAD_ID,
            reduction='sum'
        )
        total_loss += loss.item()
        total_tokens += (y != PAD_ID).sum().item()

        preds = torch.argmax(logits, dim=-1)

        # mask out padding so it doesn't affect metrics
        mask = (y != PAD_ID).reshape(-1)
        all_preds.extend(preds.reshape(-1)[mask].cpu().numpy())
        all_labels.extend(y.reshape(-1)[mask].cpu().numpy())

avg_loss = total_loss / total_tokens

In [None]:
from sklearn.metrics import accuracy_score, f1_score, cohen_kappa_score

acc = accuracy_score(all_labels, all_preds)
f1 = f1_score(all_labels, all_preds, average='macro')  # or 'weighted'
kappa = cohen_kappa_score(all_labels, all_preds)

print(f"Test Loss: {avg_loss:.4f}")
print(f"Accuracy: {acc:.4f}")
print(f"F1 Score: {f1:.4f}")
print(f"Cohen's Kappa: {kappa:.4f}")


Test Loss: 0.0019
Accuracy: 0.9996
F1 Score: 0.9969
Cohen's Kappa: 0.9996
