# Step-by-step run of alphazero self-play & training.


In [1]:
import os
from pathlib import Path

import numpy as np
import torch
import torch.nn.functional as F

# Game and players
from rgi.rgizero.games.connect4 import Connect4Game
from rgi.rgizero.players import alphazero
from rgi.rgizero.players.alphazero import AlphazeroPlayer
from rgi.rgizero.players.alphazero import play_game


from rgi.rgizero.common import TOKENS

from notebook_utils import reload_local_modules

print("✅ Imports successful")

if torch.cuda.is_available():
    device = 'cuda'
elif torch.backends.mps.is_available():
  device = 'mps'
else:
    device = 'cpu'
print(f'Using device: {device}')
assert device in ('cuda', 'mps'), f"No accelerator available, device={device}"

# Allow asyncio to work with jupyter notebook
import nest_asyncio
nest_asyncio.apply()

# Increase numpy print width
np.set_printoptions(linewidth=300)

DATA_DIR = Path.cwd().parent / "data" / "rgizero-e2e"
os.makedirs(DATA_DIR, exist_ok=True)

%load_ext line_profiler

✅ Imports successful
Using device: mps


In [2]:
MODEL_SIZE = "small"  # "tiny" or "small" or"large" or "xl"
NUM_SIMULATIONS = 50

## Step 1: Set up history-wrapped game


In [3]:
from rgi.rgizero.games.history_wrapper import HistoryTrackingGame

base_game, max_game_length = Connect4Game(connect_length=4), 7*6

game = HistoryTrackingGame(base_game)
state_0 = game.initial_state()
block_size = max_game_length + 2

print("✅ Using HistoryTrackingGame from module")
print(f"Game: {base_game.__class__.__name__}, Players: {game.num_players(state_0)}, Actions: {list(game.all_actions())}")

✅ Using HistoryTrackingGame from module
Game: Connect4Game, Players: 2, Actions: [1, 2, 3, 4, 5, 6, 7]


## Step 2: Confirm we can self-play a game with a Random Evaluator.


In [4]:
reload_local_modules(verbose=False)

from rgi.rgizero.players.alphazero import AlphazeroPlayer, play_game, NetworkEvaluatorResult, NetworkEvaluator
from typing import override, Any

class RandomEvaluator(NetworkEvaluator):
    def __init__(self, seed: int = 42):
        self.rng = np.random.default_rng(seed)

    @override
    def evaluate(self, game, state, legal_actions: list[Any]):
        policy = self.rng.random(len(legal_actions))
        values = self.rng.random(game.num_players(state))
        return NetworkEvaluatorResult(policy, values)

evaluator = RandomEvaluator()
player = AlphazeroPlayer(game, evaluator)
game_result = play_game(game, [player, player])

game_result

