# Step-by-step run of alphazero self-play & training.


In [1]:
import os
import time
from pathlib import Path

import numpy as np
import torch
import torch.nn.functional as F

# Game and players
from rgi.rgizero.games.connect4 import Connect4Game
from rgi.rgizero.players.alphazero import AlphazeroPlayer
from rgi.rgizero.players.alphazero import play_game

from notebook_utils import reload_local_modules

print("✅ Imports successful")

if torch.cuda.is_available():
    device = 'cuda'
elif torch.backends.mps.is_available():
  device = 'mps'
else:
    device = 'cpu'
print(f'Using device: {device}')
assert device in ('cuda', 'mps'), f"No accelerator available, device={device}"

# Allow asyncio to work with jupyter notebook
import nest_asyncio
nest_asyncio.apply()

# Increase numpy print width
np.set_printoptions(linewidth=300)

DATA_DIR = Path.cwd().parent / "data" / "rgizero-e2e"
os.makedirs(DATA_DIR, exist_ok=True)

MODEL_DIR = Path.cwd().parent / "models" / "rgizero-e2e"
os.makedirs(MODEL_DIR, exist_ok=True)

%load_ext line_profiler

✅ Imports successful
Using device: mps


In [2]:
DEBUG_MODE = True     # Set options to make debugger work properly. Single worker, etc.
LOAD_MODEL = False
MODEL_SIZE = "small"  # "tiny" or "small" or"large" or "xl"
NUM_SIMULATIONS = 50

# If False, we still load previous games from disk.
PLAY_GAMES = False
NUM_GAMES = 1000
SAVE_FILE_NAME = f'train_{NUM_GAMES}'

## Step 1: Set up history-wrapped game


In [3]:
from rgi.rgizero.games.history_wrapper import HistoryTrackingGame
from rgi.rgizero.data.trajectory_dataset import Vocab
from rgi.rgizero.common import TOKENS

base_game, max_game_length = Connect4Game(connect_length=4), 7*6

game = HistoryTrackingGame(base_game)
state_0 = game.initial_state()
block_size = max_game_length + 2
action_vocab = Vocab(itos=[TOKENS.START_OF_GAME] + list(base_game.all_actions()))
n_max_context = max_game_length + 2

print("✅ Using HistoryTrackingGame from module")
print(f"Game: {base_game.__class__.__name__}, Players: {game.num_players(state_0)}, Actions: {list(game.all_actions())}")

✅ Using HistoryTrackingGame from module
Game: Connect4Game, Players: 2, Actions: [1, 2, 3, 4, 5, 6, 7]


## Step 2: Create or load model.


In [4]:
reload_local_modules(verbose=False)

from rgi.rgizero.models.action_history_transformer import ActionHistoryTransformer, ActionHistoryTransformerEvaluator
from rgi.rgizero.models.transformer import TransformerConfig

model_config_dict = {
    "tiny": TransformerConfig(n_max_context=n_max_context, n_layer=2, n_head=2, n_embd=8),
    "small": TransformerConfig(n_max_context=n_max_context, n_layer=4, n_head=4, n_embd=32),
    "large": TransformerConfig(n_max_context=n_max_context, n_layer=8, n_head=8, n_embd=128),
    "xl": TransformerConfig(n_max_context=n_max_context, n_layer=16, n_head=16, n_embd=256),
}


def create_random_model(config: TransformerConfig, action_vocab_size, num_players,  seed: int):
    torch.manual_seed(seed)
    np.random.seed(seed) # Ensure numpy operations are also seeded
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
    model = ActionHistoryTransformer(config=config, action_vocab_size=action_vocab_size, num_players=num_players)
    model.to(device)
    return model

# Make model initialization deterministic
model_config = model_config_dict[MODEL_SIZE]

if LOAD_MODEL:
    raise NotImplementedError("Model loading not implemented")
else:
    model = create_random_model(model_config, action_vocab_size=action_vocab.vocab_size, num_players=game.num_players(state_0), seed=42)

# Step 3: Play games to generate training data


In [5]:
reload_local_modules(verbose=False)
import asyncio
from tqdm.asyncio import tqdm

