# Step-by-step run of alphazero self-play & training.


In [1]:
import os
import time
from pathlib import Path
import asyncio

import numpy as np
import torch

# Game and players
from rgi.rgizero.experiment import ExperimentRunner, ExperimentConfig
from rgi.rgizero.data.trajectory_dataset import Vocab, print_dataset_stats, TrajectoryDataset
from rgi.rgizero.evaluators import ActionHistoryTransformerEvaluator, AsyncNetworkEvaluator
from rgi.rgizero.models.tuner import create_random_model

import notebook_utils
from notebook_utils import reload_local_modules

device = notebook_utils.detect_device()

## Disable for debugger stability?
# # Allow asyncio to work with jupyter notebook
# import nest_asyncio
# nest_asyncio.apply()

# Increase numpy print width
np.set_printoptions(linewidth=300)

%load_ext line_profiler

transform_config_fields: {'dropout', 'n_layer', 'bias', 'n_max_context', 'n_embd', 'n_head'}
train_config_fields: {'weight_decay', 'grad_clip', 'wandb_log', 'device', 'eval_iters', 'compile', 'eval_only', 'dtype', 'batch_size', 'log_interval', 'learning_rate', 'eval_interval', 'model_name', 'beta1', 'gradient_accumulation_steps', 'lr_decay_iters', 'warmup_iters', 'model_version', 'max_epochs', 'min_lr', 'decay_lr', 'max_iters', 'beta2', 'always_save_checkpoint'}
Detected device: mps


In [2]:
RUN_GENERATIONS = True


# Create Experiment Config
experiment_config = ExperimentConfig(
    experiment_name='smoketest-e2e-v2',
    parent_experiment_name='smoketest-e2e',
    game_name='connect4',
    num_generations=5,
    num_games_per_gen=10_000,
    num_simulations=50,
    model_size="tiny",
    train_batch_size=10,
    max_training_epochs=2,
    seed=42
)


## Step 1: Set up game and experiment runner


In [3]:
from rgi.rgizero.data.trajectory_dataset import Vocab
from rgi.rgizero.common import TOKENS

# Initialize Experiment Runner
experiment_base_dir = Path.cwd().parent / 'experiments'
experiment_runner = ExperimentRunner(experiment_config, experiment_base_dir)
game = experiment_runner.game
action_vocab = experiment_runner.action_vocab
n_max_context = experiment_runner.n_max_context

DATA_DIR = experiment_runner.data_dir
MODEL_DIR = experiment_runner.models_dir

print('✅ Runner initialized')
print(f'Game: {experiment_runner.config.game_name}, Players: {experiment_runner.num_players}, Actions: {list(game.base_game.all_actions())}')
print('Data dir: ', DATA_DIR)
print('Model dir: ', MODEL_DIR)


✅ Runner initialized
Game: connect4, Players: 2, Actions: [1, 2, 3, 4, 5, 6, 7]
Data dir:  /Users/rodo/src/rgi3-sync/experiments/smoketest-e2e-v2/data
Model dir:  /Users/rodo/src/rgi3-sync/experiments/smoketest-e2e-v2/models


## Step 2: Create random generation_0 model


In [4]:
# Initialize (creates Random Gen 0 if needed)
model_0 = experiment_runner.initialize()
current_model = model_0


Starting Experiment: smoketest-e2e-v2
Loading existing Gen 0 model.


In [None]:
results_dict = {}
trajectory_paths_dict = {}
model_dict = {0: model_0}

current_model = model_dict[0]
if RUN_GENERATIONS:
    for generation_id in range(1, experiment_config.num_generations+1):
        current_model = await experiment_runner.run_generation_step_async(generation_id, current_model)
        dataset_path = experiment_runner.get_trajectory_path(generation_id)
        
        # print stats for visibility
        print_dataset_stats(dataset_path, f'gen-{generation_id}', n_max_context, action_vocab, model=current_model, game=game)
        
        model_dict[generation_id] = current_model

