# Scratchpad for profiling rgizero code.


In [None]:

import os
from pathlib import Path

import numpy as np
import torch
import torch.nn.functional as F

# Game and players
from rgi.rgizero.games.connect4 import Connect4Game
from rgi.rgizero.players import alphazero
from rgi.rgizero.players.alphazero import AlphazeroPlayer
from rgi.rgizero.players.alphazero import play_game


from rgi.rgizero.common import TOKENS

import notebook_utils
from notebook_utils import reload_local_modules

print("✅ Imports successful")

device = notebook_utils.detect_device(require_accelerator=True)

# Allow asyncio to work with jupyter notebook
import nest_asyncio
nest_asyncio.apply()

# Increase numpy print width
np.set_printoptions(linewidth=300)

DATA_DIR = Path.cwd().parent / "data" / "rgizero-e2e"
os.makedirs(DATA_DIR, exist_ok=True)

%load_ext line_profiler

In [None]:
RANDOM_MODEL_BENCHMARK = True
SERIAL_NNET_BENCHMARK = True
MODEL_SIZE = "tiny"  # "tiny" or "small" or "large" or "xl"
NUM_SIMULATIONS = 50
PARALLEL_EVAL_BENCHMARK = True
PARALLEL_PLAY_BENCHMARK = True


## Step 1: Load game from registry.

In [None]:
from rgi.rgizero.games import game_registry

# Wrap our game with history tracking
game, max_game_length = game_registry.create_game("connect4"), 7*6
state_0 = game.initial_state()
block_size = max_game_length + 2
print(f"Game: {game.base_game.__class__.__name__}, Players: {game.num_players(state_0)}, Actions: {list(game.all_actions())}")


## Step 2: Confirm we can self-play a game with a Random Evaluator.


In [None]:
reload_local_modules(verbose=False)

from rgi.rgizero.players.alphazero import AlphazeroPlayer, play_game, NetworkEvaluatorResult, NetworkEvaluator
from typing import override, Any

class RandomEvaluator(NetworkEvaluator):
    def __init__(self, seed: int = 42):
        self.rng = np.random.default_rng(seed)

    @override
    def evaluate(self, game, state, legal_actions: list[Any]):
        policy = self.rng.random(len(legal_actions))
        values = self.rng.random(game.num_players(state))
        return NetworkEvaluatorResult(policy, values)

def play_deterministic_game(seed, evaluator=None, player=None, verbose=False):
    evaluator = evaluator or RandomEvaluator(seed=seed)
    player = player or AlphazeroPlayer(game, evaluator, rng=np.random.default_rng(seed))
    game_result = play_game(game, [player, player])
    if verbose:
        print(f'game length: {len(game_result["action_history"])}, simulations={player.simulations}')
        print(game_result['action_history'])
    return game_result

game_result = play_deterministic_game(42, verbose=True)

In [None]:
# 0.37s to play single game with RandomEvaluator (no nnet). simulations=800

if RANDOM_MODEL_BENCHMARK:
    %prun -r -l 30 -s cumulative game_result = play_deterministic_game(42, verbose=True)

In [None]:
from rgi.rgizero.games import connect4
from rgi.rgizero.games import history_wrapper
if RANDOM_MODEL_BENCHMARK:
    %lprun \
        -f alphazero.MCTSNode.select_action_index \
        -f alphazero.MCTSNode.backup \
        -f connect4.Connect4Game.next_state \
        -f history_wrapper.HistoryTrackingGame.next_state \
        game_result = play_deterministic_game(42, verbose=True)



In [None]:
if RANDOM_MODEL_BENCHMARK:
    %timeit game_result = play_deterministic_game(42, verbose=True)

# Original %%timeit - 26.3 seconds.
# 3.28 s ± 60.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

# Vectorized calcualtion of select_action_index
# 2.4 s ± 38.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

# numba version
# 1.79 s ± 28.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

# Use numba for MCTS.backup
# 1.35 s ± 16.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

# Use numba for Connect4Game.next_state
# 898 ms ± 6.16 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

# Latest results:
# game length: 19, simulations=800
# [6, 3, 7, 5, 1, 7, 6, 6, 2, 3, 5, 7, 5, 7, 3, 6, 5, 2, 5]
# 236 ms ± 10.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


## Play deterministic game with neural network evaluator