from rgi.rgizero.models.action_history_transformer import AsyncNetworkEvaluator, QueuedNetworkEvaluator, ActionHistoryTransformerEvaluator
from rgi.rgizero.players.alphazero import play_game_async

async def play_games_async(num_games: int, player_factory):
    tasks = []
    async def create_player_and_create_game():
        t0 = time.time()
        player = player_factory()
        game_result = await play_game_async(game, [player, player])
        t1 = time.time()
        game_result['time'] = t1 - t0
        return game_result

    tasks = [create_player_and_create_game() for _ in range(num_games)]
    results = await tqdm.gather(*tasks)   # same as asyncio.gather, but with a progress bar
    return results


serial_evaluator = ActionHistoryTransformerEvaluator(model, device=device, block_size=block_size, vocab=action_vocab)
async_evaluator = AsyncNetworkEvaluator(base_evaluator=serial_evaluator, max_batch_size=1024, verbose=False)

master_rng = np.random.default_rng(42)
async_evaluator_factory = lambda: AsyncNetworkEvaluator(base_evaluator=serial_evaluator, max_batch_size=1024, verbose=False)
async_player_factory = lambda: AlphazeroPlayer(game, async_evaluator, rng=np.random.default_rng(master_rng.integers(0, 2**31)), add_noise=False, simulations=NUM_SIMULATIONS)

if PLAY_GAMES:
    print(f"Playing {NUM_GAMES} games, simulations={NUM_SIMULATIONS}, model_size={MODEL_SIZE}")
    await async_evaluator.start()
    results = asyncio.run(play_games_async(num_games=NUM_GAMES, player_factory=async_player_factory)) # 50 games, 4.3s
    await async_evaluator.stop()

In [6]:
from collections import defaultdict, Counter

if PLAY_GAMES:
    print("Winner by initial move:")
    dd = defaultdict(Counter)
    for result in results:
        dd[result['action_history'][0]][result['winner']] += 1
    for action, counts in sorted(dd.items()):
        print(f"  a={action}: n={sum(counts.values()):3} win[1]={100*counts[1]/sum(counts.values()):.2f}% counts={counts}")

## Resunts from 1000 connect4 games with random model.
# Winner by initial move:
#   a=1: n=121 win[1]=59.50% counts=Counter({1: 72, 2: 49})
#   a=2: n=139 win[1]=58.99% counts=Counter({1: 82, 2: 57})
#   a=3: n= 95 win[1]=63.16% counts=Counter({1: 60, 2: 35})
#   a=4: n=160 win[1]=78.75% counts=Counter({1: 126, 2: 34})
#   a=5: n=146 win[1]=65.07% counts=Counter({1: 95, 2: 51})
#   a=6: n=176 win[1]=61.36% counts=Counter({1: 108, 2: 68})
#   a=7: n=163 win[1]=53.37% counts=Counter({1: 87, 2: 76})

## Step 3: Confirm we can read & write to trajectory_dataset


In [7]:
all_actions = game.all_actions()
# num_players = game.num_players(state_0)
print(f"Vocab: {action_vocab}")

Vocab: Vocab(vocab_size=8, itos=['START_OF_GAME', 1, 2, 3, 4, 5, 6, 7], stoi={'START_OF_GAME': 0, 1: 1, 2: 2, 3: 3, 4: 4, 5: 5, 6: 6, 7: 7})


In [8]:
from rgi.rgizero.data.trajectory_dataset import TrajectoryDatasetBuilder, TrajectoryDataset, build_trajectory_loader
reload_local_modules(verbose=False)

def add_trajectory(game_result, vocab, td_builder):
    action_history = game_result['action_history']
    trajectory_length = len(action_history)
    legal_policies = game_result['legal_policies']
    legal_action_idx = game_result['legal_action_idx']
    rewards = game_result['rewards']

    # Translation key for converting legal_action_ids to vocab_action_idx.
    action_idx_to_vocab_idx = vocab.encode(all_actions)

    fixed_width_policies = np.zeros((trajectory_length, vocab.vocab_size))
    for i in range(trajectory_length):
        vocab_action_idx = action_idx_to_vocab_idx[legal_action_idx[i]]
        fixed_width_policies[i, vocab_action_idx] = legal_policies[i]

    encoded_action_history = vocab.encode(action_history)
    tiled_rewards = np.tile(rewards, (trajectory_length, 1))  # shape (num_players,) -> (num_moves, num_players)
    
    td_builder.add_trajectory(actions=encoded_action_history, fixed_width_policies=fixed_width_policies, values=tiled_rewards)