# 10m to play 2x10k generations... probabilities still very wrong.
# Evaluation time: 0.015 seconds, size=574, eval-per-second=37837.60, total-batches=6000, mean-eval-per-second=94963.99, mean-time-per-batch=0.010, mean-batch-size=990.34

# >>> log(2) + log(7) -> 2.6390573296152584
## Model doesn't seem to improve loss at all?
# step.   0: losses: train:2.5971, train_policy_loss:1.9146, train_value_loss:0.6825, val:2.5972, val_policy_loss:1.9147, val_value_loss:0.6825
# step 1000: losses: train:2.6036, train_policy_loss:1.9122, train_value_loss:0.6914, val:2.6050, val_policy_loss:1.9132, val_value_loss:0.6917
# step 2000: losses: train:2.6056, train_policy_loss:1.9119, train_value_loss:0.6937, val:2.6056, val_policy_loss:1.9123, val_value_loss:0.6933
# iter    0/1170/5000: loss 2.5699, policy_loss:1.9129, value_loss:0.6570, time 5.18s, iter_time: 0.00ms
# iter 1000/1170/5000: loss 2.5996, policy_loss:1.9068, value_loss:0.6928, time 1.96s, iter_time: 1957.74ms
# iter 2339/2340/5000: loss 2.6014, policy_loss:1.9131, value_loss:0.6884, time 0.01s, iter_time: 14.61ms




=== Generation 1 ===
Using forked data for gen 1 from /Users/rodo/src/rgi3-sync/experiments/smoketest-e2e/data/gen-1
Dataset for gen 1 exists at /Users/rodo/src/rgi3-sync/experiments/smoketest-e2e/data/gen-1. Skipping play.
Using forked model for gen 1 from /Users/rodo/src/rgi3-sync/experiments/smoketest-e2e/models/gen-1.pt
Model for gen 1 exists at /Users/rodo/src/rgi3-sync/experiments/smoketest-e2e/models/gen-1.pt. Loading.
Using forked model for gen 1 from /Users/rodo/src/rgi3-sync/experiments/smoketest-e2e/models/gen-1.pt
Using forked data for gen 1 from /Users/rodo/src/rgi3-sync/experiments/smoketest-e2e/data/gen-1
Dataset Stats:
  Trajectories: 1000
  Total actions: 14552
  Avg trajectory length: 14.55
Prefix Stats:
actions=(): 1000 win=618 loss=382 draw=0 win1%=61.80 model-win1%=0.46
actions=(1,): 157 win=80 loss=77 draw=0 win1%=50.96 model-win1%=0.60
actions=(2,): 124 win=79 loss=45 draw=0 win1%=63.71 model-win1%=0.59
actions=(3,): 113 win=71 loss=42 draw=0 win1%=62.83 model-w

# Tune Model (initial)


In [None]:
reload_local_modules(verbose=False)

state_0 = game.initial_state()
NUM_GENERATIONS = 5
LEARNING_RATE = 0.1

# Parameters which will never be used for tuning.
fixed_params = dict(
    model_name='c4-smoketest',
    model_version='0.1',
    num_players = game.num_players(state_0),
    vocab_size = action_vocab.vocab_size,
    num_generations = NUM_GENERATIONS,
    data_dir = DATA_DIR,

    eval_iters = 200,
    log_interval = 1000,
    eval_interval = 10_000,

    device = device,
)

initial_params = dict(
    n_layer=2,
    n_head=2,
    n_embd=8,  # tiny model

    n_max_context=n_max_context,
    batch_size = 32,
    gradient_accumulation_steps = 1,

    max_iters=100,
    max_epochs=1_000_000, # Make max_epoch high, rely on max_iters to stop.
        
    learning_rate = LEARNING_RATE,    
    decay_lr = True,  # whether to decay the learning rate
    lr_decay_iters = 100,  # make equal to max_iters usually
    min_lr = LEARNING_RATE / 10,  # learning_rate / 10 usually
    warmup_iters = 0,  # not super necessary potentially

    weight_decay = 1e-1,
    beta1 = 0.9,
    beta2 = 0.95,
    grad_clip = 1.0,  # clip gradients at this value, or disable if == 0.0

    dtype = "float16",

    dropout = 0.0,
    bias = False,  # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster
)

