# Scratchpad for profiling rgizero code.


In [None]:

import os
from pathlib import Path

import numpy as np
import torch
import torch.nn.functional as F

# Game and players
from rgi.rgizero.games.connect4 import Connect4Game
from rgi.rgizero.players import alphazero
from rgi.rgizero.players.alphazero import AlphazeroPlayer
from rgi.rgizero.players.alphazero import play_game


from rgi.rgizero.common import TOKENS

from notebook_utils import reload_local_modules

print("✅ Imports successful")

assert torch.cuda.is_available()

print("✅ cuda available")

# Increase numpy print width
np.set_printoptions(linewidth=300)

DATA_DIR = Path.cwd().parent / "data" / "rgizero-e2e"
os.makedirs(DATA_DIR, exist_ok=True)

%load_ext line_profiler

In [None]:
RANDOM_MODEL_BENCHMARK = False
SERIAL_NNET_BENCHMARK = False
PARALLEL_BENCHMARK = False
MODEL_SIZE = "big"  # "tiny" or "big"

## Step 1: Set up history-wrapped game

In [None]:
from rgi.rgizero.games.history_wrapper import HistoryTrackingGame

# Connect5 to make it harder to connect! This helps test variable policy and longer games.
base_game, max_game_length = Connect4Game(connect_length=5), 7*6

game = HistoryTrackingGame(base_game)
state_0 = game.initial_state()
block_size = max_game_length + 2

print("✅ Using HistoryTrackingGame from module")
print(f"Game: {base_game.__class__.__name__}, Players: {game.num_players(state_0)}, Actions: {list(game.all_actions())}")

## Step 2: Confirm we can self-play a game with a Random Evaluator.

In [None]:
reload_local_modules(verbose=False)

from rgi.rgizero.players.alphazero import AlphazeroPlayer, play_game, NetworkEvaluatorResult, NetworkEvaluator
from typing import override, Any

class RandomEvaluator(NetworkEvaluator):
    def __init__(self, seed: int = 42):
        self.rng = np.random.default_rng(seed)

    @override
    def evaluate(self, game, state, legal_actions: list[Any]):
        policy = self.rng.random(len(legal_actions))
        values = self.rng.random(game.num_players(state))
        return NetworkEvaluatorResult(policy, values)

def play_deterministic_game(seed, evaluator=None, player=None, verbose=False):
    evaluator = evaluator or RandomEvaluator(seed=seed)
    player = player or AlphazeroPlayer(game, evaluator, rng=np.random.default_rng(seed))
    game_result = play_game(game, [player, player])
    if verbose:
        print(f'game length: {len(game_result["action_history"])}, simulations={player.simulations}')
        print(game_result['action_history'])
    return game_result

game_result = play_deterministic_game(42, verbose=True)

In [None]:
# 3.3s to play single game. simulations=800

if RANDOM_MODEL_BENCHMARK:
    %prun -r -l 30 -s cumulative game_result = play_deterministic_game(42, verbose=True)

# 96532    0.650    0.000    0.994    0.000 alphazero.py:114(select_action_index)
# 96532    0.106    0.000    0.106    0.000 alphazero.py:139(select_action_index)   # numba

In [None]:
from rgi.rgizero.games import connect4
from rgi.rgizero.games import history_wrapper
if RANDOM_MODEL_BENCHMARK:
    %lprun \
        -f alphazero.MCTSNode.select_action_index \
        -f alphazero.MCTSNode.backup \
        -f connect4.Connect4Game.next_state \
        -f history_wrapper.HistoryTrackingGame.next_state \
        game_result = play_deterministic_game(42, verbose=True)



In [None]:
if RANDOM_MODEL_BENCHMARK:
    %timeit game_result = play_deterministic_game(42, verbose=True)

# Original %%timeit - 26.3 seconds.
# 3.28 s ± 60.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

# Vectorized calcualtion of select_action_index
# 2.4 s ± 38.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

# numba version
# 1.79 s ± 28.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

# Use numba for MCTS.backup
# 1.35 s ± 16.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

# Use numba for Connect4Game.next_state
# 898 ms ± 6.16 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

## Play deterministic game with neural network evaluator


In [None]:
reload_local_modules(verbose=False)

from rgi.rgizero.games.history_wrapper import HistoryTrackingGame
from rgi.rgizero.data.trajectory_dataset import Vocab
from rgi.rgizero.common import TOKENS

# Connect5 to make it harder to connect! This helps test variable policy and longer games.
base_game, max_game_length = Connect4Game(connect_length=5), 7*6

