# Step-by-step run of alphazero self-play & training.


In [1]:
import os
import time
from pathlib import Path
from collections import defaultdict, Counter
import asyncio
from typing import Callable

import numpy as np
import torch
import torch.nn.functional as F

# Game and players
from rgi.rgizero.games.connect4 import Connect4Game
from rgi.rgizero.players.alphazero import AlphazeroPlayer
from rgi.rgizero.players.alphazero import play_game

from notebook_utils import reload_local_modules

print("✅ Imports successful")

if torch.cuda.is_available():
    device = 'cuda'
elif torch.backends.mps.is_available():
  device = 'mps'
else:
    device = 'cpu'
print(f'Using device: {device}')
assert device in ('cuda', 'mps'), f"No accelerator available, device={device}"

# Allow asyncio to work with jupyter notebook
import nest_asyncio
nest_asyncio.apply()

# Increase numpy print width
np.set_printoptions(linewidth=300)

%load_ext line_profiler

✅ Imports successful
Using device: mps


In [None]:
DEBUG_MODE = True     # Set options to make debugger work properly. Single worker, etc.
LOAD_MODEL = False
TRAIN_MODEL = True
MODEL_SIZE = "small"  # "tiny" or "small" or"large" or "xl"
NUM_SIMULATIONS = 200

# If False, we still load previous games from disk.
NUM_GAMES = 2000
MAX_TRAINING_ITERS = 10_000
CONFIG_ALIAS = f'trajectory_sims-{NUM_SIMULATIONS}_games-{NUM_GAMES}_size-{MODEL_SIZE}_train-{MAX_TRAINING_ITERS}_x1'
NUM_GENERATIONS = 5

## Step 1: Set up history-wrapped game


In [3]:
from rgi.rgizero.games.history_wrapper import HistoryTrackingGame
from rgi.rgizero.data.trajectory_dataset import Vocab
from rgi.rgizero.common import TOKENS

base_game, max_game_length = Connect4Game(connect_length=4), 7*6

game = HistoryTrackingGame(base_game)
state_0 = game.initial_state()
block_size = max_game_length + 2
all_actions = game.all_actions()
action_vocab = Vocab(itos=[TOKENS.START_OF_GAME] + list(all_actions))
n_max_context = max_game_length + 2
game_name = base_game.__class__.__name__

print("✅ Using HistoryTrackingGame from module")
print(f"Game: {game_name}, Players: {game.num_players(state_0)}, Actions: {list(game.all_actions())}")

DATA_DIR = Path.cwd().parent / "data" / "rgizero-e2e" / game_name / CONFIG_ALIAS
print("Creating data dir: ", DATA_DIR)
os.makedirs(DATA_DIR, exist_ok=True)

MODEL_DIR = Path.cwd().parent / "models" / "rgizero-e2e" / game_name / CONFIG_ALIAS
print("Creating model dir: ", MODEL_DIR)
os.makedirs(MODEL_DIR, exist_ok=True)


✅ Using HistoryTrackingGame from module
Game: Connect4Game, Players: 2, Actions: [1, 2, 3, 4, 5, 6, 7]
Creating data dir:  /Users/rodo/src/rgi3/data/rgizero-e2e/Connect4Game/trajectory_sims-200_games-2000_size-small_train-10000_x1
Creating model dir:  /Users/rodo/src/rgi3/models/rgizero-e2e/Connect4Game/trajectory_sims-200_games-2000_size-small_train-10000_x1


## Step 2: Create random generation_0 model