{'winner': 1,
 'rewards': array([1., 0.], dtype=float32),
 'action_history': [4, 2, 6, 2, 5, 6, 6, 3, 7],
 'legal_policies': [array([0.18874997, 0.14375001, 0.16500002, 0.12625   , 0.0125    , 0.19625002, 0.16750005], dtype=float32),
  array([0.20250002, 0.19125   , 0.14125001, 0.22874996, 0.08125002, 0.005     , 0.15      ], dtype=float32),
  array([0.115     , 0.02625   , 0.20749995, 0.11250001, 0.13000003, 0.23      , 0.17875004], dtype=float32),
  array([0.07874999, 0.22874996, 0.19500004, 0.13000003, 0.03      , 0.15749998, 0.18      ], dtype=float32),
  array([0.10749998, 0.11375   , 0.10499998, 0.0325    , 0.57500005, 0.03      , 0.03625   ], dtype=float32),
  array([0.14750001, 0.17750002, 0.13000003, 0.02375   , 0.04      , 0.18250003, 0.29874992], dtype=float32),
  array([0.015     , 0.005     , 0.5675    , 0.04624999, 0.0125    , 0.07124999, 0.2825    ], dtype=float32),
  array([0.2125    , 0.06250001, 0.22125003, 0.14499998, 0.01875   , 0.16500002, 0.17499998], dtype=float3

In [20]:
reload_local_modules(verbose=False)

from rgi.rgizero.games.history_wrapper import HistoryTrackingGame
from rgi.rgizero.data.trajectory_dataset import Vocab
from rgi.rgizero.common import TOKENS

# Connect5 to make it harder to connect! This helps test variable policy and longer games.
base_game, max_game_length = Connect4Game(connect_length=5), 7*6

game = HistoryTrackingGame(base_game)
state_0 = game.initial_state()
vocab = Vocab(itos=[TOKENS.START_OF_GAME] + list(base_game.all_actions()))

n_max_context = max_game_length + 2

print("✅ Using HistoryTrackingGame from module")
print(f"Game: {base_game.__class__.__name__}, Players: {game.num_players(state_0)}, Actions: {list(game.all_actions())}")


from rgi.rgizero.models.action_history_transformer import ActionHistoryTransformer, ActionHistoryTransformerEvaluator
from rgi.rgizero.models.transformer import TransformerConfig

if torch.cuda.is_available():
    device = 'cuda'
elif torch.backends.mps.is_available():
  device = 'mps'
else:
    device = 'cpu'
print(f'Using device: {device}')


# Make model initialization deterministic
seed = 42
torch.manual_seed(seed)
np.random.seed(seed) # Ensure numpy operations are also seeded
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

model_config_dict = {
    "tiny": TransformerConfig(n_max_context=n_max_context, n_layer=2, n_head=2, n_embd=8),
    "small": TransformerConfig(n_max_context=n_max_context, n_layer=4, n_head=4, n_embd=32),
    "large": TransformerConfig(n_max_context=n_max_context, n_layer=8, n_head=8, n_embd=128),
    "xl": TransformerConfig(n_max_context=n_max_context, n_layer=16, n_head=16, n_embd=256),
}

model_config = model_config_dict[MODEL_SIZE]
model = ActionHistoryTransformer(config=model_config, action_vocab_size=vocab.vocab_size, num_players=game.num_players(state_0))
model.to(device)


✅ Using HistoryTrackingGame from module
Game: Connect4Game, Players: 2, Actions: [1, 2, 3, 4, 5, 6, 7]
Using device: mps


ActionHistoryTransformer(
  (action_embedding): Embedding(8, 32)
  (transformer): Transformer(
    (wpe): Embedding(44, 32)
    (dropout): Dropout(p=0.0, inplace=False)
    (blocks): ModuleList(
      (0-3): 4 x Block(
        (ln_1): LayerNorm()
        (attn): CausalSelfAttention(
          (c_attn): Linear(in_features=32, out_features=96, bias=False)
          (c_proj): Linear(in_features=32, out_features=32, bias=False)
          (attn_dropout): Dropout(p=0.0, inplace=False)
          (resid_dropout): Dropout(p=0.0, inplace=False)
        )
        (ln_2): LayerNorm()
        (mlp): MLP(
          (c_fc): Linear(in_features=32, out_features=128, bias=False)
          (gelu): GELU(approximate='none')
          (c_proj): Linear(in_features=128, out_features=32, bias=False)
          (dropout): Dropout(p=0.0, inplace=False)
        )
      )
    )
  )
  (ln_f): LayerNorm()
  (policy_value_head): PolicyValueHead(
    (policy): Linear(in_features=32, out_features=8, bias=True)
    (valu

In [21]:
reload_local_modules(verbose=False)
import asyncio
import time

from rgi.rgizero.models.action_history_transformer import AsyncNetworkEvaluator, QueuedNetworkEvaluator, ActionHistoryTransformerEvaluator
from rgi.rgizero.players.alphazero import play_game_async

async def play_single_deterministic_game_nnet_async(seed, player=None, player_factory=None, verbose=False, simulations=800):
    player = player or player_factory()
    game_result = await play_game_async(game, [player, player])
    if verbose:
        print(f'game length: {len(game_result["action_history"])}, simulations={player.simulations}')
        print(game_result['action_history'])
    return game_result

async def play_multiple_deterministic_games_nnet_async(num_games: int, player_factory_factory=None, **kwargs):
    if player_factory_factory:
        player_factory = await player_factory_factory()
        kwargs['player_factory'] = player_factory
    tasks = [play_single_deterministic_game_nnet_async(**kwargs) for _ in range(num_games)]
    results = await asyncio.gather(*tasks)
    return results

PARALLEL_PLAY_BENCHMARK = True
if PARALLEL_PLAY_BENCHMARK:
    simulations = NUM_SIMULATIONS
    
    t0 = time.time()
    print(f"model_size={MODEL_SIZE}, simulations={simulations}")
    serial_evaluator = ActionHistoryTransformerEvaluator(model, device=device, block_size=block_size, vocab=vocab)
    queued_evaluator = QueuedNetworkEvaluator(serial_evaluator, max_batch_size=1024, max_latency_ms=0, verbose=False)
    async_evaluator = AsyncNetworkEvaluator(base_evaluator=serial_evaluator, max_batch_size=1024, start=True, verbose=False)
    async_evaluator_factory = lambda: AsyncNetworkEvaluator(base_evaluator=serial_evaluator, max_batch_size=1024, start=True, verbose=False)
    await async_evaluator.start()
    serial_player_factory = lambda: AlphazeroPlayer(game, serial_evaluator, rng=np.random.default_rng(seed), simulations=simulations)
    queued_player_factory = lambda: AlphazeroPlayer(game, queued_evaluator, rng=np.random.default_rng(seed), simulations=simulations)
    async_player_factory = lambda: AlphazeroPlayer(game, async_evaluator, rng=np.random.default_rng(seed), simulations=simulations)
    async def async_player_factory_factory():
        async_evaluator = async_evaluator_factory()
        await async_evaluator.start()
        return lambda: AlphazeroPlayer(game, async_evaluator, rng=np.random.default_rng(seed), simulations=simulations)

    results = asyncio.run(play_multiple_deterministic_games_nnet_async(num_games=50, seed=42, player_factory=async_player_factory, verbose=True)) # 10.4s

    await async_evaluator.stop()





model_size=small, simulations=50
game length: 22, simulations=50
[6, 3, 6, 5, 1, 7, 6, 6, 1, 3, 3, 7, 5, 6, 4, 2, 4, 1, 6, 4, 5, 5]
game length: 22, simulations=50
[6, 3, 6, 5, 1, 7, 6, 6, 1, 3, 3, 7, 5, 6, 4, 2, 4, 1, 6, 4, 5, 5]
game length: 22, simulations=50
[6, 3, 6, 5, 1, 7, 6, 6, 1, 3, 3, 7, 5, 6, 4, 2, 4, 1, 6, 4, 5, 5]
game length: 22, simulations=50
[6, 3, 6, 5, 1, 7, 6, 6, 1, 3, 3, 7, 5, 6, 4, 2, 4, 1, 6, 4, 5, 5]
game length: 22, simulations=50
[6, 3, 6, 5, 1, 7, 6, 6, 1, 3, 3, 7, 5, 6, 4, 2, 4, 1, 6, 4, 5, 5]
game length: 22, simulations=50
[6, 3, 6, 5, 1, 7, 6, 6, 1, 3, 3, 7, 5, 6, 4, 2, 4, 1, 6, 4, 5, 5]
game length: 22, simulations=50
[6, 3, 6, 5, 1, 7, 6, 6, 1, 3, 3, 7, 5, 6, 4, 2, 4, 1, 6, 4, 5, 5]
game length: 22, simulations=50
[6, 3, 6, 5, 1, 7, 6, 6, 1, 3, 3, 7, 5, 6, 4, 2, 4, 1, 6, 4, 5, 5]
game length: 22, simulations=50
[6, 3, 6, 5, 1, 7, 6, 6, 1, 3, 3, 7, 5, 6, 4, 2, 4, 1, 6, 4, 5, 5]
game length: 22, simulations=50
[6, 3, 6, 5, 1, 7, 6, 6, 1, 3, 3, 7, 5, 6, 4

## Step 3: Confirm we can read & write to trajectory_dataset


In [5]:
from rgi.rgizero.data.trajectory_dataset import Vocab

all_actions = game.all_actions()
num_players = game.num_players(state_0)

vocab = Vocab(itos=(TOKENS.START_OF_GAME,) + all_actions)
print(f"Vocab: {vocab}")

Vocab: Vocab(vocab_size=8, itos=('START_OF_GAME', 1, 2, 3, 4, 5, 6, 7), stoi={'START_OF_GAME': 0, 1: 1, 2: 2, 3: 3, 4: 4, 5: 5, 6: 6, 7: 7})


In [6]:
from rgi.rgizero.data.trajectory_dataset import TrajectoryDatasetBuilder, TrajectoryDataset, build_trajectory_loader
reload_local_modules(verbose=False)

def add_trajectory(game_result, vocab, td_builder):
    action_history = game_result['action_history']
    trajectory_length = len(action_history)
    legal_policies = game_result['legal_policies']
    legal_action_idx = game_result['legal_action_idx']
    rewards = game_result['rewards']

    # Translation key for converting legal_action_ids to vocab_action_idx.
    action_idx_to_vocab_idx = vocab.encode(all_actions)

    fixed_width_policies = np.zeros((trajectory_length, vocab.vocab_size))
    for i in range(trajectory_length):
        vocab_action_idx = action_idx_to_vocab_idx[legal_action_idx[i]]
        fixed_width_policies[i, vocab_action_idx] = legal_policies[i]

    encoded_action_history = vocab.encode(action_history)
    tiled_rewards = np.tile(rewards, (trajectory_length, 1))  # shape (num_players,) -> (num_moves, num_players)
    
    td_builder.add_trajectory(actions=encoded_action_history, fixed_width_policies=fixed_width_policies, values=tiled_rewards)

td_builder = TrajectoryDatasetBuilder(vocab)
add_trajectory(game_result, vocab, td_builder)

td_builder.save(DATA_DIR, 'train_1000')

## Load dataset


In [7]:
td = TrajectoryDataset(DATA_DIR, 'train_1000', 5)
td[0]


TrajectoryTuple(action=tensor([4, 2, 6, 2, 5]), policy=tensor([[0.0000, 0.1887, 0.1438, 0.1650, 0.1262, 0.0125, 0.1963, 0.1675],
        [0.0000, 0.2025, 0.1912, 0.1413, 0.2287, 0.0813, 0.0050, 0.1500],
        [0.0000, 0.1150, 0.0263, 0.2075, 0.1125, 0.1300, 0.2300, 0.1788],
        [0.0000, 0.0787, 0.2287, 0.1950, 0.1300, 0.0300, 0.1575, 0.1800],
        [0.0000, 0.1075, 0.1137, 0.1050, 0.0325, 0.5750, 0.0300, 0.0362]],
       dtype=torch.float64), value=tensor([[1., 0.],
        [1., 0.],
        [1., 0.],
        [1., 0.],
        [1., 0.]]))

In [8]:
loader = build_trajectory_loader(
    DATA_DIR, 'train_1000', block_size=5, batch_size=1,
    device_is_cuda=False, workers=4)

for batch in loader:
    print(batch)
    break

(tensor([[4, 2, 6, 2, 5]]), tensor([[[0.0000, 0.1887, 0.1438, 0.1650, 0.1262, 0.0125, 0.1963, 0.1675],
         [0.0000, 0.2025, 0.1912, 0.1413, 0.2287, 0.0813, 0.0050, 0.1500],
         [0.0000, 0.1150, 0.0263, 0.2075, 0.1125, 0.1300, 0.2300, 0.1788],
         [0.0000, 0.0787, 0.2287, 0.1950, 0.1300, 0.0300, 0.1575, 0.1800],
         [0.0000, 0.1075, 0.1137, 0.1050, 0.0325, 0.5750, 0.0300, 0.0362]]],
       dtype=torch.float64), tensor([[[1., 0.],
         [1., 0.],
         [1., 0.],
         [1., 0.],
         [1., 0.]]]))


## Create model (random weights) and play a single game.


In [9]:
reload_local_modules(verbose=False)

from rgi.rgizero.models.action_history_transformer import ActionHistoryTransformer, ActionHistoryTransformerEvaluator
from rgi.rgizero.models.transformer import TransformerConfig

device = 'cuda' if torch.cuda.is_available() else 'cpu'
tiny_config: TransformerConfig = TransformerConfig(n_max_context=100, n_layer=2, n_head=2, n_embd=8)
tiny_model = ActionHistoryTransformer(config=tiny_config, action_vocab_size=vocab.vocab_size, num_players=num_players)
tiny_model.to(device)
tiny_evaluator = ActionHistoryTransformerEvaluator(tiny_model, device=device, block_size=5, vocab=vocab)


In [10]:
legal_actions = game.legal_actions(state_0)
tiny_evaluator.evaluate(game, state_0, legal_actions)

NetworkEvaluatorResult(legal_policy=array([0.13419195, 0.14256692, 0.14048576, 0.15470095, 0.13918312, 0.14442852, 0.1444428 ], dtype=float32), player_values=array([ 0.03081834, -0.0308184 ], dtype=float32))

In [11]:
tiny_player = AlphazeroPlayer(game, tiny_evaluator)
tiny_game_result = play_game(game, [tiny_player, tiny_player])

tiny_game_result

{'winner': 1,
 'rewards': array([1., 0.], dtype=float32),
 'action_history': [2, 6, 4, 1, 3, 7, 5],
 'legal_policies': [array([0.13124998, 0.145     , 0.13875002, 0.15499999, 0.135     , 0.14625002, 0.14874996], dtype=float32),
  array([0.14499998, 0.14625001, 0.13875   , 0.14874996, 0.14125   , 0.14125   , 0.13875   ], dtype=float32),
  array([0.13499999, 0.14249998, 0.13625002, 0.15125   , 0.13625002, 0.15125   , 0.1475    ], dtype=float32),
  array([0.13625   , 0.13625   , 0.13875   , 0.16375   , 0.14999999, 0.14125   , 0.13374996], dtype=float32),
  array([0.09624999, 0.09874998, 0.45499998, 0.09750002, 0.07874998, 0.09      , 0.08375002], dtype=float32),
  array([0.14375   , 0.15      , 0.14874998, 0.15875003, 0.1025    , 0.15125002, 0.14499998], dtype=float32),
  array([0.00375   , 0.005     , 0.005     , 0.005     , 0.97124994, 0.005     , 0.005     ], dtype=float32)],
 'final_state': HistoryTrackingGameState(base_state=Connect4State(board=array([[0, 0, 0, 0, 0, 0, 0],
        [

In [12]:
a_row = torch.rand(1, 2)
b_row = torch.rand(1, 2)
print(f'a_row: {a_row}')
print(f'b_row: {b_row}')
print(f'loss: {F.cross_entropy(a_row, b_row)}')

tile_shape = (3, 1)
a_tiled = torch.tile(a_row, tile_shape)
b_tiled = torch.tile(b_row, tile_shape)
print(f'a_tiled: {a_tiled}')
print(f'b_tiled: {b_tiled}')
print(f'loss: {F.cross_entropy(a_tiled, b_tiled)}')

print(f'loss: {F.cross_entropy(a_tiled, b_tiled, reduction="sum")}')


a_row: tensor([[0.4045, 0.5594]])
b_row: tensor([[0.9683, 0.1765]])
loss: 0.8582934141159058
a_tiled: tensor([[0.4045, 0.5594],
        [0.4045, 0.5594],
        [0.4045, 0.5594]])
b_tiled: tensor([[0.9683, 0.1765],
        [0.9683, 0.1765],
        [0.9683, 0.1765]])
loss: 0.858293354511261
loss: 2.5748801231384277
