# Chess Deep Learning - Google Colab Training

**Setup Instructions:**
1. Upload `shards_backup.tar.gz` and `src_code.tar.gz` to your Google Drive in a folder called `chess_training`
2. Runtime → Change runtime type → GPU (T4 or better)
3. Run all cells

**Expected training time:** 8-12 hours on T4 GPU

In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# Install dependencies
!pip install chess python-chess tqdm matplotlib -q

In [None]:
# Extract data and code from Drive
import os
from pathlib import Path

# Set paths - MODIFY THIS if your folder name is different
DRIVE_FOLDER = '/content/drive/MyDrive/chess_training'

# Create working directories
!mkdir -p /content/chess_project/artifacts/data/shards
!mkdir -p /content/chess_project/artifacts/weights
!mkdir -p /content/chess_project/artifacts/logs
!mkdir -p /content/chess_project/reports/figures

# Extract source code
print("Extracting source code...")
!tar -xzf {DRIVE_FOLDER}/src_code.tar.gz -C /content/chess_project/

# Extract shards
print("Extracting training data (this may take 2-3 minutes)...")
!tar -xzf {DRIVE_FOLDER}/shards_backup.tar.gz -C /content/chess_project/

print("✓ Setup complete!")
!ls -lh /content/chess_project/

In [None]:
# Verify GPU is available
import torch
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
else:
    print("⚠️  WARNING: No GPU detected! Go to Runtime → Change runtime type → GPU")

In [None]:
# Import modules
import sys
sys.path.append('/content/chess_project/src')

import torch
import torch.nn as nn
import numpy as np
from pathlib import Path
from tqdm import tqdm
import json
import matplotlib.pyplot as plt

from data.shard_dataset import create_shard_dataloaders
from model.nets import MiniResNetPolicyValue, initialize_weights
from model.loss import PolicyValueLoss
from utils.metrics import policy_top_k_accuracy
from utils.seeds import set_seed

set_seed(42)

# Directories
BASE_DIR = Path('/content/chess_project')
DATA_DIR = BASE_DIR / 'artifacts/data'
SHARD_DIR = DATA_DIR / 'shards'
WEIGHTS_DIR = BASE_DIR / 'artifacts/weights'
LOGS_DIR = BASE_DIR / 'artifacts/logs'
FIGURES_DIR = BASE_DIR / 'reports/figures'

print("✓ Imports successful")

In [None]:
# Configuration - OPTIMIZED FOR COLAB GPU
CONFIG = {
    # Model
    'model_type': 'miniresnet',
    'num_blocks': 6,
    'channels': 64,
    'train_value_head': True,
    
    # Training
    'batch_size': 512,  # Larger batch size for GPU
    'num_epochs': 10,
    'learning_rate': 0.001,
    'weight_decay': 1e-4,
    'lr_schedule': 'cosine_warmup',
    'warmup_epochs': 2,
    
    # Loss
    'policy_smoothing': 0.05,
    'value_weight': 0.35,
    
    # Optimization
    'use_amp': True,  # Mixed precision for GPU
    'gradient_clip': 1.0,
    
    # Data
    'use_shards': True,
    'phase_balanced': True,
    'augment_train': True,
    
    # Device
    'device': 'cuda' if torch.cuda.is_available() else 'cpu',
}

print("=" * 70)
print("TRAINING CONFIGURATION")
print("=" * 70)
print(f"Device: {CONFIG['device']}")
print(f"Model: {CONFIG['model_type']} ({CONFIG['num_blocks']}×{CONFIG['channels']})")
print(f"Batch size: {CONFIG['batch_size']} (optimized for GPU)")
print(f"Epochs: {CONFIG['num_epochs']}")
print(f"AMP: {CONFIG['use_amp']}")
print("=" * 70)

device = torch.device(CONFIG['device'])

In [None]:
# Load data
print("Loading training data...")
train_loader, val_loader = create_shard_dataloaders(
    train_shard_dir=SHARD_DIR / 'train',
    val_shard_dir=SHARD_DIR / 'val',
    batch_size=CONFIG['batch_size'],
    num_workers=2,  # Colab has good I/O
    phase_balanced=CONFIG['phase_balanced'],
    augment_train=CONFIG['augment_train'],
    pin_memory=True,
)

