# Transformer Backgammon - TPU Training

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/wmhowell18/claude-code/blob/main/transformer-backgammon/colab_tpu_training.ipynb)

Train a transformer-based backgammon AI on Google Colab TPUs.

**Setup:** Runtime ‚Üí Change runtime type ‚Üí **TPU**

## 1. Install Package

In [None]:
# Install TPU-specific JAX
!pip install -q jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

# Install backgammon package from subdirectory
!pip install -q 'git+https://github.com/wmhowell18/claude-code.git#subdirectory=transformer-backgammon'

print("‚úÖ Package installed!")

## 2. Verify TPU

In [None]:
import jax
import jax.numpy as jnp

# Force TPU backend
jax.config.update('jax_platform_name', 'tpu')

# Check devices
devices = jax.devices()
print(f"JAX backend: {jax.default_backend()}")
print(f"TPU cores: {len(devices)}")

# Quick test
x = jnp.ones((1000, 1000))
y = jnp.dot(x, x)
print(f"\n‚úÖ TPU ready! Test result shape: {y.shape}")

## 3. Mount Google Drive

In [None]:
from google.colab import drive
from pathlib import Path

# Mount Google Drive
drive.mount('/content/drive')

# Create directories for this training run
SAVE_DIR = Path('/content/drive/MyDrive/backgammon_training')
CHECKPOINT_DIR = SAVE_DIR / 'checkpoints'
LOG_DIR = SAVE_DIR / 'logs'

CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True)
LOG_DIR.mkdir(parents=True, exist_ok=True)

print(f"‚úÖ Checkpoints: {CHECKPOINT_DIR}")
print(f"‚úÖ Logs: {LOG_DIR}")

## 4. Configure Training

In [None]:
from backgammon.training.train import TrainingConfig

# TPU-optimized configuration
config = TrainingConfig(
    # Training phases
    warmstart_games=1000,      # Pip count vs pip count
    early_phase_games=5000,    # Neural self-play (simple)
    mid_phase_games=5000,      # Mixed complexity
    late_phase_games=5000,     # Full complexity
    
    # Batch sizes (TPU-optimized)
    games_per_batch=50,
    training_batch_size=256,   # Large batches for TPU!
    
    # Training mode
    train_policy=True,         # False for value-only (simpler)
    
    # Replay buffer
    replay_buffer_size=100000,
    replay_buffer_min_size=1000,
    train_steps_per_game_batch=10,
    
    # Checkpointing
    checkpoint_every_n_batches=100,
    log_every_n_batches=10,
    checkpoint_dir=str(CHECKPOINT_DIR),
    log_dir=str(LOG_DIR),
    
    seed=42,
)

total_games = sum([config.warmstart_games, config.early_phase_games, 
                   config.mid_phase_games, config.late_phase_games])
print(f"\nüìä Configuration:")
print(f"  Total games: {total_games:,}")
print(f"  Batch size: {config.training_batch_size}")
print(f"  Mode: {'Policy + Value' if config.train_policy else 'Value only'}")
print(f"\n‚úÖ Ready to train!")

## 5. Run Training

‚è±Ô∏è **Estimated time:** ~10-13 hours for 16,000 games on TPU v2-8

The training will:
- Start with warmstart (pip count agents)
- Progress to neural self-play
- Save checkpoints every 100 batches
- Log metrics every 10 batches

In [None]:
from backgammon.training.train import train

try:
    train(config)
    print("\nüéâ Training complete!")
except KeyboardInterrupt:
    print("\n‚ö†Ô∏è Interrupted. Checkpoints saved in Google Drive.")
except Exception as e:
    print(f"\n‚ùå Error: {e}")
    raise

## 6. Monitor Progress

In [None]:
import json

log_file = LOG_DIR / 'training_log.jsonl'

