# Transformer Backgammon - Colab Training (TPU / GPU / CPU)

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/wmhowell18/claude-code/blob/main/transformer-backgammon/colab_tpu_training.ipynb)

Train a transformer-based backgammon AI on Google Colab.

**For best performance:** Runtime ‚Üí Change runtime type ‚Üí **TPU** (bfloat16 mixed precision enabled automatically)

Works on GPU and CPU too ‚Äî just slower.

## 1. Install Dependencies

In [None]:
import jax

backend = jax.default_backend()
print(f"Detected backend: {backend}")

if backend == "tpu":
    # TPU runtime ‚Äî install TPU-specific JAX wheels
    !pip install -q "jax[tpu]>=0.4.26" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
else:
    # GPU or CPU ‚Äî standard JAX is fine
    !pip install -q "jax>=0.4.26" "jaxlib>=0.4.26"

# Install the backgammon package from the repo
!pip install -q "flax>=0.8.0" "optax>=0.1.7"
!pip install -q "git+https://github.com/wmhowell18/claude-code.git@main#subdirectory=transformer-backgammon"

# Verify the package installed correctly
try:
    import backgammon
    print(f"\n‚úÖ Installed! Backend: {jax.default_backend()}, Devices: {jax.device_count()}")
except ImportError:
    print("\n‚ùå Install failed ‚Äî try restarting runtime (Runtime > Restart runtime) and re-running this cell")

## 2. Verify Hardware & bfloat16

In [None]:
import jax
import jax.numpy as jnp
import time

backend = jax.default_backend()
devices = jax.devices()
print(f"JAX version:  {jax.__version__}")
print(f"Backend:      {backend}")
print(f"Device count: {len(devices)}")
for d in devices:
    print(f"  {d}")

# Determine if bfloat16 is beneficial on this hardware
USE_BFLOAT16 = backend in ("tpu", "gpu")

# bfloat16 matmul benchmark
x32 = jnp.ones((1024, 1024), dtype=jnp.float32)
x16 = jnp.ones((1024, 1024), dtype=jnp.bfloat16)

# Warmup
jnp.dot(x32, x32).block_until_ready()
jnp.dot(x16, x16).block_until_ready()

t0 = time.time()
for _ in range(100):
    jnp.dot(x32, x32).block_until_ready()
t32 = time.time() - t0

t0 = time.time()
for _ in range(100):
    jnp.dot(x16, x16).block_until_ready()
t16 = time.time() - t0

speedup = t32 / t16
print(f"\nfloat32 matmul:  {t32*10:.1f} ms")
print(f"bfloat16 matmul: {t16*10:.1f} ms")
print(f"Speedup:         {speedup:.2f}x")
print(f"\n‚úÖ Will use: {'bfloat16' if USE_BFLOAT16 else 'float32'}")

## 3. Mount Google Drive

In [None]:
from google.colab import drive
from pathlib import Path

# Mount Google Drive
drive.mount('/content/drive')

# Create directories for this training run
SAVE_DIR = Path('/content/drive/MyDrive/backgammon_training')
CHECKPOINT_DIR = SAVE_DIR / 'checkpoints'
LOG_DIR = SAVE_DIR / 'logs'

CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True)
LOG_DIR.mkdir(parents=True, exist_ok=True)

print(f"‚úÖ Checkpoints: {CHECKPOINT_DIR}")
print(f"‚úÖ Logs: {LOG_DIR}")

## 4. Configure Training

Two presets:
- **Quick** (~2,500 games, ~15-30 min on TPU): validates the pipeline works
- **Full** (~16,000 games, ~4-8 hrs on TPU): real training run

In [None]:
#@title Training preset { run: "auto" }
PRESET = "quick"  #@param ["quick", "full"]

from backgammon.training.train import TrainingConfig, v6e_quick_training_config

if PRESET == "quick":
    # Quick validation run ‚Äî uses v6e-optimized config
    # ~2,500 games, small model (~500K params), bfloat16
    config = v6e_quick_training_config()
else:
    # Full training run ‚Äî larger model, more games
    config = TrainingConfig(
        warmstart_games=1000,
        early_phase_games=5000,
        mid_phase_games=5000,
        late_phase_games=5000,
        embed_dim=128,
        num_heads=8,
        num_layers=4,
        ff_dim=512,
        games_per_batch=50,
        training_batch_size=512,
        train_steps_per_game_batch=10,
        train_policy=False,
        compute_dtype='bfloat16' if USE_BFLOAT16 else None,
        replay_buffer_size=100_000,
        replay_buffer_min_size=1000,
        checkpoint_every_n_batches=100,
        log_every_n_batches=10,
        eval_every_n_batches=50,
        eval_num_games=50,
        seed=42,
    )

# Override dtype based on hardware detection
if not USE_BFLOAT16:
    config.compute_dtype = None

# Point checkpoints/logs to Google Drive
config.checkpoint_dir = str(CHECKPOINT_DIR)
config.log_dir = str(LOG_DIR)

total_games = (config.warmstart_games + config.early_phase_games +
               config.mid_phase_games + config.late_phase_games)

print(f"Preset:       {PRESET}")
print(f"Total games:  {total_games:,}")
print(f"Model:        {config.num_layers}L / {config.embed_dim}d / {config.num_heads}H / {config.ff_dim}ff")
print(f"Batch size:   {config.training_batch_size}")
print(f"Dtype:        {'bfloat16' if config.compute_dtype else 'float32'}")
print(f"Checkpoints:  {config.checkpoint_dir}")
print(f"\n‚úÖ Ready to train!")

## 5. Run Training

