# Transformer Backgammon - TPU Training on Google Colab

This notebook trains a transformer-based backgammon AI on Google Colab TPUs.

**Before running:**
1. Runtime ‚Üí Change runtime type ‚Üí TPU
2. Make sure you have your code pushed to GitHub
3. Update the `GITHUB_REPO` variable below with your repository URL

## 1. Setup - Install Dependencies

In [None]:
# Install TPU-specific JAX
!pip install -q jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

# Install other dependencies
!pip install -q flax optax

print("‚úÖ Dependencies installed!")

## 2. Verify TPU is Available

In [None]:
import jax
import jax.numpy as jnp

# Force TPU backend
jax.config.update('jax_platform_name', 'tpu')

# Check devices
devices = jax.devices()
print(f"JAX backend: {jax.default_backend()}")
print(f"Available devices: {devices}")
print(f"Number of TPU cores: {len(devices)}")

# Quick test
x = jnp.ones((1000, 1000))
y = jnp.dot(x, x)
print(f"\n‚úÖ TPU is working! Test computation result shape: {y.shape}")

## 3. Clone Your Repository

In [None]:
import os

# UPDATE THIS with your GitHub repository URL
GITHUB_REPO = "https://github.com/YOUR_USERNAME/transformer-backgammon.git"
BRANCH = "claude/sprint2-encoder-mjosid5da1v9bhq8-EaMvJ"  # Or your main branch

# Clone repository
if not os.path.exists('transformer-backgammon'):
    !git clone {GITHUB_REPO}
    %cd transformer-backgammon
    !git checkout {BRANCH}
else:
    %cd transformer-backgammon
    !git pull origin {BRANCH}

# Install the package
!pip install -q -e .

print("\n‚úÖ Repository cloned and installed!")

## 4. Mount Google Drive (for saving checkpoints)

In [None]:
from google.colab import drive
from pathlib import Path

# Mount Google Drive
drive.mount('/content/drive')

# Create directories for this training run
SAVE_DIR = Path('/content/drive/MyDrive/backgammon_training')
CHECKPOINT_DIR = SAVE_DIR / 'checkpoints'
LOG_DIR = SAVE_DIR / 'logs'

CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True)
LOG_DIR.mkdir(parents=True, exist_ok=True)

print(f"‚úÖ Checkpoints will be saved to: {CHECKPOINT_DIR}")
print(f"‚úÖ Logs will be saved to: {LOG_DIR}")

## 5. Configure Training for TPU

In [None]:
from backgammon.training.train import TrainingConfig

# TPU-optimized configuration
config = TrainingConfig(
    # Training phases - adjust based on how long you want to train
    warmstart_games=1000,      # Start with simple pip-count games
    early_phase_games=5000,    # Neural self-play with simple variants
    mid_phase_games=5000,      # Mixed complexity
    late_phase_games=5000,     # Full complexity
    
    # Batch sizes - TPUs love large batches!
    games_per_batch=50,         # Generate 50 games per batch
    training_batch_size=256,    # Train on 256 positions at once (good for TPU)
    
    # Network architecture - using smaller default model
    embed_dim=128,              # Embedding dimension
    num_heads=8,                # Attention heads
    num_layers=4,               # Transformer layers
    ff_dim=512,                 # Feedforward dimension
    dropout_rate=0.1,
    
    # Training mode
    train_policy=True,          # Set to False for value-only training (simpler)
    
    # Replay buffer
    replay_buffer_size=100000,  # Keep 100k positions in memory
    replay_buffer_min_size=1000, # Start training after 1k positions
    train_steps_per_game_batch=10, # Train 10 times per game batch
    
    # Optimizer
    learning_rate=3e-4,
    
    # Self-play exploration
    neural_agent_temperature=1.0,
    
    # Checkpointing and logging
    checkpoint_every_n_batches=100,  # Save every 100 batches
    log_every_n_batches=10,          # Log every 10 batches
    checkpoint_dir=str(CHECKPOINT_DIR),
    log_dir=str(LOG_DIR),
    
    # Random seed
    seed=42,
)

print("Training Configuration:")
print(f"  Total games: {config.warmstart_games + config.early_phase_games + config.mid_phase_games + config.late_phase_games}")
print(f"  Batch size: {config.training_batch_size}")
print(f"  Model size: {config.embed_dim}d, {config.num_layers} layers")
print(f"  Training mode: {'Policy + Value' if config.train_policy else 'Value only'}")
print(f"\n‚úÖ Configuration ready!")

## 6. Run Training

**Note:** This will take a while (several hours depending on configuration). The training will:
1. Start with warmstart games (pip count vs pip count)
2. Progress to neural self-play
3. Save checkpoints to Google Drive every 100 batches
4. Log metrics every 10 batches

In [None]:
from backgammon.training.train import train

# Run training!
try:
    train(config)
    print("\nüéâ Training completed successfully!")
except KeyboardInterrupt:
    print("\n‚ö†Ô∏è Training interrupted. Checkpoints are saved in Google Drive.")
except Exception as e:
    print(f"\n‚ùå Training failed with error: {e}")
    raise

## 7. Monitor Training Progress

Run this cell in a separate tab while training is running to see real-time metrics.

In [None]:
# View the last 20 log entries
import json
from pathlib import Path

log_file = LOG_DIR / "training_log.jsonl"

