# Step-by-step run of alphazero self-play & training.


In [1]:
import os
import time
from pathlib import Path
from collections import defaultdict, Counter
import asyncio
from typing import Callable

import numpy as np
import torch
import torch.nn.functional as F

# Game and players
from rgi.rgizero.games.connect4 import Connect4Game
from rgi.rgizero.players.alphazero import AlphazeroPlayer
from rgi.rgizero.players.alphazero import play_game

from notebook_utils import reload_local_modules

print("✅ Imports successful")

if torch.cuda.is_available():
    device = 'cuda'
elif torch.backends.mps.is_available():
  device = 'mps'
else:
    device = 'cpu'
print(f'Using device: {device}')
assert device in ('cuda', 'mps'), f"No accelerator available, device={device}"

# Allow asyncio to work with jupyter notebook
import nest_asyncio
nest_asyncio.apply()

# Increase numpy print width
np.set_printoptions(linewidth=300)

%load_ext line_profiler

✅ Imports successful
Using device: mps


In [2]:
DEBUG_MODE = True     # Set options to make debugger work properly. Single worker, etc.
LOAD_MODEL = False
TRAIN_MODEL = True
MODEL_SIZE = "tiny"  # "tiny" or "small" or "large" or "xl"
NUM_SIMULATIONS = 200
RUN_GENERATIONS = True
RUN_TOURNAMENT = False

# If False, we still load previous games from disk.
NUM_GAMES = 10_000
MAX_TRAINING_EPOCHS = 10
TRAIN_BATCH_SIZE = 2048
MAX_TRAINING_ITERS = 1_000_000 // TRAIN_BATCH_SIZE
CONFIG_ALIAS = f'trajectory_sims-{NUM_SIMULATIONS}_games-{NUM_GAMES}_size-{MODEL_SIZE}_train-{MAX_TRAINING_ITERS}_x1'
NUM_GENERATIONS = 20

# DEBUG: Update batch_size after config_alias
MODEL_SIZE = "small"
MAX_TRAINING_ITERS = 100_000_000 // TRAIN_BATCH_SIZE
MAX_TRAINING_EPOCHS = 10_000

## Step 1: Set up history-wrapped game


In [3]:
from rgi.rgizero.games.history_wrapper import HistoryTrackingGame
from rgi.rgizero.data.trajectory_dataset import Vocab
from rgi.rgizero.common import TOKENS

base_game, max_game_length = Connect4Game(connect_length=4), 7*6

game = HistoryTrackingGame(base_game)
state_0 = game.initial_state()
block_size = max_game_length + 2
all_actions = game.all_actions()
action_vocab = Vocab(itos=[TOKENS.START_OF_GAME] + list(all_actions))
n_max_context = max_game_length + 2
game_name = base_game.__class__.__name__

print("✅ Using HistoryTrackingGame from module")
print(f"Game: {game_name}, Players: {game.num_players(state_0)}, Actions: {list(game.all_actions())}")

DATA_DIR = Path.cwd().parent / "data" / "rgizero-e2e" / game_name / CONFIG_ALIAS
print("Creating data dir: ", DATA_DIR)
os.makedirs(DATA_DIR, exist_ok=True)

MODEL_DIR = Path.cwd().parent / "models" / "rgizero-e2e" / game_name / CONFIG_ALIAS
print("Creating model dir: ", MODEL_DIR)
os.makedirs(MODEL_DIR, exist_ok=True)


✅ Using HistoryTrackingGame from module
Game: Connect4Game, Players: 2, Actions: [1, 2, 3, 4, 5, 6, 7]
Creating data dir:  /Users/rodo/src/rgi3/data/rgizero-e2e/Connect4Game/trajectory_sims-200_games-10000_size-tiny_train-488_x1
Creating model dir:  /Users/rodo/src/rgi3/models/rgizero-e2e/Connect4Game/trajectory_sims-200_games-10000_size-tiny_train-488_x1


## Step 2: Create random generation_0 model


In [4]:
def print_dataset_stats(dataset_path, split_name):
    """Print statistics about a loaded trajectory dataset."""
    td = TrajectoryDataset(dataset_path.parent, split_name, block_size=n_max_context)
    
    # Calculate basic stats
    num_trajectories = len(td)
    total_actions = td._num_actions
    avg_trajectory_length = total_actions / num_trajectories if num_trajectories > 0 else 0
    
    # Get trajectory lengths, winners, and first moves
    trajectory_lengths = []
    winners = []
    first_moves = []
    
    for i in range(num_trajectories):
        start_idx = td.boundaries[i]
        end_idx = td.boundaries[i + 1]
        traj_length = end_idx - start_idx
        trajectory_lengths.append(traj_length)
        
        # Get winner from final values (values are the same throughout trajectory)
        # Values are in range [-1, 1] where positive means player 1 advantage
        final_values = td.value_data[start_idx]  # shape: (num_players,)
        if final_values[0] > final_values[1]:
            winners.append(1)
        elif final_values[1] > final_values[0]:
            winners.append(2)
        else:
            winners.append(None)  # Draw
        
        # Get first move (decode from vocab)
        first_action_encoded = td.action_data[start_idx]
        first_action = action_vocab.decode([first_action_encoded])[0]
        first_moves.append(first_action)
    
    # Print basic stats
    print(f"Dataset Stats:")
    print(f"  Trajectories: {num_trajectories}")
    print(f"  Total actions: {total_actions}")
    print(f"  Avg trajectory length: {avg_trajectory_length:.2f}")
    print(f"  Trajectory length - min: {min(trajectory_lengths)}, max: {max(trajectory_lengths)}, mean: {np.mean(trajectory_lengths):.2f}")
    
    # Print winner stats (similar to print_game_stats)
    print(f"Winner Stats:")
    winner_stats = Counter(winners)
    total_games = num_trajectories
    win1_pct = 100 * winner_stats[1] / total_games if total_games > 0 else 0
    win2_pct = 100 * winner_stats[2] / total_games if total_games > 0 else 0
    print(f"  Winner counts: win[1]={win1_pct:.2f}% win[2]={win2_pct:.2f}%, n={total_games}")
    
    # Print stats by initial move
    print(f"Winner Stats by initial move:")
    move_stats = defaultdict(Counter)
    for first_move, winner in zip(first_moves, winners):
        move_stats[first_move][winner] += 1
    
    for action in sorted(move_stats.keys()):
        counts = move_stats[action]
        total = sum(counts.values())
        win1_pct = 100 * counts[1] / total if total > 0 else 0
        win2_pct = 100 * counts[2] / total if total > 0 else 0
        draw_pct = 100 * counts[None] / total if total > 0 else 0
        print(f"  a={action}: n={total:3} win[1]={win1_pct:.2f}% counts={counts}, win[2]={win2_pct:.2f}% draw={draw_pct:.2f}%")

def print_model_stats(model, config_alias=""):
    """Print statistics about a model."""
    # Count parameters
    num_params = sum(p.numel() for p in model.parameters())
    num_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    
    print(f"Model Stats:")
    print(f"  Config: {model.config}")
    print(f"  Total parameters: {num_params:,}")
    print(f"  Trainable parameters: {num_trainable:,}")
    if config_alias:
        print(f"  Config alias: {config_alias}")


In [5]:
reload_local_modules(verbose=False)

from rgi.rgizero.models.action_history_transformer import ActionHistoryTransformer
from rgi.rgizero.models.transformer import TransformerConfig

model_config_dict = {
    "tiny": TransformerConfig(n_max_context=n_max_context, n_layer=2, n_head=2, n_embd=8),
    "small": TransformerConfig(n_max_context=n_max_context, n_layer=4, n_head=4, n_embd=32),
    "large": TransformerConfig(n_max_context=n_max_context, n_layer=8, n_head=8, n_embd=128),
    "xl": TransformerConfig(n_max_context=n_max_context, n_layer=16, n_head=16, n_embd=256),
}


def create_random_model(config: TransformerConfig, action_vocab_size, num_players,  seed: int):
    torch.manual_seed(seed)
    np.random.seed(seed) # Ensure numpy operations are also seeded
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
    model = ActionHistoryTransformer(config=config, action_vocab_size=action_vocab_size, num_players=num_players)
    model.to(device)
    return model

