In [None]:
# %% [markdown]
# # üß† TinyAlphaZero: A Minimal Chess Engine via Self-Play
#
# This notebook implements a simplified version of DeepMind's AlphaZero algorithm for chess.
# The model learns to play chess **entirely from self-play** - no human games, no opening books,
# no handcrafted evaluation functions.
#
# ---
#
# ## Architecture Overview
#
# ```
# ‚îå‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îê
# ‚îÇ                      Board State (69 tokens)                    ‚îÇ
# ‚îÇ  [64 squares: piece/empty] + [turn] + [4 castling rights]       ‚îÇ
# ‚îî‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îò
#                                  ‚îÇ
#                                  ‚ñº
# ‚îå‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îê
# ‚îÇ                   Token + Position Embeddings                    ‚îÇ
# ‚îÇ  ‚Ä¢ 23-token vocabulary (pieces + flags)                         ‚îÇ
# ‚îÇ  ‚Ä¢ Factored position encoding (rank + file, not flat)           ‚îÇ
# ‚îî‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îò
#                                  ‚îÇ
#                                  ‚ñº
# ‚îå‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îê
# ‚îÇ                    Transformer Encoder (6 layers)               ‚îÇ
# ‚îÇ  ‚Ä¢ Full attention (every square sees every other square)        ‚îÇ
# ‚îÇ  ‚Ä¢ No causal mask (this isn't autoregressive)                   ‚îÇ
# ‚îÇ  ‚Ä¢ Pre-norm for training stability                              ‚îÇ
# ‚îî‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îò
#                                  ‚îÇ
#                                  ‚ñº
#                         Mean Pool (aggregate)
#                                  ‚îÇ
#                     ‚îå‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚î¥‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îê
#                     ‚ñº                       ‚ñº
# ‚îå‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îê ‚îå‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îê
# ‚îÇ      Policy Head            ‚îÇ ‚îÇ      Value Head             ‚îÇ
# ‚îÇ  "What move to play?"       ‚îÇ ‚îÇ  "Who's winning?"           ‚îÇ
# ‚îÇ                             ‚îÇ ‚îÇ                             ‚îÇ
# ‚îÇ  Output: 4096 logits        ‚îÇ ‚îÇ  Output: 3 classes          ‚îÇ
# ‚îÇ  (64 from √ó 64 to squares)  ‚îÇ ‚îÇ  [Loss, Draw, Win]          ‚îÇ
# ‚îî‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îò ‚îî‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îò
# ```
#
# ### Key Design Choices (vs Original AlphaZero)
#
# | Aspect | AlphaZero | TinyAlphaZero | Why |
# |--------|-----------|---------------|-----|
# | **Board encoding** | 119 spatial planes (8√ó8√ó119) | 69 tokens | Simpler, works with transformers |
# | **Network** | ResNet (CNN) | Transformer | Learns spatial relationships via attention |
# | **Position encoding** | Implicit in CNN | Factored rank+file | Hardcodes grid structure |
# | **Value output** | Scalar [-1, +1] | 3 classes [L/D/W] | Sharper gradients, prevents collapse |
# | **Parameters** | ~80M | ~4M | Trainable on consumer GPU |
#
# ---
#
# ## Training Pipeline
#
# ### Phase 1: Supervised Learning ("Learn the Rules")
#
# ```
# Random Games ‚Üí Model predicts moves ‚Üí Cross-entropy loss
# ```
#
# **Goal**: Achieve >95% legal move accuracy
#
# The model learns:
# - How pieces move (bishops diagonal, knights L-shape, etc.)
# - Board state encoding/decoding
# - Basic position evaluation
#
# **Data**: ~10k random self-play games (~500k positions)
# - Optional: Biased sampling favors captures/checks for richer positions
#
# **Metrics to watch**:
# - `legal_acc`: % of predictions that are legal moves (target: >95%)
# - `exact_acc`: % matching the actual move played (~15-20% is fine)
# - `grad_norm policy/value`: Should be similar magnitude (within 10√ó)
#
# ---
#
# ### Phase 2: Self-Play Reinforcement Learning ("Learn to Win")
#
# ```
# ‚îå‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îê
# ‚îÇ                        Self-Play Loop                           ‚îÇ
# ‚îÇ                                                                 ‚îÇ
# ‚îÇ  1. GENERATE: Model plays itself using MCTS                     ‚îÇ
# ‚îÇ     ‚îå‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îê                                                 ‚îÇ
# ‚îÇ     ‚îÇ  MCTS   ‚îÇ ‚Üê Neural network guides tree search             ‚îÇ
# ‚îÇ     ‚îÇ Search  ‚îÇ ‚Üê Explores moves, evaluates positions           ‚îÇ
# ‚îÇ     ‚îî‚îÄ‚îÄ‚îÄ‚îÄ‚î¨‚îÄ‚îÄ‚îÄ‚îÄ‚îò                                                 ‚îÇ
# ‚îÇ          ‚ñº                                                      ‚îÇ
# ‚îÇ     Game trajectory: [(board, MCTS_policy, outcome), ...]       ‚îÇ
# ‚îÇ                                                                 ‚îÇ
# ‚îÇ  2. TRAIN: Update network to match MCTS                         ‚îÇ
# ‚îÇ     ‚Ä¢ Policy ‚Üí imitate MCTS visit distribution                  ‚îÇ
# ‚îÇ     ‚Ä¢ Value ‚Üí predict game outcome                              ‚îÇ
# ‚îÇ                                                                 ‚îÇ
# ‚îÇ  3. EVALUATE: New model vs previous best                        ‚îÇ
# ‚îÇ     ‚Ä¢ If win rate > 55%: promote to new best                    ‚îÇ
# ‚îÇ     ‚Ä¢ Otherwise: keep training                                  ‚îÇ
# ‚îÇ                                                                 ‚îÇ
# ‚îÇ  4. REPEAT                                                      ‚îÇ
# ‚îî‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îò
# ```
#
# **MCTS (Monte Carlo Tree Search)**:
# - Simulates many possible game continuations
# - Uses the neural network to evaluate leaf positions
# - Balances exploration (try new moves) vs exploitation (play good moves)
# - Produces a better policy than the raw network output
#
# **Anti-Collapse Measures** (critical for small models):
# - **Dirichlet noise**: Adds randomness at root to ensure exploration
# - **Random opening ply**: Forces random moves in first ~8 ply
# - **Temperature schedule**: High temp early (explore) ‚Üí low temp late (exploit)
# - **Checkpoint pool**: Plays against diverse past versions, not just self
#
# ---
#
# ## Reading the Training Output
#
# ### Phase 1 Output
# ```
# Grad norms | policy 0.2639 | value 0.2376    ‚Üê Gradient magnitudes (should be similar)
# Epoch 010 | loss 1.234 | policy 1.100 | value 0.134 | legal 0.923 | exact 0.152
#                                                       ‚ñ≤
#                                               Key metric! Target >0.95
# ```
#
# ### Phase 2 Output
# ```
# Generating 50 self-play games...
# Self-play throughput: 45.2 positions/sec     ‚Üê Speed (batched mode is faster)
# Generated 3200 positions
#
# Iter 0 Step 100/500 | policy 2.34 | value 0.89 | grad_p 0.15 | grad_v 0.12
#                       ‚ñ≤             ‚ñ≤
#                       ‚îÇ             ‚îî‚îÄ‚îÄ Value loss (predicting winner)
#                       ‚îî‚îÄ‚îÄ Policy loss (matching MCTS)
#
# üöÄ New Champion! Win Rate: 62.50%            ‚Üê Model improved!
# Champion holds. Win Rate: 45.00%             ‚Üê Model didn't improve this iteration
# ```
#
# ---
#
# ## Expected Training Time (Colab T4)
#
# | Phase | Duration | Notes |
# |-------|----------|-------|
# | Data generation (10k games) | ~5 min | One-time |
# | Phase 1 (30 epochs) | ~15-20 min | Until legal_acc > 95% |
# | Phase 2 (per iteration) | ~10-15 min | Run 10-50 iterations |
#
# ---
#
# ## What Success Looks Like
#
# **Phase 1 complete when**:
# - Legal move accuracy > 95%
# - Value loss is decreasing (not stuck)
# - Gradient norms are balanced (within 10√ó)
#
# **Phase 2 progress**:
# - Win rate vs Phase 1 model increases over iterations
# - Self-play games show sensible chess (not random moves)
# - Policy loss decreases as model matches MCTS better
#
# **Final model**:
# - Plays legal chess 100% of the time
# - Has learned basic tactics (captures hanging pieces)
# - Understands piece values and king safety
# - Won't beat Stockfish, but plays recognizable chess!