tune_options = dict(
    n_layer = [1, 2, 3, 4, 5, 6, 8, 10, 12, 16, 32],
    # n_head = [1, 2, 4, 8, 16, 32],   # Needs to be calcualted to ensure n_embed % n_head == 0
    n_embd = [8, 16, 32, 64, 128, 256, 512, 1024, 2048],

    n_max_context = [initial_params['n_max_context']],
    batch_size = [16, 32, 64, 128, 256, 512, 1024],
    gradient_accumulation_steps = [1],  # TODO: We only support 1 for now. This fails is we don't have an exact multiple of the batch size per epoch.

    max_iters = [100, 300, 1_000, 3_000, 5_000, 10_000, 30_000, 100_000, 300_000],
    max_epochs = [1_000_000], # Make max_epoch high, rely on max_iters to stop.
 
    learning_rate = [0.001, 0.002, 0.005, 0.01, 0.02, 0.05, 0.1, 0.2, 0.5, 1.0],
    decay_lr = [False, True],

    # TODO: What is a sensible range here?
    beta1 = [0.90, 0.95, 0.99],
    beta2 = [0.95, 0.98, 0.99],

    weight_decay = [0.01, 0.05, 0.1, 0.2],
    grad_clip = [0,0, 1.0],  # clip gradients at this value, or disable if == 0.0

    dtype = ["bfloat16", "float16"],
    dropout = [0.0, 0.01, 0.02, 0.05, 0.1],
    bias = [True, False],    
)

_n_head_options = [1, 2, 4, 8, 16, 32]
computed_tune_options = dict(
    min_lr = lambda opt: [opt['learning_rate'] / 10],
    lr_decay_iters = lambda opt: [opt['max_iters']],
    warmup_iters = lambda opt: [x for x in [0, 100, 500, 1000] if x < opt['lr_decay_iters']] if opt['decay_lr'] else [0],
    n_head = lambda opt: [n for n in _n_head_options if opt['n_embd'] % n == 0],
)

TUNER_VERSION = "0.0.4-smoketest"

from rgi.rgizero.models.tuner import Tuner

tuner = Tuner(
    fixed_params=fixed_params.copy(),
    initial_params=initial_params.copy(),
    tune_options=tune_options.copy(), 
    computed_tune_options=computed_tune_options.copy(),
    cache_version=TUNER_VERSION,
    target_improvement_per_minute=1.00)
tuner.autotune_smart()


In [None]:
from rgi.rgizero.models.tuner import clear_failures_from_cache_file
clear_failures_from_cache_file('result_cache-v0.0.2.json')

In [None]:
tuner = Tuner(
    fixed_params=fixed_params.copy(),
    initial_params=initial_params.copy(),
    tune_options=tune_options.copy(), 
    computed_tune_options=computed_tune_options.copy(),
    cache_version=TUNER_VERSION,
    target_improvement_per_minute=0.1)
tuner.autotune_smart()


# Sanity check models


In [None]:
raise NotImplementedError("Skip this...")

In [None]:
# Inspect training data
td_array = [TrajectoryDataset(DATA_DIR, f"gen-{generation_id}", block_size=n_max_context) for generation_id in range(1, NUM_GENERATIONS+1)]

In [None]:
# [td for td in td_array]
unrolled = [(generation+1, d) for generation, td in enumerate(td_array) for d in td]

# gen, d = unrolled[0], 
# d.action[:2]
# d.value[0]

dd = defaultdict(lambda: defaultdict(lambda: torch.tensor([0., 0.])))