# Make model initialization deterministic
model_config = model_config_dict[MODEL_SIZE]
# TODO: Use MODEL_SIZE!
# model_config = model_config_dict["small"] # Override to see if we can fit better.
model_0 = create_random_model(model_config, action_vocab_size=action_vocab.vocab_size, num_players=game.num_players(state_0), seed=42)

# Step 3: Define play & generation code


In [6]:
from rgi.rgizero.models.action_history_transformer import AsyncNetworkEvaluator, ActionHistoryTransformerEvaluator
from rgi.rgizero.players.alphazero import play_game_async
from tqdm.asyncio import tqdm

async def play_games_async(num_games: int, player_factory: Callable[[], AlphazeroPlayer], max_concurrent_games: int = 1000):
    sem = asyncio.Semaphore(max_concurrent_games)
    tasks = []
    async def create_player_and_create_game():
        async with sem:
            t0 = time.time()
            player = player_factory()
            game_result = await play_game_async(game, [player, player])
            t1 = time.time()
            game_result['time'] = t1 - t0
            return game_result

    tasks = [create_player_and_create_game() for _ in range(num_games)]
    results = await tqdm.gather(*tasks)   # same as asyncio.gather, but with a progress bar
    return results

async def play_generation_async(model, num_games, simulations=NUM_SIMULATIONS, max_concurrent_games=1024):
    serial_evaluator = ActionHistoryTransformerEvaluator(model, device=device, block_size=block_size, vocab=action_vocab)
    async_evaluator = AsyncNetworkEvaluator(base_evaluator=serial_evaluator, max_batch_size=max_concurrent_games, verbose=False)

    master_rng = np.random.default_rng(42)
    async_player_factory = lambda: AlphazeroPlayer(game, async_evaluator, rng=np.random.default_rng(master_rng.integers(0, 2**31)), add_noise=False, simulations=simulations)

    await async_evaluator.start()
    results = await play_games_async(num_games=num_games, player_factory=async_player_factory, max_concurrent_games=max_concurrent_games)
    await async_evaluator.stop()
    return results

def print_game_stats(results):
    print("Winner Stats:")
    winner_stats = Counter(result['winner'] for result in results)
    print(f"Winner counts: win[1]={100*winner_stats[1]/sum(winner_stats.values()):.2f}% win[2]={100*winner_stats[2]/sum(winner_stats.values()):.2f}%, n={sum(winner_stats.values())}, draw={100*winner_stats[None]/sum(winner_stats.values()):.2f}%")
    game_lengths = [len(result['action_history']) for result in results]
    print(f"Game Length min: {min(game_lengths)}, max: {max(game_lengths)}, mean: {np.mean(game_lengths):.2f}")
    print("Winner Stats by initial move:")
    dd = defaultdict(Counter)
    for result in results:
        dd[result['action_history'][0]][result['winner']] += 1
    for action, counts in sorted(dd.items()):
        print(f"  a={action}: n={sum(counts.values()):3} win[1]={100*counts[1]/sum(counts.values()):.2f}% counts={counts}, win[2]={100*counts[2]/sum(counts.values()):.2f}% draw={100*counts[None]/sum(counts.values()):.2f}%")

## Step 3: Confirm we can read & write to trajectory_dataset


In [7]:
from rgi.rgizero.data.trajectory_dataset import TrajectoryDatasetBuilder, TrajectoryDataset, build_trajectory_loader
reload_local_modules(verbose=False)

def add_trajectory(game_result, vocab, td_builder):
    action_history = game_result['action_history']
    trajectory_length = len(action_history)
    legal_policies = game_result['legal_policies']
    legal_action_idx = game_result['legal_action_idx']
    rewards = game_result['rewards']

    # Translation key for converting legal_action_ids to vocab_action_idx.
    action_idx_to_vocab_idx = vocab.encode(all_actions)

    fixed_width_policies = np.zeros((trajectory_length, vocab.vocab_size))
    for i in range(trajectory_length):
        vocab_action_idx = action_idx_to_vocab_idx[legal_action_idx[i]]
        fixed_width_policies[i, vocab_action_idx] = legal_policies[i]

    encoded_action_history = vocab.encode(action_history)
    tiled_rewards = np.tile(rewards, (trajectory_length, 1))  # shape (num_players,) -> (num_moves, num_players)
    
    td_builder.add_trajectory(actions=encoded_action_history, fixed_width_policies=fixed_width_policies, values=tiled_rewards)

def write_trajectory_dataset(results, action_vocab, generation_id):
    td_builder = TrajectoryDatasetBuilder(action_vocab)
    for game_result in results:
        add_trajectory(game_result, action_vocab, td_builder)

    trajectory_path = td_builder.save(DATA_DIR, f"gen-{generation_id}")
    return trajectory_path

## Train model


In [8]:
from rgi.rgizero.train import Trainer, TrainConfig

LEARNING_RATE = 0.05

train_config = TrainConfig(
    model_name="connect4-e2e",
    model_version="v1",

    eval_interval = 1000,  # keep frequent because we'll overfit
    eval_iters = 20,
    log_interval = 100,  # don't print too too often
    max_epochs = MAX_TRAINING_EPOCHS,

    # we expect to overfit on this small dataset, so only save when val improves
    always_save_checkpoint = False,

    gradient_accumulation_steps = 1,
    batch_size = TRAIN_BATCH_SIZE,

    learning_rate = LEARNING_RATE,  # with baby networks can afford to go a bit higher
    max_iters = MAX_TRAINING_ITERS,
    lr_decay_iters = MAX_TRAINING_ITERS,  # make equal to max_iters usually
    min_lr = LEARNING_RATE / 10,  # learning_rate / 10 usually
    beta2 = 0.99,  # make a bit bigger because number of tokens per iter is small

    warmup_iters = 0,  # not super necessary potentially
)

def train_model(model, training_splits, train_config):
    # Load dataset
    num_workers = 0 if DEBUG_MODE else 4

    trajectory_loader = build_trajectory_loader(
        DATA_DIR, training_splits, block_size=n_max_context, batch_size=train_config.batch_size,
        device=device, workers=num_workers, shuffle=True)
        
    trainer = Trainer(
        model=model,
        train_config=train_config,
        train_loader=trajectory_loader,
        val_loader=trajectory_loader,  # TODO: Create separate validation loader
        device=device
    )

    trainer.train()
    return model, trainer

In [9]:
import dataclasses

def get_model_path(generation_id):
    return MODEL_DIR / f"gen-{generation_id}.pt"

def save_model(model, trainer, generation_id):
    # Save model
    model_path = get_model_path(generation_id)

    checkpoint = {
        'model': model.state_dict(),
        'model_config': dataclasses.asdict(model.config),
        'vocab': action_vocab.to_dict(),
        'iter_num': trainer.iter_num,
        'best_val_loss': trainer.best_val_loss,
        'num_players': game.num_players(state_0),
    }
    torch.save(checkpoint, model_path)
    return model_path

def load_model(generation_id):
    model_path = get_model_path(generation_id)
    loaded_checkpoint = torch.load(model_path)
    loaded_model = ActionHistoryTransformer(
        config=TransformerConfig(**loaded_checkpoint['model_config']),
        action_vocab_size=loaded_checkpoint['vocab']['vocab_size'],
        num_players=loaded_checkpoint['num_players']
    )
    loaded_model.load_state_dict(loaded_checkpoint['model']) 
    loaded_model.to(device)
    return loaded_model


