<a href="https://colab.research.google.com/github/sile16/pgx/blob/master/colab/benchmark_variants.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# PGX Performance Benchmark: 2048 & Backgammon Variants

This notebook benchmarks GPU-optimized variants of 2048 and Backgammon.

**Optimizations tested:**
- **Branchless operations**: Replace `jax.lax.cond` with `jnp.where`
- **No rotations**: Eliminate `jnp.rot90` memory shuffles
- **Fast observation**: Minimal observation (34 vs 86 elements)
- **Combined**: All optimizations together

**Instructions:**
1. Select GPU runtime: `Runtime > Change runtime type > GPU`
2. Run all cells: `Runtime > Run all`

## 1. Setup

In [None]:
# Clone the repository with optimized variants
!git clone https://github.com/sile16/pgx.git
%cd pgx

In [None]:
# Install pgx from local source
!pip install -e . -q

In [None]:
# Check JAX device
import jax
print(f"JAX version: {jax.__version__}")
print(f"Devices: {jax.devices()}")
print(f"Default backend: {jax.default_backend()}")

## 2. Run Benchmark

The benchmark will:
1. Test all 2048 variants (original, branchless, no_rotate, all)
2. Test all Backgammon variants (original, fast_obs, branchless, all)
3. Output speedup comparisons

In [None]:
# Run the full benchmark
# - batch_size=4000 works well for T4/V100/A100 GPUs
# - num_batches=5 gives stable measurements
# - skip_validation speeds up the benchmark (validation already done on CPU)

!python benchmarks/benchmark_all_variants.py \
    --batch-size 4000 \
    --num-batches 5 \
    --skip-validation \
    --output-json benchmark_results.json

In [None]:
# Display results as a formatted table
import json
import pandas as pd

with open('benchmark_results.json', 'r') as f:
    results = json.load(f)

print(f"Device: {results['device']}")
print(f"JAX: {results['jax_version']}")
print(f"Batch size: {results['config']['batch_size']}")
print()

# Create DataFrames
data_2048 = []
data_bg = []

for key, val in results['benchmarks'].items():
    if key.startswith('2048_'):
        variant = key.replace('2048_', '')
        data_2048.append({
            'Variant': variant,
            'Games/sec': val['games_per_second'],
            'Steps/sec': val['steps_per_second'],
        })
    elif key.startswith('backgammon_'):
        variant = key.replace('backgammon_', '')
        data_bg.append({
            'Variant': variant,
            'Games/sec': val['games_per_second'],
            'Steps/sec': val['steps_per_second'],
        })

df_2048 = pd.DataFrame(data_2048)
df_bg = pd.DataFrame(data_bg)

# Calculate speedups
baseline_2048 = df_2048[df_2048['Variant'] == 'original']['Games/sec'].values[0]
baseline_bg = df_bg[df_bg['Variant'] == 'original']['Games/sec'].values[0]

df_2048['Speedup'] = df_2048['Games/sec'] / baseline_2048
df_bg['Speedup'] = df_bg['Games/sec'] / baseline_bg

print("=" * 60)
print("2048 RESULTS")
print("=" * 60)
print(df_2048.to_string(index=False, float_format=lambda x: f"{x:,.1f}" if x > 10 else f"{x:.2f}x"))

print()
print("=" * 60)
print("BACKGAMMON RESULTS")
print("=" * 60)
print(df_bg.to_string(index=False, float_format=lambda x: f"{x:,.1f}" if x > 10 else f"{x:.2f}x"))

## 3. Custom Benchmark (Optional)

Run with different batch sizes to find optimal throughput:

In [None]:
# Uncomment to run with different batch sizes
# !python benchmarks/benchmark_all_variants.py --batch-size 1000 --num-batches 3 --skip-validation
# !python benchmarks/benchmark_all_variants.py --batch-size 2000 --num-batches 3 --skip-validation
# !python benchmarks/benchmark_all_variants.py --batch-size 8000 --num-batches 3 --skip-validation

## 4. Individual Game Benchmarks (Optional)

Run the original individual benchmarks for comparison:

In [None]:
# 2048 original benchmark
# !python benchmarks/benchmark_2048.py --batch-sizes 1000,2000,4000 --num-batches 3

In [None]:
# Backgammon original benchmark
# !python benchmarks/benchmark_backgammon.py --batch-sizes 1000,2000,4000 --num-batches 3 --short-game

## 5. TPU Benchmark (if available)

To run on TPU:
1. Change runtime to TPU: `Runtime > Change runtime type > TPU`
2. Run the cells below

In [None]:
# Uncomment for TPU setup
# import jax.tools.colab_tpu
# jax.tools.colab_tpu.setup_tpu()
# print(f"TPU devices: {jax.devices()}")

In [None]:
# Uncomment for TPU benchmark (use smaller batch size)
# !python benchmarks/benchmark_all_variants.py --batch-size 2000 --num-batches 5 --skip-validation