# Chess Deep Learning Agent: Supervised Training

Train policy and value networks via imitation learning from expert games.

In [1]:
import sys
sys.path.append('../src')

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import pandas as pd
import numpy as np
from pathlib import Path
from tqdm import tqdm
import json

# NEW: Import shard-based dataset loader
from data.shard_dataset import create_shard_dataloaders
from model.nets import MLPPolicy, MLPPolicyValue, CNNPolicyValue, MiniResNetPolicyValue, initialize_weights
from model.loss import PolicyValueLoss
from utils.metrics import policy_top_k_accuracy
from utils.plotting import plot_training_curves
from utils.seeds import set_seed

set_seed(42)

# Directories
DATA_DIR = Path('../artifacts/data')
SHARD_DIR = DATA_DIR / 'shards'  # NEW: Shard directory
WEIGHTS_DIR = Path('../artifacts/weights')
WEIGHTS_DIR.mkdir(parents=True, exist_ok=True)
LOGS_DIR = Path('../artifacts/logs')
LOGS_DIR.mkdir(parents=True, exist_ok=True)
FIGURES_DIR = Path('../reports/figures')
FIGURES_DIR.mkdir(parents=True, exist_ok=True)

## Configuration

In [2]:
# Training config - UPDATED for strength improvements
CONFIG = {
    # Model architecture
    'model_type': 'miniresnet',  # 'mlp', 'cnn', 'miniresnet'
    'num_blocks': 6,  # Consider 8 for more capacity with 1M+ data
    'channels': 64,   # Consider 96 for more capacity with 1M+ data
    'train_value_head': True,

    # Training
    'batch_size': 256,
    'num_epochs': 10,  # CHANGED: Reduced from 20 (more data = fewer epochs needed)
    'learning_rate': 0.001,
    'weight_decay': 1e-4,
    'lr_schedule': 'cosine_warmup',  # CHANGED: Added warmup
    'warmup_epochs': 2,  # NEW: Warmup for first 2 epochs

    # Loss - UPDATED values
    'policy_smoothing': 0.05,
    'value_weight': 0.35,  # CHANGED: from 0.7 to 0.35 (focus on policy first)

    # Optimization - NEW
    'use_amp': True,  # Automatic Mixed Precision for MPS
    'channels_last': True,  # Memory format optimization for CNNs
    'gradient_clip': 1.0,

    # Data - NEW
    'use_shards': True,  # Use shard-based loading
    'phase_balanced': True,  # Phase-balanced sampling
    'augment_train': True,  # Augmentation (file flips)

    # Device
    'device': 'mps' if torch.backends.mps.is_available() else 'cpu',
}

print("=" * 70)
print("TRAINING CONFIGURATION")
print("=" * 70)
print(f"Device: {CONFIG['device']}")
print(f"Model: {CONFIG['model_type']} ({CONFIG['num_blocks']}×{CONFIG['channels']})")
print(f"Value head: {'ENABLED' if CONFIG['train_value_head'] else 'DISABLED'}")
print(f"Epochs: {CONFIG['num_epochs']} (warmup: {CONFIG['warmup_epochs']})")
print(f"Batch size: {CONFIG['batch_size']}")
print(f"Learning rate: {CONFIG['learning_rate']}")
print(f"Policy smoothing (ε): {CONFIG['policy_smoothing']}")
print(f"Value loss weight (λ): {CONFIG['value_weight']}")
print(f"AMP: {CONFIG['use_amp']}")
print(f"Channels-last: {CONFIG['channels_last']}")
print(f"Phase-balanced sampling: {CONFIG['phase_balanced']}")
print(f"Augmentation: {CONFIG['augment_train']}")
print("=" * 70)

device = torch.device(CONFIG['device'])

TRAINING CONFIGURATION
Device: cpu
Model: miniresnet (6×64)
Value head: ENABLED
Epochs: 10 (warmup: 2)
Batch size: 256
Learning rate: 0.001
Policy smoothing (ε): 0.05
Value loss weight (λ): 0.35
AMP: True
Channels-last: True
Phase-balanced sampling: True
Augmentation: True


<cell_type>markdown</cell_type>## Load Data (Shard-Based or CSV)

In [3]:
# Load data: Check if shards exist, otherwise fall back to CSV
if CONFIG['use_shards'] and (SHARD_DIR / 'train').exists() and (SHARD_DIR / 'val').exists():
    print("Loading from sharded data...")
    
    train_loader, val_loader = create_shard_dataloaders(
        train_shard_dir=SHARD_DIR / 'train',
        val_shard_dir=SHARD_DIR / 'val',
        batch_size=CONFIG['batch_size'],
        num_workers=0,  # Mac compatibility
        phase_balanced=CONFIG['phase_balanced'],
        augment_train=CONFIG['augment_train'],
        pin_memory=True,
    )
    
    print(f"✓ Loaded shard-based dataloaders")
    print(f"  Train batches: {len(train_loader)}")
    print(f"  Val batches: {len(val_loader)}")
    