In [10]:
# Do a single generation of play & train
async def run_generation(model, num_games, num_simulations, generation_id):
    print(f"\n\n## Running generation {generation_id} for config_alias={CONFIG_ALIAS}")
    split_name = f"gen-{generation_id}"
    expected_trajectory_path = DATA_DIR / split_name
    if not expected_trajectory_path.exists():
        print(f"Playing {num_games} games, simulations={num_simulations}, model_size={MODEL_SIZE}")
        results = await play_generation_async(model, num_games=NUM_GAMES, simulations=NUM_SIMULATIONS)
        print_game_stats(results)
        trajectory_path = write_trajectory_dataset(results, action_vocab, generation_id)
        assert trajectory_path == expected_trajectory_path
    else:
        print(f"Loading trajectory from {expected_trajectory_path}")
        print_dataset_stats(expected_trajectory_path, split_name)
        trajectory_path = expected_trajectory_path
        results = None

    model_path = get_model_path(generation_id)
    if not model_path.exists():
        print(f"Training model on {split_name}")
        training_splits = [f"gen-{i}" for i in range(1, generation_id+1)]
        # TODO: We're continuing training on a previosu model here ... should we train a new model from scratch?
        print(train_config)
        updated_model, trainer = train_model(model, training_splits, train_config)
        save_model(updated_model, trainer, generation_id)
    else:
        print(f"Loading model from {model_path}")
        updated_model = load_model(generation_id)
        print_model_stats(updated_model, config_alias=MODEL_SIZE)

    return results, trajectory_path, updated_model

In [11]:
results_dict = {}
trajectory_paths_dict = {}
model_dict = {0: model_0}

current_model = model_dict[0]
if RUN_GENERATIONS:
    for generation_id in range(1, NUM_GENERATIONS+1):
        current_model = model_dict[generation_id-1]
        results_i, trajectory_path_i, model_i = await run_generation(current_model, num_games=NUM_GAMES, num_simulations=NUM_SIMULATIONS, generation_id=generation_id)
        results_dict[generation_id] = results_i
        trajectory_paths_dict[generation_id] = trajectory_path_i
        model_dict[generation_id] = model_i

## refactor, learning_rate = 0.05, warmup_iters=0
# step 0: train loss 2.7801, val loss 2.7801
# iter 0/5/488: loss 2.7801, time 611.02ms
# iter 100/105/488: loss 2.5840, time 63.73ms
# iter 200/205/488: loss 2.5958, time 62.98ms
# iter 300/305/488: loss 2.5835, time 60.15ms
# iter 400/405/488: loss 2.5793, time 63.62ms

# ## model = small
# step 0: train loss 2.7741, val loss 2.7741
# iter 0/5/488: loss 2.7743, time 1624.89ms
# iter 100/105/488: loss 2.6157, time 141.39ms
# iter 200/205/488: loss 2.6120, time 161.22ms
# iter 300/305/488: loss 2.5983, time 203.82ms


# using fused AdamW: False
# step 0: train loss 2.7801, val loss 2.7801
# iter 0/5/4882: loss 2.7801, time 1422.53ms
# iter 100/105/4882: loss 2.5970, time 110.71ms
# iter 200/205/4882: loss 2.5962, time 116.96ms
# iter 300/305/4882: loss 2.5917, time 160.95ms
# iter 400/405/4882: loss 2.5885, time 63.37ms
# iter 500/505/4882: loss 2.5912, time 65.25ms
# iter 600/605/4882: loss 2.6000, time 67.49ms
# iter 700/705/4882: loss 2.5780, time 61.43ms
# iter 800/805/4882: loss 2.5864, time 265.56ms
# iter 900/905/4882: loss 2.5857, time 263.09ms
# step 1000: train loss 2.5849, val loss 2.5847
# saving checkpoint to /Users/rodo/src/rgi3/models/connect4-e2e/v1
# iter 1000/1005/4882: loss 2.5844, time 1812.58ms
# iter 1100/1105/4882: loss 2.5832, time 62.89ms
# iter 1200/1205/4882: loss 2.5743, time 95.25ms
# iter 1300/1305/4882: loss 2.5720, time 324.18ms
# iter 1400/1405/4882: loss 2.5880, time 73.66ms
# iter 1500/1505/4882: loss 2.5745, time 295.39ms
# iter 1600/1605/4882: loss 2.5726, time 76.05ms
# iter 1700/1705/4882: loss 2.5670, time 63.20ms
# iter 1800/1805/4882: loss 2.5720, time 62.66ms
# iter 1900/1905/4882: loss 2.5694, time 449.06ms
# step 2000: train loss 2.5806, val loss 2.5806
# saving checkpoint to /Users/rodo/src/rgi3/models/connect4-e2e/v1
# iter 2000/2005/4882: loss 2.5893, time 920.12ms
# iter 2100/2105/4882: loss 2.5686, time 430.53ms
# iter 2200/2205/4882: loss 2.5741, time 63.05ms
# iter 2300/2305/4882: loss 2.5679, time 60.90ms
# iter 2400/2405/4882: loss 2.5754, time 69.07ms
# iter 2500/2505/4882: loss 2.5673, time 68.33ms
# iter 2600/2605/4882: loss 2.5648, time 66.26ms
# iter 2700/2705/4882: loss 2.5622, time 69.76ms
# iter 2800/2805/4882: loss 2.5541, time 143.65ms
# iter 2900/2905/4882: loss 2.5634, time 66.40ms
# step 3000: train loss 2.5550, val loss 2.5547
# saving checkpoint to /Users/rodo/src/rgi3/models/connect4-e2e/v1
# iter 3000/3005/4882: loss 2.5545, time 975.15ms
# iter 3100/3105/4882: loss 2.5594, time 63.55ms
# iter 3200/3205/4882: loss 2.5499, time 64.17ms
# iter 3300/3305/4882: loss 2.5481, time 70.28ms
# iter 3400/3405/4882: loss 2.5565, time 73.58ms
# iter 3500/3505/4882: loss 2.5602, time 72.22ms
# iter 3600/3605/4882: loss 2.5429, time 88.68ms
# iter 3700/3705/4882: loss 2.5259, time 63.15ms
# iter 3800/3805/4882: loss 2.5346, time 66.07ms
# iter 3900/3905/4882: loss 2.5386, time 73.50ms
# step 4000: train loss 2.5350, val loss 2.5345
# saving checkpoint to /Users/rodo/src/rgi3/models/connect4-e2e/v1
# iter 4000/4005/4882: loss 2.5424, time 1217.41ms
# iter 4100/4105/4882: loss 2.5290, time 101.01ms
# iter 4200/4205/4882: loss 2.5323, time 61.94ms
# iter 4300/4305/4882: loss 2.5250, time 72.57ms
# iter 4400/4405/4882: loss 2.5243, time 68.38ms
# iter 4500/4505/4882: loss 2.5331, time 73.33ms
# iter 4600/4605/4882: loss 2.5246, time 101.00ms
# iter 4700/4705/4882: loss 2.5336, time 67.27ms
# iter 4800/4805/4882: loss 2.5170, time 79.40ms



## Running generation 1 for config_alias=trajectory_sims-200_games-10000_size-tiny_train-488_x1
Loading trajectory from /Users/rodo/src/rgi3/data/rgizero-e2e/Connect4Game/trajectory_sims-200_games-10000_size-tiny_train-488_x1/gen-1
Dataset Stats:
  Trajectories: 10000
  Total actions: 140493
  Avg trajectory length: 14.05
  Trajectory length - min: 7, max: 42, mean: 14.05
Winner Stats:
  Winner counts: win[1]=61.93% win[2]=38.06%, n=10000
Winner Stats by initial move:
  a=1: n=1564 win[1]=51.92% counts=Counter({1: 812, 2: 752}), win[2]=48.08% draw=0.00%
  a=2: n=1344 win[1]=57.37% counts=Counter({1: 771, 2: 573}), win[2]=42.63% draw=0.00%
  a=3: n=1297 win[1]=65.23% counts=Counter({1: 846, 2: 451}), win[2]=34.77% draw=0.00%
  a=4: n=1493 win[1]=76.36% counts=Counter({1: 1140, 2: 353}), win[2]=23.64% draw=0.00%
  a=5: n=1525 win[1]=66.49% counts=Counter({1: 1014, 2: 511}), win[2]=33.51% draw=0.00%
  a=6: n=1401 win[1]=62.03% counts=Counter({1: 869, 2: 532}), win[2]=37.97% draw=0.00%
 

# Sanity check models


In [12]:
# Play single game
result = await play_generation_async(current_model, num_games=1, simulations=NUM_SIMULATIONS)

100%|██████████| 1/1 [00:05<00:00,  5.32s/it]


In [13]:
# Inspect training data
td_array = [TrajectoryDataset(DATA_DIR, f"gen-{generation_id}", block_size=n_max_context) for generation_id in range(1, NUM_GENERATIONS+1)]