In [None]:
# @title 1. Setup Environment & Dependencies { run: "auto" }
import os
import sys
import shutil
from pathlib import Path

# --- Google Drive Setup (do this FIRST) ---
DRIVE_FOLDER = "TinyAlphaZero"  # @param {type:"string"}
RESTORE_FROM_DRIVE = True  # @param {type:"boolean"}

from google.colab import drive
drive.mount('/content/drive')

drive_path = Path(f"/content/drive/MyDrive/{DRIVE_FOLDER}")
drive_checkpoints = drive_path / "checkpoints"

# --- Clone/update repository ---
REPO_URL = "https://github.com/tripptytrip/Tiny-AlphaZero"
REPO_DIR = "/content/Tiny-AlphaZero"  # Use absolute path

if not os.path.exists(REPO_DIR):
    print("üì• Cloning repository...")
    !git clone {REPO_URL} {REPO_DIR}
else:
    print("üîÑ Updating repository...")
    !cd {REPO_DIR} && git pull --ff-only

# --- Restore checkpoints from Google Drive (BEFORE chdir) ---
local_checkpoints = Path(REPO_DIR) / "checkpoints"

if RESTORE_FROM_DRIVE and drive_checkpoints.exists():
    checkpoint_files = list(drive_checkpoints.rglob("*.pt"))
    if checkpoint_files:
        print(f"\nüìÇ Restoring {len(checkpoint_files)} checkpoints from Google Drive...")
        for ckpt in checkpoint_files:
            dest = local_checkpoints / ckpt.relative_to(drive_checkpoints)
            dest.parent.mkdir(parents=True, exist_ok=True)
            shutil.copy2(ckpt, dest)
            print(f"   ‚úì {dest.name}")
        print(f"‚úÖ Restored from {drive_checkpoints}")
    else:
        print(f"\nüìÇ No .pt files found in {drive_checkpoints}")