elif (DATA_DIR / 'train.csv.gz').exists() and (DATA_DIR / 'val.csv.gz').exists():
    print("Shards not found. Loading from CSV (legacy mode)...")
    print("⚠️  Warning: CSV mode uses smaller dataset (100k positions)")
    print("   For best results, run stream sampler to create shards with 1M+ positions")
    
    from data.dataset import create_dataloaders
    
    train_df = pd.read_csv(DATA_DIR / 'train.csv.gz', compression='gzip')
    val_df = pd.read_csv(DATA_DIR / 'val.csv.gz', compression='gzip')
    
    print(f"Train: {len(train_df):,} positions")
    print(f"Val:   {len(val_df):,} positions")
    
    train_loader, val_loader = create_dataloaders(
        train_df,
        val_df,
        batch_size=CONFIG['batch_size'],
        num_workers=0,
        include_auxiliary=False,
    )
    
else:
    raise FileNotFoundError(
        f"No training data found!\n"
        f"Expected either:\n"
        f"  1. Shards: {SHARD_DIR}/train/ and {SHARD_DIR}/val/\n"
        f"  2. CSV: {DATA_DIR}/train.csv.gz and {DATA_DIR}/val.csv.gz\n\n"
        f"Run stream sampler to create shards:\n"
        f"  python -m src.data.stream_sampler --pgn-dir <path> --target 1000000"
    )

Loading from sharded data...
CREATING SHARD-BASED DATALOADERS
Found 57 shards in ../artifacts/data/shards/train
Total positions: 2,829,318
Found 18 shards in ../artifacts/data/shards/val
Total positions: 893,556
Creating phase-balanced sampler...
Train loader: 11052 batches
Val loader: 3491 batches
Phase-balanced sampling: True
Augmentation: True
✓ Loaded shard-based dataloaders
  Train batches: 11052
  Val batches: 3491


## Create Model

In [4]:
# Create model with value head support
if CONFIG['model_type'] == 'mlp':
    if CONFIG['train_value_head']:
        model = MLPPolicyValue(
            hidden_dims=(1024, 512, 512),
            policy_head_hidden=512,
            value_head_hidden=256,
            dropout=0.3,
        )
    else:
        model = MLPPolicy()
elif CONFIG['model_type'] == 'cnn':
    model = CNNPolicyValue(
        num_channels=CONFIG['channels'],
        num_layers=4,
        policy_head_hidden=512,
        value_head_hidden=256,
    )
elif CONFIG['model_type'] == 'miniresnet':
    model = MiniResNetPolicyValue(
        num_blocks=CONFIG['num_blocks'],
        channels=CONFIG['channels'],
        policy_head_hidden=512,
        value_head_hidden=256,
        dropout=0.1,
    )
else:
    raise ValueError(f"Unknown model type: {CONFIG['model_type']}")

initialize_weights(model)
model = model.to(device)

# NEW: Apply channels_last memory format for CNNs (faster on MPS)
if CONFIG['channels_last'] and CONFIG['model_type'] in ['cnn', 'miniresnet']:
    model = model.to(memory_format=torch.channels_last)
    print("✓ Using channels_last memory format")

print(f"\nModel: {CONFIG['model_type']}")
print(f"Architecture: {CONFIG['num_blocks']} blocks × {CONFIG['channels']} channels")
print(f"Value head: {'YES' if CONFIG['train_value_head'] else 'NO'}")
if hasattr(model, 'count_parameters'):
    print(f"Parameters: {model.count_parameters():,}")
else:
    print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}")

✓ Using channels_last memory format

Model: miniresnet
Architecture: 6 blocks × 64 channels
Value head: YES
Parameters: 5,995,265


## Training Setup

In [5]:
# Loss function
criterion = PolicyValueLoss(
    value_weight=CONFIG['value_weight'],
    policy_smoothing=CONFIG['policy_smoothing'],
)

# Optimizer
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=CONFIG['learning_rate'],
    weight_decay=CONFIG['weight_decay'],
)

