# Chess RL - AlphaZero Training

Train a lightweight AlphaZero-style chess engine using TensorFlow/Keras.

**Features:**
- 781-dimensional input (bitboards + castling + en passant + side to move)
- Lightweight Dense network (~1.4M parameters)
- MCTS with PUCT selection
- Parallel self-play with batched GPU inference

## 1. Setup

In [5]:
# Verify GPU is available
import tensorflow as tf
print(f"TensorFlow version: {tf.__version__}")
print(f"GPU available: {tf.config.list_physical_devices('GPU')}")

TensorFlow version: 2.18.0
GPU available: [PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]


## 2. Configuration

In [None]:
import sys, os
sys.path.append(os.path.abspath('..'))

from config import Config

# Configuration optimized for Colab GPU
config = Config()

# Checkpoint location (/checkpoint)
config.checkpoint_dir = '../checkpoint'

# Training parameters
config.num_simulations = 50          # MCTS simulations per move
config.games_per_iteration = 64      # Games per iteration (more = better GPU usage)
config.training_steps = 200          # Training steps per iteration
config.buffer_size = 100000          # Replay buffer size
config.max_moves = 200               # Max moves per game

# Warmup settings (first 10 iterations)
config.warmup_simulations = 30
config.warmup_games = 64
config.main_games = 64

# Parallel self-play settings
NUM_PARALLEL = 32                    # Games to run in parallel
NUM_ITERATIONS = 50                  # Total training iterations

print("Configuration:")
print(f"  Simulations per move: {config.num_simulations}")
print(f"  Games per iteration: {config.games_per_iteration}")
print(f"  Parallel games: {NUM_PARALLEL}")
print(f"  Total iterations: {NUM_ITERATIONS}")
print(f"  Checkpoints: {config.checkpoint_dir}")

Configuration:
  Simulations per move: 50
  Games per iteration: 128
  Parallel games: 32
  Total iterations: 50
  Checkpoints: ../checkpoint


## 3. Initialize Trainer

In [7]:
import os
from src.training.trainer import Trainer

# Create checkpoint directory
os.makedirs(config.checkpoint_dir, exist_ok=True)

# Initialize trainer with parallel self-play
trainer = Trainer(
    config,
    config.checkpoint_dir,
    num_parallel=NUM_PARALLEL,    # Run 16 games in parallel
    use_parallel=True              # Enable batched GPU inference
)

print(f"Network has {trainer.network.trainable_params:,} trainable parameters")
trainer.network.summary()

Network has 1,371,141 trainable parameters


## 4. Training Loop

In [None]:
# Run training for 50 iterations
print(f"Starting training for {NUM_ITERATIONS} iterations...")
print(f"Each iteration: {config.games_per_iteration} games, {config.training_steps} training steps")
print(f"Using parallel self-play with {NUM_PARALLEL} concurrent games\n")

trainer.train(num_iterations=NUM_ITERATIONS, show_progress=True)

Starting training for 50 iterations...
Each iteration: 128 games, 200 training steps
Using parallel self-play with 32 concurrent games

Starting training for 50 iterations
Network has 1,371,141 trainable parameters

Iteration 1: Generating 64 self-play games (parallel)...




KeyboardInterrupt: 

## 5. Test the Model

In [None]:
from src.game.chess_game import ChessGame
from src.mcts.mcts import MCTS

# Play a test game against itself
def play_test_game(network, num_simulations=100):
    """Play a game and show the moves."""
    game = ChessGame()
    mcts = MCTS(network, num_simulations=num_simulations)

    moves = []
    while not game.is_terminal() and game.move_count < 100:
        action, _ = mcts.get_action(game, temperature=0.1)
        if action < 0:
            break
        move = game.move_encoder.decode(action)

        # Get SAN notation before applying
        try:
            san = game.board.san(game.board.parse_uci(move.uci()))
        except:
            san = move.uci()
        moves.append(san)
        game.apply_move_index(action)

    return moves, game.get_outcome()

moves, outcome = play_test_game(trainer.network, num_simulations=100)
print(f"Game result: {'White wins' if outcome > 0 else ('Black wins' if outcome < 0 else 'Draw')}")
print(f"Total moves: {len(moves)}")
print(f"Moves: {' '.join(moves[:50])}{'...' if len(moves) > 50 else ''}")

In [None]:
# Show a sample position after 20 moves
game = ChessGame()
mcts = MCTS(trainer.network, num_simulations=100)

for _ in range(20):
    if game.is_terminal():
        break
    action, _ = mcts.get_action(game, temperature=0.1)
    if action >= 0:
        game.apply_move_index(action)

print(game)

## 6. Save Final Model

In [None]:
# Save the final model
final_path = os.path.join(config.checkpoint_dir, 'model_final')
trainer.network.save(final_path)
print(f"Final model saved to: {final_path}.weights.h5")

# Also save as full Keras model
keras_path = os.path.join(config.checkpoint_dir, 'model_final.keras')
trainer.network.save_full_model(keras_path)
print(f"Full Keras model saved to: {keras_path}")

## 7. Resume Training (Optional)

In [None]:
# To resume training from a checkpoint:
# trainer.load_checkpoint()  # Loads latest checkpoint
# trainer.train(num_iterations=20)  # Continue training