In [1]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import chess
import chess.svg
import json
import time
from pathlib import Path
from enum import Enum
from IPython.display import display, SVG

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 2048 
TEST_PATH = Path("./dataset_bitmaps/bitboard_test.npz") 
RESULTS_DIR = Path("experiments/results")
LOGS_DIR = Path("experiments/logs")

# ==========================================
MODEL_NAME = "wide_mlp_v1"
MODEL_TYPE = "WideMLP"
MODEL_PATH = "wide_mlp_v1.pth"
INPUT_SHAPE = 775
HIDDEN_UNITS = 2048 
OUTPUT_SHAPE = 7
FAILURE_INDICES_PATH = RESULTS_DIR / MODEL_NAME /"run_2026_01_05_wide_mlp_v1_failure_indices.json"   
# ==========================================

In [None]:
class WideMLP(nn.Module):
    def __init__(self,
                 input_shape: int,
                 hidden_units: int,
                 output_shape: int) -> None:
        super().__init__()
        self.network = nn.Sequential(
            nn.Linear(in_features=input_shape,
                      out_features=hidden_units),
            nn.BatchNorm1d(hidden_units),
            nn.ReLU(),
            nn.Dropout(p=0.5),
            nn.Linear(in_features=hidden_units,
                      out_features=hidden_units),
            nn.BatchNorm1d(hidden_units),
            nn.ReLU(),
            nn.Dropout(p=0.5),
            nn.Linear(hidden_units, output_shape)
        )

    def forward(self, x):
        return self.network(x)
    
class PyramidalMLP_v2(nn.Module):
    def __init__(self, input_shape=775, output_shape=7):
        super().__init__()
        
        self.network = nn.Sequential(
            nn.Linear(input_shape, 1024), 
            nn.BatchNorm1d(1024),
            nn.ReLU(),
            nn.Dropout(0.2),
            
            nn.Linear(1024, 600),
            nn.BatchNorm1d(600),
            nn.ReLU(),
            nn.Dropout(0.2),
            
            nn.Linear(600, 400),
            nn.BatchNorm1d(400),
            nn.ReLU(),
            nn.Dropout(0.2),
            
            nn.Linear(400, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Dropout(0.2),

            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Dropout(0.5),
            
            nn.Linear(128, output_shape)
        )

    def forward(self, x):
        return self.network(x)


In [3]:
class NPZChessDataset(Dataset):
    def __init__(self, npz_path: Path):
        with np.load(npz_path) as data:
            self.X = torch.tensor(data["X"], dtype=torch.float32)
            self.y = torch.tensor(data["y"], dtype=torch.long)
    def __len__(self): return self.X.shape[0]
    def __getitem__(self, idx): return self.X[idx], self.y[idx]

def get_test_loader():
    dataset = NPZChessDataset(TEST_PATH)
    return DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)

# ==========================================
if MODEL_TYPE == "WideMLP":
    model = WideMLP(INPUT_SHAPE, HIDDEN_UNITS, OUTPUT_SHAPE).to(DEVICE)
elif MODEL_TYPE == "PyramidalMLP":
    model = PyramidalMLP_v2(INPUT_SHAPE, OUTPUT_SHAPE).to(DEVICE)
# ==========================================

try:
    model.load_state_dict(torch.load(MODEL_PATH))
    print(f"Loaded weights from {MODEL_PATH}")
except FileNotFoundError:
    print("Weight file not found! Please check MODEL_PATH.")

model.eval()
print("Model ready.")

Loaded weights from wide_mlp_v1.pth
Model ready.


In [None]:
def generate_saliency_map(model, input_tensor, target_class):
    """
    Computes the gradient of the score for the target class w.r.t the input vector.
    """
    model.eval()
    input_tensor.requires_grad = True
    
    # Forward pass
    pred = model(input_tensor)
    score = pred[0, target_class]
    
    # Backward pass
    score.backward()
    
    # Get gradients
    gradients = input_tensor.grad.data.cpu().numpy()[0]
    
    return gradients