| Preset | Games | TPU time | GPU time | CPU time |
|--------|-------|----------|----------|----------|
| Quick  | 2,500 | ~15-30 min | ~1 hr | ~3 hrs |
| Full   | 16,000 | ~4-8 hrs | ~12+ hrs | not recommended |

Checkpoints are saved to Google Drive, so you won't lose progress if disconnected.

In [None]:
from backgammon.training.train import train
import time

print(f"Starting {PRESET} training run...")
print(f"  dtype: {'bfloat16' if config.compute_dtype else 'float32'}")
print(f"  backend: {jax.default_backend()}")
print()

t0 = time.time()
try:
    train(config)
    elapsed = time.time() - t0
    print(f"\nüéâ Training complete! ({elapsed/60:.1f} minutes)")
except KeyboardInterrupt:
    elapsed = time.time() - t0
    print(f"\n‚ö†Ô∏è Interrupted after {elapsed/60:.1f} min. Checkpoints saved to Google Drive.")
except Exception as e:
    elapsed = time.time() - t0
    print(f"\n‚ùå Error after {elapsed/60:.1f} min: {e}")
    raise

## 6. Monitor Progress

In [None]:
import json

log_file = LOG_DIR / 'training_log.jsonl'

if log_file.exists():
    with open(log_file) as f:
        lines = f.readlines()
    
    print(f"Total entries: {len(lines)}\n")
    print("Last 20 batches:")
    print("=" * 80)
    
    for line in lines[-20:]:
        e = json.loads(line)
        phase = e.get('phase', 'unknown')
        batch = e.get('batch_num', 0)
        loss = e.get('loss', 0.0)
        games = e.get('total_games', 0)
        print(f"[{phase:8s}] Batch {batch:4d} | Games: {games:5d} | Loss: {loss:.4f}")
else:
    print("No logs yet. Start training first!")

## 7. Plot Training Curves

In [None]:
import json
import matplotlib.pyplot as plt

log_file = LOG_DIR / 'training_log.jsonl'

if log_file.exists():
    with open(log_file) as f:
        entries = [json.loads(line) for line in f]
    
    batches = [e['batch_num'] for e in entries]
    losses = [e.get('loss', 0) for e in entries]
    win_rates = [e.get('white_win_rate', 0) for e in entries]
    
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
    
    # Loss
    ax1.plot(batches, losses, 'b-', alpha=0.7)
    ax1.set_xlabel('Batch')
    ax1.set_ylabel('Loss')
    ax1.set_title('Training Loss')
    ax1.grid(True, alpha=0.3)
    
    # Win rate
    ax2.plot(batches, win_rates, 'g-', alpha=0.7)
    ax2.axhline(y=0.5, color='r', linestyle='--', alpha=0.5, label='50%')
    ax2.set_xlabel('Batch')
    ax2.set_ylabel('Win Rate')
    ax2.set_title('White Win Rate')
    ax2.legend()
    ax2.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    print(f"\nüìä Summary:")
    print(f"  Batches: {len(entries)}")
    print(f"  Final loss: {losses[-1]:.4f}")
    print(f"  Final win rate: {win_rates[-1]:.1%}")
else:
    print("No logs found yet.")

## 8. Test Trained Model

In [None]:
import jax
from flax.training import checkpoints
from backgammon.training.train import create_train_state
from backgammon.core.game import GameEngine
from backgammon.evaluation.network_agent import NeuralAgent
from backgammon.evaluation.agents import PipCountAgent

# Load latest checkpoint
print("Loading latest checkpoint...")
rng = jax.random.PRNGKey(42)
state = create_train_state(config, rng)
state = checkpoints.restore_checkpoint(ckpt_dir=str(CHECKPOINT_DIR), target=state)
print(f"‚úÖ Model loaded (step {int(state.step)})")

# Play test games: Neural vs Pip Count
print("\nPlaying 10 test games: Neural vs Pip Count...")
neural = NeuralAgent(state=state, temperature=0.0, name="Neural")
pip_agent = PipCountAgent(name="PipCount")

engine = GameEngine()
wins = 0
for i in range(10):
    result = engine.play_game(neural, pip_agent, seed=42 + i)
    winner = "Neural" if result.winner == 1 else "PipCount"
    wins += 1 if result.winner == 1 else 0
    print(f"  Game {i+1}: {winner} wins ({len(result.move_history)} moves)")

print(f"\nNeural win rate: {wins}/10 ({wins*10}%)")

---

## Tips

### Presets
- **Quick** (`v6e_quick_training_config`): 2,500 games, small model, bfloat16. Good for validating the pipeline.
- **Full**: 16,000 games, medium model. Real training.

### Tuning
- **Bigger model**: In full preset, increase `embed_dim=256`, `num_layers=6`, `ff_dim=1024`
- **More games**: Increase `late_phase_games` (that's where the model improves most)
- **No bfloat16**: Set `config.compute_dtype = None` (slower but higher precision)

### Troubleshooting
| Problem | Fix |
|---------|-----|
| No TPU detected | Runtime ‚Üí Change runtime type ‚Üí TPU |
| OOM on GPU | Reduce `training_batch_size` to 128 |
| Slow on CPU | Use the "quick" preset, or switch to TPU/GPU runtime |
| Session timeout | Checkpoints are on Drive ‚Äî re-run cells 1-4, training resumes |

### Resuming after disconnect
Checkpoints are saved to Google Drive. Re-run cells 1-4 (install, verify, mount, config), then run cell 5. The `train()` function picks up from the latest checkpoint.

---

**Repo:** [github.com/wmhowell18/claude-code/transformer-backgammon](https://github.com/wmhowell18/claude-code/tree/main/transformer-backgammon)