In [4]:
def print_dataset_stats(dataset_path, split_name):
    """Print statistics about a loaded trajectory dataset."""
    td = TrajectoryDataset(dataset_path.parent, split_name, block_size=n_max_context)
    
    # Calculate basic stats
    num_trajectories = len(td)
    total_actions = td._num_actions
    avg_trajectory_length = total_actions / num_trajectories if num_trajectories > 0 else 0
    
    # Get trajectory lengths, winners, and first moves
    trajectory_lengths = []
    winners = []
    first_moves = []
    
    for i in range(num_trajectories):
        start_idx = td.boundaries[i]
        end_idx = td.boundaries[i + 1]
        traj_length = end_idx - start_idx
        trajectory_lengths.append(traj_length)
        
        # Get winner from final values (values are the same throughout trajectory)
        # Values are in range [-1, 1] where positive means player 1 advantage
        final_values = td.value_data[start_idx]  # shape: (num_players,)
        if final_values[0] > final_values[1]:
            winners.append(1)
        elif final_values[1] > final_values[0]:
            winners.append(2)
        else:
            winners.append(None)  # Draw
        
        # Get first move (decode from vocab)
        first_action_encoded = td.action_data[start_idx]
        first_action = action_vocab.decode([first_action_encoded])[0]
        first_moves.append(first_action)
    
    # Print basic stats
    print(f"Dataset Stats:")
    print(f"  Trajectories: {num_trajectories}")
    print(f"  Total actions: {total_actions}")
    print(f"  Avg trajectory length: {avg_trajectory_length:.2f}")
    print(f"  Trajectory length - min: {min(trajectory_lengths)}, max: {max(trajectory_lengths)}, mean: {np.mean(trajectory_lengths):.2f}")
    
    # Print winner stats (similar to print_game_stats)
    print(f"Winner Stats:")
    winner_stats = Counter(winners)
    total_games = num_trajectories
    win1_pct = 100 * winner_stats[1] / total_games if total_games > 0 else 0
    win2_pct = 100 * winner_stats[2] / total_games if total_games > 0 else 0
    print(f"  Winner counts: win[1]={win1_pct:.2f}% win[2]={win2_pct:.2f}%, n={total_games}")
    
    # Print stats by initial move
    print(f"Winner Stats by initial move:")
    move_stats = defaultdict(Counter)
    for first_move, winner in zip(first_moves, winners):
        move_stats[first_move][winner] += 1
    
    for action in sorted(move_stats.keys()):
        counts = move_stats[action]
        total = sum(counts.values())
        win1_pct = 100 * counts[1] / total if total > 0 else 0
        print(f"  a={action}: n={total:3} win[1]={win1_pct:.2f}% counts={counts}")

def print_model_stats(model, config_alias=""):
    """Print statistics about a model."""
    # Count parameters
    num_params = sum(p.numel() for p in model.parameters())
    num_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    
    print(f"Model Stats:")
    print(f"  Config: {model.config}")
    print(f"  Total parameters: {num_params:,}")
    print(f"  Trainable parameters: {num_trainable:,}")
    if config_alias:
        print(f"  Config alias: {config_alias}")


In [5]:
reload_local_modules(verbose=False)

from rgi.rgizero.models.action_history_transformer import ActionHistoryTransformer
from rgi.rgizero.models.transformer import TransformerConfig

model_config_dict = {
    "tiny": TransformerConfig(n_max_context=n_max_context, n_layer=2, n_head=2, n_embd=8),
    "small": TransformerConfig(n_max_context=n_max_context, n_layer=4, n_head=4, n_embd=32),
    "large": TransformerConfig(n_max_context=n_max_context, n_layer=8, n_head=8, n_embd=128),
    "xl": TransformerConfig(n_max_context=n_max_context, n_layer=16, n_head=16, n_embd=256),
}


def create_random_model(config: TransformerConfig, action_vocab_size, num_players,  seed: int):
    torch.manual_seed(seed)
    np.random.seed(seed) # Ensure numpy operations are also seeded
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
    model = ActionHistoryTransformer(config=config, action_vocab_size=action_vocab_size, num_players=num_players)
    model.to(device)
    return model

# Make model initialization deterministic
model_config = model_config_dict[MODEL_SIZE]
model_0 = create_random_model(model_config, action_vocab_size=action_vocab.vocab_size, num_players=game.num_players(state_0), seed=42)

# Step 3: Define play & generation code


