# LIBRARIES

In [88]:
import numpy as np
import pandas as pd
import chess
import torch
from torch.utils.data import Dataset, DataLoader, random_split
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm  # For Jupyter Notebook
import torch.nn.functional as F
import torch.nn as nn

# DATASET CLASS

In [89]:
class ChessDataset(Dataset):
    def __init__(self, dataframe):
        self.data = dataframe[['FEN','Eval','Move1']].values
        self.move_lookup = self._build_move_lookup()


    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        fen, eval_score, moves = self.data[idx]
        move1 = moves.split()[0]  # Explicitly take the first move
        
        # FEN to tensor
        board_tensor = self.fen_to_tensor(fen)

        # Policy target as class index
        policy_label = self.move_to_index(move1)

        # Value target as float between -1 and 1
        try:
            eval_score = float(eval_score)
        except:
            eval_score = 0.0

        value_target = torch.tensor([max(-1.0, min(1.0, eval_score / 10.0))], dtype=torch.float32)

        return board_tensor, policy_label, value_target


    def fen_to_tensor(self, fen):
        board_tensor = torch.zeros((20, 8, 8), dtype=torch.float32)
        board = chess.Board(fen)
        piece_map = {
            'P': 0, 'N': 1, 'B': 2, 'R': 3, 'Q': 4, 'K': 5,
            'p': 6, 'n': 7, 'b': 8, 'r': 9, 'q': 10, 'k': 11
        }
        for square in chess.SQUARES:
            piece = board.piece_at(square)
            if piece:
                rank, file = 7 - square // 8, square % 8
                board_tensor[piece_map[piece.symbol()], rank, file] = 1
        board_tensor[12] = int(board.turn)
        board_tensor[13] = int(board.has_kingside_castling_rights(chess.WHITE))
        board_tensor[14] = int(board.has_queenside_castling_rights(chess.WHITE))
        board_tensor[15] = int(board.has_kingside_castling_rights(chess.BLACK))
        board_tensor[16] = int(board.has_queenside_castling_rights(chess.BLACK))
        board_tensor[17] = int(board.has_legal_en_passant())
        board_tensor[18] = board.halfmove_clock / 50.0
        board_tensor[19] = board.fullmove_number / 100.0
        return board_tensor

    def move_to_index(self, uci_move):
        return self.move_lookup.get(uci_move, 0)

    def _build_move_lookup(self):
        lookup = {}
        label = 0
        for from_sq in chess.SQUARES:
            for to_sq in chess.SQUARES:
                if from_sq == to_sq:
                    continue
                # Normal moves
                uci = f"{chess.square_name(from_sq)}{chess.square_name(to_sq)}"
                lookup[uci] = label
                label += 1

                # Promotion moves (only on correct ranks)
                from_rank = chess.square_rank(from_sq)
                to_rank = chess.square_rank(to_sq)
                if (from_rank == 6 and to_rank == 7) or (from_rank == 1 and to_rank == 0):
                    for promo in ['q', 'r', 'b', 'n']:
                        uci_promo = uci + promo
                        lookup[uci_promo] = label
                        label += 1
        return lookup

In [90]:
from torch.utils.data import DataLoader
import pandas as pd

# Load your pre-split CSVs
train_df = pd.read_csv("train_data.csv", encoding="utf-8")
val_df = pd.read_csv("val_data.csv", encoding="utf-8")

train_dataset = ChessDataset(train_df)
val_dataset = ChessDataset(val_df)

BATCH_SIZE = 512

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0, pin_memory=True)


# NETWORK ARCHITECTURE

In [91]:
class ResBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(channels)
        self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(channels)

    def forward(self, x):
        residual = x
        x = F.relu(self.bn1(self.conv1(x)))
        x = self.bn2(self.conv2(x))
        x += residual
        return F.relu(x)

class DualHeadChessNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(20, 256, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(256)
        self.res_blocks = nn.Sequential(*[ResBlock(256) for _ in range(4)])

        self.policy_head = nn.Sequential(
            nn.Conv2d(256, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(128 * 8 * 8, 4672*2),  # <- 4672
            nn.ReLU(),
            nn.Dropout(0.3),  # Regularization
            nn.Linear(4672*2, 4672)  # <- 4672
        )

        self.value_head = nn.Sequential(
            nn.Conv2d(256, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(64 * 8 * 8, 256),
            nn.ReLU(),
            nn.Linear(256, 1),
            nn.Tanh()
        )

    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = self.res_blocks(x)
        policy = self.policy_head(x)
        value = self.value_head(x)
        return policy, value


# TRAINING

In [93]:
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
from tqdm import tqdm
import numpy as np

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = DualHeadChessNet().to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion_policy = nn.CrossEntropyLoss()
criterion_value = nn.MSELoss()

# Learning rate scheduler
scheduler = ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=2)

# Early stopping (optional)
best_val_acc = 0.0
early_stop_counter = 0
patience = 5  # Number of epochs to wait before stopping

EPOCHS = 20

for epoch in range(EPOCHS):
    # Training phase
    model.train()
    train_loss = 0.0
    policy_correct = 0
    policy_total = 0

    train_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS} [Train]", leave=False)
    
    for inputs, policy_targets, value_targets in train_bar:
        inputs = inputs.to(device)
        policy_targets = policy_targets.to(device)
        value_targets = value_targets.to(device)

        optimizer.zero_grad()
        policy_logits, value_preds = model(inputs)

        loss_policy = criterion_policy(policy_logits, policy_targets)
        loss_value = criterion_value(value_preds.squeeze(), value_targets.squeeze())
        loss = loss_policy + loss_value

        loss.backward()
        optimizer.step()

        train_loss += loss.item() * inputs.size(0)
        
        # Calculate training accuracy
        _, predicted = policy_logits.max(1)
        policy_correct += predicted.eq(policy_targets).sum().item()
        policy_total += policy_targets.size(0)

        train_bar.set_postfix({
            'loss': f"{loss.item():.4f}",
            'policy': f"{loss_policy.item():.4f}",
            'value': f"{loss_value.item():.4f}",
            'acc': f"{100.*policy_correct/policy_total:.1f}%"
        })

    avg_train_loss = train_loss / len(train_loader.dataset)
    train_acc = 100. * policy_correct / policy_total
    
    # Validation phase
    model.eval()
    val_loss = 0.0
    val_policy_correct = 0
    val_policy_total = 0
    val_value_mae = 0.0

    val_bar = tqdm(val_loader, desc=f"Epoch {epoch+1}/{EPOCHS} [Val]", leave=False)
    
    with torch.no_grad():
        for inputs, policy_targets, value_targets in val_bar:
            inputs = inputs.to(device)
            policy_targets = policy_targets.to(device)
            value_targets = value_targets.to(device)

            policy_logits, value_preds = model(inputs)

            loss_policy = criterion_policy(policy_logits, policy_targets)
            loss_value = criterion_value(value_preds.squeeze(), value_targets.squeeze())
            loss = loss_policy + loss_value

            val_loss += loss.item() * inputs.size(0)
            
            # Calculate validation accuracy
            _, predicted = policy_logits.max(1)
            val_policy_correct += predicted.eq(policy_targets).sum().item()
            val_policy_total += policy_targets.size(0)
            
            # Calculate value MAE
            val_value_mae += torch.abs(value_preds.squeeze() - value_targets.squeeze()).sum().item()

            val_bar.set_postfix({
                'val_loss': f"{loss.item():.4f}",
                'val_acc': f"{100.*val_policy_correct/val_policy_total:.1f}%"
            })

    avg_val_loss = val_loss / len(val_loader.dataset)
    val_acc = 100. * val_policy_correct / val_policy_total
    val_value_mae = val_value_mae / val_policy_total

    # Step the scheduler based on validation accuracy
    scheduler.step(val_acc)
    
    # Print epoch summary
    print(f"\nEpoch {epoch+1}/{EPOCHS}")
    print(f"Train Loss: {avg_train_loss:.4f} | Train Acc: {train_acc:.2f}%")
    print(f"Val Loss: {avg_val_loss:.4f} | Val Acc: {val_acc:.2f}% | Val MAE: {val_value_mae:.4f}")
    print(f"Current LR: {optimizer.param_groups[0]['lr']:.6f}")
    print("-" * 60)
    
    # Save best model
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(model.state_dict(), "chess_dualhead_best.pth")
        print(f"New best model saved with val_acc: {val_acc:.2f}%")
        early_stop_counter = 0
    else:
        early_stop_counter += 1
    
    # Early stopping check
    if early_stop_counter >= patience:
        print(f"Early stopping triggered at epoch {epoch+1}!")
        break

# Save final model
torch.save(model.state_dict(), "chess_dualhead_final.pth")
print("Training complete!")
print(f"Best validation accuracy: {best_val_acc:.2f}%")

Epoch 1/20 [Train]:  60%|█████▉    | 915/1535 [05:31<05:55,  1.74it/s, loss=3.3537, policy=3.3037, value=0.0500, acc=12.3%]