for gen, d in unrolled:
    for g in ['*', gen]:    
        # dd[tuple(tuple(d.action[:0].tolist()))][g] += d.value[0]
        dd[tuple(tuple(d.action[:1].tolist()))][g] += d.value[0]
        dd[tuple(tuple(d.action[:2].tolist()))][g] += d.value[0]
        dd[tuple(tuple(d.action[:3].tolist()))][g] += d.value[0]
        dd[tuple(tuple(d.action[:4].tolist()))][g] += d.value[0]
        dd[tuple(tuple(d.action[:5].tolist()))][g] += d.value[0]
        dd[tuple(tuple(d.action[:6].tolist()))][g] += d.value[0]
        dd[tuple(tuple(d.action[:7].tolist()))][g] += d.value[0]
        dd[tuple(tuple(d.action[:8].tolist()))][g] += d.value[0]
        dd[tuple(tuple(d.action[:9].tolist()))][g] += d.value[0]
        dd[tuple(tuple(d.action[:10].tolist()))][g] += d.value[0]

print(f"len(dd) = {len(dd)}")


In [None]:
def eval_prefix(model, game, prefix):
    serial_evaluator = ActionHistoryTransformerEvaluator(model, device=device, block_size=n_max_context, vocab=action_vocab)
    state = game.initial_state()
    for action in prefix:
        state = game.next_state(state, action)
    legal_actions = game.legal_actions(state)
    result = serial_evaluator.evaluate(game, state, legal_actions)
    return result


In [None]:
## Someting is borked? Player1 win percent should be much higher??
def compare_model_vs_data(model, game, dd):    
    list(dd.items())[10][1]['*'].sum() > 100
    top_k = sorted(dd.items(), key=lambda kv: kv[1]['*'].sum(), reverse=True)[:20]
    top_k_keys = sorted(k for k, v in top_k)
    
    prefix_list = top_k_keys

    # prefix_list = [
    #     (0,), 
    #     (0,1), (0,2), (0,3), (0,4), (0,5), (0,6), (0,7),
    #     (0,1,1), (0,1,2), (0,1,3), (0,1,4), (0,1,5), (0,1,6), (0,1,7),
    #     (0,4,1), (0,4,2), (0,4,3), (0,4,4), (0,4,5), (0,4,6), (0,4,7),
    # ]

    for prefix in prefix_list:
        print(f"\nprefix={prefix}")
        for gen, counts in dd[prefix].items():
            if gen == '*':
                print(f"gen={gen}: {counts}, win_pct={100*counts[0]/sum(counts):.2f}%, sum={sum(counts)}")
        # # assert prefix[0] == 0
        actions = prefix[1:]
        eval_result = eval_prefix(model, game, actions)
        # print(f'legal_policy={eval_result.legal_policy}')
        # print(f'player_values={eval_result.player_values}')
        print(f'player_probs={(eval_result.player_values+1)/2}')

compare_model_vs_data(current_model, game, dd)


In [None]:
# Copy model
model_0 = create_random_model(model_config, action_vocab_size=action_vocab.vocab_size, num_players=game.num_players(state_0), seed=42, device=device)
if RUN_GENERATIONS:
    model_1 = load_model(1)


In [None]:
print("\n\n### Model 0")
print(model_0.action_embedding.weight)
compare_model_vs_data(model_0, game, dd)

In [None]:
if RUN_GENERATIONS:
    print("\n\n### Model 1")
    print(model_1.action_embedding.weight)
    compare_model_vs_data(model_1, game, dd)

## Run tournament to calcualte ELO


In [None]:
import asyncio
import numpy as np
from contextlib import asynccontextmanager
from rgi.rgizero.tournament import Tournament
from rgi.rgizero.players.alphazero import AlphazeroPlayer
from rgi.rgizero.models.action_history_transformer import ActionHistoryTransformerEvaluator, AsyncNetworkEvaluator