In [None]:
reload_local_modules(verbose=False)

from rgi.rgizero.data.trajectory_dataset import Vocab
from rgi.rgizero.common import TOKENS

# Connect5 to make it harder to connect! This helps test variable policy and longer games.
vocab = Vocab(itos=[TOKENS.START_OF_GAME] + list(game.all_actions()))

n_max_context = max_game_length + 2

from rgi.rgizero.models.action_history_transformer import ActionHistoryTransformer, ActionHistoryTransformerEvaluator
from rgi.rgizero.models.transformer import TransformerConfig

# Make model initialization deterministic
seed = 42
torch.manual_seed(seed)
np.random.seed(seed) # Ensure numpy operations are also seeded
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

model_config_dict = {
    "tiny": TransformerConfig(n_max_context=n_max_context, n_layer=2, n_head=2, n_embd=8),
    "small": TransformerConfig(n_max_context=n_max_context, n_layer=4, n_head=4, n_embd=32),
    "large": TransformerConfig(n_max_context=n_max_context, n_layer=8, n_head=8, n_embd=128),
    "xl": TransformerConfig(n_max_context=n_max_context, n_layer=16, n_head=16, n_embd=256),
}

model_config = model_config_dict[MODEL_SIZE]
model = ActionHistoryTransformer(config=model_config, action_vocab_size=vocab.vocab_size, num_players=game.num_players(state_0))
model.to(device)

# TODO: Skip compiling for now ... doesn't help much with performance for this model and makes profiling harder.
# model.compile()

model_evaluator = ActionHistoryTransformerEvaluator(model, device=device, block_size=block_size, vocab=vocab)

In [None]:
def play_deterministic_game_nnet(seed, verbose=False, simulations=800):
    player = AlphazeroPlayer(game, model_evaluator, rng=np.random.default_rng(seed), simulations=simulations)
    game_result = play_game(game, [player, player])
    if verbose:
        print(f'game length: {len(game_result["action_history"])}, simulations={player.simulations}')
        print(game_result['action_history'])
    return game_result

# Run once to compile / load everything
if SERIAL_NNET_BENCHMARK:
    _ = play_deterministic_game_nnet(42, simulations=200)
    print("✅ Model warmed up")


In [None]:
if SERIAL_NNET_BENCHMARK:
    %prun -r -l 30 -s cumulative game_result = play_deterministic_game_nnet(42, simulations=200, verbose=True)

## Profile of tiny model.
#    ncalls  tottime  percall  cumtime  percall filename:lineno(function)
# 5948/2974    0.010    0.000    3.365    0.001 _contextlib.py:117(decorate_context)
#      2974    0.001    0.000    3.332    0.001 action_history_transformer.py:145(evaluate)
#      2974    0.018    0.000    3.312    0.001 grad_mode.py:279(__exit__)
#      2974    0.089    0.000    3.285    0.001 action_history_transformer.py:157(evaluate_batch)
# 98142/2974    0.037    0.000    1.465    0.000 module.py:1769(_wrapped_call_impl)
# 98142/2974    0.068    0.000    1.463    0.000 module.py:1777(_call_impl)
#      2974    0.033    0.000    1.456    0.000 action_history_transformer.py:52(forward)
#      5948    1.424    0.000    1.424    0.000 {method 'cpu' of 'torch._C.TensorBase' objects}
#      2974    0.022    0.000    1.039    0.000 transformer.py:152(forward)
#      5948    0.032    0.000    0.827    0.000 transformer.py:125(forward)
#      5948    0.029    0.000    0.475    0.000 transformer.py:77(forward)

In [None]:
from rgi.rgizero.games import connect4
from rgi.rgizero.games import history_wrapper
from rgi.rgizero.models import action_history_transformer
from rgi.rgizero.models import transformer

if SERIAL_NNET_BENCHMARK:
    %lprun \
        -f action_history_transformer.ActionHistoryTransformer.forward \
        -f action_history_transformer.ActionHistoryTransformerEvaluator.evaluate.__wrapped__ \
        -f transformer.Transformer.forward \
        -f transformer.Block.forward \
        -f transformer.CausalSelfAttention.forward \
        game_result = play_deterministic_game_nnet(42, verbose=True, simulations=200)

