In [None]:
import torch
import pandas as pd
import numpy as np
import chess

import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader, random_split

In [None]:
import psutil
ram = psutil.virtual_memory()
print(f"Total RAM: {ram.total / (1024**3):.2f} GB")
print(f"Available: {ram.available / (1024**3):.2f} GB")

In [None]:
sample_fen = "6k1/1pp1n1p1/p2p1pnp/3P4/4QP1P/2B2BP1/1PP3K1/1q6 b - - 0 28"
board = chess.Board(sample_fen)
board

In [None]:
board.piece_map()

In [None]:
piece_to_index = {"P":0,
                  "N":1,
                  "B":2,
                  "R":3,
                  "Q":4,
                  "K":5,
                  "p":6,
                  "n":7,
                  "b":8,
                  "r":9,
                  "q":10,
                  "k":11}

In [None]:
def fen_to_vector(fen: str) -> np.ndarray:
    """
    Convert a FEN string into a 768-length binary vector.
    Each of the 12 piece types has its own 64-slot block, one per square. 
    A 1 marks the presence of a piece on a square, 0 otherwise.
    """

    board = chess.Board(fen)
    vector = np.zeros(64*12, dtype=np.uint8)

    for square, piece in board.piece_map().items():
        idx = piece_to_index[piece.symbol()] * 64 + square
        vector[idx] = 1

    turn = 1 if board.turn == chess.WHITE else 0
    wc_k = 1 if board.has_kingside_castling_rights(chess.WHITE) else 0
    wc_q = 1 if board.has_queenside_castling_rights(chess.WHITE) else 0
    bc_k = 1 if board.has_kingside_castling_rights(chess.BLACK) else 0
    bc_q = 1 if board.has_queenside_castling_rights(chess.BLACK) else 0
    
    return vector

In [None]:
def fen_to_vector_with_state(fen: str) -> np.ndarray:
    """
    Converts FEN to a 774-dim vector (Bitboards + Game State).
    """
    board = chess.Board(fen)
    vector = np.zeros(774, dtype=np.uint8)
    
    for square, piece in board.piece_map().items():
        idx = piece_to_index[piece.symbol()] * 64 + square
        vector[idx] = 1

    
    # Bit 768: Side to Move (1 = White, 0 = Black)
    vector[768] = 1.0 if board.turn == chess.WHITE else 0.0
    
    # Bit 769-772: Castling Rights
    vector[769] = 1.0 if board.has_kingside_castling_rights(chess.WHITE) else 0.0
    vector[770] = 1.0 if board.has_queenside_castling_rights(chess.WHITE) else 0.0
    vector[771] = 1.0 if board.has_kingside_castling_rights(chess.BLACK) else 0.0
    vector[772] = 1.0 if board.has_queenside_castling_rights(chess.BLACK) else 0.0
    
    # Bit 773: En Passant 
    # If there is an en-passant square target, set to 1
    vector[773] = 1.0 if board.ep_square is not None else 0.0
    
    return vector

In [None]:
import pandas as pd
import numpy as np
import glob
import os
from pathlib import Path
from tqdm import tqdm
import pyarrow.parquet as pq

