# DATA MANIPULATION

In [1]:
import numpy as np
import pandas as pd
import chess
import torch
from torch.utils.data import Dataset, DataLoader, random_split
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm  # For Jupyter Notebook
import torch.nn.functional as F
# Check GPU availability
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
#device = 'cpu'
print(f"Using device: {device}")

Using device: cuda


# DATASET CLASS

In [2]:
class ChessDataset(Dataset):
    def __init__(self, dataframe):
        self.data = dataframe.values
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        fen, move = self.data[idx]
        
        # Convert FEN to 8x8x20 tensor
        board_tensor = torch.zeros((20, 8, 8), dtype=torch.float32)
        board = chess.Board(fen)
        
        # Piece channels (0-11)
        piece_map = {
            'P': 0, 'N': 1, 'B': 2, 'R': 3, 'Q': 4, 'K': 5,
            'p': 6, 'n': 7, 'b': 8, 'r': 9, 'q': 10, 'k': 11
        }
        
        for square in chess.SQUARES:
            piece = board.piece_at(square)
            if piece:
                rank, file = 7 - square // 8, square % 8
                board_tensor[piece_map[piece.symbol()], rank, file] = 1
        
        # Metadata channels (12-19)
        board_tensor[12] = int(board.turn)  # White to move
        board_tensor[13] = int(board.has_kingside_castling_rights(chess.WHITE))
        board_tensor[14] = int(board.has_queenside_castling_rights(chess.WHITE))
        board_tensor[15] = int(board.has_kingside_castling_rights(chess.BLACK))
        board_tensor[16] = int(board.has_queenside_castling_rights(chess.BLACK))
        board_tensor[17] = int(board.has_legal_en_passant())
        board_tensor[18] = board.halfmove_clock / 50.0
        board_tensor[19] = board.fullmove_number / 100.0
        
        # Convert move to label (simplified version)
        move_label = self.move_to_index(move)
        
        return board_tensor, move_label
    
    def move_to_index(self, uci_move):
        """Convert UCI move to integer label (0-4671)"""
        # Create all possible moves once and cache
        if not hasattr(self, 'move_lookup'):
            self.move_lookup = {}
            label = 0
            for from_sq in chess.SQUARES:
                for to_sq in chess.SQUARES:
                    if from_sq == to_sq:
                        continue
                    for promo in [None, 'q', 'r', 'b', 'n']:
                        self.move_lookup[f"{chess.square_name(from_sq)}{chess.square_name(to_sq)}{promo if promo else ''}"] = label
                        label += 1
        return self.move_lookup.get(uci_move, 0)

# Create datasets
train_df = pd.read_csv("train_data.csv", encoding = "utf-8")
val_df = pd.read_csv("val_data.csv", encoding = "utf-8")
train_dataset = ChessDataset(train_df)
val_dataset = ChessDataset(val_df)

#print(train_dataset.shape())

# MODEL ARCH

In [3]:
import torch.nn as nn
class ResBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(channels)
        self.bn2 = nn.BatchNorm2d(channels)
        
    def forward(self, x):
        residual = x
        x = F.relu(self.bn1(self.conv1(x)))
        x = self.bn2(self.conv2(x))
        x += residual
        return F.relu(x)

class ChessResNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(20, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            ResBlock(256),
            ResBlock(256),
            nn.Conv2d(256, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(128*8*8, 4672)
        )
        
    def forward(self, x):
        return self.net(x)

model = ChessResNet().to(device)
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")

Model parameters: 40,982,208


# MODEL TRAINING - DON'T RUN IF TESTING

In [6]:
BATCH_SIZE = 512
LR = 0.001
EPOCHS = 20
# from tqdm import tqdm  # Use this if you're not in Jupyter

# Data loaders
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, num_workers=0, pin_memory=True)

# Optimization
optimizer = torch.optim.Adam(model.parameters(), lr=LR)
criterion = nn.CrossEntropyLoss()
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max', patience=2)

# Training history
history = {'train_loss': [], 'val_acc': []}

for epoch in range(EPOCHS):
    # Training phase with progress bar
    model.train()
    train_loss = 0
    train_progress = tqdm(train_loader, desc=f'Epoch {epoch+1}/{EPOCHS} [Train]', leave=False)
    
    for inputs, labels in train_progress:
        inputs, labels = inputs.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item() * inputs.size(0)
        
        # Update progress bar
        train_progress.set_postfix({
            'loss': f"{loss.item():.4f}",
            'lr': f"{optimizer.param_groups[0]['lr']:.2e}"
        })
    
    # Validation phase with progress bar
    model.eval()
    correct = 0
    total = 0
    val_progress = tqdm(val_loader, desc=f'Epoch {epoch+1}/{EPOCHS} [Val]', leave=False)
    
    with torch.no_grad():
        for inputs, labels in val_progress:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            
            # Update progress bar
            val_progress.set_postfix({
                'acc': f"{correct/total:.2%}"
            })
    
    # Calculate epoch metrics
    epoch_loss = train_loss / len(train_loader.dataset)
    epoch_acc = correct / total
    history['train_loss'].append(epoch_loss)
    history['val_acc'].append(epoch_acc)
    
    # Learning rate adjustment
    scheduler.step(epoch_acc)
    
    # Epoch summary
    print(f"\nEpoch {epoch+1}/{EPOCHS} Summary:")
    print(f"Train Loss: {epoch_loss:.4f} | Val Acc: {epoch_acc:.2%}")
    print(f"Learning Rate: {optimizer.param_groups[0]['lr']:.2e}")
    print("-" * 50)

Epoch 1/20 [Train]:   0%|          | 0/899 [00:00<?, ?it/s]

RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


# RESULT VISUALIZATION

In [4]:
model = ChessResNet()
model.load_state_dict(torch.load('chess_resnet.pth', weights_only=True))
#model = torch.load(model.state_dict(), 'chess_resnet.pth')

<All keys matched successfully>

In [5]:
#plt.figure(figsize=(12, 5))
#plt.subplot(1, 2, 1)
#plt.plot(history['train_loss'], label='Train Loss')
#plt.title('Training Loss')
#plt.xlabel('Epoch')
#plt.legend()

#plt.subplot(1, 2, 2)
#plt.plot(history['val_acc'], label='Val Accuracy')
#plt.title('Validation Accuracy')
#plt.xlabel('Epoch')
#plt.legend()

#plt.tight_layout()
#plt.show()

# Save model
#torch.save(model.state_dict(), 'chess_resnet.pth')

# CODE TO PLAY WITH THE MODEL

In [8]:
import torch
import re

class FenToTensor:
    def __init__(self):
        # Piece type mapping to channel indices
        self.piece_map = {
            'P': 0, 'N': 1, 'B': 2, 'R': 3, 'Q': 4, 'K': 5,
            'p': 6, 'n': 7, 'b': 8, 'r': 9, 'q': 10, 'k': 11
        }
    
    def convert(self, fen):
        """Convert FEN string to 20x8x8 tensor"""
        board_tensor = torch.zeros((20, 8, 8), dtype=torch.float32)
        
        # Parse FEN components
        parts = fen.split()
        board_part = parts[0]
        active_color = parts[1]
        castling = parts[2]
        en_passant = parts[3]
        halfmove_clock = int(parts[4])
        fullmove_number = int(parts[5])
        
        # Process board position
        self._process_board(board_part, board_tensor)
        
        # Process metadata
        self._process_metadata(
            board_tensor, active_color, castling, 
            en_passant, halfmove_clock, fullmove_number
        )
        
        return board_tensor
    
    def _process_board(self, fen_board, tensor):
        """Process the board part of FEN into tensor channels 0-11"""
        rank = 0  # 0 is top rank (a8-h8), 7 is bottom rank (a1-h1)
        file = 0
        
        for char in fen_board:
            if char == '/':
                # Move to next rank
                rank += 1
                file = 0
            elif char.isdigit():
                # Skip empty squares
                file += int(char)
            else:
                # Place the piece
                channel = self.piece_map[char]
                tensor[channel, rank, file] = 1
                file += 1
    
    def _process_metadata(self, tensor, active_color, castling, 
                         en_passant, halfmove_clock, fullmove_number):
        """Process metadata into channels 12-19"""
        # Channel 12: Turn (1 for white, 0 for black)
        tensor[12] = 1 if active_color == 'w' else 0
        
        # Channels 13-16: Castling rights
        tensor[13] = 1 if 'K' in castling else 0  # White kingside
        tensor[14] = 1 if 'Q' in castling else 0  # White queenside
        tensor[15] = 1 if 'k' in castling else 0  # Black kingside
        tensor[16] = 1 if 'q' in castling else 0  # Black queenside
        
        # Channel 17: En passant target (1 if exists)
        tensor[17] = 1 if en_passant != '-' else 0
        
        # Channel 18: Halfmove clock (normalized)
        tensor[18] = min(halfmove_clock / 50.0, 1.0)
        
        # Channel 19: Fullmove number (normalized)
        tensor[19] = min(fullmove_number / 100.0, 1.0)

In [50]:
import random
def index_to_uci(index: int) -> str:
    """Convert model's output index (0-4671) back to UCI move string."""
    # Generate all possible moves in the same order as ChessDataset
    moves = []
    for from_sq in chess.SQUARES:
        for to_sq in chess.SQUARES:
            if from_sq == to_sq:
                continue  # Skip null moves
            # Non-promotion case
            moves.append(f"{chess.square_name(from_sq)}{chess.square_name(to_sq)}")
            # Promotion cases (q, r, b, n)
            for promo in ['q', 'r', 'b', 'n']:
                moves.append(f"{chess.square_name(from_sq)}{chess.square_name(to_sq)}{promo}")
    return moves[index] if index < len(moves) else "0000"

def get_best_legal_move(model_output, fen):
    board = chess.Board(fen)
    if not board.legal_moves:
        return None  # Game is over
    
    probs = torch.softmax(model_output, dim=1)[0]
    legal_moves = []
    
    # Check all possible moves (not just top k)
    for move_idx in range(len(probs)):
        uci_move = index_to_uci(move_idx)
        try:
            move = chess.Move.from_uci(uci_move)
            if move in board.legal_moves:
                legal_moves.append((uci_move, probs[move_idx].item()))
        except:
            continue
    
    if not legal_moves:
        print("RANDOM MOVE")
        return random.choice([move.uci() for move in board.legal_moves])
    
    # Return move with highest probability
    legal_moves.sort(key=lambda x: x[1], reverse=True)
    return legal_moves[0][0]

converter = FenToTensor()
model = model.eval()

In [59]:
TEMPERATURE = 0.75  # Try 0.5-1.0 for output sharpening
fen = '6k1/1p3ppp/rr2p3/1Kpp1b2/p7/P3b1P1/4P1BP/8 w - - 7 34'

tensor = converter.convert(fen)
output = model(tensor.unsqueeze(0)) 
logits = output / TEMPERATURE

output = torch.softmax(logits, dim=1)           # OK
best_move = get_best_legal_move(output, fen)    # NE RADI 

print(f"Best legal move: {best_move}")

Best legal move: None