In [None]:
# Initial timing with simulations=200
# 4.86 s ± 38.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)    # tiny model
# 16.2 s ± 213 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)     # big model

# Latest timing
# 5.6 s ± 85.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each).    # small model.
# 3.26 s ± 68.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)    # tiny model.
if SERIAL_NNET_BENCHMARK:
    %timeit game_result =  play_deterministic_game_nnet(42, verbose=True, simulations=200)

## Evaluate multiple states in parallel


In [None]:
# Create state_list and legal_actions_list for benchmark evaluation.
state = game.initial_state()

state_list = []
legal_actions_list = []

for _ in range(5):
    legal_actions = game.legal_actions(state)
    state_list.append(state)
    legal_actions_list.append(legal_actions)
    action = np.random.choice(legal_actions)
    state = game.next_state(state, action)

print(model_evaluator.evaluate(game, state_list[0], legal_actions_list[0]))
print(model_evaluator.evaluate(game, state_list[1], legal_actions_list[1]))
print(model_evaluator.evaluate(game, state_list[2], legal_actions_list[2]))
print(model_evaluator.evaluate(game, state_list[3], legal_actions_list[3]))
print(model_evaluator.evaluate(game, state_list[4], legal_actions_list[4]))


In [None]:
reload_local_modules(verbose=False)
from rgi.rgizero.models.action_history_transformer import ActionHistoryTransformerEvaluator
model_evaluator = ActionHistoryTransformerEvaluator(model, device=device, block_size=block_size, vocab=vocab)

def cycle(original_list, n=1):
    "Loop original_list, return a list of length n from this looped list."
    return [original_list[i%len(original_list)] for i in range(n)]

def eval_serial(state_list, legal_actions_list, n=1):
    state_list = cycle(state_list, n)
    legal_actions_list = cycle(legal_actions_list, n)
    for state, legal_actions in zip(state_list, legal_actions_list):
        model_evaluator.evaluate(game, state, legal_actions)

def eval_batch(state_list, legal_actions_list, n=1):
    state_list = cycle(state_list, n)
    legal_actions_list = cycle(legal_actions_list, n)
    model_evaluator.evaluate_batch(game=game, states_list=state_list, legal_actions_list=legal_actions_list) 

if PARALLEL_EVAL_BENCHMARK:
    %timeit eval_serial(state_list, legal_actions_list, n=5)
    %timeit eval_batch(state_list, legal_actions_list, n=5)

# serial
# 5.53 ms ± 147 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)

# batch
# 1.13 ms ± 35.9 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

In [None]:
# %prun -r -l 30 -s cumulative eval_batch(state_list, legal_actions_list, n=1000)    # 0.256 action_history_transformer.py:146(evaluate_batch)
# %prun -r -l 30 -s cumulative eval_batch_v2(state_list, legal_actions_list, n=1000) # 0.622 action_history_transformer.py:186(evaluate_batch_v2)

if PARALLEL_EVAL_BENCHMARK:
    %lprun -f action_history_transformer.ActionHistoryTransformerEvaluator.evaluate_batch.__wrapped__ eval_batch(state_list, legal_actions_list, n=5000)

In [None]:

def go(n):
    print(f"\nn={n}")
    %timeit eval_serial(state_list, legal_actions_list, n=n)
    %timeit eval_batch(state_list, legal_actions_list, n=n)

if PARALLEL_EVAL_BENCHMARK:
    go(1)
    go(2)
    go(5)
    go(10)
    go(50)
    go(100)
    go(1000)

# n=1
# 1.06 ms ± 29.6 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
# 1.05 ms ± 32.1 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

# n=2
# 2.08 ms ± 105 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)
# 1.11 ms ± 41.9 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

# n=5
# 5.46 ms ± 271 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)
# 1.14 ms ± 26 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

# n=10
# 10.8 ms ± 305 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)
# 1.16 ms ± 40.5 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

# n=50
# 53.1 ms ± 3.06 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
# 1.31 ms ± 47.3 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

# n=100
# 110 ms ± 5.82 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
# 1.56 ms ± 84.6 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

# n=1000
# 1.07 s ± 45.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
# 3.72 ms ± 45.6 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)

# Test QueuedNetworkEvaluator


In [None]:
reload_local_modules(verbose=False)
from rgi.rgizero.models.action_history_transformer import QueuedNetworkEvaluator