if PLAY_GAMES:
    td_builder = TrajectoryDatasetBuilder(action_vocab)
    for game_result in results:
        add_trajectory(game_result, action_vocab, td_builder)

    td_builder.save(DATA_DIR, SAVE_FILE_NAME)

## Load dataset


In [9]:
# Load dataset
td = TrajectoryDataset(DATA_DIR, SAVE_FILE_NAME, block_size=n_max_context)

# Confirm results are the same as the saved result
print("Winner by initial move:")
dd = defaultdict(Counter)
for _td in td:
    action, winner = int(_td.action[0]), int(_td.value[0][0])
    dd[action][winner] += 1
for action, counts in sorted(dd.items()):
    print(f"  a={action}: n={sum(counts.values()):3} win[1]={100*counts[1]/sum(counts.values()):.2f}% counts={counts}")


Winner by initial move:
  a=1: n=121 win[1]=59.50% counts=Counter({1: 72, 0: 49})
  a=2: n=139 win[1]=58.99% counts=Counter({1: 82, 0: 57})
  a=3: n= 95 win[1]=63.16% counts=Counter({1: 60, 0: 35})
  a=4: n=160 win[1]=78.75% counts=Counter({1: 126, 0: 34})
  a=5: n=146 win[1]=65.07% counts=Counter({1: 95, 0: 51})
  a=6: n=176 win[1]=61.36% counts=Counter({1: 108, 0: 68})
  a=7: n=163 win[1]=53.37% counts=Counter({1: 87, 0: 76})


In [10]:
num_workers = 0 if DEBUG_MODE else 4
trajectory_loader = build_trajectory_loader(
    DATA_DIR, 'train_1000', block_size=n_max_context, batch_size=1,
    device=device, workers=num_workers)

for batch in trajectory_loader:
    print(batch)
    break

