<a href="https://colab.research.google.com/github/sile16/pgx/blob/master/colab/benchmark_variants.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# PGX Performance Benchmark: 2048 & Backgammon Variants

This notebook benchmarks GPU-optimized variants of 2048 and Backgammon.

**Optimizations tested:**
- **Branchless operations**: Replace `jax.lax.cond` with `jnp.where`
- **No rotations**: Eliminate `jnp.rot90` memory shuffles
- **Fast observation**: Minimal observation (34 vs 86 elements)
- **Combined**: All optimizations together

**Instructions:**
1. Select GPU runtime: `Runtime > Change runtime type > GPU`
2. Run all cells: `Runtime > Run all`

## 1. Setup

In [None]:
# Clone the repository with optimized variants (master branch)
!git clone --branch master https://github.com/sile16/pgx.git
%cd pgx

In [None]:
# Install pgx from local source and monitoring dependencies
!pip install . gputil psutil

In [None]:
# Detect and display device information
import jax
import subprocess
import os

print("=" * 70)
print("DEVICE INFORMATION")
print("=" * 70)

# JAX info
print(f"JAX version: {jax.__version__}")
print(f"JAX backend: {jax.default_backend()}")
print(f"JAX devices: {jax.devices()}")
print()

# Get detailed device info
device = jax.devices()[0]
device_kind = device.device_kind
platform = device.platform

if platform == 'gpu':
    # Get NVIDIA GPU details
    try:
        result = subprocess.run(
            ['nvidia-smi', '--query-gpu=name,memory.total,driver_version', '--format=csv,noheader'],
            capture_output=True, text=True
        )
        gpu_info = result.stdout.strip()
        gpu_name, gpu_memory, driver = [x.strip() for x in gpu_info.split(',')]
        print(f"GPU Name: {gpu_name}")
        print(f"GPU Memory: {gpu_memory}")
        print(f"NVIDIA Driver: {driver}")
    except:
        print(f"GPU: {device_kind}")
        
elif platform == 'tpu':
    # TPU info
    print(f"TPU Type: {device_kind}")
    print(f"TPU Devices: {len(jax.devices())}")
    # Try to get more TPU info
    try:
        tpu_name = os.environ.get('TPU_NAME', 'Unknown')
        print(f"TPU Name: {tpu_name}")
    except:
        pass
else:
    print(f"Device: {device_kind} ({platform})")

print()
print("=" * 70)

# Store device info for results
DEVICE_SUMMARY = f"{platform}:{device_kind}"
print(f"\nDevice Summary: {DEVICE_SUMMARY}")

In [None]:
# Start GPU/resource monitoring (optional - provides live metrics dashboard)
# This sends metrics to an external service for visualization
try:
    import GPUtil
    import psutil
    from threading import Thread
    from time import sleep
    import requests

    class SimpleGPUMonitor:
        """Simple GPU monitor that prints stats periodically."""
        def __init__(self, interval=30):
            self.interval = interval
            self._running = False
            self._thread = None

        def _loop(self):
            while self._running:
                try:
                    gpus = GPUtil.getGPUs()
                    if gpus:
                        gpu = gpus[0]
                        print(f"[GPU Monitor] {gpu.name}: {gpu.load*100:.1f}% load, "
                              f"{gpu.memoryUsed:.0f}/{gpu.memoryTotal:.0f} MB ({gpu.memoryUtil*100:.1f}%)")
                except:
                    pass
                sleep(self.interval)

        def start(self):
            self._running = True
            self._thread = Thread(target=self._loop, daemon=True)
            self._thread.start()
            print(f"GPU monitoring started (updates every {self.interval}s)")
            return self

        def stop(self):
            self._running = False

    # Start monitoring
    gpu_monitor = SimpleGPUMonitor(interval=30).start()
except Exception as e:
    print(f"GPU monitoring not available: {e}")

## 2. Run Benchmark

The benchmark will:
1. Test all 2048 variants (original, branchless, no_rotate, all)
2. Test all Backgammon variants (original, fast_obs, branchless, all)
3. Output speedup comparisons

In [None]:
# Run the full benchmark
# - batch_size=4000 works well for T4/V100/A100 GPUs
# - num_batches=5 for 2048 (fast game)
# - bg_batches=2 for backgammon (slower, uses fewer batches to keep <15min total)
# - skip_validation speeds up the benchmark (validation already done on CPU)

!python benchmarks/benchmark_all_variants.py \
    --batch-size 4000 \
    --num-batches 5 \
    --bg-batches 2 \
    --skip-validation \
    --output-json benchmark_results.json