queued_evaluator = QueuedNetworkEvaluator(model_evaluator, max_batch_size=1024, max_latency_ms=0)

queued_evaluator.evaluate(game, state_list[0], legal_actions_list[0])

In [None]:
reload_local_modules(verbose=False)

import concurrent.futures
import asyncio
from rgi.rgizero.models.action_history_transformer import AsyncNetworkEvaluator, QueuedNetworkEvaluator


def eval_queued(state_list, legal_actions_list, n=1, num_threads=50):
    state_list_cycled = cycle(state_list, n)
    legal_actions_list_cycled = cycle(legal_actions_list, n)
    
    with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor:
        futures = [executor.submit(queued_evaluator.evaluate, game, state, legal_actions) 
                   for state, legal_actions in zip(state_list_cycled, legal_actions_list_cycled)]
        # Wait for all futures to complete and retrieve results if needed
        _ = [f.result() for f in concurrent.futures.as_completed(futures)]

def eval_queued_v2(state_list, legal_actions_list, n=1, executor=None):
    state_list_cycled = cycle(state_list, n)
    legal_actions_list_cycled = cycle(legal_actions_list, n)
    futures = [executor.submit(queued_evaluator.evaluate, game, state, legal_actions) 
                for state, legal_actions in zip(state_list_cycled, legal_actions_list_cycled)]
    # Wait for all futures to complete and retrieve results if needed
    _ = [f.result() for f in concurrent.futures.as_completed(futures)]



async def eval_async(state_list, legal_actions_list, n=1, async_evaluator=None):
    state_list_cycled = cycle(state_list, n)
    legal_actions_list_cycled = cycle(legal_actions_list, n)
    
    tasks = [async_evaluator.evaluate_async(game, state, legal_actions)
             for state, legal_actions in zip(state_list_cycled, legal_actions_list_cycled)]
    await asyncio.gather(*tasks)


if PARALLEL_EVAL_BENCHMARK:
    # %timeit eval_serial(state_list, legal_actions_list, n=1000) # 2.24 s ± 30.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
    # print(f"\nEval eval_batch performance (n=1000, max_workers=100):")
    # %timeit eval_batch(state_list, legal_actions_list, n=1000) # 49.3 ms ± 379 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)

    # print(f"\nEval queued v1 performance (n=1000, max_workers=100):")
    # %timeit eval_queued(state_list, legal_actions_list, n=1000) # 160 ms ± 3.76 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

    # with concurrent.futures.ThreadPoolExecutor(max_workers=100) as executor:
    #     print(f"\nEval queued v2 (threaded) performance (n=1000, max_workers=100):")
    #     %timeit eval_queued_v2(state_list, legal_actions_list, n=1000, executor=executor)  # 180 ms ± 2.98 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

    async_evaluator = AsyncNetworkEvaluator(model_evaluator, max_batch_size=1024)
    await async_evaluator.start()

    print(f"\nEval async performance (n=1000, max_batch_size=1024):")
    %timeit asyncio.run(eval_async(state_list, legal_actions_list, n=1000, async_evaluator=async_evaluator)) # 32.4 ms ± 412 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)
    await async_evaluator.stop()


In [None]:
if PARALLEL_EVAL_BENCHMARK:
    %prun -r -l 30 -s cumulative game_result = eval_batch(state_list, legal_actions_list, n=1000)

In [None]:
if PARALLEL_EVAL_BENCHMARK:
    async_evaluator = AsyncNetworkEvaluator(model_evaluator, max_batch_size=1024)
    await async_evaluator.start()
    %prun -r -l 30 -s cumulative asyncio.run(eval_async(state_list, legal_actions_list, n=1000, async_evaluator=async_evaluator))
    await async_evaluator.stop()


In [None]:
async def go():
    async_evaluator = AsyncNetworkEvaluator(model_evaluator, max_batch_size=1024)
    await async_evaluator.start()
    for _ in range(10): 
        await eval_async(state_list, legal_actions_list, n=1000, async_evaluator=async_evaluator)
    await async_evaluator.stop()

if PARALLEL_EVAL_BENCHMARK:
    %prun -r -l 30 -s cumulative asyncio.run(go())