In [6]:
from rgi.rgizero.models.action_history_transformer import AsyncNetworkEvaluator, ActionHistoryTransformerEvaluator
from rgi.rgizero.players.alphazero import play_game_async
from tqdm.asyncio import tqdm

async def play_games_async(num_games: int, player_factory: Callable[[], AlphazeroPlayer]):
    tasks = []
    async def create_player_and_create_game():
        t0 = time.time()
        player = player_factory()
        game_result = await play_game_async(game, [player, player])
        t1 = time.time()
        game_result['time'] = t1 - t0
        return game_result

    tasks = [create_player_and_create_game() for _ in range(num_games)]
    results = await tqdm.gather(*tasks)   # same as asyncio.gather, but with a progress bar
    return results

async def play_generation_async(model, num_games, simulations=NUM_SIMULATIONS):
    serial_evaluator = ActionHistoryTransformerEvaluator(model, device=device, block_size=block_size, vocab=action_vocab)
    async_evaluator = AsyncNetworkEvaluator(base_evaluator=serial_evaluator, max_batch_size=1024, verbose=False)

    master_rng = np.random.default_rng(42)
    async_player_factory = lambda: AlphazeroPlayer(game, async_evaluator, rng=np.random.default_rng(master_rng.integers(0, 2**31)), add_noise=False, simulations=simulations)

    await async_evaluator.start()
    results = await play_games_async(num_games=num_games, player_factory=async_player_factory)
    await async_evaluator.stop()
    return results

def print_game_stats(results):
    print("Winner Stats:")
    winner_stats = Counter(result['winner'] for result in results)
    print(f"Winner counts: win[1]={100*winner_stats[1]/sum(winner_stats.values()):.2f}% win[2]={100*winner_stats[2]/sum(winner_stats.values()):.2f}%, n={sum(winner_stats.values())}")
    game_lengths = [len(result['action_history']) for result in results]
    print(f"Game Length min: {min(game_lengths)}, max: {max(game_lengths)}, mean: {np.mean(game_lengths):.2f}")
    print("Winner Stats by initial move:")
    dd = defaultdict(Counter)
    for result in results:
        dd[result['action_history'][0]][result['winner']] += 1
    for action, counts in sorted(dd.items()):
        print(f"  a={action}: n={sum(counts.values()):3} win[1]={100*counts[1]/sum(counts.values()):.2f}% counts={counts}")

## Step 3: Confirm we can read & write to trajectory_dataset


In [7]:
from rgi.rgizero.data.trajectory_dataset import TrajectoryDatasetBuilder, TrajectoryDataset, build_trajectory_loader
reload_local_modules(verbose=False)

def add_trajectory(game_result, vocab, td_builder):
    action_history = game_result['action_history']
    trajectory_length = len(action_history)
    legal_policies = game_result['legal_policies']
    legal_action_idx = game_result['legal_action_idx']
    rewards = game_result['rewards']

    # Translation key for converting legal_action_ids to vocab_action_idx.
    action_idx_to_vocab_idx = vocab.encode(all_actions)

    fixed_width_policies = np.zeros((trajectory_length, vocab.vocab_size))
    for i in range(trajectory_length):
        vocab_action_idx = action_idx_to_vocab_idx[legal_action_idx[i]]
        fixed_width_policies[i, vocab_action_idx] = legal_policies[i]

    encoded_action_history = vocab.encode(action_history)
    tiled_rewards = np.tile(rewards, (trajectory_length, 1))  # shape (num_players,) -> (num_moves, num_players)
    
    td_builder.add_trajectory(actions=encoded_action_history, fixed_width_policies=fixed_width_policies, values=tiled_rewards)

def write_trajectory_dataset(results, action_vocab, generation_id):
    td_builder = TrajectoryDatasetBuilder(action_vocab)
    for game_result in results:
        add_trajectory(game_result, action_vocab, td_builder)

    trajectory_path = td_builder.save(DATA_DIR, f"gen-{generation_id}")
    return trajectory_path

## Train model