if log_file.exists():
    with open(log_file, 'r') as f:
        lines = f.readlines()
    
    print(f"Total log entries: {len(lines)}\n")
    print("Last 20 entries:")
    print("=" * 100)
    
    for line in lines[-20:]:
        entry = json.loads(line)
        phase = entry.get('phase', 'unknown')
        batch = entry.get('batch_num', 0)
        loss = entry.get('loss', 0.0)
        games = entry.get('total_games', 0)
        
        print(f"[{phase:8s}] Batch {batch:4d} | Games: {games:5d} | Loss: {loss:.4f}")
else:
    print("No logs yet. Training hasn't started or just started.")

## 8. Analyze Training Results

In [None]:
import json
import matplotlib.pyplot as plt
from pathlib import Path

log_file = LOG_DIR / "training_log.jsonl"

if log_file.exists():
    # Load all log entries
    with open(log_file, 'r') as f:
        entries = [json.loads(line) for line in f]
    
    # Extract metrics
    batches = [e['batch_num'] for e in entries]
    losses = [e.get('loss', 0) for e in entries]
    win_rates = [e.get('white_win_rate', 0) for e in entries]
    
    # Create plots
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
    
    # Loss plot
    ax1.plot(batches, losses, 'b-', alpha=0.7)
    ax1.set_xlabel('Batch Number')
    ax1.set_ylabel('Loss')
    ax1.set_title('Training Loss Over Time')
    ax1.grid(True, alpha=0.3)
    
    # Win rate plot
    ax2.plot(batches, win_rates, 'g-', alpha=0.7)
    ax2.set_xlabel('Batch Number')
    ax2.set_ylabel('White Win Rate')
    ax2.set_title('Win Rate Over Time')
    ax2.axhline(y=0.5, color='r', linestyle='--', alpha=0.5, label='50% (balanced)')
    ax2.legend()
    ax2.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    # Print summary
    print(f"\nTraining Summary:")
    print(f"  Total batches: {len(entries)}")
    print(f"  Final loss: {losses[-1]:.4f}")
    print(f"  Final win rate: {win_rates[-1]:.2%}")
else:
    print("No training logs found yet.")

## 9. Load and Test Trained Model

In [None]:
import jax
from flax.training import checkpoints
from backgammon.training.train import create_train_state
from backgammon.core.game import GameEngine
from backgammon.evaluation.network_agent import NeuralAgent
from backgammon.evaluation.agents import PipCountAgent

# Load the latest checkpoint
print("Loading checkpoint...")
rng = jax.random.PRNGKey(42)
state = create_train_state(config, rng)

# Restore from checkpoint
state = checkpoints.restore_checkpoint(
    ckpt_dir=str(CHECKPOINT_DIR),
    target=state,
)

print("‚úÖ Model loaded!")

# Play a test game
print("\nPlaying test game: Neural Network vs Pip Count Agent...")

neural_agent = NeuralAgent(
    state=state,
    temperature=0.0,  # Greedy selection for evaluation
    name="NeuralNet",
)

pip_agent = PipCountAgent(name="PipCount")

engine = GameEngine()
result = engine.play_game(neural_agent, pip_agent, seed=42)

print(f"\nüé≤ Game Result:")
print(f"  Winner: {result.winner}")
print(f"  Points: {result.points}")
print(f"  Moves: {len(result.move_history)}")
print(f"  Duration: {len(result.move_history)} turns")

## 10. Download Checkpoints (Optional)

Your checkpoints are already saved in Google Drive, but you can also download them directly.

In [None]:
# Create a zip file of checkpoints
!cd /content/drive/MyDrive/backgammon_training && zip -r checkpoints.zip checkpoints/

print("‚úÖ Checkpoints zipped!")
print(f"   Location: {SAVE_DIR}/checkpoints.zip")
print("\nYou can download this from your Google Drive or use the Colab file browser.")

---

## Notes & Tips

### Training Time Estimates
- Warmstart (1000 games): ~30 min
- Early phase (5000 games): ~3-4 hours
- Mid phase (5000 games): ~3-4 hours
- Late phase (5000 games): ~3-4 hours
- **Total: ~10-13 hours for 16,000 games**

### TPU Optimization Tips
1. **Larger batches are better** - TPUs thrive on batch sizes of 128-512
2. **Save often** - Colab sessions can timeout, checkpoint frequently
3. **Use Google Drive** - Don't lose your work!
4. **Monitor memory** - TPUs have limited memory, adjust if you see OOM errors

### Adjusting Configuration
- **Faster training**: Reduce `*_phase_games` values
- **Better model**: Increase `embed_dim`, `num_layers`, or `ff_dim`
- **Value-only mode**: Set `train_policy=False` for simpler training
- **More exploration**: Increase `neural_agent_temperature` (e.g., 1.5)

### Resuming Training
If your session disconnects, just run cells 1-4 again, then modify cell 6:
```python
# Load existing checkpoint and continue
state = checkpoints.restore_checkpoint(
    ckpt_dir=str(CHECKPOINT_DIR),
    target=state,
)
```

### Common Issues
- **TPU not available**: Make sure Runtime ‚Üí Change runtime type ‚Üí TPU
- **Out of memory**: Reduce `training_batch_size` or `replay_buffer_size`
- **Slow training**: Increase `training_batch_size` (TPUs like big batches)
- **Session timeout**: Enable Colab Pro for longer sessions

---

**Good luck with your training! üé≤ü§ñ**