print(f"✓ Train batches: {len(train_loader)}")
print(f"✓ Val batches: {len(val_loader)}")

In [None]:
# Create model
model = MiniResNetPolicyValue(
    num_blocks=CONFIG['num_blocks'],
    channels=CONFIG['channels'],
    policy_head_hidden=512,
    value_head_hidden=256,
    dropout=0.1,
)

initialize_weights(model)
model = model.to(device)

print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")

In [None]:
# Training setup
criterion = PolicyValueLoss(
    value_weight=CONFIG['value_weight'],
    policy_smoothing=CONFIG['policy_smoothing'],
)

optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=CONFIG['learning_rate'],
    weight_decay=CONFIG['weight_decay'],
)

# LR scheduler
def warmup_cosine_schedule(epoch):
    if epoch < CONFIG['warmup_epochs']:
        return (epoch + 1) / CONFIG['warmup_epochs']
    else:
        progress = (epoch - CONFIG['warmup_epochs']) / (CONFIG['num_epochs'] - CONFIG['warmup_epochs'])
        return 0.5 * (1.0 + np.cos(np.pi * progress))

scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=warmup_cosine_schedule)

# AMP scaler
scaler = torch.cuda.amp.GradScaler() if CONFIG['use_amp'] else None

print("✓ Training setup complete")

In [None]:
# Training functions
def train_epoch(model, loader, criterion, optimizer, device, scaler=None):
    model.train()
    total_loss = 0.0
    total_top1_acc = 0.0
    total_top5_acc = 0.0

    pbar = tqdm(loader, desc="Training")
    for batch in pbar:
        board, move_target, value_target = batch
        board = board.to(device)
        move_target = move_target.to(device)
        value_target = value_target.to(device)

        optimizer.zero_grad()

        # Forward pass with AMP
        if scaler is not None:
            with torch.cuda.amp.autocast():
                policy_logits, _, value_pred = model(board, return_value=True)
                loss, loss_dict = criterion(policy_logits, value_pred, move_target, value_target)
            
            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), CONFIG['gradient_clip'])
            scaler.step(optimizer)
            scaler.update()
        else:
            policy_logits, _, value_pred = model(board, return_value=True)
            loss, loss_dict = criterion(policy_logits, value_pred, move_target, value_target)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), CONFIG['gradient_clip'])
            optimizer.step()

        total_loss += loss_dict['total_loss']
        top1_acc = policy_top_k_accuracy(policy_logits, move_target, k=1)
        top5_acc = policy_top_k_accuracy(policy_logits, move_target, k=5)
        total_top1_acc += top1_acc
        total_top5_acc += top5_acc

        pbar.set_postfix({'loss': f"{loss.item():.4f}", 'top1': f"{top1_acc:.3f}"})

    n = len(loader)
    return {
        'loss': total_loss / n,
        'top1_acc': total_top1_acc / n,
        'top5_acc': total_top5_acc / n,
    }


def evaluate(model, loader, criterion, device):
    model.eval()
    total_loss = 0.0
    total_top1_acc = 0.0
    total_top3_acc = 0.0

    with torch.no_grad():
        for batch in tqdm(loader, desc="Validation"):
            board, move_target, value_target = batch
            board = board.to(device)
            move_target = move_target.to(device)
            value_target = value_target.to(device)

            policy_logits, _, value_pred = model(board, return_value=True)
            loss, loss_dict = criterion(policy_logits, value_pred, move_target, value_target)

            total_loss += loss_dict['total_loss']
            total_top1_acc += policy_top_k_accuracy(policy_logits, move_target, k=1)
            total_top3_acc += policy_top_k_accuracy(policy_logits, move_target, k=3)

    n = len(loader)
    return {
        'loss': total_loss / n,
        'top1_acc': total_top1_acc / n,
        'top3_acc': total_top3_acc / n,
    }

print("✓ Training functions defined")

