# Chess RL - AlphaZero Training

Train a lightweight AlphaZero-style chess engine using TensorFlow/Keras.

**Features:**
- 781-dimensional input (bitboards + castling + en passant + side to move)
- Lightweight Dense network (~1.5M parameters)
- MCTS with PUCT selection
- Self-play training

## 1. Setup

In [None]:
# Mount Google Drive for checkpoints
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# Clone the repository (or upload files)
!git clone https://github.com/YOUR_USERNAME/chess-rl.git
%cd chess-rl

In [None]:
# Install dependencies
!pip install -q python-chess tqdm

In [None]:
# Verify GPU is available
import tensorflow as tf
print(f"TensorFlow version: {tf.__version__}")
print(f"GPU available: {tf.config.list_physical_devices('GPU')}")

## 2. Configuration

In [None]:
# Import and configure
import sys
sys.path.insert(0, '/content/chess-rl')

from config import Config

# Customize configuration for Colab
config = Config()

# Adjust for Colab resources
config.checkpoint_dir = '/content/drive/MyDrive/chess-rl/checkpoints'
config.buffer_size = 50000  # Smaller buffer for Colab RAM
config.games_per_iteration = 50  # Fewer games per iteration
config.training_steps = 200  # Fewer training steps
config.num_simulations = 200  # Fewer MCTS simulations

print("Configuration loaded!")
print(f"Checkpoints will be saved to: {config.checkpoint_dir}")

## 3. Initialize Training Components

In [None]:
import os
import numpy as np
from tqdm.notebook import tqdm

from src.model.network import ChessNetwork
from src.training.trainer import Trainer
from src.training.self_play import SelfPlay
from src.training.replay_buffer import ReplayBuffer
from src.game.chess_game import ChessGame
from src.mcts.mcts import MCTS

# Create checkpoint directory
os.makedirs(config.checkpoint_dir, exist_ok=True)

# Initialize trainer
trainer = Trainer(config, config.checkpoint_dir)

print(f"Network has {trainer.network.trainable_params:,} trainable parameters")
trainer.network.summary()

## 4. Training Loop

In [None]:
# Training parameters
NUM_ITERATIONS = 50  # Adjust based on available time

print(f"Starting training for {NUM_ITERATIONS} iterations...")
print(f"Each iteration: {config.games_per_iteration} games, {config.training_steps} training steps")

In [None]:
# Run training
for iteration in range(1, NUM_ITERATIONS + 1):
    print(f"\n{'='*50}")
    print(f"Iteration {iteration}/{NUM_ITERATIONS}")
    print(f"{'='*50}")
    
    stats = trainer.run_iteration(show_progress=True)
    
    # Print stats
    print(f"\nStats:")
    print(f"  Games played: {stats['num_games']}")
    print(f"  Examples generated: {stats['num_examples']}")
    print(f"  Buffer size: {stats['buffer_size']}")
    
    if 'avg_total_loss' in stats:
        print(f"  Total loss: {stats['avg_total_loss']:.4f}")
        print(f"  Policy loss: {stats['avg_policy_loss']:.4f}")
        print(f"  Value loss: {stats['avg_value_loss']:.4f}")
    
    if stats.get('checkpoint_saved'):
        print(f"  Checkpoint saved!")

## 5. Test the Model

In [None]:
# Play a test game against itself
def play_test_game(network, num_simulations=100):
    """Play a game and show the moves."""
    game = ChessGame()
    mcts = MCTS(network, num_simulations=num_simulations)
    
    moves = []
    while not game.is_terminal() and game.move_count < 100:
        action, _ = mcts.get_action(game, temperature=0.1)
        move = game.move_encoder.decode(action)
        
        # Handle promotions
        if game.board.piece_at(move.from_square):
            piece = game.board.piece_at(move.from_square)
            if piece.piece_type == 1:  # Pawn
                to_rank = move.to_square // 8
                if to_rank == 0 or to_rank == 7:
                    move = game.board.parse_san(game.board.san(move))
        
        san = game.board.san(game.board.parse_uci(move.uci()))
        moves.append(san)
        game.apply_move_index(action)
    
    return moves, game.get_outcome()

moves, outcome = play_test_game(trainer.network)
print(f"Game result: {'White wins' if outcome > 0 else ('Black wins' if outcome < 0 else 'Draw')}")
print(f"Moves: {' '.join(moves[:40])}..." if len(moves) > 40 else f"Moves: {' '.join(moves)}")

In [None]:
# Show the final position
game = ChessGame()
mcts = MCTS(trainer.network, num_simulations=100)

for _ in range(20):  # Play 20 moves
    if game.is_terminal():
        break
    action, _ = mcts.get_action(game, temperature=0.1)
    game.apply_move_index(action)

print(game)

## 6. Save Final Model

In [None]:
# Save the final model
final_path = os.path.join(config.checkpoint_dir, 'model_final')
trainer.network.save(final_path)
print(f"Final model saved to: {final_path}")

# Also save as full Keras model
keras_path = os.path.join(config.checkpoint_dir, 'model_final.keras')
trainer.network.save_full_model(keras_path)
print(f"Full Keras model saved to: {keras_path}")

## 7. Resume Training (Optional)

In [None]:
# To resume training from a checkpoint:
# trainer.load_checkpoint()  # Loads latest checkpoint
# trainer.train(num_iterations=10)  # Continue training