game = HistoryTrackingGame(base_game)
state_0 = game.initial_state()
vocab = Vocab(itos=[TOKENS.START_OF_GAME] + list(base_game.all_actions()))

n_max_context = max_game_length + 2

print("✅ Using HistoryTrackingGame from module")
print(f"Game: {base_game.__class__.__name__}, Players: {game.num_players(state_0)}, Actions: {list(game.all_actions())}")


from rgi.rgizero.models.action_history_transformer import ActionHistoryTransformer, ActionHistoryTransformerEvaluator
from rgi.rgizero.models.transformer import TransformerConfig

device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Make model initialization deterministic
seed = 42
torch.manual_seed(seed)
np.random.seed(seed) # Ensure numpy operations are also seeded
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

model_config_dict = {
    "tiny": TransformerConfig(n_max_context=n_max_context, n_layer=2, n_head=2, n_embd=8),
    "big": TransformerConfig(n_max_context=n_max_context, n_layer=8, n_head=8, n_embd=128),
}

model_config = model_config_dict[MODEL_SIZE]
model = ActionHistoryTransformer(config=model_config, action_vocab_size=vocab.vocab_size, num_players=game.num_players(state_0))
model.to(device)

# TODO: Skip compiling for now ... doesn't help much with performance for this model and makes profiling harder.
# model.compile()

model_evaluator = ActionHistoryTransformerEvaluator(model, device=device, block_size=block_size, vocab=vocab)

In [None]:
def play_deterministic_game_nnet(seed, verbose=False, simulations=800):
    player = AlphazeroPlayer(game, model_evaluator, rng=np.random.default_rng(seed), simulations=simulations)
    game_result = play_game(game, [player, player])
    if verbose:
        print(f'game length: {len(game_result["action_history"])}, simulations={player.simulations}')
        print(game_result['action_history'])
    return game_result


# Run once to compile / load everything
if SERIAL_NNET_BENCHMARK:
    _ = play_deterministic_game_nnet(42, simulations=200)
    print("✅ Model warmed up")


In [None]:
if SERIAL_NNET_BENCHMARK:
    %prun -r -l 30 -s cumulative game_result = play_deterministic_game_nnet(42, simulations=200, verbose=True)


In [None]:
from rgi.rgizero.games import connect4
from rgi.rgizero.games import history_wrapper
from rgi.rgizero.models import action_history_transformer
from rgi.rgizero.models import transformer

if SERIAL_NNET_BENCHMARK:
    %lprun \
        -f action_history_transformer.ActionHistoryTransformer.forward \
        -f action_history_transformer.ActionHistoryTransformerEvaluator.evaluate.__wrapped__ \
        -f transformer.Transformer.forward \
        -f transformer.Block.forward \
        -f transformer.CausalSelfAttention.forward \
        game_result = play_deterministic_game_nnet(42, verbose=True, simulations=200)

In [None]:

# Initial timing with simulations=200
# 4.86 s ± 38.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)    # tiny model
# 16.2 s ± 213 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)     # big model

if SERIAL_NNET_BENCHMARK:
    %timeit game_result =  play_deterministic_game_nnet(42, verbose=True, simulations=200)

## Play Multiple Games in Parallel with Neural Network Evaluator


In [None]:
import concurrent.futures

def play_multiple_deterministic_games_nnet(num_games: int, simulations: int = 200, verbose: bool = False):
    game_results = []
    # Using ThreadPoolExecutor, as GPU operations can release the GIL
    with concurrent.futures.ThreadPoolExecutor(max_workers=num_games) as executor:
        # Submit all games to the executor
        future_to_game = {
            executor.submit(play_deterministic_game_nnet, 42, verbose=verbose, simulations=simulations)
            for _ in range(num_games)
        }

        # Collect results as they complete
        for future in concurrent.futures.as_completed(future_to_game):
            game_results.append(future.result())

    return game_results


if PARALLEL_BENCHMARK:
    # Run a few games to warm up and test   
    # results = play_multiple_deterministic_games_nnet(num_games=1, simulations=20, verbose=True)  # 1.8s (big model)
    results = play_multiple_deterministic_games_nnet(num_games=2, simulations=20, verbose=True)  # 4.0s (big model)


In [None]:
# Create state_list and legal_actions_list for benchmark evaluation.
state = game.initial_state()

state_list = []
legal_actions_list = []

for _ in range(5):
    legal_actions = game.legal_actions(state)
    state_list.append(state)
    legal_actions_list.append(legal_actions)
    action = np.random.choice(legal_actions)
    state = game.next_state(state, action)