@asynccontextmanager
async def create_player_factory(model, simulations, game, device, block_size, action_vocab, max_batch_size):
    """
    Creates a shared evaluator and returns a factory function that produces 
    new AlphazeroPlayer instances using that shared evaluator.
    """
    # 1. Setup the shared evaluator
    serial_evaluator = ActionHistoryTransformerEvaluator(
        model, 
        device=device, 
        block_size=n_max_context, 
        vocab=action_vocab
    )
    async_evaluator = AsyncNetworkEvaluator(
        base_evaluator=serial_evaluator, 
        max_batch_size=max_batch_size, 
        verbose=False
    )
    
    # 2. Start the evaluator background task
    await async_evaluator.start()
    
    try:
        # 3. Define the factory. This is called by Tournament for every game.
        # It creates a NEW player instance but uses the SHARED async_evaluator.
        def player_factory():
            # Create a fresh RNG for each game/player instance
            rng = np.random.default_rng(np.random.randint(0, 2**31))
            return AlphazeroPlayer(
                game, 
                async_evaluator, 
                rng=rng, 
                add_noise=True, 
                simulations=simulations
            )
            
        yield player_factory
        
    finally:
        # 4. Cleanup
        await async_evaluator.stop()

async def run_tournament_async():
    # Use async with to manage the lifecycle of the evaluators
    async with (
        # create_player_factory(model_dict[0], 100, game, device, block_size, action_vocab, 10) as factory_gen0_100,
        # create_player_factory(model_dict[1], 100, game, device, block_size, action_vocab, 10) as factory_gen1_100,
        # create_player_factory(model_dict[2], 100, game, device, block_size, action_vocab, 10) as factory_gen2_100,
        # create_player_factory(model_dict[3], 100, game, device, block_size, action_vocab, 10) as factory_gen3_100,
        # create_player_factory(model_dict[4], 100, game, device, block_size, action_vocab, 10) as factory_gen4_100,
        # create_player_factory(model_dict[5], 100, game, device, block_size, action_vocab, 10) as factory_gen5_100,
        # create_player_factory(model_dict[10], 100, game, device, block_size, action_vocab, 10) as factory_gen6_100,
        # create_player_factory(model_dict[15], 100, game, device, block_size, action_vocab, 10) as factory_gen7_100,
        # create_player_factory(model_dict[20], 100, game, device, block_size, action_vocab, 10) as factory_gen8_100,

        create_player_factory(model_dict[0], 200, game, device, block_size, action_vocab, 10) as factory_gen0_200,
        #create_player_factory(model_dict[1], 200, game, device, block_size, action_vocab, 10) as factory_gen1_200,
        #create_player_factory(model_dict[2], 200, game, device, block_size, action_vocab, 10) as factory_gen2_200,
        #create_player_factory(model_dict[3], 200, game, device, block_size, action_vocab, 10) as factory_gen3_200,
        #create_player_factory(model_dict[4], 200, game, device, block_size, action_vocab, 10) as factory_gen4_200,
        create_player_factory(model_dict[5], 200, game, device, block_size, action_vocab, 10) as factory_gen5_200,
        #create_player_factory(model_dict[10], 200, game, device, block_size, action_vocab, 10) as factory_gen10_200,
        #create_player_factory(model_dict[15], 200, game, device, block_size, action_vocab, 10) as factory_gen15_200,
        create_player_factory(model_dict[20], 200, game, device, block_size, action_vocab, 10) as factory_gen20_200,
        ):
        
        # The dictionary now maps names to FACTORIES (Callables), not Player instances
        player_factories = {
            # "factory_gen0_100": factory_gen0_100,
            # "factory_gen1_100": factory_gen1_100,
            # "factory_gen2_100": factory_gen2_100,
            # "factory_gen3_100": factory_gen3_100,
            # "factory_gen4_100": factory_gen4_100,
            # "factory_gen5_100": factory_gen5_100,
            # "factory_gen6_100": factory_gen6_100,
            # "factory_gen7_100": factory_gen7_100,

            "factory_gen0_200": factory_gen0_200,
            #"factory_gen1_200": factory_gen1_200,
            #"factory_gen2_200": factory_gen2_200,
            #"factory_gen3_200": factory_gen3_200,
            #"factory_gen4_200": factory_gen4_200,
            "factory_gen5_200": factory_gen5_200,
            #"factory_gen10_200": factory_gen10_200,
            #"factory_gen15_200": factory_gen15_200,
            "factory_gen20_200": factory_gen20_200,
        }
        
        tournament = Tournament(game, player_factories, initial_elo=1000)
        
        print("Running tournament...")
        # await tournament.run(num_games=1_000, concurrent_games=2000)
        await tournament.run(num_games=100, concurrent_games=2000)
        tournament.print_standings()