In [None]:
if PARALLEL_EVAL_BENCHMARK:
    %lprun \
        -f action_history_transformer.ActionHistoryTransformerEvaluator.evaluate_batch.__wrapped__ \
        -f action_history_transformer.QueuedNetworkEvaluator._run_once \
        eval_queued(state_list, legal_actions_list, n=1000)

## Play Multiple Games in Parallel with Neural Network Evaluator


In [None]:
reload_local_modules(verbose=False)

from rgi.rgizero.models.action_history_transformer import QueuedNetworkEvaluator
from rgi.rgizero.players.alphazero import play_game_async

import concurrent.futures

def play_single_deterministic_game_nnet(seed, player=None, player_factory=None, verbose=False, simulations=800):
    player = player or (player_factory and player_factory()) or AlphazeroPlayer(game, model_evaluator, rng=np.random.default_rng(seed), simulations=simulations)
    game_result = play_game(game, [player, player])
    if verbose:
        print(f'game length: {len(game_result["action_history"])}, simulations={player.simulations}')
        print(game_result['action_history'])
    return game_result


def play_multiple_deterministic_games_nnet(num_games: int, **kwargs):
    game_results = []

    # Using ThreadPoolExecutor, as GPU operations can release the GIL
    with concurrent.futures.ThreadPoolExecutor(max_workers=num_games) as executor:
        # Submit all games to the executor
        future_to_game = {
            executor.submit(play_single_deterministic_game_nnet, **kwargs)
            for _ in range(num_games)
        }

        # Collect results as they complete
        for future in concurrent.futures.as_completed(future_to_game):
            game_results.append(future.result())

    return game_results

async def play_single_deterministic_game_nnet_async(seed, player=None, player_factory=None, verbose=False, simulations=800):
    player = player or player_factory()
    game_result = await play_game_async(game, [player, player])
    if verbose:
        print(f'game length: {len(game_result["action_history"])}, simulations={player.simulations}')
        print(game_result['action_history'])
    return game_result


async def play_multiple_deterministic_games_nnet_async(num_games: int, player_factory_factory=None, **kwargs):
    if player_factory_factory:
        player_factory = await player_factory_factory()
        kwargs['player_factory'] = player_factory
    tasks = [play_single_deterministic_game_nnet_async(**kwargs) for _ in range(num_games)]
    results = await asyncio.gather(*tasks)
    return results