print(model_evaluator.evaluate(game, state_list[0], legal_actions_list[0]))
print(model_evaluator.evaluate(game, state_list[1], legal_actions_list[1]))
print(model_evaluator.evaluate(game, state_list[2], legal_actions_list[2]))
print(model_evaluator.evaluate(game, state_list[3], legal_actions_list[3]))
print(model_evaluator.evaluate(game, state_list[4], legal_actions_list[4]))


In [None]:
reload_local_modules(verbose=False)
from rgi.rgizero.models.action_history_transformer import ActionHistoryTransformerEvaluator
model_evaluator = ActionHistoryTransformerEvaluator(model, device=device, block_size=block_size, vocab=vocab)

def cycle(original_list, n=1):
    "Loop original_list, return a list of length n from this looped list."
    return [original_list[i%len(original_list)] for i in range(n)]

def eval_serial(state_list, legal_actions_list, n=1):
    state_list = cycle(state_list, n)
    legal_actions_list = cycle(legal_actions_list, n)
    for state, legal_actions in zip(state_list, legal_actions_list):
        model_evaluator.evaluate(game, state, legal_actions)

def eval_batch(state_list, legal_actions_list, n=1):
    state_list = cycle(state_list, n)
    legal_actions_list = cycle(legal_actions_list, n)
    model_evaluator.evaluate_batch(states_list=state_list, legal_actions_list=legal_actions_list) 


eval_serial(state_list, legal_actions_list, n=5)
eval_batch(state_list, legal_actions_list, n=5)

In [None]:
# %prun -r -l 30 -s cumulative eval_batch(state_list, legal_actions_list, n=1000)    # 0.256 action_history_transformer.py:146(evaluate_batch)
# %prun -r -l 30 -s cumulative eval_batch_v2(state_list, legal_actions_list, n=1000) # 0.622 action_history_transformer.py:186(evaluate_batch_v2)

%lprun -f action_history_transformer.ActionHistoryTransformerEvaluator.evaluate_batch.__wrapped__ eval_batch(state_list, legal_actions_list, n=5000)

In [None]:

def go(n):
    print(f"\nn={n}")
    %timeit eval_serial(state_list, legal_actions_list, n=n)
    %timeit eval_batch(state_list, legal_actions_list, n=n)

# go(1)
# go(2)
# go(5)
# go(10)
# go(50)
# go(100)
# go(1000)


# n=1
# 2.29 ms ± 63.6 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)
# 2.35 ms ± 26.7 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)

# n=2
# 4.93 ms ± 210 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)
# 2.43 ms ± 113 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)

# n=5
# 11.2 ms ± 128 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)
# 2.22 ms ± 15.8 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)

# n=10
# 22.2 ms ± 379 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)
# 2.35 ms ± 85.4 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)

# n=50
# 116 ms ± 2.06 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
# 3.08 ms ± 32.8 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)

# n=100
# 212 ms ± 5.19 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
# 5.36 ms ± 52.1 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)

# n=1000
# 2.24 s ± 49.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
# 49.1 ms ± 566 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [None]:
# Speed comparison on large model, TransformerConfig(n_max_context=n_max_context, n_layer=8, n_head=8, n_embd=128)
# We should aim for batch size of 50+

# for _ in range(1000):
#     # %timeit eval_batch(state_list, legal_actions_list, n=10000) # 95% GPU usage, 469  ms loop, speedup=47 (large model)
#     # %timeit eval_batch(state_list, legal_actions_list, n=1000)  # 94% GPU usage, 49.5 ms loop, speedup=45 (large model)
#     # %timeit eval_batch(state_list, legal_actions_list, n=100)   # 91% GPU usage, 5.42 ms loop, speedup=41 (large model)
#     # %timeit eval_batch(state_list, legal_actions_list, n=50)    # 89% GPU usage, 3.07 ms loop, speedup=36 (large model)
#     # %timeit eval_batch(state_list, legal_actions_list, n=40)    # 87% GPU usage, 2.73 ms loop, speedup=32 (large model)
#     # %timeit eval_batch(state_list, legal_actions_list, n=30)    # 71% GPU usage, 2.58 ms loop, speedup=26 (large model)
#     # %timeit eval_batch(state_list, legal_actions_list, n=20)    # 63% GPU usage, 2.53 ms loop, speedup=17 (large model)
#     # %timeit eval_batch(state_list, legal_actions_list, n=10)    # 50% GPU usage, 2.29 ms loop, speedup=10 (large model)
#     # %timeit eval_batch(state_list, legal_actions_list, n=1)     # 35% GPU usage, 2.24 ms loop, speedup=1  (large model)