(tensor([[6, 2, 2, 5, 4, 5, 6, 7, 3, 7, 2, 4, 6, 4, 6, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
       device='mps:0'), tensor([[[0.0000, 0.1200, 0.1600, 0.1000, 0.1600, 0.1400, 0.1600, 0.1600],
         [0.0000, 0.1400, 0.1600, 0.1600, 0.1200, 0.1400, 0.1400, 0.1400],
         [0.0000, 0.1400, 0.1800, 0.1000, 0.1600, 0.1400, 0.1400, 0.1400],
         [0.0000, 0.1400, 0.1800, 0.1600, 0.1200, 0.1200, 0.1200, 0.1600],
         [0.0000, 0.1200, 0.1400, 0.1000, 0.1600, 0.1800, 0.1600, 0.1400],
         [0.0000, 0.1200, 0.1400, 0.1400, 0.1800, 0.1400, 0.1200, 0.1600],
         [0.0000, 0.1200, 0.1400, 0.1000, 0.1800, 0.1800, 0.1400, 0.1400],
         [0.0000, 0.1400, 0.1600, 0.1600, 0.1200, 0.1400, 0.1400, 0.1400],
         [0.0000, 0.1400, 0.1400, 0.1200, 0.1600, 0.1200, 0.1400, 0.1800],
         [0.0000, 0.1400, 0.1600, 0.2000, 0.1000, 0.1400, 0.1200, 0.1400],
         [0.0000, 0.1400, 0.1400, 0.1200, 0.1600, 0.1400, 0.1200, 0.1800]

## Train model


In [11]:
from rgi.rgizero.train import Trainer, TrainConfig

train_config = TrainConfig(
    model_name="connect4-e2e",
    model_version="v1",

    eval_interval = 250,  # keep frequent because we'll overfit
    eval_iters = 200,
    log_interval = 10_000,  # don't print too too often

    # we expect to overfit on this small dataset, so only save when val improves
    always_save_checkpoint = False,

    gradient_accumulation_steps = 1,
    batch_size = 64,

    learning_rate = 1e-3,  # with baby networks can afford to go a bit higher
    max_iters = 5000,
    lr_decay_iters = 5000,  # make equal to max_iters usually
    min_lr = 1e-4,  # learning_rate / 10 usually
    beta2 = 0.99,  # make a bit bigger because number of tokens per iter is small

    warmup_iters = 100,  # not super necessary potentially
)

trainer = Trainer(
    model=model,
    train_config=train_config,
    train_loader=trajectory_loader,
    val_loader=trajectory_loader,  # TODO: Create separate validation loader
    device=device
)

trainer.train()

num decayed parameter tensors: 19, with 50,880 parameters
num non-decayed parameter tensors: 11, with 298 parameters
using fused AdamW: False
Training epoch 0 of 10
step 0: train loss 0.8588, val loss 0.8588
iter 0: loss 0.9355, time 802.88ms
step 250: train loss 0.8253, val loss 0.8253
saving checkpoint to /Users/rodo/src/rgi3/models/connect4-e2e/v1
step 500: train loss 0.8457, val loss 0.8457
step 750: train loss 0.8281, val loss 0.8281
Training epoch 1 of 10
step 1000: train loss 0.8560, val loss 0.8560
step 1250: train loss 0.8222, val loss 0.8222
saving checkpoint to /Users/rodo/src/rgi3/models/connect4-e2e/v1
step 1500: train loss 0.8389, val loss 0.8389
step 1750: train loss 0.8299, val loss 0.8299
Training epoch 2 of 10
step 2000: train loss 0.8492, val loss 0.8492
step 2250: train loss 0.8217, val loss 0.8217
saving checkpoint to /Users/rodo/src/rgi3/models/connect4-e2e/v1
step 2500: train loss 0.8324, val loss 0.8324
step 2750: train loss 0.8265, val loss 0.8265
Training epoc

In [15]:
import dataclasses
# Save model
MODEL_PATH = MODEL_DIR / f"{SAVE_FILE_NAME}-round1.pt"

checkpoint = {
    'model': model.state_dict(),
    'model_config': dataclasses.asdict(model.config),
    'vocab': action_vocab.to_dict(),
    'iter_num': trainer.iter_num,
    'best_val_loss': trainer.best_val_loss,
    'num_players': game.num_players(state_0),
}
torch.save(checkpoint, DATA_DIR / f"{SAVE_FILE_NAME}-round1-checkpoint.pt")

loaded_checkpoint = torch.load(DATA_DIR / f"{SAVE_FILE_NAME}-round1-checkpoint.pt")
loaded_model = ActionHistoryTransformer(
    config=TransformerConfig(**loaded_checkpoint['model_config']),
    action_vocab_size=loaded_checkpoint['vocab']['vocab_size'],
    num_players=loaded_checkpoint['num_players']
)
loaded_model.to(device)


ActionHistoryTransformer(
  (action_embedding): Embedding(8, 32)
  (transformer): Transformer(
    (wpe): Embedding(44, 32)
    (dropout): Dropout(p=0.0, inplace=False)
    (blocks): ModuleList(
      (0-3): 4 x Block(
        (ln_1): LayerNorm()
        (attn): CausalSelfAttention(
          (c_attn): Linear(in_features=32, out_features=96, bias=False)
          (c_proj): Linear(in_features=32, out_features=32, bias=False)
          (attn_dropout): Dropout(p=0.0, inplace=False)
          (resid_dropout): Dropout(p=0.0, inplace=False)
        )
        (ln_2): LayerNorm()
        (mlp): MLP(
          (c_fc): Linear(in_features=32, out_features=128, bias=False)
          (gelu): GELU(approximate='none')
          (c_proj): Linear(in_features=128, out_features=32, bias=False)
          (dropout): Dropout(p=0.0, inplace=False)
        )
      )
    )
  )
  (ln_f): LayerNorm()
  (policy_value_head): PolicyValueHead(
    (policy): Linear(in_features=32, out_features=8, bias=True)
    (valu