In [None]:
# Display results as a formatted table
import json
import subprocess

with open('benchmark_results.json', 'r') as f:
    results = json.load(f)

# Get GPU name for header
try:
    result = subprocess.run(
        ['nvidia-smi', '--query-gpu=name', '--format=csv,noheader'],
        capture_output=True, text=True
    )
    gpu_name = result.stdout.strip()
except:
    gpu_name = results['device']

print("=" * 70)
print(f"BENCHMARK RESULTS - {gpu_name}")
print("=" * 70)
print(f"Device: {results['device']}")
print(f"JAX: {results['jax_version']}")
print(f"Batch size: {results['config']['batch_size']}")
print()

# 2048 Results
print("2048 Performance:")
print(f"{'Variant':<15} {'Games/sec':>12} {'Steps/sec':>14} {'Speedup':>10}")
print("-" * 55)

baseline_2048 = results['benchmarks']['2048_original']['games_per_second']
for variant in ['original', 'branchless', 'no_rotate', 'all']:
    key = f'2048_{variant}'
    if key in results['benchmarks']:
        data = results['benchmarks'][key]
        speedup = data['games_per_second'] / baseline_2048
        print(f"{variant:<15} {data['games_per_second']:>12,.1f} {data['steps_per_second']:>14,.1f} {speedup:>9.2f}x")

print()

# Backgammon Results
print("Backgammon Performance:")
print(f"{'Variant':<15} {'Games/sec':>12} {'Steps/sec':>14} {'Speedup':>10}")
print("-" * 55)

baseline_bg = results['benchmarks']['backgammon_original']['games_per_second']
for variant in ['original', 'fast_obs', 'branchless', 'all']:
    key = f'backgammon_{variant}'
    if key in results['benchmarks']:
        data = results['benchmarks'][key]
        speedup = data['games_per_second'] / baseline_bg
        print(f"{variant:<15} {data['games_per_second']:>12,.1f} {data['steps_per_second']:>14,.1f} {speedup:>9.2f}x")

print()
print("=" * 70)

# Print copy-pasteable summary
print("\nðŸ“‹ COPY-PASTE SUMMARY:")
print("-" * 70)
print(f"Device: {gpu_name}")
print(f"Batch size: {results['config']['batch_size']}")
print()
print("2048:")
for variant in ['original', 'branchless', 'no_rotate', 'all']:
    key = f'2048_{variant}'
    if key in results['benchmarks']:
        data = results['benchmarks'][key]
        speedup = data['games_per_second'] / baseline_2048
        print(f"  {variant}: {data['games_per_second']:,.0f} games/sec ({speedup:.2f}x)")

print()
print("Backgammon:")
for variant in ['original', 'fast_obs', 'branchless', 'all']:
    key = f'backgammon_{variant}'
    if key in results['benchmarks']:
        data = results['benchmarks'][key]
        speedup = data['games_per_second'] / baseline_bg
        print(f"  {variant}: {data['games_per_second']:,.0f} games/sec ({speedup:.2f}x)")
print("-" * 70)

## 3. Custom Benchmark (Optional)

Run with different batch sizes to find optimal throughput:

In [None]:
# Uncomment to run with different batch sizes
# !python benchmarks/benchmark_all_variants.py --batch-size 1000 --num-batches 3 --skip-validation
# !python benchmarks/benchmark_all_variants.py --batch-size 2000 --num-batches 3 --skip-validation
# !python benchmarks/benchmark_all_variants.py --batch-size 8000 --num-batches 3 --skip-validation

## 4. Individual Game Benchmarks (Optional)

Run the original individual benchmarks for comparison:

In [None]:
# 2048 original benchmark
# !python benchmarks/benchmark_2048.py --batch-sizes 1000,2000,4000 --num-batches 3

In [None]:
# Backgammon original benchmark
# !python benchmarks/benchmark_backgammon.py --batch-sizes 1000,2000,4000 --num-batches 3 --short-game

## 5. TPU Benchmark (if available)

To run on TPU:
1. Change runtime to TPU: `Runtime > Change runtime type > TPU`
2. Run the cells below

In [None]:
# Uncomment for TPU setup
# import jax.tools.colab_tpu
# jax.tools.colab_tpu.setup_tpu()
# print(f"TPU devices: {jax.devices()}")

In [None]:
# Uncomment for TPU benchmark (use smaller batch size)
# !python benchmarks/benchmark_all_variants.py --batch-size 2000 --num-batches 5 --skip-validation