In [8]:
from rgi.rgizero.train import Trainer, TrainConfig

train_config = TrainConfig(
    model_name="connect4-e2e",
    model_version="v1",

    eval_interval = 250,  # keep frequent because we'll overfit
    eval_iters = 200,
    log_interval = 10_000,  # don't print too too often

    # we expect to overfit on this small dataset, so only save when val improves
    always_save_checkpoint = False,

    gradient_accumulation_steps = 1,
    batch_size = 64,

    learning_rate = 1e-3,  # with baby networks can afford to go a bit higher
    max_iters = MAX_TRAINING_ITERS,
    lr_decay_iters = 5000,  # make equal to max_iters usually
    min_lr = 1e-4,  # learning_rate / 10 usually
    beta2 = 0.99,  # make a bit bigger because number of tokens per iter is small

    warmup_iters = 100,  # not super necessary potentially
)

def train_model(model, training_splits, train_config):
    # Load dataset
    num_workers = 0 if DEBUG_MODE else 4

    trajectory_loader = build_trajectory_loader(
        DATA_DIR, training_splits, block_size=n_max_context, batch_size=1,
        device=device, workers=num_workers)
        
    trainer = Trainer(
        model=model,
        train_config=train_config,
        train_loader=trajectory_loader,
        val_loader=trajectory_loader,  # TODO: Create separate validation loader
        device=device
    )

    trainer.train()
    return model, trainer

In [9]:
import dataclasses

def get_model_path(generation_id):
    return MODEL_DIR / f"gen-{generation_id}.pt"

def save_model(model, trainer, generation_id):
    # Save model
    model_path = get_model_path(generation_id)

    checkpoint = {
        'model': model.state_dict(),
        'model_config': dataclasses.asdict(model.config),
        'vocab': action_vocab.to_dict(),
        'iter_num': trainer.iter_num,
        'best_val_loss': trainer.best_val_loss,
        'num_players': game.num_players(state_0),
    }
    torch.save(checkpoint, model_path)
    return model_path

def load_model(generation_id):
    model_path = get_model_path(generation_id)
    loaded_checkpoint = torch.load(model_path)
    loaded_model = ActionHistoryTransformer(
        config=TransformerConfig(**loaded_checkpoint['model_config']),
        action_vocab_size=loaded_checkpoint['vocab']['vocab_size'],
        num_players=loaded_checkpoint['num_players']
    )
    loaded_model.to(device)
    return loaded_model


In [10]:
# Do a single generation of play & train
async def run_generation(model, num_games, num_simulations, generation_id):
    print(f"\n\n## Running generation {generation_id} for config_alias={CONFIG_ALIAS}")
    split_name = f"gen-{generation_id}"
    expected_trajectory_path = DATA_DIR / split_name
    if not expected_trajectory_path.exists():
        print(f"Playing {num_games} games, simulations={num_simulations}, model_size={MODEL_SIZE}")
        results = await play_generation_async(model, num_games=NUM_GAMES, simulations=NUM_SIMULATIONS)
        print_game_stats(results)
        trajectory_path = write_trajectory_dataset(results, action_vocab, generation_id)
        assert trajectory_path == expected_trajectory_path
    else:
        print(f"Loading trajectory from {expected_trajectory_path}")
        print_dataset_stats(expected_trajectory_path, split_name)
        trajectory_path = expected_trajectory_path
        results = None

    model_path = get_model_path(generation_id)
    if not model_path.exists():
        print(f"Training model on {split_name}")
        training_splits = [f"gen-{i}" for i in range(1, generation_id+1)]
        updated_model, trainer = train_model(model, training_splits, train_config)
        save_model(updated_model, trainer, generation_id)
    else:
        print(f"Loading model from {model_path}")
        updated_model = load_model(generation_id)
        print_model_stats(updated_model, config_alias=MODEL_SIZE)

    return results, trajectory_path, updated_model

In [None]:
results_dict = {}
trajectory_paths_dict = {}
model_dict = {0: model_0}