# NEW: Learning rate scheduler with warmup
if CONFIG['lr_schedule'] == 'cosine_warmup':
    def warmup_cosine_schedule(epoch):
        """Warmup for first few epochs, then cosine decay."""
        if epoch < CONFIG['warmup_epochs']:
            # Linear warmup
            return (epoch + 1) / CONFIG['warmup_epochs']
        else:
            # Cosine annealing after warmup
            progress = (epoch - CONFIG['warmup_epochs']) / (CONFIG['num_epochs'] - CONFIG['warmup_epochs'])
            return 0.5 * (1.0 + np.cos(np.pi * progress))
    
    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=warmup_cosine_schedule)
    print(f"✓ Using cosine schedule with {CONFIG['warmup_epochs']}-epoch warmup")
    
elif CONFIG['lr_schedule'] == 'cosine':
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer,
        T_max=CONFIG['num_epochs'],
    )
    print("✓ Using cosine annealing schedule")
    
elif CONFIG['lr_schedule'] == 'step':
    scheduler = torch.optim.lr_scheduler.StepLR(
        optimizer,
        step_size=3,
        gamma=0.5,
    )
    print("✓ Using step schedule")
else:
    scheduler = None
    print("⚠️  No LR schedule")

# NEW: AMP scaler for mixed precision training
scaler = None
if CONFIG['use_amp']:
    try:
        scaler = torch.cuda.amp.GradScaler()
        print("✓ Using Automatic Mixed Precision (AMP)")
    except:
        print("⚠️  AMP not available on this device")
        CONFIG['use_amp'] = False

✓ Using cosine schedule with 2-epoch warmup
✓ Using Automatic Mixed Precision (AMP)




## Training Loop

In [6]:
# UPDATED training functions with AMP and channels_last support
def train_epoch(model, loader, criterion, optimizer, device, use_amp=False, scaler=None, channels_last=False):
    model.train()
    total_loss = 0.0
    total_policy_loss = 0.0
    total_value_loss = 0.0
    total_top1_acc = 0.0
    total_top5_acc = 0.0

    pbar = tqdm(loader, desc="Training")
    for batch in pbar:
        board, move_target, value_target = batch
        board = board.to(device)
        move_target = move_target.to(device)
        value_target = value_target.to(device)
        
        # Apply channels_last if enabled
        if channels_last:
            board = board.to(memory_format=torch.channels_last)

        optimizer.zero_grad()

        # Forward pass with optional AMP
        if use_amp and scaler is not None:
            with torch.cuda.amp.autocast():
                outputs = model(board, return_value=CONFIG['train_value_head'])
                if isinstance(outputs, tuple):
                    policy_logits, _, value_pred = outputs
                else:
                    policy_logits = outputs
                    value_pred = None

                loss, loss_dict = criterion(
                    policy_logits,
                    value_pred if CONFIG['train_value_head'] else None,
                    move_target,
                    value_target if CONFIG['train_value_head'] else None,
                )

            # Backward with scaling
            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), CONFIG['gradient_clip'])
            scaler.step(optimizer)
            scaler.update()
        else:
            # Standard training without AMP
            outputs = model(board, return_value=CONFIG['train_value_head'])
            if isinstance(outputs, tuple):
                policy_logits, _, value_pred = outputs
            else:
                policy_logits = outputs
                value_pred = None

            loss, loss_dict = criterion(
                policy_logits,
                value_pred if CONFIG['train_value_head'] else None,
                move_target,
                value_target if CONFIG['train_value_head'] else None,
            )

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), CONFIG['gradient_clip'])
            optimizer.step()

        # Metrics
        total_loss += loss_dict['total_loss']
        total_policy_loss += loss_dict['policy_loss']
        if 'value_loss' in loss_dict:
            total_value_loss += loss_dict['value_loss']

        top1_acc = policy_top_k_accuracy(policy_logits, move_target, k=1)
        top5_acc = policy_top_k_accuracy(policy_logits, move_target, k=5)
        total_top1_acc += top1_acc
        total_top5_acc += top5_acc

        pbar.set_postfix({
            'loss': f"{loss.item():.4f}",
            'top1': f"{top1_acc:.3f}",
        })

    n = len(loader)
    return {
        'loss': total_loss / n,
        'policy_loss': total_policy_loss / n,
        'value_loss': total_value_loss / n if total_value_loss > 0 else 0.0,
        'top1_acc': total_top1_acc / n,
        'top5_acc': total_top5_acc / n,
    }