if log_file.exists():
    with open(log_file) as f:
        lines = f.readlines()
    
    print(f"Total entries: {len(lines)}\n")
    print("Last 20 batches:")
    print("=" * 80)
    
    for line in lines[-20:]:
        e = json.loads(line)
        phase = e.get('phase', 'unknown')
        batch = e.get('batch_num', 0)
        loss = e.get('loss', 0.0)
        games = e.get('total_games', 0)
        print(f"[{phase:8s}] Batch {batch:4d} | Games: {games:5d} | Loss: {loss:.4f}")
else:
    print("No logs yet. Start training first!")

## 7. Plot Training Curves

In [None]:
import json
import matplotlib.pyplot as plt

log_file = LOG_DIR / 'training_log.jsonl'

if log_file.exists():
    with open(log_file) as f:
        entries = [json.loads(line) for line in f]
    
    batches = [e['batch_num'] for e in entries]
    losses = [e.get('loss', 0) for e in entries]
    win_rates = [e.get('white_win_rate', 0) for e in entries]
    
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
    
    # Loss
    ax1.plot(batches, losses, 'b-', alpha=0.7)
    ax1.set_xlabel('Batch')
    ax1.set_ylabel('Loss')
    ax1.set_title('Training Loss')
    ax1.grid(True, alpha=0.3)
    
    # Win rate
    ax2.plot(batches, win_rates, 'g-', alpha=0.7)
    ax2.axhline(y=0.5, color='r', linestyle='--', alpha=0.5, label='50%')
    ax2.set_xlabel('Batch')
    ax2.set_ylabel('Win Rate')
    ax2.set_title('White Win Rate')
    ax2.legend()
    ax2.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    print(f"\nüìä Summary:")
    print(f"  Batches: {len(entries)}")
    print(f"  Final loss: {losses[-1]:.4f}")
    print(f"  Final win rate: {win_rates[-1]:.1%}")
else:
    print("No logs found yet.")

## 8. Test Trained Model

In [None]:
import jax
from flax.training import checkpoints
from backgammon.training.train import create_train_state
from backgammon.core.game import GameEngine
from backgammon.evaluation.network_agent import NeuralAgent
from backgammon.evaluation.agents import PipCountAgent

# Load checkpoint
print("Loading latest checkpoint...")
rng = jax.random.PRNGKey(42)
state = create_train_state(config, rng)
state = checkpoints.restore_checkpoint(ckpt_dir=str(CHECKPOINT_DIR), target=state)
print("‚úÖ Model loaded!")

# Play test game
print("\nPlaying test game: Neural vs Pip Count...")
neural = NeuralAgent(state=state, temperature=0.0, name="Neural")
pip = PipCountAgent(name="PipCount")

engine = GameEngine()
result = engine.play_game(neural, pip, seed=42)

print(f"\nüé≤ Result:")
print(f"  Winner: {result.winner}")
print(f"  Points: {result.points}")
print(f"  Moves: {len(result.move_history)}")

---

## Tips & Notes

### Adjusting Configuration
- **Faster training**: Reduce game counts (e.g., 100/500/500/500)
- **Bigger model**: Increase `embed_dim`, `num_layers`, `ff_dim`
- **Value-only**: Set `train_policy=False` (simpler, faster)
- **More exploration**: Increase `neural_agent_temperature` to 1.5

### Common Issues
- **No TPU**: Runtime ‚Üí Change runtime type ‚Üí TPU
- **Out of memory**: Reduce `training_batch_size` to 128
- **Slow training**: Increase `training_batch_size` (TPUs like 256+)
- **Timeout**: Colab Pro gives 24hr sessions vs 12hr free

### Resuming Training
If disconnected, checkpoints are safe in Google Drive. Just re-run cells 1-4 and load the checkpoint to continue.

---

**Repository:** [github.com/wmhowell18/claude-code/transformer-backgammon](https://github.com/wmhowell18/claude-code/tree/main/transformer-backgammon)

**Good luck! üé≤ü§ñ**