# 10 hours!
for generation_id in range(1, 10):
    current_model = model_dict[generation_id-1]
    results_i, trajectory_path_i, model_i = await run_generation(current_model, num_games=NUM_GAMES, num_simulations=NUM_SIMULATIONS, generation_id=generation_id)
    results_dict[generation_id] = results_i
    trajectory_paths_dict[generation_id] = trajectory_path_i
    model_dict[generation_id] = model_i




## Running generation 1 for config_alias=trajectory_sims-200_games-2000_size-small_train-10000_x1
Playing 2000 games, simulations=200, model_size=small


100%|██████████| 2000/2000 [04:31<00:00,  7.36it/s] 


Winner Stats:
Winner counts: win[1]=74.10% win[2]=25.90%, n=2000
Game Length min: 7, max: 35, mean: 13.93
Winner Stats by initial move:
  a=1: n=254 win[1]=64.17% counts=Counter({1: 163, 2: 91})
  a=2: n=271 win[1]=73.80% counts=Counter({1: 200, 2: 71})
  a=3: n=239 win[1]=79.50% counts=Counter({1: 190, 2: 49})
  a=4: n=296 win[1]=85.14% counts=Counter({1: 252, 2: 44})
  a=5: n=332 win[1]=79.82% counts=Counter({1: 265, 2: 67})
  a=6: n=304 win[1]=75.00% counts=Counter({1: 228, 2: 76})
  a=7: n=304 win[1]=60.53% counts=Counter({1: 184, 2: 120})
Training model on gen-1
num decayed parameter tensors: 19, with 50,880 parameters
num non-decayed parameter tensors: 11, with 298 parameters
using fused AdamW: False
Training epoch 0 of 10
step 0: train loss 0.8803, val loss 0.8661
iter 0: loss 1.2963, time 752.56ms
step 250: train loss 0.7977, val loss 0.8460
saving checkpoint to /Users/rodo/src/rgi3/models/connect4-e2e/v1
step 500: train loss 0.8031, val loss 0.8179
saving checkpoint to /Users/

100%|██████████| 2000/2000 [05:08<00:00,  6.49it/s] 


Winner Stats:
Winner counts: win[1]=34.45% win[2]=65.35%, n=2000
Game Length min: 7, max: 42, mean: 16.69
Winner Stats by initial move:
  a=1: n=352 win[1]=27.27% counts=Counter({2: 255, 1: 96, None: 1})
  a=2: n=403 win[1]=36.48% counts=Counter({2: 253, 1: 147, None: 3})
  a=3: n=213 win[1]=32.39% counts=Counter({2: 144, 1: 69})
  a=4: n=117 win[1]=49.57% counts=Counter({2: 59, 1: 58})
  a=5: n=327 win[1]=40.06% counts=Counter({2: 196, 1: 131})
  a=6: n=335 win[1]=33.73% counts=Counter({2: 222, 1: 113})
  a=7: n=253 win[1]=29.64% counts=Counter({2: 178, 1: 75})
Training model on gen-2
num decayed parameter tensors: 19, with 50,880 parameters
num non-decayed parameter tensors: 11, with 298 parameters
using fused AdamW: False
Training epoch 0 of 10
step 0: train loss 0.9892, val loss 1.0391
iter 0: loss 1.1576, time 757.64ms
step 250: train loss 0.8800, val loss 0.8954
saving checkpoint to /Users/rodo/src/rgi3/models/connect4-e2e/v1
step 500: train loss 0.9098, val loss 0.9666
step 750:

100%|██████████| 2000/2000 [04:55<00:00,  6.77it/s]