def plot_chess_heatmap(gradients, fen_str, true_label, pred_label):
    """
    Aggregates gradients from the 12 bitboards (first 768 indices) into a 8x8 grid.
    """
    # Reshape first 768 features into (12 channels, 64 squares)
    # The order of channels : P, N, B, R, Q, K, p, n, b, r, q, k
    piece_grads = gradients[:768].reshape(12, 64)
    
    # Sum absolute gradients across all piece channels to get importance per square
    # Shape becomes (64,)
    saliency = np.sum(np.abs(piece_grads), axis=0)
    
    # Reshape into 8x8 board
    heatmap = saliency.reshape(8, 8)
    
    # Flip logic: rank 1 is usually index 0-7, rank 8 is 56-63.
    # Matplotlib plots 0 at top-left by default, but chess rank 8 is top.
    # We usually just need to reshape correctly.
    # Let's assume standard: Index 0 is a1 (bottom-left).
    # We need to flip vertically for imshow to put Rank 8 at top.
    heatmap = np.flipud(heatmap) 

    # Plot
    fig, ax = plt.subplots(figsize=(6, 6))
    sns.heatmap(heatmap, cmap="viridis", alpha=0.7, zorder=2, ax=ax, cbar=False)
    
    # Load Board Image if FEN provided
    if fen_str:
        # Save svg to temp file or render on top? 
        # Easier approach: Use text over heatmap or side-by-side
        ax.set_title(f"Saliency Map (True: {true_label}, Pred: {pred_label})")
        plt.show()
        
        # Display actual board below
        print("Actual Board Position:")
        display(chess.svg.board(chess.Board(fen_str), size=350))
    else:
        plt.show()

def vector_to_fen(vector):
    """
    Reconstructs a FEN string from the 775-dim bitboard vector.
    """

    index_to_piece = {
        0: 'P', 1: 'N', 2: 'B', 3: 'R', 4: 'Q', 5: 'K',
        6: 'p', 7: 'n', 8: 'b', 9: 'r', 10: 'q', 11: 'k'
    }
    
    board = chess.Board(None) 
    
    for piece_idx in range(12):
        for square in range(64):
            idx = piece_idx * 64 + square
            if vector[idx] == 1:
                piece = chess.Piece.from_symbol(index_to_piece[piece_idx])
                board.set_piece_at(square, piece)
                
    board.turn = chess.WHITE if vector[768] == 1 else chess.BLACK

    castling_fen = ""
    if vector[769] == 1: castling_fen += "K"
    if vector[770] == 1: castling_fen += "Q"
    if vector[771] == 1: castling_fen += "k"
    if vector[772] == 1: castling_fen += "q"
    if castling_fen == "": castling_fen = "-"
    
    board.set_castling_fen(castling_fen)
    
    return board.fen()

In [None]:
with open(FAILURE_INDICES_PATH, "r") as f:
    failure_data = json.load(f)

# ==========================================
ERROR_MAGNITUDE = "6"  # Choose from '3', '4', '5', '6' 
FAILURE_IDX_IN_LIST = 0 # Pick the Nth failure in that category
# ==========================================

dataset_idx = failure_data[ERROR_MAGNITUDE][FAILURE_IDX_IN_LIST]

# Load that specific sample
dataset = NPZChessDataset(TEST_PATH)
X_sample, y_sample = dataset[dataset_idx]
    
# Unsqueeze so [775] > [1, 775] as model expects batch
X_input = X_sample.unsqueeze(0).to(DEVICE)

# Predict
model.eval()
pred_logits = model(X_input)
pred_label = torch.argmax(pred_logits, dim=1).item()

# Generate Saliency to see why it predicted the wrong class
grads = generate_saliency_map(model, X_input, pred_label)

print(grads)

# # Reconstruct FEN for visualization
# reconstructed_fen = vector_to_fen(X_sample)

# print(f"Analyzing Test Sample Index: {dataset_idx}")
# print(f"FEN: {reconstructed_fen}")
# print(f"True Label: {y_sample.item()} | Predicted: {pred_label}")
# print(f"Error Magnitude: {abs(y_sample.item() - pred_label)}")

# plot_chess_heatmap(grads, fen_str=reconstructed_fen, true_label=y_sample.item(), pred_label=pred_label)