# Test QueuedNetworkEvaluator

In [None]:
reload_local_modules(verbose=False)
from rgi.rgizero.models.action_history_transformer import QueuedNetworkEvaluator

queued_evaluator = QueuedNetworkEvaluator(model_evaluator, max_batch_size=1024, max_latency_ms=0)

queued_evaluator.evaluate(game, state_list[0], legal_actions_list[0])


In [None]:
import nest_asyncio
nest_asyncio.apply()


In [None]:
reload_local_modules(verbose=False)

import concurrent.futures
import asyncio
from rgi.rgizero.models.action_history_transformer import AsyncNetworkEvaluator, QueuedNetworkEvaluator


def eval_queued(state_list, legal_actions_list, n=1, num_threads=50):
    state_list_cycled = cycle(state_list, n)
    legal_actions_list_cycled = cycle(legal_actions_list, n)
    
    with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor:
        futures = [executor.submit(queued_evaluator.evaluate, game, state, legal_actions) 
                   for state, legal_actions in zip(state_list_cycled, legal_actions_list_cycled)]
        # Wait for all futures to complete and retrieve results if needed
        _ = [f.result() for f in concurrent.futures.as_completed(futures)]

def eval_queued_v2(state_list, legal_actions_list, n=1, executor=None):
    state_list_cycled = cycle(state_list, n)
    legal_actions_list_cycled = cycle(legal_actions_list, n)
    futures = [executor.submit(queued_evaluator.evaluate, game, state, legal_actions) 
                for state, legal_actions in zip(state_list_cycled, legal_actions_list_cycled)]
    # Wait for all futures to complete and retrieve results if needed
    _ = [f.result() for f in concurrent.futures.as_completed(futures)]


async def eval_async(state_list, legal_actions_list, n=1, async_evaluator=None):
    state_list_cycled = cycle(state_list, n)
    legal_actions_list_cycled = cycle(legal_actions_list, n)
    
    tasks = [async_evaluator.evaluate(game, state, legal_actions)
             for state, legal_actions in zip(state_list_cycled, legal_actions_list_cycled)]
    await asyncio.gather(*tasks)


# %timeit eval_serial(state_list, legal_actions_list, n=1000) # 2.24 s ± 30.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
# print(f"\nEval eval_batch performance (n=1000, max_workers=100):")
# %timeit eval_batch(state_list, legal_actions_list, n=1000) # 49.3 ms ± 379 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)

# print(f"\nEval queued v1 performance (n=1000, max_workers=100):")
# %timeit eval_queued(state_list, legal_actions_list, n=1000) # 160 ms ± 3.76 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

# with concurrent.futures.ThreadPoolExecutor(max_workers=100) as executor:
#     print(f"\nEval queued v2 (threaded) performance (n=1000, max_workers=100):")
#     %timeit eval_queued_v2(state_list, legal_actions_list, n=1000, executor=executor)  # 180 ms ± 2.98 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

async_evaluator = AsyncNetworkEvaluator(model_evaluator, max_batch_size=1024)
await async_evaluator.start()

print(f"\nEval async performance (n=1000, max_batch_size=1024):")
%timeit asyncio.run(eval_async(state_list, legal_actions_list, n=1000, async_evaluator=async_evaluator)) # 78.9 ms ± 2.63 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
await async_evaluator.stop()


In [None]:
%prun -r -l 30 -s cumulative game_result = eval_batch(state_list, legal_actions_list, n=1000)

In [None]:
async_evaluator = AsyncNetworkEvaluator(model_evaluator, max_batch_size=1024)
await async_evaluator.start()
%prun -r -l 30 -s cumulative asyncio.run(eval_async(state_list, legal_actions_list, n=1000, async_evaluator=async_evaluator))
await async_evaluator.stop()


In [None]:
async def go():
    async_evaluator = AsyncNetworkEvaluator(model_evaluator, max_batch_size=1024)
    await async_evaluator.start()
    for _ in range(10): 
        await eval_async(state_list, legal_actions_list, n=1000, async_evaluator=async_evaluator)
    await async_evaluator.stop()


%prun -r -l 30 -s cumulative asyncio.run(go())


In [None]:

%lprun \
    -f action_history_transformer.ActionHistoryTransformerEvaluator.evaluate_batch.__wrapped__ \
    -f action_history_transformer.QueuedNetworkEvaluator._run_once \
    eval_queued(state_list, legal_actions_list, n=1000)