Winner Stats:
Winner counts: win[1]=70.40% win[2]=28.85%, n=2000
Game Length min: 7, max: 42, mean: 17.32
Winner Stats by initial move:
  a=1: n=192 win[1]=60.94% counts=Counter({1: 117, 2: 74, None: 1})
  a=2: n=168 win[1]=66.07% counts=Counter({1: 111, 2: 55, None: 2})
  a=3: n=193 win[1]=77.20% counts=Counter({1: 149, 2: 43, None: 1})
  a=4: n= 55 win[1]=81.82% counts=Counter({1: 45, 2: 10})
  a=5: n=279 win[1]=77.06% counts=Counter({1: 215, 2: 58, None: 6})
  a=6: n=549 win[1]=71.58% counts=Counter({1: 393, 2: 154, None: 2})
  a=7: n=564 win[1]=67.02% counts=Counter({1: 378, 2: 183, None: 3})
Training model on gen-3
num decayed parameter tensors: 19, with 50,880 parameters
num non-decayed parameter tensors: 11, with 298 parameters
using fused AdamW: False
Training epoch 0 of 10
step 0: train loss 0.8756, val loss 0.8738
iter 0: loss 1.2976, time 757.87ms
step 250: train loss 0.9521, val loss 0.9340
step 500: train loss 0.8912, val loss 0.8923
step 750: train loss 0.8866, val loss 0

100%|██████████| 2000/2000 [04:42<00:00,  7.09it/s]


Winner Stats:
Winner counts: win[1]=81.60% win[2]=18.20%, n=2000
Game Length min: 7, max: 42, mean: 15.27
Winner Stats by initial move:
  a=1: n=127 win[1]=77.95% counts=Counter({1: 99, 2: 28})
  a=2: n=171 win[1]=75.44% counts=Counter({1: 129, 2: 42})
  a=3: n= 96 win[1]=77.08% counts=Counter({1: 74, 2: 21, None: 1})
  a=4: n=164 win[1]=92.07% counts=Counter({1: 151, 2: 13})
  a=5: n=289 win[1]=89.27% counts=Counter({1: 258, 2: 31})
  a=6: n=309 win[1]=78.32% counts=Counter({1: 242, 2: 66, None: 1})
  a=7: n=844 win[1]=80.45% counts=Counter({1: 679, 2: 163, None: 2})
Training model on gen-4
num decayed parameter tensors: 19, with 50,880 parameters
num non-decayed parameter tensors: 11, with 298 parameters
using fused AdamW: False
Training epoch 0 of 10
step 0: train loss 0.8452, val loss 0.8584
iter 0: loss 1.1197, time 795.09ms
step 250: train loss 0.8651, val loss 0.8691
step 500: train loss 0.9404, val loss 0.9036
step 750: train loss 0.8824, val loss 0.9169
step 1000: train loss 0

100%|██████████| 2000/2000 [05:29<00:00,  6.07it/s]


Winner Stats:
Winner counts: win[1]=79.20% win[2]=20.55%, n=2000
Game Length min: 7, max: 42, mean: 16.35
Winner Stats by initial move:
  a=1: n=211 win[1]=67.77% counts=Counter({1: 143, 2: 67, None: 1})
  a=2: n=281 win[1]=77.58% counts=Counter({1: 218, 2: 62, None: 1})
  a=3: n=136 win[1]=79.41% counts=Counter({1: 108, 2: 28})
  a=5: n=379 win[1]=84.96% counts=Counter({1: 322, 2: 57})
  a=6: n=385 win[1]=80.78% counts=Counter({1: 311, 2: 72, None: 2})
  a=7: n=608 win[1]=79.28% counts=Counter({1: 482, 2: 125, None: 1})
Training model on gen-5
num decayed parameter tensors: 19, with 50,880 parameters
num non-decayed parameter tensors: 11, with 298 parameters
using fused AdamW: False
Training epoch 0 of 10
step 0: train loss 0.8800, val loss 0.8751
iter 0: loss 0.7769, time 700.96ms
step 250: train loss 0.8825, val loss 0.8552
saving checkpoint to /Users/rodo/src/rgi3/models/connect4-e2e/v1
step 500: train loss 0.8574, val loss 0.8914
step 750: train loss 0.8883, val loss 0.8659
step 1

100%|██████████| 2000/2000 [6:08:34<00:00, 11.06s/it]    