def evaluate(model, loader, criterion, device, channels_last=False):
    model.eval()
    total_loss = 0.0
    total_top1_acc = 0.0
    total_top3_acc = 0.0  # Added top-3 for better tracking

    with torch.no_grad():
        for batch in tqdm(loader, desc="Validation"):
            board, move_target, value_target = batch
            board = board.to(device)
            move_target = move_target.to(device)
            value_target = value_target.to(device)
            
            if channels_last:
                board = board.to(memory_format=torch.channels_last)

            outputs = model(board, return_value=CONFIG['train_value_head'])
            if isinstance(outputs, tuple):
                policy_logits, _, value_pred = outputs
            else:
                policy_logits = outputs
                value_pred = None

            loss, loss_dict = criterion(
                policy_logits,
                value_pred if CONFIG['train_value_head'] else None,
                move_target,
                value_target if CONFIG['train_value_head'] else None,
            )

            total_loss += loss_dict['total_loss']
            top1_acc = policy_top_k_accuracy(policy_logits, move_target, k=1)
            top3_acc = policy_top_k_accuracy(policy_logits, move_target, k=3)
            total_top1_acc += top1_acc
            total_top3_acc += top3_acc

    n = len(loader)
    return {
        'loss': total_loss / n,
        'top1_acc': total_top1_acc / n,
        'top3_acc': total_top3_acc / n,
    }

In [None]:
# Training history
history = {
    'train_loss': [],
    'val_loss': [],
    'train_acc': [],
    'val_acc': [],
    'learning_rates': [],  # NEW: Track learning rate
}

best_val_acc = 0.0

print("\n" + "=" * 70)
print("STARTING TRAINING")
print("=" * 70)

# Train
for epoch in range(1, CONFIG['num_epochs'] + 1):
    print(f"\n=== Epoch {epoch}/{CONFIG['num_epochs']} ===")

    # Train
    train_metrics = train_epoch(
        model, train_loader, criterion, optimizer, device,
        use_amp=CONFIG['use_amp'],
        scaler=scaler,
        channels_last=CONFIG['channels_last']
    )
    print(f"Train - Loss: {train_metrics['loss']:.4f}, "
          f"Top-1: {train_metrics['top1_acc']:.4f}, "
          f"Top-5: {train_metrics['top5_acc']:.4f}")

    # Validate
    val_metrics = evaluate(model, val_loader, criterion, device, channels_last=CONFIG['channels_last'])
    print(f"Val   - Loss: {val_metrics['loss']:.4f}, "
          f"Top-1: {val_metrics['top1_acc']:.4f}, "
          f"Top-3: {val_metrics['top3_acc']:.4f}")

    # Update history
    history['train_loss'].append(train_metrics['loss'])
    history['val_loss'].append(val_metrics['loss'])
    history['train_acc'].append(train_metrics['top1_acc'])
    history['val_acc'].append(val_metrics['top1_acc'])
    history['learning_rates'].append(optimizer.param_groups[0]['lr'])

    # Save best model
    if val_metrics['top1_acc'] > best_val_acc:
        best_val_acc = val_metrics['top1_acc']
        torch.save(model.state_dict(), WEIGHTS_DIR / 'best_model.pth')
        print(f"✓ Saved best model (val_acc={best_val_acc:.4f})")

    # LR schedule
    if scheduler is not None:
        scheduler.step()
        print(f"LR: {optimizer.param_groups[0]['lr']:.6f}")

# Save final model
torch.save(model.state_dict(), WEIGHTS_DIR / 'final_model.pth')

print("\n" + "=" * 70)
print("TRAINING COMPLETE")
print("=" * 70)
print(f"Best validation accuracy: {best_val_acc:.4f} ({best_val_acc*100:.2f}%)")
print(f"Model saved to: {WEIGHTS_DIR / 'best_model.pth'}")
print("=" * 70)


STARTING TRAINING

=== Epoch 1/10 ===


Training:   0%|▏                                                                                                                  | 19/11052 [14:17<134:38:08, 43.93s/it, loss=7.8936, top1=0.004]

## Visualize Training

In [None]:
plot_training_curves(
    history['train_loss'],
    history['val_loss'],
    history['train_acc'],
    history['val_acc'],
    save_path=FIGURES_DIR / 'training_curves.png',
)

## Save Training Log

In [None]:
# Save training log and history
log = {
    'config': CONFIG,
    'history': history,
    'best_val_acc': best_val_acc,
}

with open(LOGS_DIR / 'training_log.json', 'w') as f:
    json.dump(log, f, indent=2)

# Also save history separately for easy loading in report notebook
with open(LOGS_DIR / 'training_history.json', 'w') as f:
    json.dump(history, f, indent=2)

print("✓ Saved training log")
print(f"✓ Best validation accuracy: {best_val_acc:.4f}")
print(f"✓ Model saved to: {WEIGHTS_DIR / 'best_model.pth'}")

## Summary

✓ Trained policy+value network
✓ Achieved top-1 accuracy: {best_val_acc:.2%}
✓ Saved best model to artifacts/weights/

**Next**: Notebook 03 - Integrate search and test playing