# RUN_TOURNAMENT = True
if RUN_TOURNAMENT:
    await run_tournament_async()

# Running tournament...
# Tournament Progress: 100%|██████████| 10000/10000 [1:25:59<00:00,  1.94it/s]

# Tournament Standings:
# Rank  Player               ELO        Games    W-L-D          
# -----------------------------------------------------------------
# 1     factory_gen6_200     1140.5     1247     827-419-1      
# 2     factory_gen2_200     1100.1     1251     693-554-4      
# 3     factory_gen5_100     1074.4     1251     598-652-1      
# 4     factory_gen3_200     1029.1     1252     674-573-5      
# 5     factory_gen4_200     1027.0     1248     711-536-1      
# 6     factory_gen0_200     1020.0     1254     444-810-0      
# 7     factory_gen5_200     990.2      1248     742-502-4      
# 8     factory_gen7_100     987.5      1250     650-597-3      
# 9     factory_gen7_200     979.2      1248     768-476-4      
# 10    factory_gen2_100     974.0      1249     522-723-4      
# 11    factory_gen6_100     966.6      1248     684-564-0      
# 12    factory_gen4_100     964.2      1251     557-693-1      
# 13    factory_gen1_100     962.5      1252     547-705-0      
# 14    factory_gen3_100     947.0      1251     528-723-0      
# 15    factory_gen1_200     941.1      1252     630-620-2      
# 16    factory_gen0_100     896.5      1248     410-838-0     


## 20 generations.
# Running tournament...
# Tournament Progress: 100%|██████████| 1000/1000 [08:35<00:00,  1.94it/s]

# Tournament Standings:
# Rank  Player               ELO        Games    W-L-D          
# -----------------------------------------------------------------
# 1     factory_gen10_200    1114.2     333      212-120-1      
# 2     factory_gen2_200     1032.6     333      190-141-2      
# 3     factory_gen1_200     1003.9     334      159-175-0      
# 4     factory_gen20_200    1000.9     335      171-164-0      
# 5     factory_gen5_200     974.6      331      183-146-2      
# 6     factory_gen0_200     873.8      334      82-251-1  

# Tune Model (continued)


In [None]:
tuner = Tuner(
    fixed_params=fixed_params.copy(),
    initial_params=initial_params.copy(),
    tune_options=tune_options.copy(), 
    computed_tune_options=computed_tune_options.copy(),
    cache_version=TUNER_VERSION,
    target_improvement_per_minute=0.01)
tuner.autotune_smart()

# Using initial model as baseline.
# ## Initial Model, loss=2.1298508644104004 elapsed=171.78943705558777s
# ## Searching generation 0 with 22 candidates, including ['bias: False -> True', 'learning_rate: 0.005 -> 0.002', 'learning_rate: 0.005 -> 0.002', 'dtype: bfloat16 -> float16', 'weight_decay: 0.1 -> 0.2']
# ## improved: False, loss=2.1332 elapsed=178.64s, mutation bias: False -> True
# ## improved: False, loss=2.1395 elapsed=172.03s, mutation learning_rate: 0.005 -> 0.002
# ## improved: False, loss=2.1395 elapsed=172.03s, mutation learning_rate: 0.005 -> 0.002


In [None]:
reload_local_modules(verbose=False)

tuner = Tuner(
    fixed_params=fixed_params.copy(),
    initial_params=initial_params.copy(),
    tune_options=tune_options.copy(), 
    computed_tune_options=computed_tune_options.copy(),
    cache_version=TUNER_VERSION,
    target_improvement_per_minute=0.001)
tuner.autotune_smart()
from rgi.rgizero.models.action_history_transformer import ActionHistoryTransformer, ActionHistoryTransformerEvaluator
from rgi.rgizero.models.transformer import TransformerConfig