Winner Stats:
Winner counts: win[1]=74.45% win[2]=25.55%, n=2000
Game Length min: 7, max: 40, mean: 16.72
Winner Stats by initial move:
  a=1: n=248 win[1]=68.15% counts=Counter({1: 169, 2: 79})
  a=2: n=277 win[1]=74.37% counts=Counter({1: 206, 2: 71})
  a=3: n=194 win[1]=77.32% counts=Counter({1: 150, 2: 44})
  a=4: n=137 win[1]=83.94% counts=Counter({1: 115, 2: 22})
  a=5: n=398 win[1]=79.90% counts=Counter({1: 318, 2: 80})
  a=6: n=320 win[1]=71.88% counts=Counter({1: 230, 2: 90})
  a=7: n=426 win[1]=70.66% counts=Counter({1: 301, 2: 125})
Training model on gen-6
num decayed parameter tensors: 19, with 50,880 parameters
num non-decayed parameter tensors: 11, with 298 parameters
using fused AdamW: False
Training epoch 0 of 10
step 0: train loss 0.8402, val loss 0.8619
iter 0: loss 0.5096, time 676.92ms
step 250: train loss 0.8390, val loss 0.9075
step 500: train loss 0.9353, val loss 0.8651
step 750: train loss 0.9134, val loss 0.8200
saving checkpoint to /Users/rodo/src/rgi3/models

100%|██████████| 2000/2000 [05:37<00:00,  5.92it/s] 


Winner Stats:
Winner counts: win[1]=71.80% win[2]=28.10%, n=2000
Game Length min: 7, max: 42, mean: 16.48
Winner Stats by initial move:
  a=1: n=273 win[1]=65.57% counts=Counter({1: 179, 2: 94})
  a=2: n=262 win[1]=70.23% counts=Counter({1: 184, 2: 78})
  a=3: n=151 win[1]=74.83% counts=Counter({1: 113, 2: 38})
  a=4: n=194 win[1]=85.57% counts=Counter({1: 166, 2: 28})
  a=5: n=424 win[1]=72.64% counts=Counter({1: 308, 2: 116})
  a=6: n=315 win[1]=73.33% counts=Counter({1: 231, 2: 83, None: 1})
  a=7: n=381 win[1]=66.93% counts=Counter({1: 255, 2: 125, None: 1})
Training model on gen-7
num decayed parameter tensors: 19, with 50,880 parameters
num non-decayed parameter tensors: 11, with 298 parameters
using fused AdamW: False
Training epoch 0 of 10
step 0: train loss 0.9137, val loss 0.9247
iter 0: loss 0.4834, time 748.08ms
step 250: train loss 0.9224, val loss 0.8732
saving checkpoint to /Users/rodo/src/rgi3/models/connect4-e2e/v1
step 500: train loss 0.8249, val loss 0.7689
saving ch

100%|██████████| 2000/2000 [35:55<00:00,  1.08s/it]   


Winner Stats:
Winner counts: win[1]=69.65% win[2]=30.15%, n=2000
Game Length min: 7, max: 42, mean: 17.11
Winner Stats by initial move:
  a=1: n=322 win[1]=62.11% counts=Counter({1: 200, 2: 120, None: 2})
  a=2: n=306 win[1]=71.90% counts=Counter({1: 220, 2: 85, None: 1})
  a=3: n=219 win[1]=68.49% counts=Counter({1: 150, 2: 69})
  a=4: n=160 win[1]=85.62% counts=Counter({1: 137, 2: 23})
  a=5: n=451 win[1]=71.18% counts=Counter({1: 321, 2: 129, None: 1})
  a=6: n=327 win[1]=69.42% counts=Counter({1: 227, 2: 100})
  a=7: n=215 win[1]=64.19% counts=Counter({1: 138, 2: 77})