In [14]:
# [td for td in td_array]
unrolled = [(generation+1, d) for generation, td in enumerate(td_array) for d in td]

# gen, d = unrolled[0], 
# d.action[:2]
# d.value[0]

dd = defaultdict(lambda: defaultdict(lambda: torch.tensor([0., 0.])))

for gen, d in unrolled:
    for g in ['*', gen]:    
        # dd[tuple(tuple(d.action[:0].tolist()))][g] += d.value[0]
        dd[tuple(tuple(d.action[:1].tolist()))][g] += d.value[0]
        dd[tuple(tuple(d.action[:2].tolist()))][g] += d.value[0]
        dd[tuple(tuple(d.action[:3].tolist()))][g] += d.value[0]
        dd[tuple(tuple(d.action[:4].tolist()))][g] += d.value[0]
        dd[tuple(tuple(d.action[:5].tolist()))][g] += d.value[0]
        dd[tuple(tuple(d.action[:6].tolist()))][g] += d.value[0]
        dd[tuple(tuple(d.action[:7].tolist()))][g] += d.value[0]
        dd[tuple(tuple(d.action[:8].tolist()))][g] += d.value[0]
        dd[tuple(tuple(d.action[:9].tolist()))][g] += d.value[0]
        dd[tuple(tuple(d.action[:10].tolist()))][g] += d.value[0]

print(f"len(dd) = {len(dd)}")


len(dd) = 352822


In [15]:
def eval_prefix(model, game, prefix):
    serial_evaluator = ActionHistoryTransformerEvaluator(model, device=device, block_size=block_size, vocab=action_vocab)
    state = game.initial_state()
    for action in prefix:
        state = game.next_state(state, action)
    legal_actions = game.legal_actions(state)
    result = serial_evaluator.evaluate(game, state, legal_actions)
    return result


In [16]:
## Someting is borked? Player1 win percent should be much higher??
def compare_model_vs_data(model, game, dd):    
    list(dd.items())[10][1]['*'].sum() > 100
    top_k = sorted(dd.items(), key=lambda kv: kv[1]['*'].sum(), reverse=True)[:20]
    top_k_keys = sorted(k for k, v in top_k)
    
    prefix_list = top_k_keys

    # prefix_list = [
    #     (0,), 
    #     (0,1), (0,2), (0,3), (0,4), (0,5), (0,6), (0,7),
    #     (0,1,1), (0,1,2), (0,1,3), (0,1,4), (0,1,5), (0,1,6), (0,1,7),
    #     (0,4,1), (0,4,2), (0,4,3), (0,4,4), (0,4,5), (0,4,6), (0,4,7),
    # ]

    for prefix in prefix_list:
        print(f"\nprefix={prefix}")
        for gen, counts in dd[prefix].items():
            print(f"gen={gen}: {counts}, win_pct={100*counts[0]/sum(counts):.2f}%, sum={sum(counts)}")
        assert prefix[0] == 0
        actions = prefix[1:]
        eval_result = eval_prefix(model, game, actions)
        # print(f'legal_policy={eval_result.legal_policy}')
        # print(f'player_values={eval_result.player_values}')
        print(f'player_probs={(eval_result.player_values+1)/2}')

compare_model_vs_data(current_model, game, dd)



prefix=(0,)
gen=*: tensor([126100.5000,  73899.5000]), win_pct=63.05%, sum=200000.0
gen=1: tensor([6193.5000, 3806.5000]), win_pct=61.94%, sum=10000.0
gen=2: tensor([4645., 5355.]), win_pct=46.45%, sum=10000.0
gen=3: tensor([5202.5000, 4797.5000]), win_pct=52.03%, sum=10000.0
gen=4: tensor([4570., 5430.]), win_pct=45.70%, sum=10000.0
gen=5: tensor([5839., 4161.]), win_pct=58.39%, sum=10000.0
gen=6: tensor([6709., 3291.]), win_pct=67.09%, sum=10000.0
gen=7: tensor([6454.5000, 3545.5000]), win_pct=64.54%, sum=10000.0
gen=8: tensor([5073., 4927.]), win_pct=50.73%, sum=10000.0
gen=9: tensor([5754.5000, 4245.5000]), win_pct=57.54%, sum=10000.0
gen=10: tensor([4074., 5926.]), win_pct=40.74%, sum=10000.0
gen=11: tensor([8811., 1189.]), win_pct=88.11%, sum=10000.0
gen=12: tensor([8075., 1925.]), win_pct=80.75%, sum=10000.0
gen=13: tensor([7093.5000, 2906.5000]), win_pct=70.93%, sum=10000.0
gen=14: tensor([5778.5000, 4221.5000]), win_pct=57.78%, sum=10000.0
gen=15: tensor([8239., 1761.]), win_

In [17]:
# Copy model
model_0 = create_random_model(model_config, action_vocab_size=action_vocab.vocab_size, num_players=game.num_players(state_0), seed=42)
if RUN_GENERATIONS:
    model_1 = load_model(1)


In [18]:
print("\n\n### Model 0")
print(model_0.action_embedding.weight)
compare_model_vs_data(model_0, game, dd)