tiny_config: TransformerConfig = TransformerConfig(n_max_context=100, n_layer=2, n_head=2, n_embd=8)
tiny_model = ActionHistoryTransformer(config=tiny_config, action_vocab_size=action_vocab.vocab_size, num_players=game.num_players(state_0))
tiny_model.to(device)
tiny_evaluator = ActionHistoryTransformerEvaluator(tiny_model, device=device, block_size=5, vocab=action_vocab)


In [None]:
tuner = Tuner(
    fixed_params=fixed_params.copy(),
    initial_params=initial_params.copy(),
    tune_options=tune_options.copy(), 
    computed_tune_options=computed_tune_options.copy(),
    cache_version=TUNER_VERSION,
    target_improvement_per_minute=0.0001)
tuner.autotune_smart()


In [None]:
reload_local_modules(verbose=False)

tuner = Tuner(
    fixed_params=fixed_params.copy(),
    tune_options=tune_options.copy(), 
    initial_params=initial_params.copy(),
    computed_tune_options=computed_tune_options.copy(),
    cache_version=TUNER_VERSION,
    target_improvement_per_minute=0.0)
tuner.autotune_smart()

## Debug best model

In [None]:
reload_local_modules(verbose=False)

best_model = tuner.load_best_model()
compare_model_vs_data(best_model, game, dd)


In [None]:
from pprint import pprint
print(f'tuner.best_loss={tuner.best_loss}')
print(f'tuner.best_loss_elapsed={int(tuner.best_loss_elapsed)//60}m{tuner.best_loss_elapsed%60:.0f}s')
pprint(tuner.best_params)
# best_params = tuner.initial_params


## Print tuner stats


In [None]:
reload_local_modules(verbose=False)
tuner = Tuner(
    fixed_params=fixed_params.copy(),
    tune_options=tune_options.copy(), 
    initial_params=initial_params.copy(),
    computed_tune_options=computed_tune_options.copy(),
    cache_version=TUNER_VERSION,
    target_improvement_per_minute=0.001)
# print stats based on cached results.
tuner_stats = tuner.print_hparam_stats()

In [None]:
tuner_stats

In [None]:
# [(k,v['mean_val_delta']) for (k,v) in sorted(tuner_stats.items(), key=lambda x: x[1]['mean_val_delta'], reverse=True)]

for x in  sorted([(v['mean_val_delta'], k, v['mean_val_1'], v['mean_val_2']) for (k,v) in tuner_stats.items() if not np.isnan(v['mean_val_delta'])], reverse=True): print(x)
# sorted([(v['mean_val_delta'], k) for (k,v) in tuner_stats.items() if not np.isnan(v['mean_val_delta'])], reverse=True)


In [None]:
for x in sorted([(v['std_val_delta'], k, v['mean_val_1'], v['mean_val_2']) for (k,v) in tuner_stats.items() if not np.isnan(v['std_val_delta'])], reverse=True): print(x)


## Debug Convergence

Synthetic sanity-check: train on a toy 2-step game where the first action strongly determines the winner. This verifies the value head and training loop can learn simple patterns.


In [None]:
raise NotImplementedError("xxx STOP HERE xxx")

In [None]:
state_0 = game.initial_state()
all_actions_0 = game.all_actions()

print(all_actions_0)


In [None]:
import random

def play_random_game_with_fake_reward(game, max_actions) -> dict:
    state = game.initial_state()
    action_history = []
    legal_policies = []
    legal_action_idx_list = []

    all_actions = game.all_actions()
    all_action_idx_map = {action: idx for idx, action in enumerate(all_actions)}

    num_actions = 0
    while not game.is_terminal(state) and num_actions < max_actions:
        current_player = game.current_player_id(state)
        legal_actions = game.legal_actions(state)
        action_idx = random.randrange(len(legal_actions))
        action = legal_actions[action_idx]

        action_history.append(action)
        legal_policies.append(np.ones(len(legal_actions))/len(legal_actions))
        legal_action_idx = np.array([all_action_idx_map[action] for action in legal_actions])
        legal_action_idx_list.append(legal_action_idx)

        state = game.next_state(state, action)
        num_actions += 1

    # Determine outcome
    fake_reward = np.mean(action_history) / len(legal_actions)
    rewards = np.array([fake_reward, 1.0-fake_reward])
    if fake_reward >= 0.5:
        winner = 1
    else:
        winner = 2

    return {
        "winner": winner,
        "rewards": rewards,
        "action_history": action_history,
        "legal_policies": legal_policies,
        "final_state": state,
        "legal_action_idx": legal_action_idx_list,
    }