Training model on gen-8
num decayed parameter tensors: 19, with 50,880 parameters
num non-decayed parameter tensors: 11, with 298 parameters
using fused AdamW: False
Training epoch 0 of 10
step 0: train loss 0.9348, val loss 0.8634
iter 0: loss 0.3115, time 649.77ms
step 250: train loss 0.9272, val loss 0.8559
saving checkpoint to /Users/rodo/src/rgi3/models/connect4-e2e/v1
step 500: train loss 0.9016, val loss 0.8975

100%|██████████| 2000/2000 [09:21<00:00,  3.56it/s]  


Winner Stats:
Winner counts: win[1]=53.30% win[2]=46.60%, n=2000
Game Length min: 7, max: 42, mean: 16.93
Winner Stats by initial move:
  a=1: n=341 win[1]=48.09% counts=Counter({2: 176, 1: 164, None: 1})
  a=2: n=334 win[1]=50.90% counts=Counter({1: 170, 2: 163, None: 1})
  a=3: n=268 win[1]=55.60% counts=Counter({1: 149, 2: 119})
  a=4: n= 92 win[1]=92.39% counts=Counter({1: 85, 2: 7})
  a=5: n=447 win[1]=54.81% counts=Counter({1: 245, 2: 202})
  a=6: n=238 win[1]=55.04% counts=Counter({1: 131, 2: 107})
  a=7: n=280 win[1]=43.57% counts=Counter({2: 158, 1: 122})
Training model on gen-9
num decayed parameter tensors: 19, with 50,880 parameters
num non-decayed parameter tensors: 11, with 298 parameters
using fused AdamW: False
Training epoch 0 of 10
step 0: train loss 0.8265, val loss 0.9363
iter 0: loss 0.4596, time 912.02ms
step 250: train loss 0.9797, val loss 0.9409
step 500: train loss 0.9528, val loss 0.9221
saving checkpoint to /Users/rodo/src/rgi3/models/connect4-e2e/v1
step 75

# Debug


In [None]:
# Play single game
result = await play_generation_async(current_model, num_games=1, simulations=NUM_SIMULATIONS)

In [68]:
# Inspect training data
td_array = [TrajectoryDataset(DATA_DIR, f"gen-{generation_id}", block_size=n_max_context) for generation_id in range(1, 10)]

In [69]:
# [td for td in td_array]
unrolled = [(generation+1, d) for generation, td in enumerate(td_array) for d in td]

# gen, d = unrolled[0], 
# d.action[:2]
# d.value[0]

dd = defaultdict(lambda: defaultdict(lambda: [0, 0]))

for gen, d in unrolled:
    for g in [0, gen]:
        dd[tuple(tuple(d.action[:0].tolist()))][g][int(d.value[0][0])] += 1
        dd[tuple(tuple(d.action[:1].tolist()))][g][int(d.value[0][0])] += 1
        dd[tuple(tuple(d.action[:2].tolist()))][g][int(d.value[0][0])] += 1
        dd[tuple(tuple(d.action[:3].tolist()))][g][int(d.value[0][0])] += 1
        dd[tuple(tuple(d.action[:4].tolist()))][g][int(d.value[0][0])] += 1

print(f"len(dd) = {len(dd)}")


len(dd) = 2644


In [None]:
## Someting is borked? Player1 win percent should be much higher??
prefix = (3,)
for gen, counts in dd[prefix].items():
    print(f"gen={gen}: {counts}, win_pct={100*counts[0]/sum(counts):.2f}%, sum={sum(counts)}")

gen=0: [557, 1152], win_pct=32.59%, sum=1709
gen=1: [49, 190], win_pct=20.50%, sum=239
gen=2: [144, 69], win_pct=67.61%, sum=213
gen=3: [44, 149], win_pct=22.80%, sum=193
gen=4: [22, 74], win_pct=22.92%, sum=96
gen=5: [28, 108], win_pct=20.59%, sum=136
gen=6: [44, 150], win_pct=22.68%, sum=194
gen=7: [38, 113], win_pct=25.17%, sum=151
gen=8: [69, 150], win_pct=31.51%, sum=219
gen=9: [119, 149], win_pct=44.40%, sum=268
