[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/sile/pgx/blob/master/colab/benchmark_backgammon_pair.ipynb)


# Backgammon vs Backgammon2P Benchmark (Colab)

Run quick throughput benchmarks for both `backgammon` and `backgammon2p`.

- Set `PGX_REPO_URL` in the environment to use your own fork (defaults to `https://github.com/sile/pgx.git`).
- Attach a GPU runtime in Colab for best results. For CPU-only, set `JAX_PLATFORMS=cpu` before running the benchmark cell.


In [None]:
# Setup: clone repo (skip if already cloned) and install dependencies
import os

REPO_URL = os.environ.get("PGX_REPO_URL", "https://github.com/sile/pgx.git")
if not os.path.exists("pgx"):
    !git clone $REPO_URL pgx
%cd pgx
!pip install -q -r requirements/requirements.txt
import jax; jax.devices()

In [None]:
# Quick benchmark for both envs (backgammon and backgammon2p)
# Uses --profile to print warmup and per-batch timings.
!python benchmarks/benchmark_backgammon.py --envs backgammon,backgammon2p --quick --profile --output-json benchmarks/benchmark_results.json

In [None]:
# Plot games/sec and batch times from benchmark_results.json
import json, pathlib
import matplotlib.pyplot as plt

path = pathlib.Path('benchmarks/benchmark_results.json')
data = json.loads(path.read_text()) if path.exists() else {}
runs = data.get('benchmark_runs', [])
if not runs:
    raise SystemExit('No benchmark_results.json found or empty; run the benchmark cell first.')

# Take latest run for each env
latest_per_env = {}
for run in runs[::-1]:
    env = run.get('env', 'unknown')
    if env not in latest_per_env:
        latest_per_env[env] = run

fig, axes = plt.subplots(1, 2, figsize=(12, 4))
ax1, ax2 = axes

for env, run in sorted(latest_per_env.items()):
    res = run.get('results', [])
    batch_sizes = [r['batch_size'] for r in res]
    games_sec = [r['games_per_second'] for r in res]
    batch_avg = [r.get('batch_time_avg', 0) for r in res]
    ax1.plot(batch_sizes, games_sec, marker='o', label=env)
    ax2.plot(batch_sizes, batch_avg, marker='o', label=env)

ax1.set_title('Games/sec')
ax1.set_xlabel('Batch size')
ax1.set_ylabel('Games/sec')
ax1.legend()
ax1.grid(True, linestyle='--', alpha=0.4)

ax2.set_title('Avg batch time (s)')
ax2.set_xlabel('Batch size')
ax2.set_ylabel('Seconds')
ax2.legend()
ax2.grid(True, linestyle='--', alpha=0.4)
plt.tight_layout()
plt.show()