In [1]:
import pandas as pd
import numpy as np
from enum import Enum
import chess
from pathlib import Path
from typing import Tuple
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader

In [2]:
if torch.cuda.is_available():
    device = "cuda" 
    print("CUDA version:", torch.version.cuda)
    print("GPU:", torch.cuda.get_device_name())
else:
    device = "cpu"
    
device

CUDA version: 12.9
GPU: NVIDIA GeForce RTX 3070


'cuda'

In [3]:
"""
v1 - 775-1024-600-400-200-100
"""

'\nv1 - 775-1024-600-400-200-100\n'

In [4]:
RUN_ID = "run_2026_01_19_probs_mlp_v1"
model_save_name = "probs_mlp_v1"

In [5]:
class ChessDataset(Dataset):
    def __init__(self, root_dir: Path, split: str, sigma: float = 0.6):
        self.root_dir = Path(root_dir)
        self.split = split
        self.sigma = sigma
        self.num_classes = 7
        self.class_indices = torch.arange(self.num_classes, dtype=torch.float32)
        self.X = np.load(self.root_dir / f"{self.split}_X.npy", mmap_mode='r')
        self.y = np.load(self.root_dir / f"{self.split}_y.npy", mmap_mode='r')
        self.scores = np.load(self.root_dir / f"{self.split}_scores.npy", mmap_mode='r')

    def __len__(self) -> int:
        return self.X.shape[0]

    def score_to_continuous_index(self, score: float) -> float:
        """
        Maps Centipawn score to a continuous index (e.g. 400cp -> 1.5).
        """
        
        if score >= 500: 
            # Fade from 0.5 (at 500) to 0.0 (at 700)
            return max(0.0, 0.5 - (score - 500) / 200.0)
        
        if score <= -500:
            # Fade from 5.5 (at -500) to 6.0 (at -700)
            return min(6.0, 5.5 + (-500 - score) / 200.0)
        
        # Interpolate the Middle Classes
        # 300 to 500  -> Maps to 1.5 to 0.5
        if score >= 300: return 1.5 - (score - 300) / 200.0
        # 100 to 300  -> Maps to 2.5 to 1.5
        if score >= 100: return 2.5 - (score - 100) / 200.0
        # -100 to 100 -> Maps to 3.5 to 2.5
        if score >= -100: return 3.5 - (score - (-100)) / 200.0
        # -300 to -100 -> Maps to 4.5 to 3.5
        if score >= -300: return 4.5 - (score - (-300)) / 200.0
        # -500 to -300 -> Maps to 5.5 to 4.5
        if score > -500: return 5.5 - (score - (-500)) / 200.0
        
        return 3.0 
    
    def __getitem__(self, idx) -> Tuple[torch.tensor, torch.tensor]:
        score = self.scores[idx].item()
        target_idx = self.score_to_continuous_index(score)
        
        # Create Gaussian Distribution centered at target_idx
        dist = torch.exp(-((self.class_indices - target_idx) ** 2) / (2 * self.sigma ** 2))
        
        # Normalize so it sums to 1.0
        soft_target = dist / dist.sum()
        
        x_tensor = torch.tensor(self.X[idx], dtype=torch.float32)

        # return self.X[idx], self.y[idx]
        return x_tensor, soft_target

In [6]:
BATCH_SIZE = 512
num_workers = 0 
ROOT_DIR = Path("./dataset_bitmaps_cp/")

train_dataset = ChessDataset(root_dir=ROOT_DIR, split="train")
train_dataloader = DataLoader(dataset=train_dataset, 
                              batch_size=BATCH_SIZE, 
                              num_workers=num_workers,
                              shuffle=True,
                              pin_memory=True)

val_dataset = ChessDataset(root_dir=ROOT_DIR, split="val")
val_dataloader = DataLoader(dataset=val_dataset, 
                            batch_size=BATCH_SIZE, 
                            num_workers=num_workers,
                            shuffle=False,
                            pin_memory=True)

test_dataset = ChessDataset(root_dir=ROOT_DIR, split="test")
test_dataloader = DataLoader(dataset=test_dataset, 
                             batch_size=BATCH_SIZE, 
                             num_workers=num_workers,
                             shuffle=False,
                             pin_memory=True)

In [7]:
import time

start = time.time()
for i, (X, y) in enumerate(train_dataloader):
    if i == 100:  # measure 100 batches
        break
print("Avg batch load time:", (time.time() - start) / 100)


Avg batch load time: 0.03596267938613892


In [8]:
Xb, yb = next(iter(train_dataloader))
print("X batch shape:", Xb.shape, "dtype:", Xb.dtype)
print("y batch shape:", yb.shape, "dtype:", yb.dtype)

X batch shape: torch.Size([512, 775]) dtype: torch.float32
y batch shape: torch.Size([512, 7]) dtype: torch.float32


In [9]:
class PositionLabel(Enum):
    WHITE_WINNING = 0
    WHITE_DECISIVE = 1
    WHITE_BETTER = 2
    EQUAL = 3
    BLACK_BETTER = 4
    BLACK_DECISIVE = 5
    BLACK_WINNING = 6