[-4.33430448e-03 -4.13293578e-03  5.56480885e-03  1.18819419e-02
  7.69415498e-03  1.03726494e-03  3.68698314e-03  1.63372443e-03
 -1.92809328e-01  4.83172089e-02 -3.46954525e-01  1.25636905e-01
 -8.77027273e-01 -5.59846818e-01 -3.79317999e-01 -2.21662521e-02
  7.05590770e-02  3.75913471e-01  2.97038198e-01 -5.80576956e-02
 -2.35117644e-01 -5.60871661e-01 -3.97791743e-01 -2.16712549e-01
  1.05204411e-01 -1.92649662e-04  5.52715778e-01  1.25195235e-02
 -1.13929912e-01 -8.12322497e-01 -1.29273385e-02 -7.43571520e-02
  9.18054208e-02 -6.78124309e-01 -1.10529214e-02 -1.95739940e-01
 -2.20500350e-01 -9.61963534e-02 -8.89158845e-02  1.33558095e-01
 -6.96933150e-01 -8.65940690e-01 -2.07026213e-01  1.81053713e-01
 -4.76596415e-01 -9.33789611e-02  1.24928236e-01  2.81348228e-01
 -1.37934184e+00 -1.81829214e+00 -5.98194003e-01 -5.54945469e-01
 -1.17258024e+00 -9.92037177e-01  4.20202821e-01 -2.96234190e-01
  4.99806646e-03  4.82312031e-03 -2.48831068e-03  9.61325038e-03
  9.80768818e-04  3.02116

: 

In [6]:
# PIECE_ACC_LOG = LOGS_DIR / "accuracy_vs_piece_count.csv"

# def compute_piece_accuracy(model, dataloader):
#     print("Computing Accuracy vs Piece Count (this may take a minute)...")
#     results = []
#     model.eval()
    
#     with torch.no_grad():
#         for X, y in dataloader:
#             X, y = X.to(DEVICE), y.to(DEVICE)
#             preds = model(X).argmax(dim=1)
            
#             # X shape is (Batch, 775). First 768 are bitboards.
#             # Summing X[:, :768] gives total number of pieces on board for each sample.
#             # (Assuming bitboards are 1 for piece, 0 for empty)
#             piece_counts = X[:, :768].sum(dim=1).cpu().numpy()
#             correct = (preds == y).cpu().numpy()
            
#             for count, is_correct in zip(piece_counts, correct):
#                 results.append({"piece_count": int(count), "correct": int(is_correct)})
                
#     df = pd.DataFrame(results)
#     # Group by piece count and calc mean accuracy
#     summary = df.groupby("piece_count")["correct"].mean().reset_index()
#     summary.rename(columns={"correct": "accuracy"}, inplace=True)
#     return summary

# # Check if exists
# if PIECE_ACC_LOG.exists():
#     print(f"Loading existing log from {PIECE_ACC_LOG}...")
#     df_piece_acc = pd.read_csv(PIECE_ACC_LOG)
# else:
#     test_loader = get_test_loader()
#     df_piece_acc = compute_piece_accuracy(model, test_loader)
#     df_piece_acc.to_csv(PIECE_ACC_LOG, index=False)
#     print(f"Saved new log to {PIECE_ACC_LOG}")

# # Plot
# plt.figure(figsize=(10, 5))
# sns.lineplot(data=df_piece_acc, x="piece_count", y="accuracy", marker="o")
# plt.title(f"Model Accuracy vs. Number of Pieces ({MODEL_TYPE})")
# plt.xlabel("Total Pieces on Board")
# plt.ylabel("Accuracy")
# plt.grid(True, alpha=0.3)
# plt.show()

In [7]:
# PHASE_ACC_LOG = LOGS_DIR / "accuracy_vs_phase.csv"

# def compute_phase_accuracy(model, dataloader):
#     print("Computing Accuracy vs Game Phase...")
    
#     results = []
#     model.eval()
    
#     with torch.no_grad():
#         for X, y in dataloader:
#             X_dev, y_dev = X.to(DEVICE), y.to(DEVICE)
#             preds = model(X_dev).argmax(dim=1)
#             correct = (preds == y_dev).cpu().numpy()
            
#             # Calculate phase score manually from bitboards
#             # Sum bits in specific ranges * weights
#             X_cpu = X.cpu().numpy()
            
#             # White + Black Material
#             pawns = X_cpu[:, 0:64].sum(1) + X_cpu[:, 384:448].sum(1)
#             knights = X_cpu[:, 64:128].sum(1) + X_cpu[:, 448:512].sum(1)
#             bishops = X_cpu[:, 128:192].sum(1) + X_cpu[:, 512:576].sum(1)
#             rooks = X_cpu[:, 192:256].sum(1) + X_cpu[:, 576:640].sum(1)
#             queens = X_cpu[:, 256:320].sum(1) + X_cpu[:, 640:704].sum(1)
            
#             # Total material score (Standard: Max ~78 excl Kings)
#             material = (pawns * 1) + (knights * 3) + (bishops * 3) + (rooks * 5) + (queens * 9)
            
#             # Binning: Opening (>60), Mid (30-60), End (<30) - adjust as needed
#             for mat, is_corr in zip(material, correct):
#                 results.append({"material_score": mat, "correct": int(is_corr)})

#     df = pd.DataFrame(results)
#     # Binning
#     df['phase_bin'] = pd.cut(df['material_score'], bins=[0, 30, 60, 200], labels=['Endgame', 'Middlegame', 'Opening'])
#     summary = df.groupby("phase_bin")["correct"].mean().reset_index()
#     return summary

# if PHASE_ACC_LOG.exists():
#     print(f"Loading existing log from {PHASE_ACC_LOG}...")
#     df_phase = pd.read_csv(PHASE_ACC_LOG)
# else:
#     test_loader = get_test_loader()
#     df_phase = compute_phase_accuracy(model, test_loader)
#     df_phase.to_csv(PHASE_ACC_LOG, index=False)
#     print(f"Saved new log to {PHASE_ACC_LOG}")

# # Plot
# plt.figure(figsize=(8, 5))
# sns.barplot(data=df_phase, x="phase_bin", y="correct", palette="viridis")
# plt.title("Accuracy vs Game Phase")
# plt.ylabel("Accuracy")
# plt.ylim(0, 1.0)
# plt.show()

In [8]:
# POSITIONAL_LOG = LOGS_DIR / "positional_report.json"

# def generate_positional_report(model, dataloader):
#     print("Generating Positional Understanding Report...")
#     model.eval()
    
#     equal_positions = 0
#     correct_equal = 0
#     predicted_draw_when_not = 0
#     total_samples = 0
    
#     with torch.no_grad():
#         for X, y in dataloader:
#             X, y = X.to(DEVICE), y.to(DEVICE)
#             preds = model(X).argmax(dim=1)
            
#             # Focus on Class 3 (Equal)
#             is_equal_label = (y == 3)
#             equal_positions += is_equal_label.sum().item()
#             correct_equal += (preds[is_equal_label] == 3).sum().item()
            
#             # Where model predicted Draw (3) but label was NOT 3
#             is_predicted_equal = (preds == 3)
#             # Mask out where it was actually equal
#             false_alarms = is_predicted_equal & (~is_equal_label)
#             predicted_draw_when_not += false_alarms.sum().item()
            
#             total_samples += len(y)

#     report = {
#         "total_samples_analyzed": total_samples,
#         "total_equal_positions": equal_positions,
#         "accuracy_on_equal_positions": correct_equal / equal_positions if equal_positions > 0 else 0,
#         "draw_blindness_rate": 1 - (correct_equal / equal_positions) if equal_positions > 0 else 0,
#         "draw_hallucination_count": predicted_draw_when_not
#     }
#     return report

# if POSITIONAL_LOG.exists():
#     with open(POSITIONAL_LOG, "r") as f:
#         report = json.load(f)
#     print("Loaded existing positional report.")
# else:
#     test_loader = get_test_loader()
#     report = generate_positional_report(model, test_loader)
#     with open(POSITIONAL_LOG, "w") as f:
#         json.dump(report, f, indent=4)
#     print("Created new positional report.")

# print(json.dumps(report, indent=4))