In [None]:
play_random_game_with_fake_reward(game, max_actions=2)

In [None]:
results = [play_random_game_with_fake_reward(game, max_actions=2) for _ in range(100_000)]
print_game_stats(results)


In [None]:
fake_gen_name = "fake-0"
trajectory_path = write_trajectory_dataset(results, action_vocab, fake_gen_name)


In [None]:
# fake_model_config = model_config_dict[MODEL_SIZE]
fake_model_config = model_config_dict["large"]
fake_model = create_random_model(fake_model_config, action_vocab_size=action_vocab.vocab_size, num_players=game.num_players(state_0), seed=42, device=device)

training_splits = [f'gen-{fake_gen_name}']
fake_model, fake_trainer = train_model(fake_model, training_splits, train_config)
save_model(fake_model, fake_trainer, fake_gen_name)

## model_size=tiny
# num decayed parameter tensors: 11, with 1,968 parameters
# num non-decayed parameter tensors: 7, with 50 parameters
# using fused AdamW: False
# step 0: train loss 2.7817, val loss 2.7816
# iter 0/49/488: loss 2.7821, time 2537.56ms
# iter 100/147/488: loss 2.6890, time 53.61ms
# iter 200/245/488: loss 2.6342, time 63.05ms
# iter 300/343/488: loss 2.6187, time 55.31ms
# iter 400/441/488: loss 2.6147, time 61.11ms

## model_size=large
# num decayed parameter tensors: 35, with 1,579,776 parameters
# num non-decayed parameter tensors: 19, with 2,186 parameters
# using fused AdamW: False
# step 0: train loss 2.8087, val loss 2.8088
# iter 0/49/488: loss 2.8099, time 11225.20ms
# iter 100/147/488: loss 2.6065, time 596.91ms
# iter 200/245/488: loss 2.6075, time 618.00ms
# iter 300/343/488: loss 2.6080, time 613.63ms
# iter 400/441/488: loss 2.6051, time 616.39ms

In [None]:
# for rerun in range(10):
#     print(f"Re-running training for {fake_gen_name} {rerun+1} of 10")
#     fake_model, fake_trainer = train_model(fake_model, training_splits, train_config)
#     save_model(fake_model, fake_trainer, fake_gen_name)

In [None]:
# [td for td in td_array]
fake_td_array = [TrajectoryDataset(DATA_DIR, split, block_size=n_max_context) for split in training_splits]
fake_unrolled = [(generation+1, d) for generation, td in enumerate(fake_td_array) for d in td]

# gen, d = unrolled[0], 
# d.action[:2]
# d.value[0]

# Inspect training data
fake_dd = defaultdict(lambda: defaultdict(lambda: torch.tensor([0., 0.])))

for gen, d in fake_unrolled:
    for g in ['*', gen]:    
        fake_dd[tuple(tuple(d.action[:0].tolist()))][g] += d.value[0]
        fake_dd[tuple(tuple(d.action[:1].tolist()))][g] += d.value[0]
        fake_dd[tuple(tuple(d.action[:2].tolist()))][g] += d.value[0]
        # fake_dd[tuple(tuple(d.action[:3].tolist()))][g] += d.value[0]

print(f"len(fake_dd) = {len(fake_dd)}")


In [None]:
fake_model = load_model(fake_gen_name)
compare_model_vs_data(fake_model, game, dd)


In [None]:
fake_model, fake_trainer = train_model(fake_model, training_splits, train_config)