In [15]:
class MLP(nn.Module):
    def __init__(self, input_shape=775, output_shape=7):
        super().__init__()
        
        self.network = nn.Sequential(
            nn.Linear(input_shape, 1024), 
            nn.BatchNorm1d(1024),
            nn.ReLU(),
            nn.Dropout(0.2),
            
            nn.Linear(1024, 600),
            nn.BatchNorm1d(600),
            nn.ReLU(),
            nn.Dropout(0.2),
            
            nn.Linear(600, 400),
            nn.BatchNorm1d(400),
            nn.ReLU(),
            nn.Dropout(0.2),
            
            nn.Linear(400, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Dropout(0.2),

            nn.Linear(256, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Dropout(0.5),
            
            nn.Linear(128, output_shape)
        )

    def forward(self, x):
        return self.network(x)

In [11]:
def train_step(model: torch.nn.Module,
               dataloader: torch.utils.data.DataLoader,
               loss_fn: torch.nn.Module,
               optimizer: torch.optim.Optimizer,
               scaler: torch.amp.GradScaler,
               device=device) -> tuple[float, float]:
    """
    Performs one training epoch for the given model.
    Returns the average loss and accuracy across all batches.
    """
    
    # Put model in train mode
    model.train()

    train_loss, train_acc = 0, 0

    for batch, (X,y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device) # X and Y are both shape (BATCH_SIZE,)

        # Reset gradients
        optimizer.zero_grad()

        # Forward Pass
        with torch.amp.autocast(device):
            y_pred = model(X)
            loss = loss_fn(y_pred, y)
        
        # Backpropagation
        scaler.scale(loss).backward()

        # Update weights
        scaler.step(optimizer)
        scaler.update()  
        
        train_loss += loss.item()

        # Calculate accuracy metrics
        """softmax and argmax dim=1 because tensor of shape (batchsize, num_classes)"""
        y_pred_class = torch.argmax(y_pred, dim=-1) # y_pred_class.shape = (BATCH_SIZE,)
        # train_acc += (y_pred_class == y).sum().item()/len(y_pred)

        # Remove for non prob ablation
        y_true_class = torch.argmax(y, dim=-1)
        train_acc += (y_pred_class == y_true_class).sum().item()/len(y_pred)
        
    train_loss = train_loss / len(dataloader)
    train_acc = train_acc / len(dataloader)

    return train_loss, train_acc

In [12]:
def eval_step(model: torch.nn.Module,
              dataloader: torch.utils.data.DataLoader,
              loss_fn: torch.nn.Module,
              device=device) -> tuple[float, float]:
    """
    Evaluates the given model on the given dataloader without gradient updates.
    Dataloader should either be the validation or test dataloader.
    Returns the average loss and accuracy across all batches.
    """
    
    # Put model in eval mode
    model.eval()

    test_loss, test_acc = 0, 0

    with torch.inference_mode():
        for batch, (X,y) in enumerate(dataloader):
            X, y = X.to(device), y.to(device)

            # Forward Pass
            test_pred = model(X)

            # Calculate the loss
            loss = loss_fn(test_pred, y)
            test_loss += loss.item()

            # Calculate accuracy metrics
            test_pred_labels = torch.argmax(test_pred, dim=1)
            # test_acc += (test_pred_labels == y).sum().item()/len(test_pred_labels)

            # Remove for non prob ablation
            y_true_labels = torch.argmax(y, dim=1)
            test_acc += (test_pred_labels == y_true_labels).sum().item()/len(test_pred_labels)

    test_loss = test_loss / len(dataloader)
    test_acc = test_acc / len(dataloader)

    return test_loss, test_acc

In [13]:
from tqdm import tqdm
import copy

def run_experiment(model: torch.nn.Module,
                   model_save_name: str,
                   train_dataloader: torch.utils.data.DataLoader,
                   val_dataloader: torch.utils.data.DataLoader,
                   loss_fn: torch.nn.Module,
                   optimizer: torch.optim.Optimizer,
                   scaler: torch.amp.GradScaler,
                   epochs: int,
                   patience: int,
                   device=device):
    
    results = {"train_loss": [],
               "train_acc": [],
               "val_loss": [],
               "val_acc": []}
    
    best_val_acc = 0.0
    best_model_weights = None
    patience_counter = 0 
    
    # scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    #     optimizer, mode='min', factor=0.1, patience=3
    # )

    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, 
        T_max=epochs,      
        eta_min=1e-6         
    )
    
    print(f"Starting Training: {model_save_name}")

    for epoch in tqdm(range(epochs)):
        train_loss, train_acc = train_step(model=model,
                                           dataloader=train_dataloader,
                                           loss_fn=loss_fn,
                                           optimizer=optimizer,
                                           device=device,
                                           scaler=scaler)
        val_loss, val_acc = eval_step(model=model,
                                      dataloader=val_dataloader,
                                      loss_fn=loss_fn,
                                      device=device)
        
        # scheduler.step(val_loss)
        scheduler.step()
        
        if val_acc > best_val_acc: 
            best_val_acc = val_acc
            best_model_weights = copy.deepcopy(model.state_dict())
            patience_counter = 0
            
            print(f"Epoch: {epoch} | New Best Val Acc: {val_acc:.4f} (Saved)")
            torch.save(model.state_dict(), f"models/{model_save_name}.pth")
        else:
            patience_counter += 1
            print(f"Epoch: No improvement. Patience {patience_counter}/{patience}")

        print(f"Epoch: {epoch} | Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.4f}")

        results["train_loss"].append(train_loss)
        results["train_acc"].append(train_acc)
        results["val_loss"].append(val_loss)
        results["val_acc"].append(val_acc)

        if patience_counter >= patience:
            print(f"\n[Early Stopping] No improvement for {patience} epochs. Stopping.")
            break 

    if best_model_weights is not None:
        model.load_state_dict(best_model_weights)
        print(f"\nLoaded best model weights with Val Acc: {best_val_acc:.4f}")

    return results

In [17]:
from torchinfo import summary

model = MLP(input_shape=775, 
            output_shape=7).to(device)

summary(model, input_size=(BATCH_SIZE,775))


Layer (type:depth-idx)                   Output Shape              Param #
MLP                                      [512, 7]                  --
├─Sequential: 1-1                        [512, 7]                  --
│    └─Linear: 2-1                       [512, 1024]               794,624
│    └─BatchNorm1d: 2-2                  [512, 1024]               2,048
│    └─ReLU: 2-3                         [512, 1024]               --
│    └─Dropout: 2-4                      [512, 1024]               --
│    └─Linear: 2-5                       [512, 600]                615,000
│    └─BatchNorm1d: 2-6                  [512, 600]                1,200
│    └─ReLU: 2-7                         [512, 600]                --
│    └─Dropout: 2-8                      [512, 600]                --
│    └─Linear: 2-9                       [512, 400]                240,400
│    └─BatchNorm1d: 2-10                 [512, 400]                800
│    └─ReLU: 2-11                        [512, 400]            

In [None]:
torch.manual_seed(42)
torch.cuda.manual_seed(42)

NUM_EPOCHS = 100

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(params=model.parameters(),
                              lr=0.001,
                              weight_decay=0.01)

scaler = torch.amp.GradScaler("cuda")

result = run_experiment(model=model,
                        model_save_name=model_save_name,
                        train_dataloader=train_dataloader,
                        val_dataloader=val_dataloader,
                        loss_fn=loss_fn,
                        optimizer=optimizer,
                        scaler=scaler,
                        epochs=NUM_EPOCHS,
                        patience=NUM_EPOCHS + 1,
                        device=device)

Starting Training: probs_mlp_v1


  1%|          | 1/100 [06:08<10:07:37, 368.25s/it]

Epoch: 0 | New Best Val Acc: 0.4803 (Saved)
Epoch: 0 | Train Loss: 1.4833 | Val Loss: 1.4093 | Val Acc: 0.4803


  2%|▏         | 2/100 [12:20<10:05:28, 370.70s/it]

Epoch: 1 | New Best Val Acc: 0.5092 (Saved)
Epoch: 1 | Train Loss: 1.3955 | Val Loss: 1.3664 | Val Acc: 0.5092


  3%|▎         | 3/100 [18:32<9:59:50, 371.04s/it] 

Epoch: 2 | New Best Val Acc: 0.5237 (Saved)
Epoch: 2 | Train Loss: 1.3596 | Val Loss: 1.3467 | Val Acc: 0.5237


  4%|▍         | 4/100 [24:47<9:56:10, 372.61s/it]

Epoch: 3 | New Best Val Acc: 0.5304 (Saved)
Epoch: 3 | Train Loss: 1.3381 | Val Loss: 1.3336 | Val Acc: 0.5304


  5%|▌         | 5/100 [30:54<9:46:44, 370.58s/it]

Epoch: 4 | New Best Val Acc: 0.5365 (Saved)
Epoch: 4 | Train Loss: 1.3236 | Val Loss: 1.3268 | Val Acc: 0.5365


  6%|▌         | 6/100 [37:03<9:39:56, 370.17s/it]

Epoch: 5 | New Best Val Acc: 0.5395 (Saved)
Epoch: 5 | Train Loss: 1.3130 | Val Loss: 1.3231 | Val Acc: 0.5395


  7%|▋         | 7/100 [43:14<9:34:13, 370.47s/it]

Epoch: 6 | New Best Val Acc: 0.5447 (Saved)
Epoch: 6 | Train Loss: 1.3046 | Val Loss: 1.3165 | Val Acc: 0.5447


  8%|▊         | 8/100 [49:25<9:28:03, 370.47s/it]

Epoch: 7 | New Best Val Acc: 0.5456 (Saved)
Epoch: 7 | Train Loss: 1.2980 | Val Loss: 1.3131 | Val Acc: 0.5456


  9%|▉         | 9/100 [55:36<9:22:08, 370.64s/it]

Epoch: 8 | New Best Val Acc: 0.5509 (Saved)
Epoch: 8 | Train Loss: 1.2927 | Val Loss: 1.3070 | Val Acc: 0.5509


 10%|█         | 10/100 [1:01:47<9:16:17, 370.86s/it]

Epoch: 9 | New Best Val Acc: 0.5532 (Saved)
Epoch: 9 | Train Loss: 1.2877 | Val Loss: 1.3033 | Val Acc: 0.5532


 11%|█         | 11/100 [1:07:59<9:10:36, 371.20s/it]

Epoch: 10 | New Best Val Acc: 0.5544 (Saved)
Epoch: 10 | Train Loss: 1.2838 | Val Loss: 1.3016 | Val Acc: 0.5544


 12%|█▏        | 12/100 [1:14:14<9:06:08, 372.37s/it]

Epoch: 11 | New Best Val Acc: 0.5549 (Saved)
Epoch: 11 | Train Loss: 1.2802 | Val Loss: 1.3012 | Val Acc: 0.5549


 13%|█▎        | 13/100 [1:20:26<8:59:51, 372.31s/it]

Epoch: 12 | New Best Val Acc: 0.5559 (Saved)
Epoch: 12 | Train Loss: 1.2770 | Val Loss: 1.2981 | Val Acc: 0.5559


 14%|█▍        | 14/100 [1:26:39<8:53:43, 372.37s/it]

Epoch: 13 | New Best Val Acc: 0.5568 (Saved)
Epoch: 13 | Train Loss: 1.2744 | Val Loss: 1.2977 | Val Acc: 0.5568


 15%|█▌        | 15/100 [1:32:51<8:47:20, 372.24s/it]

Epoch: 14 | New Best Val Acc: 0.5579 (Saved)
Epoch: 14 | Train Loss: 1.2716 | Val Loss: 1.2963 | Val Acc: 0.5579


 16%|█▌        | 16/100 [1:39:05<8:41:54, 372.79s/it]

Epoch: 15 | New Best Val Acc: 0.5612 (Saved)
Epoch: 15 | Train Loss: 1.2697 | Val Loss: 1.2923 | Val Acc: 0.5612


 17%|█▋        | 17/100 [1:45:17<8:35:32, 372.68s/it]

Epoch: 16 | New Best Val Acc: 0.5620 (Saved)
Epoch: 16 | Train Loss: 1.2674 | Val Loss: 1.2895 | Val Acc: 0.5620


 18%|█▊        | 18/100 [1:51:31<8:29:54, 373.11s/it]

Epoch: No improvement. Patience 1/101
Epoch: 17 | Train Loss: 1.2652 | Val Loss: 1.2897 | Val Acc: 0.5611


 19%|█▉        | 19/100 [1:57:45<8:24:08, 373.43s/it]

Epoch: 18 | New Best Val Acc: 0.5625 (Saved)
Epoch: 18 | Train Loss: 1.2634 | Val Loss: 1.2882 | Val Acc: 0.5625


 20%|██        | 20/100 [2:03:51<8:14:36, 370.95s/it]

Epoch: 19 | New Best Val Acc: 0.5641 (Saved)
Epoch: 19 | Train Loss: 1.2617 | Val Loss: 1.2868 | Val Acc: 0.5641


 21%|██        | 21/100 [2:10:12<8:12:37, 374.14s/it]

Epoch: No improvement. Patience 1/101
Epoch: 20 | Train Loss: 1.2600 | Val Loss: 1.2885 | Val Acc: 0.5636


 22%|██▏       | 22/100 [2:16:24<8:05:23, 373.37s/it]

Epoch: 21 | New Best Val Acc: 0.5658 (Saved)
Epoch: 21 | Train Loss: 1.2580 | Val Loss: 1.2844 | Val Acc: 0.5658


 23%|██▎       | 23/100 [2:22:40<8:00:19, 374.28s/it]

Epoch: 22 | New Best Val Acc: 0.5659 (Saved)
Epoch: 22 | Train Loss: 1.2568 | Val Loss: 1.2840 | Val Acc: 0.5659


 24%|██▍       | 24/100 [2:28:57<7:55:03, 375.04s/it]

Epoch: No improvement. Patience 1/101
Epoch: 23 | Train Loss: 1.2551 | Val Loss: 1.2828 | Val Acc: 0.5658


 25%|██▌       | 25/100 [2:35:10<7:47:54, 374.33s/it]

Epoch: 24 | New Best Val Acc: 0.5669 (Saved)
Epoch: 24 | Train Loss: 1.2536 | Val Loss: 1.2813 | Val Acc: 0.5669


 26%|██▌       | 26/100 [2:41:30<7:44:00, 376.22s/it]

Epoch: 25 | New Best Val Acc: 0.5690 (Saved)
Epoch: 25 | Train Loss: 1.2524 | Val Loss: 1.2801 | Val Acc: 0.5690


 27%|██▋       | 27/100 [2:47:51<7:39:25, 377.61s/it]

Epoch: 26 | New Best Val Acc: 0.5693 (Saved)
Epoch: 26 | Train Loss: 1.2508 | Val Loss: 1.2790 | Val Acc: 0.5693


 28%|██▊       | 28/100 [2:53:56<7:28:27, 373.71s/it]

Epoch: No improvement. Patience 1/101
Epoch: 27 | Train Loss: 1.2496 | Val Loss: 1.2813 | Val Acc: 0.5677


 29%|██▉       | 29/100 [2:59:36<7:10:33, 363.85s/it]

Epoch: No improvement. Patience 2/101
Epoch: 28 | Train Loss: 1.2481 | Val Loss: 1.2790 | Val Acc: 0.5693


 30%|███       | 30/100 [3:05:54<7:09:06, 367.81s/it]

Epoch: 29 | New Best Val Acc: 0.5706 (Saved)
Epoch: 29 | Train Loss: 1.2466 | Val Loss: 1.2767 | Val Acc: 0.5706


 31%|███       | 31/100 [3:12:09<7:05:47, 370.25s/it]

Epoch: 30 | New Best Val Acc: 0.5718 (Saved)
Epoch: 30 | Train Loss: 1.2453 | Val Loss: 1.2765 | Val Acc: 0.5718


 32%|███▏      | 32/100 [3:18:24<7:00:58, 371.45s/it]

Epoch: 31 | New Best Val Acc: 0.5723 (Saved)
Epoch: 31 | Train Loss: 1.2440 | Val Loss: 1.2756 | Val Acc: 0.5723


 33%|███▎      | 33/100 [3:24:39<6:55:57, 372.51s/it]

Epoch: No improvement. Patience 1/101
Epoch: 32 | Train Loss: 1.2426 | Val Loss: 1.2751 | Val Acc: 0.5719


 34%|███▍      | 34/100 [3:30:53<6:50:21, 373.06s/it]

Epoch: 33 | New Best Val Acc: 0.5738 (Saved)
Epoch: 33 | Train Loss: 1.2414 | Val Loss: 1.2710 | Val Acc: 0.5738


 35%|███▌      | 35/100 [3:37:07<6:44:20, 373.24s/it]

Epoch: 34 | New Best Val Acc: 0.5741 (Saved)
Epoch: 34 | Train Loss: 1.2403 | Val Loss: 1.2718 | Val Acc: 0.5741


 36%|███▌      | 36/100 [3:43:21<6:38:24, 373.50s/it]

Epoch: No improvement. Patience 1/101
Epoch: 35 | Train Loss: 1.2388 | Val Loss: 1.2726 | Val Acc: 0.5737


 37%|███▋      | 37/100 [3:49:35<6:32:18, 373.63s/it]

Epoch: No improvement. Patience 2/101
Epoch: 36 | Train Loss: 1.2375 | Val Loss: 1.2725 | Val Acc: 0.5730


 38%|███▊      | 38/100 [3:55:38<6:23:00, 370.65s/it]

Epoch: 37 | New Best Val Acc: 0.5753 (Saved)
Epoch: 37 | Train Loss: 1.2361 | Val Loss: 1.2706 | Val Acc: 0.5753


 39%|███▉      | 39/100 [4:01:09<6:04:30, 358.53s/it]

Epoch: 38 | New Best Val Acc: 0.5759 (Saved)
Epoch: 38 | Train Loss: 1.2347 | Val Loss: 1.2687 | Val Acc: 0.5759


 40%|████      | 40/100 [4:07:18<6:01:46, 361.78s/it]

Epoch: No improvement. Patience 1/101
Epoch: 39 | Train Loss: 1.2333 | Val Loss: 1.2697 | Val Acc: 0.5752


 41%|████      | 41/100 [4:13:08<5:52:15, 358.24s/it]

Epoch: No improvement. Patience 2/101
Epoch: 40 | Train Loss: 1.2324 | Val Loss: 1.2706 | Val Acc: 0.5742


 42%|████▏     | 42/100 [4:18:50<5:41:26, 353.22s/it]

Epoch: 41 | New Best Val Acc: 0.5783 (Saved)
Epoch: 41 | Train Loss: 1.2310 | Val Loss: 1.2677 | Val Acc: 0.5783


 43%|████▎     | 43/100 [4:24:26<5:30:51, 348.26s/it]

Epoch: No improvement. Patience 1/101
Epoch: 42 | Train Loss: 1.2297 | Val Loss: 1.2687 | Val Acc: 0.5776


 44%|████▍     | 44/100 [4:30:34<5:30:33, 354.17s/it]

Epoch: 43 | New Best Val Acc: 0.5785 (Saved)
Epoch: 43 | Train Loss: 1.2283 | Val Loss: 1.2654 | Val Acc: 0.5785


 45%|████▌     | 45/100 [4:36:00<5:16:45, 345.56s/it]

Epoch: No improvement. Patience 1/101
Epoch: 44 | Train Loss: 1.2272 | Val Loss: 1.2681 | Val Acc: 0.5767


 46%|████▌     | 46/100 [4:41:22<5:04:43, 338.58s/it]

Epoch: 45 | New Best Val Acc: 0.5796 (Saved)
Epoch: 45 | Train Loss: 1.2257 | Val Loss: 1.2657 | Val Acc: 0.5796


 47%|████▋     | 47/100 [4:47:19<5:04:01, 344.19s/it]

Epoch: No improvement. Patience 1/101
Epoch: 46 | Train Loss: 1.2243 | Val Loss: 1.2663 | Val Acc: 0.5783


 48%|████▊     | 48/100 [4:53:07<4:59:06, 345.13s/it]

Epoch: 47 | New Best Val Acc: 0.5804 (Saved)
Epoch: 47 | Train Loss: 1.2228 | Val Loss: 1.2652 | Val Acc: 0.5804


 49%|████▉     | 49/100 [4:59:07<4:57:10, 349.61s/it]

Epoch: 48 | New Best Val Acc: 0.5807 (Saved)
Epoch: 48 | Train Loss: 1.2213 | Val Loss: 1.2642 | Val Acc: 0.5807


 50%|█████     | 50/100 [5:05:02<4:52:51, 351.42s/it]

Epoch: No improvement. Patience 1/101
Epoch: 49 | Train Loss: 1.2201 | Val Loss: 1.2678 | Val Acc: 0.5788


 51%|█████     | 51/100 [5:10:55<4:47:24, 351.93s/it]

Epoch: No improvement. Patience 2/101
Epoch: 50 | Train Loss: 1.2184 | Val Loss: 1.2637 | Val Acc: 0.5801


 52%|█████▏    | 52/100 [5:16:26<4:36:31, 345.65s/it]

Epoch: 51 | New Best Val Acc: 0.5814 (Saved)
Epoch: 51 | Train Loss: 1.2167 | Val Loss: 1.2624 | Val Acc: 0.5814


 53%|█████▎    | 53/100 [5:22:14<4:31:11, 346.20s/it]

Epoch: No improvement. Patience 1/101
Epoch: 52 | Train Loss: 1.2157 | Val Loss: 1.2634 | Val Acc: 0.5812


 54%|█████▍    | 54/100 [5:28:24<4:30:50, 353.28s/it]

Epoch: No improvement. Patience 2/101
Epoch: 53 | Train Loss: 1.2139 | Val Loss: 1.2644 | Val Acc: 0.5810


 55%|█████▌    | 55/100 [5:34:39<4:29:50, 359.79s/it]

Epoch: No improvement. Patience 3/101
Epoch: 54 | Train Loss: 1.2128 | Val Loss: 1.2645 | Val Acc: 0.5806


 56%|█████▌    | 56/100 [5:40:52<4:26:47, 363.80s/it]

Epoch: 55 | New Best Val Acc: 0.5830 (Saved)
Epoch: 55 | Train Loss: 1.2108 | Val Loss: 1.2612 | Val Acc: 0.5830


 57%|█████▋    | 57/100 [5:46:36<4:16:34, 358.02s/it]

Epoch: 56 | New Best Val Acc: 0.5838 (Saved)
Epoch: 56 | Train Loss: 1.2094 | Val Loss: 1.2628 | Val Acc: 0.5838


 58%|█████▊    | 58/100 [5:52:17<4:07:00, 352.87s/it]

Epoch: 57 | New Best Val Acc: 0.5846 (Saved)
Epoch: 57 | Train Loss: 1.2082 | Val Loss: 1.2602 | Val Acc: 0.5846


 59%|█████▉    | 59/100 [5:57:54<3:57:54, 348.15s/it]

Epoch: No improvement. Patience 1/101
Epoch: 58 | Train Loss: 1.2063 | Val Loss: 1.2601 | Val Acc: 0.5839


 60%|██████    | 60/100 [6:03:31<3:49:50, 344.76s/it]

Epoch: No improvement. Patience 2/101
Epoch: 59 | Train Loss: 1.2047 | Val Loss: 1.2606 | Val Acc: 0.5840


 61%|██████    | 61/100 [6:09:22<3:45:11, 346.44s/it]

Epoch: 60 | New Best Val Acc: 0.5855 (Saved)
Epoch: 60 | Train Loss: 1.2033 | Val Loss: 1.2598 | Val Acc: 0.5855


 62%|██████▏   | 62/100 [6:15:18<3:41:23, 349.55s/it]

Epoch: No improvement. Patience 1/101
Epoch: 61 | Train Loss: 1.2012 | Val Loss: 1.2587 | Val Acc: 0.5850


 63%|██████▎   | 63/100 [6:21:15<3:36:51, 351.68s/it]

Epoch: 62 | New Best Val Acc: 0.5867 (Saved)
Epoch: 62 | Train Loss: 1.1998 | Val Loss: 1.2576 | Val Acc: 0.5867


 64%|██████▍   | 64/100 [6:27:29<3:35:03, 358.44s/it]

Epoch: No improvement. Patience 1/101
Epoch: 63 | Train Loss: 1.1983 | Val Loss: 1.2601 | Val Acc: 0.5853


 65%|██████▌   | 65/100 [6:33:43<3:31:42, 362.94s/it]

Epoch: 64 | New Best Val Acc: 0.5868 (Saved)
Epoch: 64 | Train Loss: 1.1967 | Val Loss: 1.2569 | Val Acc: 0.5868


 66%|██████▌   | 66/100 [6:40:00<3:28:10, 367.38s/it]

Epoch: No improvement. Patience 1/101
Epoch: 65 | Train Loss: 1.1952 | Val Loss: 1.2588 | Val Acc: 0.5854


 67%|██████▋   | 67/100 [6:46:16<3:23:26, 369.90s/it]

Epoch: No improvement. Patience 2/101
Epoch: 66 | Train Loss: 1.1933 | Val Loss: 1.2588 | Val Acc: 0.5867


 68%|██████▊   | 68/100 [6:52:27<3:17:28, 370.25s/it]

Epoch: 67 | New Best Val Acc: 0.5870 (Saved)
Epoch: 67 | Train Loss: 1.1912 | Val Loss: 1.2586 | Val Acc: 0.5870


 69%|██████▉   | 69/100 [6:58:47<3:12:48, 373.19s/it]

Epoch: 68 | New Best Val Acc: 0.5870 (Saved)
Epoch: 68 | Train Loss: 1.1897 | Val Loss: 1.2584 | Val Acc: 0.5870


 70%|███████   | 70/100 [7:05:01<3:06:39, 373.31s/it]

Epoch: 69 | New Best Val Acc: 0.5885 (Saved)
Epoch: 69 | Train Loss: 1.1877 | Val Loss: 1.2567 | Val Acc: 0.5885


 71%|███████   | 71/100 [7:10:24<2:53:08, 358.21s/it]

Epoch: No improvement. Patience 1/101
Epoch: 70 | Train Loss: 1.1862 | Val Loss: 1.2597 | Val Acc: 0.5869


 72%|███████▏  | 72/100 [7:16:17<2:46:25, 356.63s/it]

Epoch: No improvement. Patience 2/101
Epoch: 71 | Train Loss: 1.1844 | Val Loss: 1.2603 | Val Acc: 0.5864


 73%|███████▎  | 73/100 [7:22:13<2:40:24, 356.45s/it]

Epoch: No improvement. Patience 3/101
Epoch: 72 | Train Loss: 1.1828 | Val Loss: 1.2591 | Val Acc: 0.5864


In [None]:
piece_to_index = {"P":0,
                  "N":1,
                  "B":2,
                  "R":3,
                  "Q":4,
                  "K":5,
                  "p":6,
                  "n":7,
                  "b":8,
                  "r":9,
                  "q":10,
                  "k":11}

In [None]:
def fen_to_vector(fen: str) -> np.ndarray:
    """
    Converts FEN to a 775-dim vector (Bitboards + Game State).
    """
    board = chess.Board(fen)
    vector = np.zeros(775, dtype=np.uint8)
    
    for square, piece in board.piece_map().items():
        idx = piece_to_index[piece.symbol()] * 64 + square
        vector[idx] = 1

    
    # Side to Move (1 = White, 0 = Black)
    vector[768] = 1.0 if board.turn == chess.WHITE else 0.0
    
    # Castling Rights
    vector[769] = 1.0 if board.has_kingside_castling_rights(chess.WHITE) else 0.0
    vector[770] = 1.0 if board.has_queenside_castling_rights(chess.WHITE) else 0.0
    vector[771] = 1.0 if board.has_kingside_castling_rights(chess.BLACK) else 0.0
    vector[772] = 1.0 if board.has_queenside_castling_rights(chess.BLACK) else 0.0
    
    # If there is an en-passant square target, set to 1
    vector[773] = 1.0 if board.ep_square is not None else 0.0

    # Is there a check?
    vector[774] = 1.0 if board.is_check() else 0.0
    
    return vector

In [None]:
def check_model_prediction(model: torch.nn.Module,
                           random_fen: str,
                           fen_class: int,
                           device=device):
    """Takes the given fen to see the board, predicted score and actual score"""

    numpy_fen = fen_to_vector(random_fen)
    torch_fen = torch.tensor(numpy_fen, dtype=torch.float32).unsqueeze(0).to(device)
    pred = torch.argmax(model(torch_fen), dim=-1)
    print("Model Prediction: ", pred.item())
    print("Stockfish Evaluation: ", fen_class)

# Remember to update fen_class manually
random_fen = "r3kb1r/p2b1pp1/2p1pq1p/P2n4/3P4/1Q3N2/1PP2PPP/R1B2RK1 b kq - 2 14"
check_model_prediction(model=model,
                       random_fen=random_fen,
                       fen_class=5, # remember to manually set 
                       device=device)

board = chess.Board(random_fen)
board

In [None]:
from sklearn.metrics import classification_report

all_preds = []
all_labels = []

model.eval()
with torch.inference_mode():
    for X, y in val_dataloader:
        X, y = X.to(device), y.to(device)
        preds = model(X).argmax(dim=1)
        all_preds.extend(preds.cpu().numpy())

        # all_labels.extend(y.cpu().numpy())
        true_labels = y.argmax(dim=1) 
        all_labels.extend(true_labels.cpu().numpy())

print(classification_report(all_labels, all_preds))


In [None]:
import json
import time

LOGS_DIR = f"experiments/logs/{model_save_name}"

def save_config_metadata(experiment_name: str, 
                         model: torch.nn.Module, 
                         hyperparams: dict, 
                         dataset_paths: dict,
                         save_dir: str = LOGS_DIR):
    """
    Saves all 'static' setup details: Model architecture, parameter counts, 
    datasets used, and hyperparameters.
    """
    Path(save_dir).mkdir(parents=True, exist_ok=True)

    # Model Metadata
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    
    config_data = {
        "experiment_name": experiment_name,
        "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
        "model_architecture": {
            "class_name": model.__class__.__name__,
            "total_parameters": total_params,
            "trainable_parameters": trainable_params,
            "input_dim": hyperparams.get("input_shape", "unknown"),
            "output_dim": hyperparams.get("output_shape", "unknown"),
            "structure_summary": str(model)
        },
        "datasets": dataset_paths,
        "hyperparameters": hyperparams,
        "device": torch.cuda.get_device_name() if torch.cuda.is_available() else "cpu"
    }

    file_path = f"{save_dir}/{experiment_name}_config.json"
    with open(file_path, "w") as f:
        json.dump(config_data, f, indent=4)
    
    print(f"[Config] Saved metadata to {file_path}")
    
def save_training_logs(experiment_name: str, 
                       results_dict: dict, 
                       save_dir: str = LOGS_DIR):
    """
    Saves the epoch-by-epoch learning curves (Loss/Acc) to CSV.
    Expects results_dict to be the output from your run_experiment function.
    """
    Path(save_dir).mkdir(parents=True, exist_ok=True)
    
    df = pd.DataFrame(results_dict)
    
    if "epoch" not in df.columns:
        df["epoch"] = range(1, len(df) + 1)
        
    file_path = f"{save_dir}/{experiment_name}_learning_curves.csv"
    df.to_csv(file_path, index=False)
    
    print(f"[Logs] Saved training history to {file_path}")

In [None]:
import numpy as np
from pathlib import Path
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score
from typing import List, Dict, Tuple

RESULTS_DIR = f"experiments/results/{model_save_name}"

def calculate_ordinal_metrics(preds: np.ndarray,
                              labels: np.ndarray) -> Dict[str, float]:
    """
    Calculates metrics specific to ordinal classification (where Class 0 is close to Class 1).
    """
    abs_diffs = np.abs(preds - labels)
    
    metrics = {
        "mae": float(np.mean(abs_diffs)),
        "off_by_one_accuracy": float(np.mean(abs_diffs <= 1)),
        "off_by_two_accuracy": float(np.mean(abs_diffs <= 2))
    }
    return metrics

def categorize_failures(preds: np.ndarray, 
                        labels: np.ndarray) -> Dict[int, List[int]]:
    """
    Categorizes errors by magnitude.
    Returns a dict where keys are the error magnitude (3, 4, 5, 6) and values are lists of dataset indices.
    """
    abs_diffs = np.abs(preds - labels)
    failure_dict = {}
    
    # We care about errors >= 3 (e.g. Predicting 'Equal' when 'Black Winning')
    # Max error is 6 (Predicting 'White Winning' when 'Black Winning')
    for magnitude in range(3, 7):
        indices = np.where(abs_diffs == magnitude)[0].tolist()
        if indices:
            failure_dict[magnitude] = indices
            
    return failure_dict

def run_inference(model: torch.nn.Module, 
                  dataloader: torch.utils.data.DataLoader, 
                  device: str) -> Tuple[np.ndarray, np.ndarray, float]:
    """
    Runs inference and tracks latency. Returns predictions, true labels, and avg latency per sample (ms).
    """
    model.eval()
    all_preds = []
    all_labels = []
    
    start_time = time.time()
    
    with torch.inference_mode():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            preds = model(X).argmax(dim=1)
            all_preds.extend(preds.cpu().numpy())

            # all_labels.extend(y.cpu().numpy())
            true_labels = y.argmax(dim=1) 
            all_labels.extend(true_labels.cpu().numpy())
            
    total_time = time.time() - start_time
    num_samples = len(all_labels)
    avg_latency_ms = (total_time / num_samples) * 1000
    
    return np.array(all_preds), np.array(all_labels), avg_latency_ms

def save_test_results(experiment_name: str, 
                      model: torch.nn.Module, 
                      test_dataloader: torch.utils.data.DataLoader, 
                      device: str,
                      save_dir: str = RESULTS_DIR):
    """
    Orchestrates the testing process and saves all research-grade metrics.
    """
    Path(save_dir).mkdir(parents=True, exist_ok=True)
        
    preds, labels, latency_ms = run_inference(model, test_dataloader, device)

    acc = accuracy_score(labels, preds)
    class_report = classification_report(labels, preds, output_dict=True)
    conf_matrix = confusion_matrix(labels, preds)
    ordinal_metrics = calculate_ordinal_metrics(preds, labels)
    failure_indices = categorize_failures(preds, labels)
    
    final_metrics = {
        "experiment_name": experiment_name,
        "global_accuracy": acc,
        "inference_latency_ms": latency_ms,
        "ordinal_metrics": ordinal_metrics,
        "catastrophic_failure_counts": {k: len(v) for k, v in failure_indices.items()},
        "classification_report": class_report
    }

    json_path = f"{save_dir}/{experiment_name}_metrics.json"
    with open(json_path, "w") as f:
        json.dump(final_metrics, f, indent=4)
        
    npy_path = f"{save_dir}/{experiment_name}_confusion_matrix.npy"
    np.save(npy_path, conf_matrix)
    
    # Failure Indices JSON (for later visual analysis of specific FENs)
    failures_path = f"{save_dir}/{experiment_name}_failure_indices.json"
    with open(failures_path, "w") as f:
        json.dump(failure_indices, f)
    
    print("-" * 60)
    print(f"[Results] Accuracy:        {acc*100:.2f}%")
    print(f"[Results] Off-by-1 Acc:    {ordinal_metrics['off_by_one_accuracy']*100:.2f}%")
    print(f"[Results] MAE:             {ordinal_metrics['mae']:.4f}")
    print(f"[Results] Latency:         {latency_ms:.4f} ms/sample")
    print("[Results] Catastrophic Failures (Count):")
    for k in sorted(failure_indices.keys()):
        print(f"   - Off by {k}: {len(failure_indices[k])} samples")
    print(f"[Results] Saved all metrics to {save_dir}")
    print("-" * 60)


In [None]:
hyperparams = {
    "epochs": NUM_EPOCHS,
    "batch_size": BATCH_SIZE,
    "learning_rate": 0.001,
    "optimizer": "AdamW",
    "input_shape": 775,
    "output_shape": 7
}

dataset_paths = {
    "train": str(ROOT_DIR / "train_X.npy"),
    "val":   str(ROOT_DIR / "val_X.npy"),
    "test":  str(ROOT_DIR / "test_X.npy")
}

save_config_metadata(experiment_name=RUN_ID,
                     model=model,
                     hyperparams=hyperparams,
                     dataset_paths=dataset_paths)

# Save Training Logs (Using the 'result' variable from run_experiment)
save_training_logs(experiment_name=RUN_ID, 
                   results_dict=result)

save_test_results(experiment_name=RUN_ID,
                  model=model,
                  test_dataloader=test_dataloader,
                  device=device)

In [None]:
import pandas as pd
import matplotlib.pyplot as plt

file_path = Path(f"./experiments/logs/{model_save_name}/{RUN_ID}_learning_curves.csv")
df = pd.read_csv(file_path)

plt.figure(figsize=(12, 5))

# Loss
plt.subplot(1, 2, 1)
plt.plot(df['epoch'], df['train_loss'], label='Train Loss', color='blue')
plt.plot(df['epoch'], df['val_loss'], label='Val Loss', color='orange')
plt.title('Loss over Epochs')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.grid(True, alpha=0.3)

# Accuracy
plt.subplot(1, 2, 2)
plt.plot(df['epoch'], df['train_acc'], label='Train Accuracy', color='blue')
plt.plot(df['epoch'], df['val_acc'], label='Val Accuracy', color='orange')
plt.title('Accuracy over Epochs')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()