def play_multiple_async_threads(num_games: int, num_threads=8, **kwargs):
    game_results = []
    games_per_thread = [(num_games+i) // num_threads for i in range(num_threads)]
    def go(n):
        # Create a new event loop for this thread (asyncio.run() doesn't work in threads)
        loop = asyncio.new_event_loop()
        asyncio.set_event_loop(loop)
        try:
            return loop.run_until_complete(play_multiple_deterministic_games_nnet_async(n, **kwargs))
        finally:
            loop.close()

    with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor:
        future_to_game = { executor.submit(go, n) for n in games_per_thread}

        # Collect results as they complete
        for future in concurrent.futures.as_completed(future_to_game):
            game_results.append(future.result())

    return game_results



import time

if PARALLEL_PLAY_BENCHMARK:
    simulations = NUM_SIMULATIONS
    
    t0 = time.time()
    print(f"model_size={MODEL_SIZE}, simulations={simulations}")
    serial_evaluator = ActionHistoryTransformerEvaluator(model, device=device, block_size=block_size, vocab=vocab)
    queued_evaluator = QueuedNetworkEvaluator(serial_evaluator, max_batch_size=1024, max_latency_ms=0, verbose=False)
    async_evaluator = AsyncNetworkEvaluator(base_evaluator=serial_evaluator, max_batch_size=1024, verbose=False)
    async_evaluator_factory = lambda: AsyncNetworkEvaluator(base_evaluator=serial_evaluator, max_batch_size=1024, start=True, verbose=False)
    await async_evaluator.start()
    serial_player_factory = lambda: AlphazeroPlayer(game, serial_evaluator, rng=np.random.default_rng(seed), simulations=simulations)
    queued_player_factory = lambda: AlphazeroPlayer(game, queued_evaluator, rng=np.random.default_rng(seed), simulations=simulations)
    async_player_factory = lambda: AlphazeroPlayer(game, async_evaluator, rng=np.random.default_rng(seed), simulations=simulations)
    async def async_player_factory_factory():
        async_evaluator = async_evaluator_factory()
        await async_evaluator.start()
        return lambda: AlphazeroPlayer(game, async_evaluator, rng=np.random.default_rng(seed), simulations=simulations)

    # Run a few games to warm up and test
    # results = play_multiple_deterministic_games_nnet(num_games=1, seed=42, player_factory=serial_player_factory, verbose=True)  # 2.8s, 1 game, simulations=20 (large model)
    # results = play_multiple_deterministic_games_nnet(num_games=10, seed=42, player_factory=serial_player_factory, verbose=True)  # 29.4s, 10 games, simulations=20 (large model)
    # results = play_multiple_deterministic_games_nnet(num_games=10, seed=42, player_factory=queued_player_factory, verbose=True)  # 5.4s 10 games, simulations=20 (large model)
    # results = play_multiple_deterministic_games_nnet(num_games=50, seed=42, player_factory=queued_player_factory, verbose=True)  # 11.7s for 100 games, 20 simulations (large model)
    # results = play_multiple_deterministic_games_nnet(num_games=100, seed=42, player_factory=queued_player_factory, verbose=True)  # 18.6s for 100 games, 20 simulations (large model)
    # results = play_multiple_deterministic_games_nnet(num_games=1000, seed=42, player_factory=queued_player_factory, verbose=True)  # 2m38s, 1000 games, 20 simulations (large model), typical batch-side ~500
    
    # results = asyncio.run(play_multiple_deterministic_games_nnet_async(num_games=10, seed=42, player_factory=async_player_factory, verbose=True)) # 4.5s, batch size=10
    # results = asyncio.run(play_multiple_deterministic_games_nnet_async(num_games=50, seed=42, player_factory=async_player_factory, verbose=True)) # 8.7s, batch size=50
    # results = asyncio.run(play_multiple_deterministic_games_nnet_async(num_games=100, seed=42, player_factory=async_player_factory, verbose=True)) # 14.6s, batch size=100
    # results = asyncio.run(play_multiple_deterministic_games_nnet_async(num_games=1000, seed=42, player_factory=async_player_factory, verbose=True)) # 1m37s, batch size=1000

    # %prun -r -l 60 -s cumulative results = asyncio.run(play_multiple_deterministic_games_nnet_async(num_games=100, seed=42, player_factory=async_player_factory, verbose=False))
    # play_multiple_async_threads(num_games=1, seed=42, player_factory=async_player_factory, num_threads=1, verbose=False)
    # results = asyncio.run(play_multiple_deterministic_games_nnet_async(num_games=100, seed=42, player_factory=async_player_factory, verbose=True)) # large -> 8.5s, batch size=100, 60% GPU
    # results = asyncio.run(play_multiple_deterministic_games_nnet_async(num_games=1, seed=42, player_factory=async_player_factory, verbose=True)) # xl -> 3.1s, batch size=1
    # results = asyncio.run(play_multiple_deterministic_games_nnet_async(num_games=10, seed=42, player_factory=async_player_factory, verbose=True)) # xl -> 3.2s, batch size=10
    # results = asyncio.run(play_multiple_deterministic_games_nnet_async(num_games=100, seed=42, player_factory=async_player_factory, verbose=True)) # xl -> 16.5s, batch size=100, 86% GPU
    # results = asyncio.run(play_multiple_deterministic_games_nnet_async(num_games=200, seed=42, player_factory=async_player_factory, verbose=True)) # xl -> 30.5s, batch size=200

    # results = asyncio.run(play_multiple_deterministic_games_nnet_async(num_games=1, seed=42, player_factory=async_player_factory, verbose=True)) # xl -> 3.2s, batch size=10

    # async_player = async_player_factory()
    # game_result = asyncio.run(play_game_async(game, agents = [async_player, async_player], reuse_tree=False))
    # game_result = asyncio.run(play_game_async(game, agents = [async_player, async_player]))
    # print(f'game length: {len(game_result["action_history"])}, simulations={async_player.simulations}, action_history={game_result["action_history"]}')
    
    # mac m4 results. simulations=50
    benchmark_name = "async_player_factory"
    # results = asyncio.run(play_multiple_deterministic_games_nnet_async(num_games=1, seed=42, player_factory=async_player_factory, verbose=True)) # 4.9s mps, 4.0 cpu
    # results = asyncio.run(play_multiple_deterministic_games_nnet_async(num_games=10, seed=42, player_factory=async_player_factory, verbose=True)) # 7.1s
    results = asyncio.run(play_multiple_deterministic_games_nnet_async(num_games=50, seed=42, player_factory=async_player_factory, verbose=True)) # 10.4s
    # results = asyncio.run(play_multiple_deterministic_games_nnet_async(num_games=200, seed=42, player_factory=async_player_factory, verbose=True)) # 29.2s

    ## Multiple async threads slower on m4? GPU possibly saturated?
    # results = play_multiple_async_threads(num_threads=1, num_games=50, seed=42, player_factory_factory=async_player_factory_factory, verbose=True) # 10.2s. 5 games/sec
    # results = play_multiple_async_threads(num_threads=2, num_games=50, seed=42, player_factory_factory=async_player_factory_factory, verbose=True) # 13.4s
    # results = play_multiple_async_threads(num_threads=8, num_games=1000, seed=42, player_factory_factory=async_player_factory_factory, verbose=True) # 2+ minutes?

    # %prun -r -l 60 -s cumulative results = asyncio.run(play_multiple_deterministic_games_nnet_async(num_games=100, seed=42, player_factory=async_player_factory, verbose=False))
    

    
    # reuse_tree=False, elapsed=3.32, game length: 41, simulations=20, action_history=[6, 4, 7, 5, 1, 7, 6, 6, 1, 4, 4, 6, 6, 6, 3, 2, 4, 1, 7, 4, 5, 3, 7, 7, 7, 2, 2, 1, 1, 4, 5, 5, 2, 2, 2, 1, 3, 5, 5, 3, 3]
    # reuse_tree=True,  elapsed=3.7s, game length: 41, simulations=20, action_history=[6, 4, 6, 6, 1, 7, 6, 6, 1, 4, 4, 6, 5, 5, 4, 2, 4, 1, 7, 4, 5, 2, 7, 7, 7, 2, 3, 1, 1, 5, 5, 7, 2, 2, 2, 1, 3, 3, 5, 3, 3]

    await async_evaluator.stop()

    t1 = time.time()
    print(f"model_size={MODEL_SIZE}, simulations={simulations}, elapsed={t1-t0:.2f}s, games/sec={simulations/(t1-t0):.2f}, benchmark={benchmark_name}")



## 10000 simulations, 
# Running 10000 simulations (previous total visits: 0)
# Running 8491 simulations (previous total visits: 1509)
# Running 8137 simulations (previous total visits: 1863)
# Running 8498 simulations (previous total visits: 1502)

## 1000 simulations, independant players.
# Running 1000 simulations (previous total visits: 0)
# Running 1000 simulations (previous total visits: 0)
# Running 955 simulations (previous total visits: 45)
# Running 966 simulations (previous total visits: 34)

## 1000 simulations, reusing player.
# Running 1000 simulations (previous total visits: 0)
# Running 826 simulations (previous total visits: 174)
# Running 879 simulations (previous total visits: 121)
# Running 830 simulations (previous total visits: 170)
# Running 709 simulations (previous total visits: 291)
# Running 845 simulations (previous total visits: 155)

# game length: 41, simulations=20
# [6, 4, 7, 5, 1, 7, 6, 6, 1, 4, 4, 6, 6, 6, 3, 2, 4, 1, 7, 4, 5, 3, 7, 7, 7, 2, 2, 1, 1, 4, 5, 5, 2, 2, 2, 1, 3, 5, 5, 3, 3]
# game length: 41, simulations=20
# [6, 4, 7, 5, 1, 7, 6, 6, 1, 4, 4, 6, 6, 6, 3, 2, 4, 1, 7, 4, 5, 3, 7, 7, 7, 2, 2, 1, 1, 4, 5, 5, 2, 2, 2, 1, 3, 5, 5, 3, 3]


In [None]:
# m4 benchmarks.
# model_size=tiny,  simulations=50, elapsed=3.20s, games/sec=15.60, benchmark=async_player_factory
# model_size=small, simulations=50, elapsed=4.43s, games/sec=11.28, benchmark=async_player_factory
# model_size=large, simulations=50, elapsed=11.54s, games/sec=4.33, benchmark=async_player_factory
# model_size=xl,    simulations=50, elapsed=32.59s, games/sec=1.53, benchmark=async_player_factory