def consolidate_to_memmap(input_dir: Path, output_dir: Path) -> int:
    """
    Reads Parquet files and writes them to a raw binary memmap on disk.
    Returns the total number of rows processed.
    """
    os.makedirs(output_dir, exist_ok=True)
    
    # Count total rows to make sure we have the correct number
    files = sorted(list(input_dir.glob("*.parquet")))
    total_rows = sum(pq.read_metadata(f).num_rows for f in files)
    print(f"Total Dataset Size: {total_rows}")
    
    # Create files to reserve space in memory
    temp_X_path = output_dir / "temp_X.dat"
    temp_y_path = output_dir / "temp_y.dat"
    
    X_shape = (total_rows, 768) # (sample, bitmap)
    y_shape = (total_rows,) # (label,)
    
    print(f"Allocating {total_rows * 768 / 1e9:.2f} GB on disk...") # 768 bits * total_rows = memory taken
    
    X_mmap = np.memmap(temp_X_path, dtype='uint8', mode='w+', shape=X_shape)
    y_mmap = np.memmap(temp_y_path, dtype='int16', mode='w+', shape=y_shape) # when loading data into PyTorch for classification, labels are typically converted to LongTensor (int64), here we use int16 to save space

    # Stream data
    current_idx = 0
    for f in tqdm(files, desc="Writing to Memmap"):
        df = pd.read_parquet(f)
        n = len(df)
        
        # Vectorize
        x_chunk = np.stack([fen_to_vector(fen) for fen in df["fen"]])
        
        # Labeling
        y_chunk = df["stockfish_label_depth_20"].values.astype(np.int16)
        
        # Write to disk slot
        X_mmap[current_idx : current_idx + n] = x_chunk
        y_mmap[current_idx : current_idx + n] = y_chunk
        
        current_idx += n
        
        # Periodic flush to ensure data hits the disk
        if current_idx % 100_000 == 0:
            X_mmap.flush()
            y_mmap.flush()
            
    # Final flush
    X_mmap.flush()
    y_mmap.flush()
    
    # Delete the memmap objects to close file handles
    del X_mmap
    del y_mmap
    
    print(f"Raw data stored in {output_dir}")
    return total_rows

In [None]:
def create_splits_from_memmap(data_dir: Path, 
                              total_rows: int, 
                              train_ratio: float = 0.8, 
                              val_ratio: float = 0.1, 
                              test_ratio: float = 0.1):
    """
    Loads the raw memmap, shuffles indices, and saves compressed .npz splits.
    """
    assert abs((train_ratio + val_ratio + test_ratio) - 1.0) < 1e-5, "Ratios must sum to 1.0"
    
    temp_X_path = data_dir / "temp_X.dat"
    temp_y_path = data_dir / "temp_y.dat"
    
    if not temp_X_path.exists():
        print(f"Error: Raw data not found at {temp_X_path}.")
        return

    X_mmap = np.memmap(temp_X_path, dtype='uint8', mode='r', shape=(total_rows, 768))
    y_mmap = np.memmap(temp_y_path, dtype='int16', mode='r', shape=(total_rows,))
    
    # Shuffle indices so 
    indices = np.arange(total_rows)
    np.random.shuffle(indices)
    
    # Train Test Split
    train_end = int(train_ratio * total_rows)
    val_end = int((train_ratio + val_ratio) * total_rows)

    train_idx = indices[:train_end]
    val_idx = indices[train_end:val_end]
    test_idx = indices[val_end:]

    print(f"Splits: Train={len(train_idx)}, Val={len(val_idx)}, Test={len(test_idx)}")
    
    np.savez_compressed(data_dir / "chess_bitboard_train.npz", X=X_mmap[train_idx], y=y_mmap[train_idx])
    np.savez_compressed(data_dir / "chess_bitboard_val.npz",   X=X_mmap[val_idx],   y=y_mmap[val_idx])
    np.savez_compressed(data_dir / "chess_bitboard_test.npz",  X=X_mmap[test_idx],  y=y_mmap[test_idx])
    
    print(f"Files saved to {data_dir}")

In [None]:
INPUT_PATH = Path("./dataset_parts")
OUTPUT_PATH = Path("./dataset_processed")

total_rows = consolidate_to_memmap(INPUT_PATH, OUTPUT_PATH)

create_splits_from_memmap(
    data_dir=OUTPUT_PATH, 
    total_rows=total_rows, 
    train_ratio=0.8, 
    val_ratio=0.1, 
    test_ratio=0.1
)

In [None]:
FILE_PATH = Path("./dataset_processed/chess_bitboard_train.npz")

with np.load(FILE_PATH) as data:    
    X = data['X']
    y = data['y']
    print(f"Total Shapes:")
    print(f"  X (Features): {X.shape}")
    print(f"  y (Labels):   {y.shape}")
    print(f"First Label: {y[0]}")
    print(f"First sample vector: {X[0]}")