In [None]:
# Main training loop
history = {
    'train_loss': [],
    'val_loss': [],
    'train_acc': [],
    'val_acc': [],
    'learning_rates': [],
}

best_val_acc = 0.0

print("\n" + "=" * 70)
print("STARTING TRAINING")
print("=" * 70)

for epoch in range(1, CONFIG['num_epochs'] + 1):
    print(f"\n=== Epoch {epoch}/{CONFIG['num_epochs']} ===")

    # Train
    train_metrics = train_epoch(model, train_loader, criterion, optimizer, device, scaler)
    print(f"Train - Loss: {train_metrics['loss']:.4f}, Top-1: {train_metrics['top1_acc']:.4f}, Top-5: {train_metrics['top5_acc']:.4f}")

    # Validate
    val_metrics = evaluate(model, val_loader, criterion, device)
    print(f"Val   - Loss: {val_metrics['loss']:.4f}, Top-1: {val_metrics['top1_acc']:.4f}, Top-3: {val_metrics['top3_acc']:.4f}")

    # Update history
    history['train_loss'].append(train_metrics['loss'])
    history['val_loss'].append(val_metrics['loss'])
    history['train_acc'].append(train_metrics['top1_acc'])
    history['val_acc'].append(val_metrics['top1_acc'])
    history['learning_rates'].append(optimizer.param_groups[0]['lr'])

    # Save best model
    if val_metrics['top1_acc'] > best_val_acc:
        best_val_acc = val_metrics['top1_acc']
        torch.save(model.state_dict(), WEIGHTS_DIR / 'best_model.pth')
        print(f"✓ Saved best model (val_acc={best_val_acc:.4f})")

    # LR schedule
    scheduler.step()
    print(f"LR: {optimizer.param_groups[0]['lr']:.6f}")

# Save final model
torch.save(model.state_dict(), WEIGHTS_DIR / 'final_model.pth')

print("\n" + "=" * 70)
print("TRAINING COMPLETE")
print("=" * 70)
print(f"Best validation accuracy: {best_val_acc:.4f} ({best_val_acc*100:.2f}%)")
print("=" * 70)

In [None]:
# Plot training curves
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))

# Loss
ax1.plot(history['train_loss'], label='Train')
ax1.plot(history['val_loss'], label='Val')
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')
ax1.set_title('Training Loss')
ax1.legend()
ax1.grid(True)

# Accuracy
ax2.plot(history['train_acc'], label='Train')
ax2.plot(history['val_acc'], label='Val')
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Top-1 Accuracy')
ax2.set_title('Training Accuracy')
ax2.legend()
ax2.grid(True)

plt.tight_layout()
plt.savefig(FIGURES_DIR / 'training_curves.png', dpi=150, bbox_inches='tight')
plt.show()

print("✓ Saved training curves")

In [None]:
# Save training log
log = {
    'config': CONFIG,
    'history': history,
    'best_val_acc': best_val_acc,
}

with open(LOGS_DIR / 'training_log.json', 'w') as f:
    json.dump(log, f, indent=2)

print("✓ Saved training log")
print(f"\nFinal Results:")
print(f"  Best validation accuracy: {best_val_acc:.4f} ({best_val_acc*100:.2f}%)")
print(f"  Model saved to: {WEIGHTS_DIR / 'best_model.pth'}")

In [None]:
# Copy trained model back to Google Drive
print("Copying trained model to Google Drive...")
!cp {WEIGHTS_DIR}/best_model.pth {DRIVE_FOLDER}/best_model.pth
!cp {WEIGHTS_DIR}/final_model.pth {DRIVE_FOLDER}/final_model.pth
!cp {LOGS_DIR}/training_log.json {DRIVE_FOLDER}/training_log.json
!cp {FIGURES_DIR}/training_curves.png {DRIVE_FOLDER}/training_curves.png

print("\n✓ Training complete! All files saved to Google Drive.")
print("\nDownload these files from your Drive:")
print("  - best_model.pth (trained model weights)")
print("  - training_log.json (training history)")
print("  - training_curves.png (visualization)")