### Model 0
Parameter containing:
tensor([[-0.0504,  0.0093, -0.0226,  0.0165,  0.0072, -0.0277,  0.0177, -0.0237,
          0.0188, -0.0143,  0.0269, -0.0013, -0.0198,  0.0144, -0.0243,  0.0237,
         -0.0140,  0.0181,  0.0318, -0.0157,  0.0148, -0.0034, -0.0324,  0.0208,
         -0.0291, -0.0305,  0.0342, -0.0172,  0.0260, -0.0208, -0.0124,  0.0096],
        [-0.0126,  0.0072,  0.0104, -0.0769, -0.0078, -0.0041, -0.0216,  0.0517,
          0.0385, -0.0091, -0.0166,  0.0048,  0.0219, -0.0037, -0.0307,  0.0049,
         -0.0108,  0.0062,  0.0002, -0.0093, -0.0116, -0.0281, -0.0046, -0.0189,
         -0.0464, -0.0034, -0.0101, -0.0152, -0.0383,  0.0229, -0.0058, -0.0105],
        [-0.0208, -0.0252,  0.0038, -0.0199, -0.0087, -0.0136, -0.0072,  0.0004,
         -0.0083,  0.0022, -0.0049,  0.0290,  0.0091, -0.0186, -0.0067,  0.0133,
         -0.0194, -0.0004,  0.0097,  0.0346, -0.0028, -0.0301,  0.0015,  0.0175,
         -0.0117,  0.0087,  0.0185, -0.0033,  0.0177,  0.0148,  0.0138,

In [19]:
if RUN_GENERATIONS:
    print("\n\n### Model 1")
    print(model_1.action_embedding.weight)
    compare_model_vs_data(model_1, game, dd)



### Model 1
Parameter containing:
tensor([[ 0.1855, -0.1236,  0.0325, -0.2279, -0.1312,  0.2044, -0.0552,  0.0015,
          0.0456, -0.0645, -0.2379,  0.0170,  0.0728,  0.1388, -0.0804,  0.0047,
          0.2121, -0.1895,  0.0768,  0.0613,  0.1133, -0.3043,  0.0052, -0.1532,
          0.2247,  0.2399,  0.0458,  0.2320,  0.2496, -0.0537, -0.2677, -0.1787],
        [ 0.0808, -0.2869,  0.0688, -0.1646,  0.3144,  0.0958,  0.0340, -0.1392,
         -0.2065,  0.4270,  0.2191, -0.1682,  0.0057, -0.0339, -0.0352,  0.3426,
          0.0628,  0.1943, -0.1348, -0.0467, -0.1249, -0.0522,  0.3732,  0.0999,
         -0.0559, -0.1346, -0.1585, -0.2288, -0.0584, -0.0580,  0.0284, -0.2018],
        [ 0.3760, -0.0341, -0.0524, -0.1812, -0.2408,  0.1764, -0.0913, -0.1694,
         -0.2916, -0.1422, -0.1412, -0.0781,  0.1165,  0.0633,  0.1347, -0.4096,
          0.2124, -0.2153, -0.0055, -0.1609, -0.2353,  0.0733, -0.1883, -0.4946,
          0.1866,  0.2968,  0.3776,  0.3554,  0.0164, -0.1493, -0.3089,

## Run tournament to calcualte ELO


In [20]:
import asyncio
import numpy as np
from contextlib import asynccontextmanager
from rgi.rgizero.tournament import Tournament
from rgi.rgizero.players.alphazero import AlphazeroPlayer
from rgi.rgizero.models.action_history_transformer import ActionHistoryTransformerEvaluator, AsyncNetworkEvaluator

@asynccontextmanager
async def create_player_factory(model, simulations, game, device, block_size, action_vocab, max_batch_size):
    """
    Creates a shared evaluator and returns a factory function that produces 
    new AlphazeroPlayer instances using that shared evaluator.
    """
    # 1. Setup the shared evaluator
    serial_evaluator = ActionHistoryTransformerEvaluator(
        model, 
        device=device, 
        block_size=block_size, 
        vocab=action_vocab
    )
    async_evaluator = AsyncNetworkEvaluator(
        base_evaluator=serial_evaluator, 
        max_batch_size=max_batch_size, 
        verbose=False
    )
    
    # 2. Start the evaluator background task
    await async_evaluator.start()
    
    try:
        # 3. Define the factory. This is called by Tournament for every game.
        # It creates a NEW player instance but uses the SHARED async_evaluator.
        def player_factory():
            # Create a fresh RNG for each game/player instance
            rng = np.random.default_rng(np.random.randint(0, 2**31))
            return AlphazeroPlayer(
                game, 
                async_evaluator, 
                rng=rng, 
                add_noise=True, 
                simulations=simulations
            )
            
        yield player_factory
        
    finally:
        # 4. Cleanup
        await async_evaluator.stop()

async def run_tournament_async():
    # Use async with to manage the lifecycle of the evaluators
    async with (
        # create_player_factory(model_dict[0], 100, game, device, block_size, action_vocab, 10) as factory_gen0_100,
        # create_player_factory(model_dict[1], 100, game, device, block_size, action_vocab, 10) as factory_gen1_100,
        # create_player_factory(model_dict[2], 100, game, device, block_size, action_vocab, 10) as factory_gen2_100,
        # create_player_factory(model_dict[3], 100, game, device, block_size, action_vocab, 10) as factory_gen3_100,
        # create_player_factory(model_dict[4], 100, game, device, block_size, action_vocab, 10) as factory_gen4_100,
        # create_player_factory(model_dict[5], 100, game, device, block_size, action_vocab, 10) as factory_gen5_100,
        # create_player_factory(model_dict[10], 100, game, device, block_size, action_vocab, 10) as factory_gen6_100,
        # create_player_factory(model_dict[15], 100, game, device, block_size, action_vocab, 10) as factory_gen7_100,
        # create_player_factory(model_dict[20], 100, game, device, block_size, action_vocab, 10) as factory_gen8_100,

        create_player_factory(model_dict[0], 200, game, device, block_size, action_vocab, 10) as factory_gen0_200,
        create_player_factory(model_dict[1], 200, game, device, block_size, action_vocab, 10) as factory_gen1_200,
        create_player_factory(model_dict[2], 200, game, device, block_size, action_vocab, 10) as factory_gen2_200,
        create_player_factory(model_dict[3], 200, game, device, block_size, action_vocab, 10) as factory_gen3_200,
        create_player_factory(model_dict[4], 200, game, device, block_size, action_vocab, 10) as factory_gen4_200,
        create_player_factory(model_dict[5], 200, game, device, block_size, action_vocab, 10) as factory_gen5_200,
        create_player_factory(model_dict[10], 200, game, device, block_size, action_vocab, 10) as factory_gen10_200,
        create_player_factory(model_dict[15], 200, game, device, block_size, action_vocab, 10) as factory_gen15_200,
        create_player_factory(model_dict[20], 200, game, device, block_size, action_vocab, 10) as factory_gen20_200,
        ):
        
        # The dictionary now maps names to FACTORIES (Callables), not Player instances
        player_factories = {
            # "factory_gen0_100": factory_gen0_100,
            # "factory_gen1_100": factory_gen1_100,
            # "factory_gen2_100": factory_gen2_100,
            # "factory_gen3_100": factory_gen3_100,
            # "factory_gen4_100": factory_gen4_100,
            # "factory_gen5_100": factory_gen5_100,
            # "factory_gen6_100": factory_gen6_100,
            # "factory_gen7_100": factory_gen7_100,

            "factory_gen0_200": factory_gen0_200,
            "factory_gen1_200": factory_gen1_200,
            "factory_gen2_200": factory_gen2_200,
            #"factory_gen3_200": factory_gen3_200,
            #"factory_gen4_200": factory_gen4_200,
            "factory_gen5_200": factory_gen5_200,
            "factory_gen10_200": factory_gen10_200,
            #"factory_gen15_200": factory_gen15_200,
            "factory_gen20_200": factory_gen20_200,
        }
        
        tournament = Tournament(game, player_factories, initial_elo=1000)
        
        print("Running tournament...")
        await tournament.run(num_games=1_000, concurrent_games=2000)
        tournament.print_standings()

if RUN_TOURNAMENT:
    await run_tournament_async()

# Running tournament...
# Tournament Progress: 100%|██████████| 10000/10000 [1:25:59<00:00,  1.94it/s]

# Tournament Standings:
# Rank  Player               ELO        Games    W-L-D          
# -----------------------------------------------------------------
# 1     factory_gen6_200     1140.5     1247     827-419-1      
# 2     factory_gen2_200     1100.1     1251     693-554-4      
# 3     factory_gen5_100     1074.4     1251     598-652-1      
# 4     factory_gen3_200     1029.1     1252     674-573-5      
# 5     factory_gen4_200     1027.0     1248     711-536-1      
# 6     factory_gen0_200     1020.0     1254     444-810-0      
# 7     factory_gen5_200     990.2      1248     742-502-4      
# 8     factory_gen7_100     987.5      1250     650-597-3      
# 9     factory_gen7_200     979.2      1248     768-476-4      
# 10    factory_gen2_100     974.0      1249     522-723-4      
# 11    factory_gen6_100     966.6      1248     684-564-0      
# 12    factory_gen4_100     964.2      1251     557-693-1      
# 13    factory_gen1_100     962.5      1252     547-705-0      
# 14    factory_gen3_100     947.0      1251     528-723-0      
# 15    factory_gen1_200     941.1      1252     630-620-2      
# 16    factory_gen0_100     896.5      1248     410-838-0     


## 20 generations.
# Running tournament...
# Tournament Progress: 100%|██████████| 1000/1000 [08:35<00:00,  1.94it/s]

# Tournament Standings:
# Rank  Player               ELO        Games    W-L-D          
# -----------------------------------------------------------------
# 1     factory_gen10_200    1114.2     333      212-120-1      
# 2     factory_gen2_200     1032.6     333      190-141-2      
# 3     factory_gen1_200     1003.9     334      159-175-0      
# 4     factory_gen20_200    1000.9     335      171-164-0      
# 5     factory_gen5_200     974.6      331      183-146-2      
# 6     factory_gen0_200     873.8      334      82-251-1  

# Tune Model


In [21]:
# reload_local_modules()

from rgi.rgizero.models.tuner import Tuner

transform_config_fields = {f.name for f in dataclasses.fields(TransformerConfig)}
train_config_fields = {f.name for f in dataclasses.fields(TrainConfig)}

print(f'transform_config_fields: {transform_config_fields}')
print(f'train_config_fields: {train_config_fields}')


def train_with(**overrides):
    """Wrapper fn to train a model using the latest train.py code and the given overrides."""
    t0 = time.time()

    for override in overrides:
        if override not in transform_config_fields and override not in train_config_fields:
            raise ValueError(f"Invalid override: {override}")


    model_config_overrides = {k:v for k,v in overrides.items() if k in transform_config_fields}
    train_config_overrides = {k:v for k,v in overrides.items() if k in train_config_fields}

    model_config = TransformerConfig(**model_config_overrides)
    train_config = TrainConfig(**train_config_overrides)

    print(f"model_config={model_config}")
    print(f"train_config={train_config}")
    model = create_random_model(model_config, action_vocab_size=action_vocab.vocab_size, num_players=game.num_players(state_0), seed=42)

    training_splits = [f'gen-{generation_id}' for generation_id in range(1, NUM_GENERATIONS+1)]

    model, trainer = train_model(model, training_splits, train_config)
    loss_dict = trainer.estimate_loss()
    loss_dict = {k: float(v) for k, v in loss_dict.items()}

    # def train_model(model, training_splits, train_config):
    # loss_dict = train.train_and_evaluate(**overrides)
    elapsed = time.time() - t0
    print(f"## train_loss: {loss_dict['train']:.4f}, val_loss: {loss_dict['val']:.4f}, Time taken: {elapsed}s, overrides={overrides}")
    return loss_dict, elapsed

# train_with(model_name='c4-tuning', model_version='0.1',
#             n_max_context=n_max_context, n_layer=2, n_head=2, n_embd=8,  # tiny model
#             max_iters=100, max_epochs=10,
#             eval_iters = 20, log_interval = 100, 
#             gradient_accumulation_steps = 1,
#             batch_size = 32, learning_rate = LEARNING_RATE,    
#             lr_decay_iters = 100,  # make equal to max_iters usually
#             min_lr = LEARNING_RATE / 10,  # learning_rate / 10 usually
#             beta2 = 0.99,  # make a bit bigger because number of tokens per iter is small
#             warmup_iters = 0,  # not super necessary potentially
#            )

transform_config_fields: {'dropout', 'n_embd', 'n_head', 'bias', 'n_layer', 'n_max_context'}
train_config_fields: {'beta2', 'device', 'learning_rate', 'gradient_accumulation_steps', 'max_iters', 'eval_iters', 'beta1', 'warmup_iters', 'always_save_checkpoint', 'max_epochs', 'model_version', 'grad_clip', 'lr_decay_iters', 'weight_decay', 'model_name', 'log_interval', 'compile', 'wandb_log', 'eval_interval', 'decay_lr', 'min_lr', 'batch_size', 'dtype', 'eval_only'}


In [22]:
reload_local_modules()

reloaded rgi
reloaded rgi.rgizero
reloaded rgi.rgizero.games
reloaded rgi.rgizero.games.base
reloaded rgi.rgizero.games.connect4
reloaded rgi.rgizero.players
reloaded rgi.rgizero.players.base
reloaded rgi.rgizero.players.alphazero
reloaded rgi.rgizero.common
reloaded rgi.rgizero.games.history_wrapper
reloaded rgi.rgizero.data
reloaded rgi.rgizero.data.trajectory_dataset
reloaded rgi.rgizero.models
reloaded rgi.rgizero.models.transformer
reloaded rgi.rgizero.models.token_transformer
reloaded rgi.rgizero.models.action_history_transformer
reloaded rgi.rgizero.train
reloaded rgi.rgizero.tournament
reloaded rgi.rgizero.models.tuner
  -> Updated 'Connect4Game' in globals() from 'rgi.rgizero.games.connect4'
  -> Updated 'AlphazeroPlayer' in globals() from 'rgi.rgizero.players.alphazero'
  -> Updated 'play_game' in globals() from 'rgi.rgizero.players.alphazero'
  -> Updated 'HistoryTrackingGame' in globals() from 'rgi.rgizero.games.history_wrapper'
  -> Updated 'Vocab' in globals() from 'rgi.r

In [23]:
# TODO: Pass as arg to Tuner instead of monkey patching...
import rgi.rgizero.models.tuner
rgi.rgizero.models.tuner.train_with = train_with

initial_params = dict(
    model_name='c4-tuning', model_version='0.1',
    n_max_context=n_max_context, n_layer=2, n_head=2, n_embd=8,  # tiny model
    max_iters=100, max_epochs=1_000_000, # Make max_epoch high, rely on max_iters to stop.
    eval_iters = 200, log_interval = 1000, eval_interval = 10_000,
    gradient_accumulation_steps = 1,
    batch_size = 32, learning_rate = LEARNING_RATE,    
    decay_lr = True,  # whether to decay the learning rate
    lr_decay_iters = 100,  # make equal to max_iters usually
    min_lr = LEARNING_RATE / 10,  # learning_rate / 10 usually
    warmup_iters = 0,  # not super necessary potentially

    weight_decay = 1e-1,
    beta1 = 0.9,
    beta2 = 0.95,
    grad_clip = 1.0,  # clip gradients at this value, or disable if == 0.0

    dtype = "float16",

    # block_size = block_size,
    # vocab_size = action_vocab.vocab_size,  # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency
    dropout = 0.0,
    bias = False,  # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster

)

tune_options = {k: [v] for k, v in initial_params.items()}

tune_options.update(dict(
    batch_size = [16, 32, 64, 128, 256, 512],
    gradient_accumulation_steps = [1, 4, 16],
    learning_rate = [0.001, 0.005, 0.01, 0.05, 0.1, 0.5, 1.0],
    lr_decay_iters = [100],
    beta1 = [0.9],
    beta2 = [0.95, 0.99],
    warmup_iters = [0],
    n_embd = [8, 16, 32, 64, 128],
    n_layer = [2, 4, 8, 16, 32],
    n_head = [2, 4, 8, 16, 32],
    max_iters = [100, 300, 1_000, 3_000, 10_000, 30_000, 100_000],

    dtype = ["float16"], # ["bfloat16", "float16", "float32"],
    # vocab_size = [action_vocab.vocab_size],
    dropout = [0.0, 0.01, 0.05, 0.1],
    bias = [True, False],
    
    decay_lr = [True, False],
))

computed_tune_options = dict(
    min_lr = lambda opt: [opt['learning_rate'] / 10],
    lr_decay_iters = lambda opt: [opt['max_iters']],
    warmup_iters = lambda opt: [0, 100, 1000] if opt['decay_lr'] else [0],
)

TUNER_VERSION = "0.0.2"


# TODO: We need to recalculate the 'calculated' options every time any hparam is changed...
tuner = Tuner(
    tune_options=tune_options.copy(), 
    initial_params=initial_params.copy(),
    computed_tune_options=computed_tune_options.copy(),
    cache_version=TUNER_VERSION,
    target_improvement_per_minute=100.00)
tuner.autotune()


Initial params: None
## Initial Model, loss=2.4566147327423096 elapsed=13.604346990585327s
## Tuning generation 1: model_name
## Tuning generation 1: model_version
## Tuning generation 1: n_max_context
## Tuning generation 1: n_layer
## Tuning generation 1: n_layer improved, val=4, best=2.4735538959503174, delta=-0.016939163208007812 elapsed=10.176063060760498s delta=3.428283929824829s
## Tuning generation 1: n_head
## Tuning generation 1: n_head improved, val=4, best=2.467500686645508, delta=0.00605320930480957 elapsed=9.368541955947876s delta=0.8075211048126221s
## Tuning generation 1: n_embd
## Tuning generation 1: max_iters
## Tuning generation 1: max_epochs
## Tuning generation 1: eval_iters
## Tuning generation 1: log_interval
## Tuning generation 1: eval_interval
## Tuning generation 1: gradient_accumulation_steps
## Tuning generation 1: batch_size
## Tuning generation 1: learning_rate
## Tuning generation 1: decay_lr
## Tuning generation 1: lr_decay_iters
## Computed tune optio

True

In [24]:
tuner = Tuner(
    tune_options=tune_options.copy(), 
    initial_params=initial_params.copy(),
    computed_tune_options=computed_tune_options.copy(),
    cache_version=TUNER_VERSION,
    target_improvement_per_minute=10.00)
tuner.autotune()

Initial params: None
## Initial Model, loss=2.4566147327423096 elapsed=13.604346990585327s
## Tuning generation 1: model_name
## Tuning generation 1: model_version
## Tuning generation 1: n_max_context
## Tuning generation 1: n_layer
## Tuning generation 1: n_layer improved, val=4, best=2.4735538959503174, delta=-0.016939163208007812 elapsed=10.176063060760498s delta=3.428283929824829s
## Tuning generation 1: n_head
## Tuning generation 1: n_head improved, val=4, best=2.467500686645508, delta=0.00605320930480957 elapsed=9.368541955947876s delta=0.8075211048126221s
## Tuning generation 1: n_embd
## Tuning generation 1: max_iters
## Tuning generation 1: max_epochs
## Tuning generation 1: eval_iters
## Tuning generation 1: log_interval
## Tuning generation 1: eval_interval
## Tuning generation 1: gradient_accumulation_steps
## Tuning generation 1: batch_size
## Tuning generation 1: learning_rate
## Tuning generation 1: decay_lr
## Tuning generation 1: lr_decay_iters
## Computed tune optio

True

In [25]:
tuner = Tuner(
    tune_options=tune_options.copy(), 
    initial_params=initial_params.copy(),
    computed_tune_options=computed_tune_options.copy(),
    cache_version=TUNER_VERSION,
    target_improvement_per_minute=1.00)
tuner.autotune()

Initial params: None
## Initial Model, loss=2.4566147327423096 elapsed=13.604346990585327s
## Tuning generation 1: model_name
## Tuning generation 1: model_version
## Tuning generation 1: n_max_context
## Tuning generation 1: n_layer
## Tuning generation 1: n_layer improved, val=4, best=2.4735538959503174, delta=-0.016939163208007812 elapsed=10.176063060760498s delta=3.428283929824829s
## Tuning generation 1: n_head
## Tuning generation 1: n_head improved, val=4, best=2.467500686645508, delta=0.00605320930480957 elapsed=9.368541955947876s delta=0.8075211048126221s
## Tuning generation 1: n_embd
## Tuning generation 1: max_iters
## Tuning generation 1: max_epochs
## Tuning generation 1: eval_iters
## Tuning generation 1: log_interval
## Tuning generation 1: eval_interval
## Tuning generation 1: gradient_accumulation_steps
## Tuning generation 1: batch_size
## Tuning generation 1: learning_rate
## Tuning generation 1: decay_lr
## Tuning generation 1: lr_decay_iters
## Computed tune optio

True

In [26]:
tuner = Tuner(
    tune_options=tune_options.copy(), 
    initial_params=initial_params.copy(),
    computed_tune_options=computed_tune_options.copy(),
    cache_version=TUNER_VERSION,
    target_improvement_per_minute=0.10)
tuner.autotune()

Initial params: None
## Initial Model, loss=2.4566147327423096 elapsed=13.604346990585327s
## Tuning generation 1: model_name
## Tuning generation 1: model_version
## Tuning generation 1: n_max_context
## Tuning generation 1: n_layer
## Tuning generation 1: n_head
## Tuning generation 1: n_head improved, val=8, best=2.4566547870635986, delta=-4.00543212890625e-05 elapsed=8.165772676467896s delta=5.438574314117432s
## Tuning generation 1: n_embd
## Tuning generation 1: max_iters
## Tuning generation 1: max_iters improved, val=1000, best=2.4133732318878174, delta=0.04328155517578125 elapsed=22.913728952407837s delta=-14.747956275939941s
## Tuning generation 1: max_epochs
## Tuning generation 1: eval_iters
## Tuning generation 1: log_interval
## Tuning generation 1: eval_interval
## Tuning generation 1: gradient_accumulation_steps
## Tuning generation 1: batch_size
## Tuning generation 1: batch_size improved, val=64, best=2.3779265880584717, delta=0.0354466438293457 elapsed=27.58056807518

True

In [27]:
tuner = Tuner(
    tune_options=tune_options.copy(), 
    initial_params=initial_params.copy(),
    computed_tune_options=computed_tune_options.copy(),
    cache_version=TUNER_VERSION,
    target_improvement_per_minute=0.01)
tuner.autotune()

Initial params: None
## Initial Model, loss=2.4566147327423096 elapsed=13.604346990585327s
## Tuning generation 1: model_name
## Tuning generation 1: model_version
## Tuning generation 1: n_max_context
## Tuning generation 1: n_layer
## Tuning generation 1: n_head
## Tuning generation 1: n_head improved, val=8, best=2.4566547870635986, delta=-4.00543212890625e-05 elapsed=8.165772676467896s delta=5.438574314117432s
## Tuning generation 1: n_embd
## Tuning generation 1: max_iters
## Tuning generation 1: max_iters improved, val=3000, best=2.3668911457061768, delta=0.08976364135742188 elapsed=59.66869020462036s delta=-51.502917528152466s
## Tuning generation 1: max_epochs
## Tuning generation 1: eval_iters
## Tuning generation 1: log_interval
## Tuning generation 1: eval_interval
## Tuning generation 1: gradient_accumulation_steps
## Tuning generation 1: batch_size
## Tuning generation 1: batch_size improved, val=128, best=2.3502447605133057, delta=0.016646385192871094 elapsed=89.431571006

True

In [28]:
tuner = Tuner(
    tune_options=tune_options.copy(), 
    initial_params=initial_params.copy(),
    computed_tune_options=computed_tune_options.copy(),
    cache_version=TUNER_VERSION,
    target_improvement_per_minute=0.001)
tuner.autotune()

Initial params: None
## Initial Model, loss=2.4566147327423096 elapsed=13.604346990585327s
## Tuning generation 1: model_name
## Tuning generation 1: model_version
## Tuning generation 1: n_max_context
## Tuning generation 1: n_layer
## Tuning generation 1: n_head
## Tuning generation 1: n_embd
## Tuning generation 1: max_iters
## Tuning generation 1: max_iters improved, val=10000, best=2.360471248626709, delta=0.09614348411560059 elapsed=187.25318884849548s delta=-173.64884185791016s
## Tuning generation 1: max_epochs
## Tuning generation 1: eval_iters
## Tuning generation 1: log_interval
## Tuning generation 1: eval_interval
## Tuning generation 1: gradient_accumulation_steps
## Tuning generation 1: batch_size
## Tuning generation 1: batch_size improved, val=64, best=2.347130060195923, delta=0.013341188430786133 elapsed=206.6811351776123s delta=-19.42794632911682s
## Tuning generation 1: learning_rate
## Tuning generation 1: decay_lr
## Tuning generation 1: lr_decay_iters
## Computed



using fused AdamW: False
step 0: train loss 2.7524, val loss 2.7525
iter 0/391/30000: loss 2.7629, time 33845.71ms


KeyboardInterrupt: 

In [44]:
reload_local_modules(verbose=False)
tuner = Tuner(
    tune_options=tune_options.copy(), 
    initial_params=initial_params.copy(),
    computed_tune_options=computed_tune_options.copy(),
    cache_version=TUNER_VERSION,
    target_improvement_per_minute=0.001)
# print stats based on cached results.
tuner_stats = tuner.print_hparam_stats()

Initial params: None
## tree=defaultdict(<function Tuner.print_hparam_stats.<locals>.<lambda> at 0x52b7722a0>, {'batch_size': defaultdict(<function Tuner.print_hparam_stats.<locals>.<lambda>.<locals>.<lambda> at 0x52b786ac0>, {"[('beta1', 0.9), ('beta2', 0.95), ('bias', False), ('decay_lr', True), ('dropout', 0.0), ('dtype', 'bfloat16'), ('eval_interval', 10000), ('eval_iters', 200), ('grad_clip', 1.0), ('gradient_accumulation_steps', 1), ('learning_rate', 0.05), ('log_interval', 200), ('lr_decay_iters', 100), ('max_epochs', 1000000), ('max_iters', 100), ('min_lr', 0.005), ('model_name', 'c4-tuning'), ('model_version', '0.1'), ('n_embd', 8), ('n_head', 2), ('n_layer', 2), ('n_max_context', 44), ('warmup_iters', 0), ('weight_decay', 0.1)]": defaultdict(<function Tuner.print_hparam_stats.<locals>.<lambda>.<locals>.<lambda>.<locals>.<lambda> at 0x52b786a20>, {32: {'train': 2.4529192447662354, 'val': 2.4566147327423096, 'elapsed': 13.30814504623413}}), "[('beta1', 0.9), ('beta2', 0.95), ('

  x = asanyarray(arr - arrmean)
  stats_dict['std_elapsed_delta'] = stats_dict['mean_elapsed_delta'] / stats_dict['mean_elapsed_1']
  


In [45]:
tuner_stats

{('batch_size', 32, 64): {'mean_val_1': 2.427546977996826,
  'mean_val_2': 2.413127040863037,
  'mean_val_delta': 0.014419937133789062,
  'mean_elapsed_1': 41.317571036020915,
  'mean_elapsed_2': 54.33175460497538,
  'mean_elapsed_delta': -13.014183568954468,
  'std_val_1': 0.04657474285224739,
  'std_val_2': 0.049823862605564706,
  'std_elapsed_1': 59.09380767082747,
  'std_elapsed_2': 79.07996321354398,
  'std_val_delta': 0.005940126911854109,
  'std_elapsed_delta': -0.31497939599616404},
 ('batch_size', 32, 16): {'mean_val_1': 2.468580897649129,
  'mean_val_2': 2.4686112880706785,
  'mean_val_delta': -3.0390421549479167e-05,
  'mean_elapsed_1': 7.628191741307576,
  'mean_elapsed_2': 8.731634044647217,
  'mean_elapsed_delta': -1.1034423033396403,
  'std_val_1': 0.008104734346281891,
  'std_val_2': 0.006920688074390324,
  'std_elapsed_1': 0.6606143017651211,
  'std_elapsed_2': 4.23662562546057,
  'std_val_delta': -1.2310887432702924e-05,
  'std_elapsed_delta': -0.1446531944608009},
 (

In [56]:
# [(k,v['mean_val_delta']) for (k,v) in sorted(tuner_stats.items(), key=lambda x: x[1]['mean_val_delta'], reverse=True)]

sorted([(v['mean_val_delta'],k) for (k,v) in tuner_stats.items() if not np.isnan(v['mean_val_delta'])], reverse=True)


[(inf, ('warmup_iters', 100, 0)),
 (inf, ('n_head', 16, 4)),
 (inf, ('n_head', 16, 2)),
 (inf, ('max_iters', 3000, 1000)),
 (inf, ('gradient_accumulation_steps', 4, 1)),
 (0.11702394485473633, ('batch_size', 64, 512)),
 (0.09156906604766846, ('max_iters', 100, 30000)),
 (0.08910207748413086, ('max_iters', 100, 3000)),
 (0.08820090293884278, ('max_iters', 100, 10000)),
 (0.08159437179565429, ('max_iters', 300, 3000)),
 (0.07621145248413086, ('max_iters', 300, 10000)),
 (0.07434988021850586, ('n_embd', 8, 64)),
 (0.06997077805655343, ('max_iters', 3000, 30000)),
 (0.06852293014526367, ('batch_size', 128, 512)),
 (0.06849443912506104, ('max_iters', 300, 30000)),
 (0.05829771359761556, ('n_embd', 8, 32)),
 (0.05670485496520996, ('max_iters', 1000, 10000)),
 (0.05446876798357282, ('max_iters', 100, 1000)),
 (0.04985438452826606, ('max_iters', 3000, 10000)),
 (0.04596608877182007, ('batch_size', 64, 256)),
 (0.0431286096572876, ('learning_rate', 0.5, 0.05)),
 (0.041318804025650024, ('n_embd'

## Debug Convergence

Synthetic sanity-check: train on a toy 2-step game where the first action strongly determines the winner. This verifies the value head and training loop can learn simple patterns.


In [None]:
xxx STOP HERE xxx

In [None]:
state_0 = game.initial_state()
all_actions_0 = game.all_actions()

print(all_actions_0)


In [None]:
import random

def play_random_game_with_fake_reward(game, max_actions) -> dict:
    state = game.initial_state()
    action_history = []
    legal_policies = []
    legal_action_idx_list = []

    all_actions = game.all_actions()
    all_action_idx_map = {action: idx for idx, action in enumerate(all_actions)}

    num_actions = 0
    while not game.is_terminal(state) and num_actions < max_actions:
        current_player = game.current_player_id(state)
        legal_actions = game.legal_actions(state)
        action_idx = random.randrange(len(legal_actions))
        action = legal_actions[action_idx]

        action_history.append(action)
        legal_policies.append(np.ones(len(legal_actions))/len(legal_actions))
        legal_action_idx = np.array([all_action_idx_map[action] for action in legal_actions])
        legal_action_idx_list.append(legal_action_idx)

        state = game.next_state(state, action)
        num_actions += 1

    # Determine outcome
    fake_reward = np.mean(action_history) / len(legal_actions)
    rewards = np.array([fake_reward, 1.0-fake_reward])
    if fake_reward >= 0.5:
        winner = 1
    else:
        winner = 2

    return {
        "winner": winner,
        "rewards": rewards,
        "action_history": action_history,
        "legal_policies": legal_policies,
        "final_state": state,
        "legal_action_idx": legal_action_idx_list,
    }

In [None]:
play_random_game_with_fake_reward(game, max_actions=2)

In [None]:
results = [play_random_game_with_fake_reward(game, max_actions=2) for _ in range(100_000)]
print_game_stats(results)


In [None]:
fake_gen_name = "fake-0"
trajectory_path = write_trajectory_dataset(results, action_vocab, fake_gen_name)


In [None]:
# fake_model_config = model_config_dict[MODEL_SIZE]
fake_model_config = model_config_dict["large"]
fake_model = create_random_model(fake_model_config, action_vocab_size=action_vocab.vocab_size, num_players=game.num_players(state_0), seed=42)

training_splits = [f'gen-{fake_gen_name}']
fake_model, fake_trainer = train_model(fake_model, training_splits, train_config)
save_model(fake_model, fake_trainer, fake_gen_name)

## model_size=tiny
# num decayed parameter tensors: 11, with 1,968 parameters
# num non-decayed parameter tensors: 7, with 50 parameters
# using fused AdamW: False
# step 0: train loss 2.7817, val loss 2.7816
# iter 0/49/488: loss 2.7821, time 2537.56ms
# iter 100/147/488: loss 2.6890, time 53.61ms
# iter 200/245/488: loss 2.6342, time 63.05ms
# iter 300/343/488: loss 2.6187, time 55.31ms
# iter 400/441/488: loss 2.6147, time 61.11ms

## model_size=large
# num decayed parameter tensors: 35, with 1,579,776 parameters
# num non-decayed parameter tensors: 19, with 2,186 parameters
# using fused AdamW: False
# step 0: train loss 2.8087, val loss 2.8088
# iter 0/49/488: loss 2.8099, time 11225.20ms
# iter 100/147/488: loss 2.6065, time 596.91ms
# iter 200/245/488: loss 2.6075, time 618.00ms
# iter 300/343/488: loss 2.6080, time 613.63ms
# iter 400/441/488: loss 2.6051, time 616.39ms

In [None]:
# for rerun in range(10):
#     print(f"Re-running training for {fake_gen_name} {rerun+1} of 10")
#     fake_model, fake_trainer = train_model(fake_model, training_splits, train_config)
#     save_model(fake_model, fake_trainer, fake_gen_name)

In [None]:
# [td for td in td_array]
fake_td_array = [TrajectoryDataset(DATA_DIR, split, block_size=n_max_context) for split in training_splits]
fake_unrolled = [(generation+1, d) for generation, td in enumerate(fake_td_array) for d in td]

# gen, d = unrolled[0], 
# d.action[:2]
# d.value[0]

# Inspect training data
fake_dd = defaultdict(lambda: defaultdict(lambda: torch.tensor([0., 0.])))

for gen, d in fake_unrolled:
    for g in ['*', gen]:    
        fake_dd[tuple(tuple(d.action[:0].tolist()))][g] += d.value[0]
        fake_dd[tuple(tuple(d.action[:1].tolist()))][g] += d.value[0]
        fake_dd[tuple(tuple(d.action[:2].tolist()))][g] += d.value[0]
        # fake_dd[tuple(tuple(d.action[:3].tolist()))][g] += d.value[0]

print(f"len(fake_dd) = {len(fake_dd)}")


In [None]:
fake_model = load_model(fake_gen_name)
compare_model_vs_data(fake_model, game, dd)


In [None]:
fake_model, fake_trainer = train_model(fake_model, training_splits, train_config)