else:
    print(f"\nüìÇ No Drive checkpoints to restore (fresh start)")

# Now change to repo directory
os.chdir(REPO_DIR)

# --- Install dependencies ---
!pip install -q chess pyyaml tqdm

# --- Verify hardware ---
import torch
print(f"\n{'='*50}")
print(f"PyTorch: {torch.__version__}")
print(f"CUDA Available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    gpu_name = torch.cuda.get_device_name(0)
    gpu_mem = torch.cuda.get_device_properties(0).total_memory / 1e9
    print(f"üöÄ GPU: {gpu_name} ({gpu_mem:.1f} GB)")

    if "T4" in gpu_name:
        RECOMMENDED_BATCH = 512
        RECOMMENDED_BATCHED_GAMES = 16
    elif "L4" in gpu_name or "A100" in gpu_name:
        RECOMMENDED_BATCH = 1024
        RECOMMENDED_BATCHED_GAMES = 32
    else:
        RECOMMENDED_BATCH = 256
        RECOMMENDED_BATCHED_GAMES = 8
    print(f"üìä Recommended batch size: {RECOMMENDED_BATCH}")
    print(f"üìä Recommended batched games: {RECOMMENDED_BATCHED_GAMES}")
else:
    print("‚ö†Ô∏è  No GPU! Go to Runtime > Change runtime type > T4 GPU")
    RECOMMENDED_BATCH = 128
    RECOMMENDED_BATCHED_GAMES = 4
print(f"{'='*50}\n")

# --- Show restored checkpoints ---
if local_checkpoints.exists():
    all_ckpts = list(local_checkpoints.rglob("*.pt"))
    if all_ckpts:
        print("üìä Current checkpoints:")
        for ckpt in sorted(all_ckpts):
            print(f"   {ckpt.relative_to(local_checkpoints)}")

# Add src to path
sys.path.insert(0, os.path.join(REPO_DIR, "src"))

In [None]:
# =============================================================================
# CELL 2: Phase 1 - Generate Training Data
# =============================================================================

# @title 2. Phase 1: Generate Training Data { run: "auto" }

# --- Data Generation Parameters ---
NUM_GAMES = 10000  # @param {type:"integer"}
MAX_GAME_LEN = 150  # @param {type:"integer"}
USE_BIASED_SAMPLING = True  # @param {type:"boolean"}

# Biased sampling generates more tactical positions (captures, checks)
# This helps the model learn faster than pure random play

import time
from pathlib import Path

REPO_DIR = "/content/Tiny-AlphaZero"

start = time.time()

cmd = f"python3 {REPO_DIR}/scripts/generate_data.py --num-games {NUM_GAMES} --max-game-length {MAX_GAME_LEN} --output {REPO_DIR}/data/train_games.json"
if USE_BIASED_SAMPLING:
    cmd += " --biased --capture-weight 3.0 --check-weight 2.0"

print(f"üé≤ Generating {NUM_GAMES} games...")
!{cmd}

# Convert to memmap (faster loading during training)
print("\nüíæ Converting to memmap format...")
!python3 {REPO_DIR}/scripts/convert_to_memmap.py --input {REPO_DIR}/data/train_games.json --output-dir {REPO_DIR}/data/phase1

elapsed = time.time() - start
print(f"\n‚úÖ Data generation complete in {elapsed:.1f}s")




In [None]:
# =============================================================================
# CELL 3: Phase 1 - Train Supervised Model
# =============================================================================

# @title 3. Phase 1: Train Supervised Model { run: "auto" }

# --- Training Parameters ---
EPOCHS = 30  # @param {type:"integer"}
BATCH_SIZE = 512  # @param {type:"integer"}
LEARNING_RATE = 0.001  # @param {type:"number"}
VALUE_WEIGHT = 2.0  # @param {type:"number"}

# --- Drive Sync ---
DRIVE_FOLDER = "TinyAlphaZero"  # @param {type:"string"}
SYNC_TO_DRIVE = True  # @param {type:"boolean"}

import yaml
import shutil
from pathlib import Path

REPO_DIR = "/content/Tiny-AlphaZero"
drive_checkpoints = Path(f"/content/drive/MyDrive/{DRIVE_FOLDER}/checkpoints")
local_checkpoints = Path(REPO_DIR) / "checkpoints"

# Update config file with our parameters
config_path = f"{REPO_DIR}/config/phase1.yaml"
with open(config_path, 'r') as f:
    config = yaml.safe_load(f)

config['training']['epochs'] = EPOCHS
config['training']['batch_size'] = BATCH_SIZE
config['training']['learning_rate'] = LEARNING_RATE
config['training']['value_weight'] = VALUE_WEIGHT
config['data']['num_workers'] = 2  # Colab has limited CPU cores

with open(config_path, 'w') as f:
    yaml.dump(config, f)

print(f"‚öôÔ∏è  Config: epochs={EPOCHS}, batch={BATCH_SIZE}, lr={LEARNING_RATE}, value_weight={VALUE_WEIGHT}")
print(f"\nüß† Training Phase 1 Model...\n")

!cd {REPO_DIR} && python3 scripts/train_phase1.py --data-dir data/phase1 --epochs {EPOCHS}

print("\n‚úÖ Phase 1 training complete!")
print(f"üìÅ Checkpoint saved to: {local_checkpoints}/phase1/best.pt")

# Sync to Drive
if SYNC_TO_DRIVE:
    phase1_local = local_checkpoints / "phase1"
    if phase1_local.exists():
        for ckpt in phase1_local.glob("*.pt"):
            dest = drive_checkpoints / "phase1" / ckpt.name
            dest.parent.mkdir(parents=True, exist_ok=True)
            shutil.copy2(ckpt, dest)
            print(f"üíæ Synced: {ckpt.name} ‚Üí Drive")

In [None]:
# @title 4. Phase 1: Robust Model Validation (Visual) { run: "auto" }

import torch
import chess
import chess.svg
import sys
import os
import yaml
from IPython.display import display, HTML

# --- Setup ---
sys.path.insert(0, os.path.join(os.getcwd(), "src"))
from model.transformer import ChessTransformer
from data.encoding import encode_board, encode_move, decode_move

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
checkpoint_path = "checkpoints/phase1/best.pt"

# --- Helper Functions ---
def get_model_from_checkpoint(path, device):
    if not os.path.exists(path):
        return None, f"‚ùå Checkpoint not found at {path}"

    try:
        ckpt = torch.load(path, map_location=device)
        # Attempt to read config from checkpoint, fallback to defaults if missing
        model_cfg = ckpt.get('config', {})
        # Flatten nested config if it exists (handles both 'model': {...} and flat)
        if 'model' in model_cfg: model_cfg = model_cfg['model']

        model = ChessTransformer(
            vocab_size=model_cfg.get('vocab_size', 23),
            num_moves=model_cfg.get('num_moves', 4096),
            d_model=model_cfg.get('d_model', 256),
            n_layers=model_cfg.get('n_layers', 6),
            n_heads=model_cfg.get('n_heads', 8)
        ).to(device)

        model.load_state_dict(ckpt['model_state_dict'])
        model.eval()
        return model, "‚úÖ Model loaded successfully"
    except Exception as e:
        return None, f"‚ùå Error loading model: {str(e)}"

def analyze_position(model, fen, title):
    board = chess.Board(fen)
    tokens = torch.tensor(encode_board(board), dtype=torch.long, device=device).unsqueeze(0)

    with torch.no_grad():
        policy, value_logits = model(tokens)
        probs = torch.softmax(policy, dim=-1).squeeze()
        value = model.get_value(value_logits).item()

    # Metrics
    legal_indices = [encode_move(m) for m in board.legal_moves]
    legal_mass = probs[legal_indices].sum().item()

    # Visualization
    svg = chess.svg.board(board, size=300)

    # HTML Output Construction
    html = f"""
    <div style="display: flex; align-items: flex-start; margin-bottom: 20px; border: 1px solid #444; padding: 10px; border-radius: 8px;">
        <div style="margin-right: 20px;">{svg}</div>
        <div>
            <h3>{title}</h3>
            <p><strong>Legal Move Probability Mass:</strong>
                <span style="color: {'lime' if legal_mass > 0.95 else 'orange' if legal_mass > 0.8 else 'red'}">
                {legal_mass:.2%}
                </span>
            </p>
            <p><strong>Value Prediction:</strong> {value:.3f} (Win Prob)</p>
            <h4>Top 5 Predictions:</h4>
            <ul style="list-style-type: none; padding: 0;">
    """

    top_k = probs.topk(5)
    for idx, p in zip(top_k.indices.tolist(), top_k.values.tolist()):
        try:
            # We use a trick to decode: create a dummy board or just rank/file logic
            from_sq, to_sq = idx // 64, idx % 64
            move_uci = chess.square_name(from_sq) + chess.square_name(to_sq)

            # Check legality
            is_legal = idx in legal_indices
            icon = "‚úÖ" if is_legal else "‚ùå"
            color = "#EEE" if is_legal else "#F88"

            html += f"<li style='color: {color}; font-family: monospace;'>{move_uci}: {p:.2%} {icon}</li>"
        except:
            html += f"<li>Error decoding move {idx}</li>"

    html += "</ul></div></div>"
    return html, legal_mass

# --- Execution ---
model, status = get_model_from_checkpoint(checkpoint_path, device)
print(status)

if model:
    test_cases = [
        (chess.STARTING_FEN, "1. Starting Position (Opening Knowledge)"),
        ("rnbqkbnr/pppp1ppp/8/4p3/6P1/5P2/PPPPP2P/RNBQKBNR b KQkq - 0 2", "2. Fool's Mate Pattern (Black to Move)"),
        ("4k3/8/8/8/8/8/8/4R1K1 b - - 0 1", "3. Check Evasion (Must Move King)")
    ]

    total_mass = 0
    full_html = "<h2>‚ôüÔ∏è Model Diagnostics</h2>"

    for fen, title in test_cases:
        html_part, mass = analyze_position(model, fen, title)
        full_html += html_part
        total_mass += mass

    avg_mass = total_mass / len(test_cases)
    display(HTML(full_html))

    print(f"{'='*40}")
    print(f"üìä Average Legal Accuracy: {avg_mass:.2%}")
    if avg_mass > 0.95:
        print("üéâ PASSED: Model is robust and ready for Phase 2.")
    elif avg_mass > 0.80:
        print("‚ö†Ô∏è  WARNING: Model is shaky. Phase 2 might be inefficient.")
    else:
        print("‚ùå FAILED: Do not proceed to Phase 2. Train more on Phase 1.")

In [None]:
# =============================================================================
# CELL 5: Phase 2 - Self-Play Training Loop
# =============================================================================

# @title 5. Phase 2: Self-Play Training Loop { run: "auto" }

# --- Self-Play Parameters ---
NUM_ITERATIONS = 50  # @param {type:"integer"}
GAMES_PER_ITER = 50  # @param {type:"integer"}
MCTS_SIMS = 400  # @param {type:"integer"}
TRAINING_STEPS = 500  # @param {type:"integer"}

# --- Performance Tuning ---
USE_BATCHED = True  # @param {type:"boolean"}
BATCHED_GAMES = 32  # @param {type:"integer"}
BATCH_SIZE = 1024  # @param {type:"integer"}

# --- Anti-Collapse Settings ---
DIRICHLET_ALPHA = 0.3  # @param {type:"number"}
DIRICHLET_FRAC = 0.25  # @param {type:"number"}
RANDOM_OPENING_PLY = 8  # @param {type:"integer"}
RANDOM_OPENING_PROB = 0.5  # @param {type:"number"}

# --- Drive Sync ---
DRIVE_FOLDER = "TinyAlphaZero"  # @param {type:"string"}
SAVE_TO_DRIVE_EVERY = 1  # @param {type:"integer"}

import os
import shutil
from pathlib import Path

REPO_DIR = "/content/Tiny-AlphaZero"
drive_checkpoints = Path(f"/content/drive/MyDrive/{DRIVE_FOLDER}/checkpoints")
local_checkpoints = Path(REPO_DIR) / "checkpoints"

def sync_to_drive():
    """Copy local checkpoints to Google Drive."""
    if not local_checkpoints.exists():
        return
    for ckpt in local_checkpoints.rglob("*.pt"):
        dest = drive_checkpoints / ckpt.relative_to(local_checkpoints)
        dest.parent.mkdir(parents=True, exist_ok=True)
        shutil.copy2(ckpt, dest)
    print(f"üíæ Synced checkpoints to Google Drive")

def find_latest_checkpoint():
    """Find the most recent checkpoint (prefer phase2, fall back to phase1)."""
    phase2_ckpts = sorted(local_checkpoints.glob("phase2/phase2_iter_*.pt"))
    if phase2_ckpts:
        return str(phase2_ckpts[-1])

    phase1_best = local_checkpoints / "phase1" / "best.pt"
    if phase1_best.exists():
        return str(phase1_best)

    return None

def get_starting_iteration():
    """Determine which iteration to start from based on existing checkpoints."""
    phase2_ckpts = sorted(local_checkpoints.glob("phase2/phase2_iter_*.pt"))
    if phase2_ckpts:
        last_ckpt = phase2_ckpts[-1].stem  # "phase2_iter_5"
        last_iter = int(last_ckpt.split("_")[-1])
        return last_iter + 1
    return 0

# Check for starting checkpoint
checkpoint = find_latest_checkpoint()
start_iteration = get_starting_iteration()

if checkpoint is None:
    print("‚ùå No checkpoint found! Run Phase 1 first.")
else:
    print(f"üöÄ Starting Phase 2 Self-Play Training")
    print(f"{'='*50}")
    print(f"Starting from: {Path(checkpoint).name}")
    print(f"Starting iteration: {start_iteration}")
    print(f"Target iterations: {NUM_ITERATIONS}")
    print(f"Games/iter: {GAMES_PER_ITER} | MCTS sims: {MCTS_SIMS}")
    print(f"Batched: {USE_BATCHED} ({BATCHED_GAMES} parallel)")
    print(f"Auto-save to Drive: every {SAVE_TO_DRIVE_EVERY} iteration(s)")
    print(f"{'='*50}\n")

    for iteration in range(start_iteration, NUM_ITERATIONS):
        print(f"\n{'='*50}")
        print(f"üîÑ ITERATION {iteration + 1}/{NUM_ITERATIONS}")
        print(f"{'='*50}")

        # Find latest checkpoint for this iteration
        ckpt = find_latest_checkpoint()

        cmd = f"""cd {REPO_DIR} && python3 scripts/train_phase2.py \
            --checkpoint {ckpt} \
            --num-games {GAMES_PER_ITER} \
            --mcts-sims {MCTS_SIMS} \
            --training-steps {TRAINING_STEPS} \
            --batch-size {BATCH_SIZE} \
            --dirichlet-alpha {DIRICHLET_ALPHA} \
            --dirichlet-frac {DIRICHLET_FRAC} \
            --random-opening-ply {RANDOM_OPENING_PLY} \
            --random-opening-prob {RANDOM_OPENING_PROB}"""

        if USE_BATCHED:
            cmd += f" --use-batched --batched-games {BATCHED_GAMES}"

        !{cmd}

        # Sync to Google Drive periodically
        if (iteration + 1) % SAVE_TO_DRIVE_EVERY == 0:
            sync_to_drive()

    # Final sync
    sync_to_drive()

    print(f"\n{'='*50}")
    print(f"‚úÖ Phase 2 Training Complete!")
    print(f"{'='*50}")



In [None]:
# =============================================================================
# CELL 6: Manual Save to Google Drive (if needed)
# =============================================================================

# @title 6. Save Model to Google Drive (Manual) { run: "auto" }

DRIVE_FOLDER = "TinyAlphaZero"  # @param {type:"string"}

import shutil
from pathlib import Path

REPO_DIR = "/content/Tiny-AlphaZero"
drive_checkpoints = Path(f"/content/drive/MyDrive/{DRIVE_FOLDER}/checkpoints")
local_checkpoints = Path(REPO_DIR) / "checkpoints"

if local_checkpoints.exists():
    count = 0
    for ckpt in local_checkpoints.rglob("*.pt"):
        dest = drive_checkpoints / ckpt.relative_to(local_checkpoints)
        dest.parent.mkdir(parents=True, exist_ok=True)
        shutil.copy2(ckpt, dest)
        print(f"üìÅ Saved: {ckpt.relative_to(local_checkpoints)}")
        count += 1
    print(f"\n‚úÖ {count} checkpoints saved to Google Drive: {drive_checkpoints}")
else:
    print("‚ùå No local checkpoints found")


In [None]:
# @title 7. Benchmark & Analyze Model { run: "auto" }

# --- Configuration ---
NUM_ARENA_GAMES = 20  # @param {type:"integer"}
MCTS_SIMS = 200  # @param {type:"integer"}
SHOW_SAMPLE_GAME = True  # @param {type:"boolean"}

import sys
from pathlib import Path

REPO_DIR = "/content/Tiny-AlphaZero"
sys.path.insert(0, f"{REPO_DIR}/src")

import torch
import chess
import chess.svg
from IPython.display import display, HTML, clear_output
import time

from model.transformer import ChessTransformer
from data.encoding import encode_board, encode_move, decode_move
from mcts.tree import MCTS, MCTSConfig

# --- Find Best Checkpoint ---
local_checkpoints = Path(REPO_DIR) / "checkpoints"

def find_best_checkpoint():
    """Find the most recent/best checkpoint."""
    # Priority: best_generation > phase2_iter > phase1/best
    for pattern in ["phase2/best_generation_*.pt", "phase2/phase2_iter_*.pt", "phase1/best.pt"]:
        matches = sorted(local_checkpoints.glob(pattern))
        if matches:
            return matches[-1]
    return None

checkpoint_path = find_best_checkpoint()

if checkpoint_path is None:
    print("‚ùå No checkpoints found. Run training first.")
else:
    print(f"üìä TinyAlphaZero Benchmark Suite")
    print(f"{'='*60}")
    print(f"Checkpoint: {checkpoint_path.relative_to(local_checkpoints)}")

    # --- Load Model ---
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = ChessTransformer().to(device)
    model.load_state_dict(torch.load(checkpoint_path, map_location=device)["model_state_dict"])
    model.eval()
    print(f"Device: {device}")
    print(f"{'='*60}\n")

    # =========================================================================
    # TEST 1: Legal Move Accuracy
    # =========================================================================
    print("üß™ TEST 1: Legal Move Accuracy")
    print("-" * 40)

    test_positions = [
        (chess.STARTING_FEN, "Starting position"),
        ("rnbqkbnr/pppppppp/8/8/4P3/8/PPPP1PPP/RNBQKBNR b KQkq - 0 1", "After 1.e4"),
        ("r1bqkb1r/pppp1ppp/2n2n2/4p3/2B1P3/5N2/PPPP1PPP/RNBQK2R w KQkq - 4 4", "Italian Game"),
        ("rnbqkb1r/pp2pppp/2p2n2/3p4/2PP4/2N5/PP2PPPP/R1BQKBNR w KQkq - 0 4", "Slav Defense"),
        ("r1bqr1k1/ppp2ppp/2np1n2/2b1p3/2B1P3/2NP1N2/PPP2PPP/R1BQR1K1 w - - 0 8", "Middlegame"),
        ("8/8/4k3/8/8/4K3/4P3/8 w - - 0 1", "King + Pawn endgame"),
        ("r3k2r/pppppppp/8/8/8/8/PPPPPPPP/R3K2R w KQkq - 0 1", "Castling available"),
    ]

    total_legal_mass = 0
    for fen, name in test_positions:
        board = chess.Board(fen)
        tokens = torch.tensor(encode_board(board), dtype=torch.long, device=device).unsqueeze(0)

        with torch.no_grad():
            policy, value_logits = model(tokens)
            probs = torch.softmax(policy, dim=-1).squeeze()
            value = model.get_value(value_logits).item()

        legal_indices = [encode_move(m) for m in board.legal_moves]
        legal_mass = probs[legal_indices].sum().item()
        total_legal_mass += legal_mass

        # Top move
        top_idx = probs.argmax().item()
        top_move = decode_move(top_idx, board)
        top_legal = "‚úÖ" if top_idx in legal_indices else "‚ùå"

        print(f"  {name:25s} | Legal: {legal_mass:6.2%} | Value: {value:+.3f} | Top: {top_move.uci()} {top_legal}")

    avg_legal = total_legal_mass / len(test_positions)
    print(f"\n  Average legal move mass: {avg_legal:.2%}")

    if avg_legal > 0.98:
        print("  ‚úÖ Excellent - model has mastered legal moves")
    elif avg_legal > 0.95:
        print("  ‚úÖ Good - ready for self-play")
    elif avg_legal > 0.90:
        print("  ‚ö†Ô∏è  Okay - may need more Phase 1 training")
    else:
        print("  ‚ùå Poor - needs more Phase 1 training")

    # =========================================================================
    # TEST 2: Arena vs Random Player
    # =========================================================================
    print(f"\nüß™ TEST 2: Arena vs Random Player ({NUM_ARENA_GAMES} games)")
    print("-" * 40)

    import random

    mcts = MCTS(model, config=MCTSConfig(num_simulations=MCTS_SIMS), device=device)

    wins, losses, draws = 0, 0, 0
    total_moves = 0

    for game_idx in range(NUM_ARENA_GAMES):
        board = chess.Board()
        model_is_white = (game_idx % 2 == 0)  # Alternate colors

        while not board.is_game_over() and board.fullmove_number < 150:
            if board.turn == chess.WHITE:
                is_model_turn = model_is_white
            else:
                is_model_turn = not model_is_white

            if is_model_turn:
                # Model plays with MCTS
                move_idx = mcts.select_move(board, temperature=0.1)
                move = decode_move(move_idx, board)
            else:
                # Random player
                move = random.choice(list(board.legal_moves))

            board.push(move)
            total_moves += 1

        result = board.result()
        if result == "1-0":
            if model_is_white:
                wins += 1
            else:
                losses += 1
        elif result == "0-1":
            if model_is_white:
                losses += 1
            else:
                wins += 1
        else:
            draws += 1

        # Progress indicator
        print(f"  Game {game_idx + 1:2d}/{NUM_ARENA_GAMES}: {result:7s} | Model as {'White' if model_is_white else 'Black'} | Moves: {board.fullmove_number}")

    win_rate = wins / NUM_ARENA_GAMES
    print(f"\n  Results: {wins}W - {draws}D - {losses}L")
    print(f"  Win rate: {win_rate:.1%}")
    print(f"  Avg game length: {total_moves / NUM_ARENA_GAMES:.1f} ply")

    if win_rate > 0.95:
        print("  üèÜ Excellent - dominates random play")
    elif win_rate > 0.85:
        print("  ‚úÖ Very good - strong tactical understanding")
    elif win_rate > 0.70:
        print("  ‚úÖ Good - solid progress")
    elif win_rate > 0.50:
        print("  ‚ö†Ô∏è  Okay - learning but needs more training")
    else:
        print("  ‚ùå Poor - not better than random yet")

    # =========================================================================
    # TEST 3: Tactical Puzzles
    # =========================================================================
    print(f"\nüß™ TEST 3: Tactical Puzzles")
    print("-" * 40)

    puzzles = [
        # (FEN, best_move_uci, description)
        ("r1bqkb1r/pppp1ppp/2n2n2/4p2Q/2B1P3/8/PPPP1PPP/RNB1K1NR w KQkq - 4 4", "h5f7", "Scholar's Mate"),
        ("r1b1k2r/ppppqppp/2n2n2/2b1p3/2B1P3/3P1N2/PPP2PPP/RNBQK2R w KQkq - 0 6", "c4f7", "Fork on f7"),
        ("rnbqkbnr/ppp2ppp/4p3/3pP3/3P4/8/PPP2PPP/RNBQKBNR w KQkq d6 0 4", "e5d6", "En passant capture"),
        ("r3k2r/pppppppp/8/8/8/8/PPPPPPPP/R3K2R w KQkq - 0 1", "e1g1", "Kingside castle"),
        ("8/5P2/8/8/8/8/8/4K2k w - - 0 1", "f7f8q", "Pawn promotion"),
    ]

    correct = 0
    for fen, best_uci, description in puzzles:
        board = chess.Board(fen)
        tokens = torch.tensor(encode_board(board), dtype=torch.long, device=device).unsqueeze(0)

        with torch.no_grad():
            policy, _ = model(tokens)
            probs = torch.softmax(policy, dim=-1).squeeze()

        # Mask illegal moves
        legal_indices = [encode_move(m) for m in board.legal_moves]
        masked_probs = torch.zeros_like(probs)
        masked_probs[legal_indices] = probs[legal_indices]

        top_idx = masked_probs.argmax().item()
        top_move = decode_move(top_idx, board)

        best_move = chess.Move.from_uci(best_uci)
        is_correct = (top_move == best_move)
        correct += int(is_correct)

        status = "‚úÖ" if is_correct else f"‚ùå (played {top_move.uci()})"
        print(f"  {description:20s} | Best: {best_uci} | {status}")

    puzzle_score = correct / len(puzzles)
    print(f"\n  Puzzle score: {correct}/{len(puzzles)} ({puzzle_score:.0%})")

    # =========================================================================
    # TEST 4: Sample Game Display
    # =========================================================================
    if SHOW_SAMPLE_GAME:
        print(f"\nüß™ TEST 4: Sample Game (Model vs Random)")
        print("-" * 40)

        board = chess.Board()
        moves = []

        while not board.is_game_over() and board.fullmove_number <= 40:
            if board.turn == chess.WHITE:
                # Model plays white with MCTS
                move_idx = mcts.select_move(board, temperature=0.3)
                move = decode_move(move_idx, board)
            else:
                # Random plays black
                move = random.choice(list(board.legal_moves))

            moves.append(move.uci())
            board.push(move)

        # Print moves in PGN-ish format
        print("  Moves:")
        move_str = ""
        for i, uci in enumerate(moves):
            if i % 2 == 0:
                move_str += f"  {i//2 + 1}. {uci}"
            else:
                move_str += f" {uci}"
            if (i + 1) % 10 == 0:
                print(move_str)
                move_str = ""
        if move_str:
            print(move_str)

        print(f"\n  Result: {board.result()}")
        print(f"  Final position FEN: {board.fen()}")

        # Display final board
        try:
            svg = chess.svg.board(board=board, size=350)
            display(HTML(f"<div style='margin: 20px 0;'>{svg}</div>"))
        except:
            print("  (SVG display not available)")

    # =========================================================================
    # SUMMARY
    # =========================================================================
    print(f"\n{'='*60}")
    print("üìä BENCHMARK SUMMARY")
    print(f"{'='*60}")
    print(f"  Legal move accuracy:  {avg_legal:.1%}")
    print(f"  Win rate vs random:   {win_rate:.1%}")
    print(f"  Tactical puzzles:     {correct}/{len(puzzles)}")
    print(f"{'='*60}")

    # Overall assessment
    score = (avg_legal * 0.3) + (win_rate * 0.5) + (puzzle_score * 0.2)
    if score > 0.90:
        print("üèÜ Overall: Excellent! Model plays strong chess.")
    elif score > 0.75:
        print("‚úÖ Overall: Good progress. Continue Phase 2 training.")
    elif score > 0.60:
        print("‚ö†Ô∏è  Overall: Learning. Needs more iterations.")
    else:
        print("‚ùå Overall: Early stage. Keep training.")

üìä TinyAlphaZero Benchmark Suite
Checkpoint: phase2/best_generation_0.pt
Device: cuda

üß™ TEST 1: Legal Move Accuracy
----------------------------------------
  Starting position         | Legal: 99.86% | Value: -0.229 | Top: a2a4 ‚úÖ
  After 1.e4                | Legal: 99.95% | Value: +0.213 | Top: c7c5 ‚úÖ
  Italian Game              | Legal: 97.50% | Value: -0.000 | Top: f3e5 ‚úÖ
  Slav Defense              | Legal: 96.38% | Value: +0.001 | Top: g2g3 ‚úÖ
  Middlegame                | Legal: 95.44% | Value: -0.000 | Top: c4d5 ‚úÖ
  King + Pawn endgame       | Legal: 96.28% | Value: +0.000 | Top: e3f3 ‚úÖ
  Castling available        | Legal: 99.76% | Value: +0.403 | Top: f2f3 ‚úÖ

  Average legal move mass: 97.88%
  ‚úÖ Good - ready for self-play

üß™ TEST 2: Arena vs Random Player (20 games)
----------------------------------------
