# Step-by-step run of alphazero self-play & training.


In [1]:
import os
import time
from pathlib import Path
import asyncio

import numpy as np
import torch

# Game and players
from rgi.rgizero.experiment import ExperimentRunner, ExperimentConfig
from rgi.rgizero.data.trajectory_dataset import Vocab, print_dataset_stats, TrajectoryDataset
from rgi.rgizero.evaluators import ActionHistoryTransformerEvaluator, AsyncNetworkEvaluator
from rgi.rgizero.models.tuner import create_random_model

import notebook_utils
from notebook_utils import reload_local_modules

device = notebook_utils.detect_device()

## Disable for debugger stability?
# # Allow asyncio to work with jupyter notebook
# import nest_asyncio
# nest_asyncio.apply()

# Increase numpy print width
np.set_printoptions(linewidth=300)

%load_ext line_profiler

transform_config_fields: {'bias', 'n_max_context', 'dropout', 'n_layer', 'n_head', 'n_embd'}
train_config_fields: {'batch_size', 'min_lr', 'model_version', 'beta2', 'max_epochs', 'learning_rate', 'always_save_checkpoint', 'max_iters', 'beta1', 'warmup_iters', 'eval_iters', 'weight_decay', 'model_name', 'early_stop_patience', 'grad_clip', 'compile', 'wandb_log', 'lr_decay_iters', 'device', 'eval_only', 'log_interval', 'dtype', 'gradient_accumulation_steps', 'eval_interval', 'decay_lr'}
Detected device: mps


In [2]:
RUN_GENERATIONS = True


# Create Experiment Config
experiment_config = ExperimentConfig(
    experiment_name='smoketest-e2e-v7',   # Use sliding window. Fix game reward bug. Shorted eval interval. Set max-epoch=1. max_epoch=2, less training steps.
    # experiment_name='smoketest-e2e-v2',
    # parent_experiment_name='smoketest-e2e',
    game_name='connect4',
    num_generations=10,
    num_games_per_gen=2_000,
    num_simulations=200,
    # model_size="tiny",
    # train_batch_size=10,
    # max_training_epochs=2,
    seed=42,
)

# Tuned params from connect4 with 23k training games.
tuned_params = {
    'batch_size': 512,
    'beta1': 0.9,
    'beta2': 0.99,
    'bias': False,
    'decay_lr': True,
    'dropout': 0.0,
    'dtype': 'float16',
    'grad_clip': 1.0,
    'gradient_accumulation_steps': 1,
    'learning_rate': 0.001,
    'lr_decay_iters': 5000,
    'max_epochs': 2,      # We retrain the same model each generation, so having a low epoch count good. This means we don't overfit during early generations.
    'max_iters': 10_000,  # 30_000
    'min_lr': 0.0001,
    'n_embd': 64,
    'n_head': 2,
    'n_layer': 4,
    'n_max_context': 44,
    'warmup_iters': 200,
    'weight_decay': 0.2,
    'eval_iters': 50,
    'log_interval': 250,
    'eval_interval': 500,
    'early_stop_patience': 1
    }

## Step 1: Set up game and experiment runner


In [3]:
from rgi.rgizero.data.trajectory_dataset import Vocab
from rgi.rgizero.common import TOKENS

# Initialize Experiment Runner
experiment_base_dir = Path.cwd().parent / 'experiments'
experiment_runner = ExperimentRunner(experiment_config, experiment_base_dir, training_args=tuned_params)
game = experiment_runner.game
action_vocab = experiment_runner.action_vocab
n_max_context = experiment_runner.n_max_context

DATA_DIR = experiment_runner.data_dir
MODEL_DIR = experiment_runner.models_dir

print('✅ Runner initialized')
print(f'Game: {experiment_runner.config.game_name}, Players: {experiment_runner.num_players}, Actions: {list(game.base_game.all_actions())}')
print('Data dir: ', DATA_DIR)
print('Model dir: ', MODEL_DIR)


✅ Runner initialized
Game: connect4, Players: 2, Actions: [1, 2, 3, 4, 5, 6, 7]
Data dir:  /Users/rodo/src/rgi3-sync/experiments/smoketest-e2e-v7/data
Model dir:  /Users/rodo/src/rgi3-sync/experiments/smoketest-e2e-v7/models


## Step 2: Create random generation_0 model


In [4]:
# Initialize (creates Random Gen 0 if needed)
model_0 = experiment_runner.initialize()
current_model = model_0


Starting Experiment: smoketest-e2e-v7
Initializing Random Gen 0 model.
Saved model to /Users/rodo/src/rgi3-sync/experiments/smoketest-e2e-v7/models/gen-0.pt


In [5]:
results_dict = {}
trajectory_paths_dict = {}
model_dict = {0: model_0}

current_model = model_dict[0]
if RUN_GENERATIONS:
    for generation_id in range(1, experiment_config.num_generations+1):
        current_model = await experiment_runner.run_generation_step_async(generation_id, current_model)
        dataset_paths = experiment_runner.get_trajectory_paths(generation_id)
        
        # print stats for visibility
        print_dataset_stats(dataset_paths, n_max_context, action_vocab, model=current_model, game=game)
        
        model_dict[generation_id] = current_model
else:
    generation_id = experiment_config.num_generations
    current_model = experiment_runner.load_model(generation_id)

# 10m to play 2x10k generations... probabilities still very wrong.
# Evaluation time: 0.015 seconds, size=574, eval-per-second=37837.60, total-batches=6000, mean-eval-per-second=94963.99, mean-time-per-batch=0.010, mean-batch-size=990.34

# >>> log(2) + log(7) -> 2.6390573296152584
## Model doesn't seem to improve loss at all?
# step.   0: losses: train:2.5971, train_policy_loss:1.9146, train_value_loss:0.6825, val:2.5972, val_policy_loss:1.9147, val_value_loss:0.6825
# step 1000: losses: train:2.6036, train_policy_loss:1.9122, train_value_loss:0.6914, val:2.6050, val_policy_loss:1.9132, val_value_loss:0.6917
# step 2000: losses: train:2.6056, train_policy_loss:1.9119, train_value_loss:0.6937, val:2.6056, val_policy_loss:1.9123, val_value_loss:0.6933
# iter    0/1170/5000: loss 2.5699, policy_loss:1.9129, value_loss:0.6570, time 5.18s, iter_time: 0.00ms
# iter 1000/1170/5000: loss 2.5996, policy_loss:1.9068, value_loss:0.6928, time 1.96s, iter_time: 1957.74ms
# iter 2339/2340/5000: loss 2.6014, policy_loss:1.9131, value_loss:0.6884, time 0.01s, iter_time: 14.61ms




=== Generation 1 ===
Playing 2000 games...


Self Play:   4%|▍         | 89/2000 [00:44<04:38,  6.87it/s]  

Evaluation time: 0.009 seconds, size=1000, eval-per-second=114467.11, total-batches=1000, mean-eval-per-second=89943.26, mean-time-per-batch=0.011, mean-batch-size=1000.00


Self Play:  22%|██▏       | 447/2000 [01:36<02:53,  8.96it/s]

Evaluation time: 0.021 seconds, size=1000, eval-per-second=47472.12, total-batches=2000, mean-eval-per-second=78431.52, mean-time-per-batch=0.013, mean-batch-size=1000.00


Self Play:  44%|████▍     | 878/2000 [02:33<01:15, 14.83it/s]

Evaluation time: 0.019 seconds, size=1000, eval-per-second=53859.44, total-batches=3000, mean-eval-per-second=72124.02, mean-time-per-batch=0.014, mean-batch-size=1000.00


Self Play:  66%|██████▋   | 1326/2000 [03:30<00:53, 12.59it/s]

Evaluation time: 0.020 seconds, size=678, eval-per-second=33158.87, total-batches=4000, mean-eval-per-second=61657.04, mean-time-per-batch=0.016, mean-batch-size=969.54


Self Play:  84%|████████▎ | 1673/2000 [04:07<00:49,  6.63it/s]

Evaluation time: 0.026 seconds, size=327, eval-per-second=12661.79, total-batches=5000, mean-eval-per-second=52680.10, mean-time-per-batch=0.017, mean-batch-size=873.28


Self Play:  94%|█████████▍| 1890/2000 [04:31<00:10, 10.94it/s]

Evaluation time: 0.006 seconds, size=110, eval-per-second=17903.51, total-batches=6000, mean-eval-per-second=46058.76, mean-time-per-batch=0.017, mean-batch-size=763.27


Self Play:  99%|█████████▉| 1979/2000 [04:42<00:03,  5.91it/s]

Evaluation time: 0.002 seconds, size=22, eval-per-second=9255.23, total-batches=7000, mean-eval-per-second=42685.32, mean-time-per-batch=0.016, mean-batch-size=662.38


Self Play: 100%|█████████▉| 1999/2000 [04:47<00:00,  3.43it/s]

Evaluation time: 0.002 seconds, size=2, eval-per-second=1226.23, total-batches=8000, mean-eval-per-second=41379.50, mean-time-per-batch=0.014, mean-batch-size=580.88


Self Play: 100%|██████████| 2000/2000 [04:47<00:00,  6.95it/s]


Writing 2000 trajectories...
Training model for gen 1...
num decayed parameter tensors: 19, with 200,064 parameters
num non-decayed parameter tensors: 11, with 586 parameters
using fused AdamW: False
step 0: losses: train:2.7687, train_policy_loss:2.0614, train_value_loss:0.7072, val:2.7748, val_policy_loss:2.0622, val_value_loss:0.7126
saving best checkpoint to /Users/rodo/src/rgi3-sync/models/smoketest-e2e-v7/gen-1/best.pt
iter 0/4/10000: loss 2.7673, policy_loss:2.0620, value_loss:0.7053, time 0.65s, iter_time: 0.00ms
step 3: losses: train:2.7556, train_policy_loss:2.0551, train_value_loss:0.7006, val:2.7591, val_policy_loss:2.0556, val_value_loss:0.7035
saving best checkpoint to /Users/rodo/src/rgi3-sync/models/smoketest-e2e-v7/gen-1/best.pt
saving checkpoint to /Users/rodo/src/rgi3-sync/models/smoketest-e2e-v7/gen-1
step 7: losses: train:2.7197, train_policy_loss:2.0341, train_value_loss:0.6856, val:2.7150, val_policy_loss:2.0352, val_value_loss:0.6798
saving best checkpoint to /U

Self Play:   4%|▍         | 84/2000 [00:45<04:09,  7.68it/s]  

Evaluation time: 0.019 seconds, size=1000, eval-per-second=52070.17, total-batches=1000, mean-eval-per-second=87920.44, mean-time-per-batch=0.011, mean-batch-size=1000.00


Self Play:  20%|██        | 404/2000 [01:40<05:04,  5.23it/s]

Evaluation time: 0.021 seconds, size=1000, eval-per-second=47870.35, total-batches=2000, mean-eval-per-second=76821.89, mean-time-per-batch=0.013, mean-batch-size=1000.00


Self Play:  42%|████▏     | 834/2000 [02:39<02:07,  9.16it/s]

Evaluation time: 0.017 seconds, size=1000, eval-per-second=57424.75, total-batches=3000, mean-eval-per-second=71333.05, mean-time-per-batch=0.014, mean-batch-size=1000.00


Self Play:  63%|██████▎   | 1256/2000 [03:40<01:49,  6.81it/s]

Evaluation time: 0.018 seconds, size=748, eval-per-second=42521.17, total-batches=4000, mean-eval-per-second=61299.90, mean-time-per-batch=0.016, mean-batch-size=980.40


Self Play:  80%|████████  | 1600/2000 [04:26<00:33, 11.94it/s]

Evaluation time: 0.039 seconds, size=400, eval-per-second=10312.95, total-batches=5000, mean-eval-per-second=51443.67, mean-time-per-batch=0.017, mean-batch-size=898.39


Self Play:  92%|█████████▏| 1842/2000 [04:54<00:20,  7.81it/s]

Evaluation time: 0.008 seconds, size=156, eval-per-second=20131.42, total-batches=6000, mean-eval-per-second=45255.04, mean-time-per-batch=0.018, mean-batch-size=795.25


Self Play:  98%|█████████▊| 1968/2000 [05:08<00:03,  8.14it/s]

Evaluation time: 0.003 seconds, size=34, eval-per-second=11277.69, total-batches=7000, mean-eval-per-second=41886.33, mean-time-per-batch=0.017, mean-batch-size=693.97


Self Play: 100%|█████████▉| 1999/2000 [05:11<00:00,  9.27it/s]

Evaluation time: 0.002 seconds, size=1, eval-per-second=560.81, total-batches=8000, mean-eval-per-second=40859.16, mean-time-per-batch=0.015, mean-batch-size=608.51


Self Play: 100%|██████████| 2000/2000 [05:12<00:00,  6.39it/s]


Writing 2000 trajectories...
Training model for gen 2...
num decayed parameter tensors: 19, with 200,064 parameters
num non-decayed parameter tensors: 11, with 586 parameters
using fused AdamW: False
step 0: losses: train:2.7138, train_policy_loss:2.0346, train_value_loss:0.6792, val:2.7199, val_policy_loss:2.0335, val_value_loss:0.6864
saving best checkpoint to /Users/rodo/src/rgi3-sync/models/smoketest-e2e-v7/gen-2/best.pt
iter 0/8/10000: loss 2.7178, policy_loss:2.0350, value_loss:0.6828, time 0.67s, iter_time: 0.00ms
step 7: losses: train:2.6847, train_policy_loss:2.0134, train_value_loss:0.6713, val:2.6925, val_policy_loss:2.0113, val_value_loss:0.6812
saving best checkpoint to /Users/rodo/src/rgi3-sync/models/smoketest-e2e-v7/gen-2/best.pt
saving checkpoint to /Users/rodo/src/rgi3-sync/models/smoketest-e2e-v7/gen-2
step 15: losses: train:2.6444, train_policy_loss:1.9763, train_value_loss:0.6681, val:2.6537, val_policy_loss:1.9761, val_value_loss:0.6776
saving best checkpoint to /

Self Play:   2%|▏         | 49/2000 [00:44<02:46, 11.72it/s]  

Evaluation time: 0.007 seconds, size=1000, eval-per-second=139373.43, total-batches=1000, mean-eval-per-second=90218.53, mean-time-per-batch=0.011, mean-batch-size=1000.00


Self Play:  17%|█▋        | 331/2000 [01:41<03:26,  8.10it/s]

Evaluation time: 0.020 seconds, size=1000, eval-per-second=50717.71, total-batches=2000, mean-eval-per-second=79266.10, mean-time-per-batch=0.013, mean-batch-size=1000.00


Self Play:  36%|███▌      | 716/2000 [02:39<03:16,  6.55it/s]

Evaluation time: 0.010 seconds, size=1000, eval-per-second=96589.54, total-batches=3000, mean-eval-per-second=73243.81, mean-time-per-batch=0.014, mean-batch-size=1000.00


Self Play:  57%|█████▋    | 1131/2000 [03:45<03:38,  3.97it/s]

Evaluation time: 0.047 seconds, size=870, eval-per-second=18698.86, total-batches=4000, mean-eval-per-second=61264.19, mean-time-per-batch=0.016, mean-batch-size=994.65


Self Play:  75%|███████▌  | 1506/2000 [04:42<00:41, 11.94it/s]

Evaluation time: 0.010 seconds, size=493, eval-per-second=51483.71, total-batches=5000, mean-eval-per-second=49022.93, mean-time-per-batch=0.019, mean-batch-size=933.19


Self Play:  89%|████████▉ | 1779/2000 [05:14<00:38,  5.69it/s]

Evaluation time: 0.012 seconds, size=222, eval-per-second=18859.59, total-batches=6000, mean-eval-per-second=44024.72, mean-time-per-batch=0.019, mean-batch-size=836.45


Self Play:  97%|█████████▋| 1939/2000 [05:33<00:09,  6.47it/s]

Evaluation time: 0.004 seconds, size=62, eval-per-second=13995.31, total-batches=7000, mean-eval-per-second=39871.58, mean-time-per-batch=0.018, mean-batch-size=734.95


Self Play: 100%|█████████▉| 1993/2000 [05:38<00:00,  9.95it/s]

Evaluation time: 0.002 seconds, size=7, eval-per-second=3662.69, total-batches=8000, mean-eval-per-second=38853.96, mean-time-per-batch=0.017, mean-batch-size=646.48


Self Play: 100%|██████████| 2000/2000 [05:40<00:00,  5.87it/s]


Writing 2000 trajectories...
Training model for gen 3...
num decayed parameter tensors: 19, with 200,064 parameters
num non-decayed parameter tensors: 11, with 586 parameters
using fused AdamW: False
step 0: losses: train:2.6531, train_policy_loss:1.9795, train_value_loss:0.6737, val:2.6370, val_policy_loss:1.9801, val_value_loss:0.6569
saving best checkpoint to /Users/rodo/src/rgi3-sync/models/smoketest-e2e-v7/gen-3/best.pt
iter 0/11/10000: loss 2.6458, policy_loss:1.9796, value_loss:0.6662, time 1.31s, iter_time: 0.00ms
step 10: losses: train:2.6426, train_policy_loss:1.9688, train_value_loss:0.6738, val:2.6448, val_policy_loss:1.9695, val_value_loss:0.6752
Early stopping triggered! Valid loss hasn't improved for 1 evals.
Reloading best model from /Users/rodo/src/rgi3-sync/models/smoketest-e2e-v7/gen-3/best.pt (val_loss=2.6370)
Saved model to /Users/rodo/src/rgi3-sync/experiments/smoketest-e2e-v7/models/gen-3.pt
Dataset Stats:
  Trajectories: 6000
  Total actions: 103256
  Avg trajec

Self Play:   1%|▏         | 29/2000 [00:44<05:04,  6.47it/s]  

Evaluation time: 0.019 seconds, size=1000, eval-per-second=52695.57, total-batches=1000, mean-eval-per-second=89903.71, mean-time-per-batch=0.011, mean-batch-size=1000.00


Self Play:  16%|█▌        | 319/2000 [01:40<06:56,  4.04it/s]

Evaluation time: 0.016 seconds, size=1000, eval-per-second=64172.34, total-batches=2000, mean-eval-per-second=77203.77, mean-time-per-batch=0.013, mean-batch-size=1000.00


Self Play:  36%|███▌      | 716/2000 [02:39<02:06, 10.14it/s]

Evaluation time: 0.018 seconds, size=1000, eval-per-second=56895.83, total-batches=3000, mean-eval-per-second=71333.21, mean-time-per-batch=0.014, mean-batch-size=1000.00


Self Play:  57%|█████▊    | 1150/2000 [03:46<02:59,  4.73it/s]

Evaluation time: 0.011 seconds, size=852, eval-per-second=75099.76, total-batches=4000, mean-eval-per-second=60008.93, mean-time-per-batch=0.017, mean-batch-size=992.98


Self Play:  75%|███████▌  | 1503/2000 [04:41<01:11,  6.98it/s]

Evaluation time: 0.008 seconds, size=498, eval-per-second=61657.27, total-batches=5000, mean-eval-per-second=49587.16, mean-time-per-batch=0.019, mean-batch-size=929.08


Self Play:  88%|████████▊ | 1767/2000 [05:15<00:42,  5.48it/s]

Evaluation time: 0.005 seconds, size=235, eval-per-second=50440.69, total-batches=6000, mean-eval-per-second=43445.91, mean-time-per-batch=0.019, mean-batch-size=832.90


Self Play:  96%|█████████▋| 1930/2000 [05:33<00:09,  7.47it/s]

Evaluation time: 0.005 seconds, size=71, eval-per-second=14011.93, total-batches=7000, mean-eval-per-second=40162.74, mean-time-per-batch=0.018, mean-batch-size=734.58


Self Play: 100%|█████████▉| 1993/2000 [05:40<00:00, 10.41it/s]

Evaluation time: 0.002 seconds, size=9, eval-per-second=5211.75, total-batches=8000, mean-eval-per-second=38745.47, mean-time-per-batch=0.017, mean-batch-size=647.10


Self Play: 100%|██████████| 2000/2000 [05:41<00:00,  5.85it/s]


Writing 2000 trajectories...
Training model for gen 4...
num decayed parameter tensors: 19, with 200,064 parameters
num non-decayed parameter tensors: 11, with 586 parameters
using fused AdamW: False
step 0: losses: train:2.6579, train_policy_loss:1.9810, train_value_loss:0.6769, val:2.6465, val_policy_loss:1.9813, val_value_loss:0.6652
saving best checkpoint to /Users/rodo/src/rgi3-sync/models/smoketest-e2e-v7/gen-4/best.pt
iter 0/15/10000: loss 2.6432, policy_loss:1.9809, value_loss:0.6623, time 1.40s, iter_time: 0.00ms
step 14: losses: train:2.6383, train_policy_loss:1.9655, train_value_loss:0.6728, val:2.6348, val_policy_loss:1.9651, val_value_loss:0.6697
saving best checkpoint to /Users/rodo/src/rgi3-sync/models/smoketest-e2e-v7/gen-4/best.pt
saving checkpoint to /Users/rodo/src/rgi3-sync/models/smoketest-e2e-v7/gen-4
step 29: losses: train:2.6231, train_policy_loss:1.9517, train_value_loss:0.6714, val:2.6243, val_policy_loss:1.9522, val_value_loss:0.6720
saving best checkpoint to

Self Play:   2%|▏         | 47/2000 [00:44<02:43, 11.91it/s]  

Evaluation time: 0.012 seconds, size=1000, eval-per-second=85977.04, total-batches=1000, mean-eval-per-second=92157.41, mean-time-per-batch=0.011, mean-batch-size=1000.00


Self Play:  14%|█▍        | 280/2000 [01:38<07:14,  3.95it/s]

Evaluation time: 0.021 seconds, size=1000, eval-per-second=47849.05, total-batches=2000, mean-eval-per-second=79669.83, mean-time-per-batch=0.013, mean-batch-size=1000.00


Self Play:  35%|███▌      | 706/2000 [02:36<02:17,  9.40it/s]

Evaluation time: 0.014 seconds, size=1000, eval-per-second=70407.31, total-batches=3000, mean-eval-per-second=73711.06, mean-time-per-batch=0.014, mean-batch-size=1000.00


Self Play:  55%|█████▍    | 1092/2000 [03:43<02:17,  6.59it/s]

Evaluation time: 0.022 seconds, size=908, eval-per-second=41866.04, total-batches=4000, mean-eval-per-second=61174.39, mean-time-per-batch=0.016, mean-batch-size=997.35


Self Play:  73%|███████▎  | 1463/2000 [04:43<01:04,  8.28it/s]

Evaluation time: 0.065 seconds, size=537, eval-per-second=8223.86, total-batches=5000, mean-eval-per-second=49290.93, mean-time-per-batch=0.019, mean-batch-size=941.96


Self Play:  87%|████████▋ | 1733/2000 [05:18<00:34,  7.71it/s]

Evaluation time: 0.013 seconds, size=267, eval-per-second=19901.53, total-batches=6000, mean-eval-per-second=43922.74, mean-time-per-batch=0.019, mean-batch-size=850.12


Self Play:  96%|█████████▌| 1916/2000 [05:36<00:07, 11.60it/s]

Evaluation time: 0.005 seconds, size=83, eval-per-second=15571.98, total-batches=7000, mean-eval-per-second=40789.78, mean-time-per-batch=0.018, mean-batch-size=751.18


Self Play:  99%|█████████▉| 1983/2000 [05:44<00:01,  9.27it/s]

Evaluation time: 0.002 seconds, size=18, eval-per-second=8475.24, total-batches=8000, mean-eval-per-second=39329.45, mean-time-per-batch=0.017, mean-batch-size=662.80


Self Play: 100%|██████████| 2000/2000 [05:47<00:00,  5.76it/s]

Evaluation time: 0.001 seconds, size=1, eval-per-second=882.64, total-batches=9000, mean-eval-per-second=38759.61, mean-time-per-batch=0.015, mean-batch-size=589.93
Writing 2000 trajectories...





Training model for gen 5...
num decayed parameter tensors: 19, with 200,064 parameters
num non-decayed parameter tensors: 11, with 586 parameters
using fused AdamW: False
step 0: losses: train:2.6285, train_policy_loss:1.9543, train_value_loss:0.6742, val:2.6236, val_policy_loss:1.9541, val_value_loss:0.6695
saving best checkpoint to /Users/rodo/src/rgi3-sync/models/smoketest-e2e-v7/gen-5/best.pt
iter 0/18/10000: loss 2.6308, policy_loss:1.9537, value_loss:0.6771, time 1.57s, iter_time: 0.00ms
step 17: losses: train:2.6223, train_policy_loss:1.9502, train_value_loss:0.6722, val:2.6187, val_policy_loss:1.9502, val_value_loss:0.6685
saving best checkpoint to /Users/rodo/src/rgi3-sync/models/smoketest-e2e-v7/gen-5/best.pt
saving checkpoint to /Users/rodo/src/rgi3-sync/models/smoketest-e2e-v7/gen-5
step 35: losses: train:2.6154, train_policy_loss:1.9420, train_value_loss:0.6734, val:2.6113, val_policy_loss:1.9421, val_value_loss:0.6692
saving best checkpoint to /Users/rodo/src/rgi3-sync/mo

Self Play:   3%|▎         | 62/2000 [00:46<03:30,  9.21it/s]  

Evaluation time: 0.020 seconds, size=1000, eval-per-second=49798.21, total-batches=1000, mean-eval-per-second=83258.90, mean-time-per-batch=0.012, mean-batch-size=1000.00


Self Play:  18%|█▊        | 352/2000 [01:40<02:42, 10.16it/s]

Evaluation time: 0.014 seconds, size=1000, eval-per-second=70472.37, total-batches=2000, mean-eval-per-second=76386.39, mean-time-per-batch=0.013, mean-batch-size=1000.00


Self Play:  38%|███▊      | 763/2000 [02:36<02:36,  7.88it/s]

Evaluation time: 0.021 seconds, size=1000, eval-per-second=48398.42, total-batches=3000, mean-eval-per-second=72057.57, mean-time-per-batch=0.014, mean-batch-size=1000.00


Self Play:  60%|█████▉    | 1198/2000 [03:38<01:40,  7.98it/s]

Evaluation time: 0.016 seconds, size=804, eval-per-second=48751.23, total-batches=4000, mean-eval-per-second=61844.03, mean-time-per-batch=0.016, mean-batch-size=989.65


Self Play:  79%|███████▊  | 1574/2000 [04:23<00:30, 13.84it/s]

Evaluation time: 0.022 seconds, size=426, eval-per-second=19608.59, total-batches=5000, mean-eval-per-second=52488.59, mean-time-per-batch=0.017, mean-batch-size=912.23


Self Play:  91%|█████████▏| 1826/2000 [04:48<00:20,  8.44it/s]

Evaluation time: 0.004 seconds, size=175, eval-per-second=48871.64, total-batches=6000, mean-eval-per-second=47285.35, mean-time-per-batch=0.017, mean-batch-size=808.48


Self Play:  98%|█████████▊| 1965/2000 [05:01<00:02, 14.38it/s]

Evaluation time: 0.003 seconds, size=40, eval-per-second=12326.22, total-batches=7000, mean-eval-per-second=44204.67, mean-time-per-batch=0.016, mean-batch-size=707.28


Self Play: 100%|██████████| 2000/2000 [05:06<00:00,  6.53it/s]

Evaluation time: 0.001 seconds, size=1, eval-per-second=892.03, total-batches=8000, mean-eval-per-second=42913.68, mean-time-per-batch=0.014, mean-batch-size=620.62
Writing 2000 trajectories...
Training model for gen 6...
num decayed parameter tensors: 19, with 200,064 parameters
num non-decayed parameter tensors: 11, with 586 parameters
using fused AdamW: False





step 0: losses: train:2.6085, train_policy_loss:1.9409, train_value_loss:0.6676, val:2.6061, val_policy_loss:1.9407, val_value_loss:0.6654
saving best checkpoint to /Users/rodo/src/rgi3-sync/models/smoketest-e2e-v7/gen-6/best.pt
iter 0/22/10000: loss 2.6046, policy_loss:1.9413, value_loss:0.6633, time 2.37s, iter_time: 0.00ms
step 21: losses: train:2.6067, train_policy_loss:1.9374, train_value_loss:0.6693, val:2.6057, val_policy_loss:1.9366, val_value_loss:0.6691
saving best checkpoint to /Users/rodo/src/rgi3-sync/models/smoketest-e2e-v7/gen-6/best.pt
saving checkpoint to /Users/rodo/src/rgi3-sync/models/smoketest-e2e-v7/gen-6
step 43: losses: train:2.5972, train_policy_loss:1.9303, train_value_loss:0.6669, val:2.5964, val_policy_loss:1.9290, val_value_loss:0.6675
saving best checkpoint to /Users/rodo/src/rgi3-sync/models/smoketest-e2e-v7/gen-6/best.pt
saving checkpoint to /Users/rodo/src/rgi3-sync/models/smoketest-e2e-v7/gen-6
Reloading best model from /Users/rodo/src/rgi3-sync/models

Self Play:   4%|▎         | 73/2000 [00:44<04:11,  7.67it/s]  

Evaluation time: 0.009 seconds, size=1000, eval-per-second=112459.89, total-batches=1000, mean-eval-per-second=91406.72, mean-time-per-batch=0.011, mean-batch-size=1000.00


Self Play:  18%|█▊        | 353/2000 [01:39<03:00,  9.14it/s]

Evaluation time: 0.009 seconds, size=1000, eval-per-second=106486.85, total-batches=2000, mean-eval-per-second=87789.98, mean-time-per-batch=0.011, mean-batch-size=1000.00


Self Play:  40%|████      | 809/2000 [02:34<01:48, 10.96it/s]

Evaluation time: 0.014 seconds, size=1000, eval-per-second=72237.12, total-batches=3000, mean-eval-per-second=86189.48, mean-time-per-batch=0.012, mean-batch-size=1000.00


Self Play:  61%|██████    | 1219/2000 [03:33<03:48,  3.42it/s]

Evaluation time: 0.014 seconds, size=782, eval-per-second=54911.95, total-batches=4000, mean-eval-per-second=73794.15, mean-time-per-batch=0.013, mean-batch-size=984.59


Self Play:  78%|███████▊  | 1563/2000 [04:18<00:34, 12.69it/s]

Evaluation time: 0.010 seconds, size=436, eval-per-second=41847.06, total-batches=5000, mean-eval-per-second=60039.84, mean-time-per-batch=0.015, mean-batch-size=907.80


Self Play:  91%|█████████ | 1812/2000 [04:44<00:19,  9.89it/s]

Evaluation time: 0.009 seconds, size=192, eval-per-second=20225.70, total-batches=6000, mean-eval-per-second=53450.47, mean-time-per-batch=0.015, mean-batch-size=807.33


Self Play:  98%|█████████▊| 1952/2000 [04:55<00:04, 10.60it/s]

Evaluation time: 0.002 seconds, size=49, eval-per-second=22973.50, total-batches=7000, mean-eval-per-second=50945.79, mean-time-per-batch=0.014, mean-batch-size=707.01


Self Play: 100%|█████████▉| 1997/2000 [04:58<00:00,  8.05it/s]

Evaluation time: 0.002 seconds, size=4, eval-per-second=2462.89, total-batches=8000, mean-eval-per-second=49788.15, mean-time-per-batch=0.012, mean-batch-size=621.18


Self Play: 100%|██████████| 2000/2000 [04:59<00:00,  6.68it/s]


Writing 2000 trajectories...
Training model for gen 7...
num decayed parameter tensors: 19, with 200,064 parameters
num non-decayed parameter tensors: 11, with 586 parameters
using fused AdamW: False
step 0: losses: train:2.5959, train_policy_loss:1.9287, train_value_loss:0.6672, val:2.5825, val_policy_loss:1.9289, val_value_loss:0.6536
saving best checkpoint to /Users/rodo/src/rgi3-sync/models/smoketest-e2e-v7/gen-7/best.pt
iter 0/25/10000: loss 2.6075, policy_loss:1.9292, value_loss:0.6783, time 2.97s, iter_time: 0.00ms
step 24: losses: train:2.5924, train_policy_loss:1.9253, train_value_loss:0.6671, val:2.5831, val_policy_loss:1.9259, val_value_loss:0.6572
Early stopping triggered! Valid loss hasn't improved for 1 evals.
Reloading best model from /Users/rodo/src/rgi3-sync/models/smoketest-e2e-v7/gen-7/best.pt (val_loss=2.5825)
Saved model to /Users/rodo/src/rgi3-sync/experiments/smoketest-e2e-v7/models/gen-7.pt
Dataset Stats:
  Trajectories: 14000
  Total actions: 244265
  Avg traje

Self Play:   3%|▎         | 68/2000 [00:45<02:27, 13.14it/s]  

Evaluation time: 0.011 seconds, size=1000, eval-per-second=93022.78, total-batches=1000, mean-eval-per-second=103521.03, mean-time-per-batch=0.010, mean-batch-size=1000.00


Self Play:  18%|█▊        | 354/2000 [01:39<03:36,  7.59it/s]

Evaluation time: 0.015 seconds, size=1000, eval-per-second=64998.74, total-batches=2000, mean-eval-per-second=92729.67, mean-time-per-batch=0.011, mean-batch-size=1000.00


Self Play:  39%|███▉      | 784/2000 [02:34<01:31, 13.23it/s]

Evaluation time: 0.012 seconds, size=1000, eval-per-second=81679.11, total-batches=3000, mean-eval-per-second=89958.51, mean-time-per-batch=0.011, mean-batch-size=1000.00


Self Play:  60%|██████    | 1201/2000 [03:34<01:51,  7.18it/s]

Evaluation time: 0.010 seconds, size=800, eval-per-second=82390.69, total-batches=4000, mean-eval-per-second=77172.26, mean-time-per-batch=0.013, mean-batch-size=987.57


Self Play:  79%|███████▉  | 1581/2000 [04:15<00:36, 11.54it/s]

Evaluation time: 0.059 seconds, size=420, eval-per-second=7114.42, total-batches=5000, mean-eval-per-second=66773.70, mean-time-per-batch=0.014, mean-batch-size=910.11


Self Play:  92%|█████████▏| 1834/2000 [04:44<00:12, 13.49it/s]

Evaluation time: 0.004 seconds, size=167, eval-per-second=39886.61, total-batches=6000, mean-eval-per-second=55666.89, mean-time-per-batch=0.014, mean-batch-size=805.53


Self Play:  98%|█████████▊| 1965/2000 [04:54<00:02, 15.26it/s]

Evaluation time: 0.003 seconds, size=36, eval-per-second=11096.86, total-batches=7000, mean-eval-per-second=52838.24, mean-time-per-batch=0.013, mean-batch-size=703.88


Self Play: 100%|██████████| 2000/2000 [04:59<00:00,  6.68it/s]

Evaluation time: 0.002 seconds, size=1, eval-per-second=526.59, total-batches=8000, mean-eval-per-second=50829.02, mean-time-per-batch=0.012, mean-batch-size=617.83
Writing 2000 trajectories...





Training model for gen 8...
num decayed parameter tensors: 19, with 200,064 parameters
num non-decayed parameter tensors: 11, with 586 parameters
using fused AdamW: False
step 0: losses: train:2.5932, train_policy_loss:1.9277, train_value_loss:0.6655, val:2.5819, val_policy_loss:1.9270, val_value_loss:0.6550
saving best checkpoint to /Users/rodo/src/rgi3-sync/models/smoketest-e2e-v7/gen-8/best.pt
iter 0/29/10000: loss 2.5857, policy_loss:1.9296, value_loss:0.6561, time 2.85s, iter_time: 0.00ms
step 28: losses: train:2.5900, train_policy_loss:1.9230, train_value_loss:0.6669, val:2.5847, val_policy_loss:1.9245, val_value_loss:0.6602
Early stopping triggered! Valid loss hasn't improved for 1 evals.
Reloading best model from /Users/rodo/src/rgi3-sync/models/smoketest-e2e-v7/gen-8/best.pt (val_loss=2.5819)
Saved model to /Users/rodo/src/rgi3-sync/experiments/smoketest-e2e-v7/models/gen-8.pt
Dataset Stats:
  Trajectories: 16000
  Total actions: 278600
  Avg trajectory length: 17.41
Prefix St

Self Play:   3%|▎         | 60/2000 [00:44<05:19,  6.07it/s]  

Evaluation time: 0.009 seconds, size=1000, eval-per-second=106484.15, total-batches=1000, mean-eval-per-second=102975.31, mean-time-per-batch=0.010, mean-batch-size=1000.00


Self Play:  16%|█▌        | 324/2000 [01:38<02:11, 12.74it/s]

Evaluation time: 0.015 seconds, size=1000, eval-per-second=67254.13, total-batches=2000, mean-eval-per-second=92737.04, mean-time-per-batch=0.011, mean-batch-size=1000.00


Self Play:  38%|███▊      | 762/2000 [02:34<01:45, 11.70it/s]

Evaluation time: 0.016 seconds, size=1000, eval-per-second=62594.08, total-batches=3000, mean-eval-per-second=88553.27, mean-time-per-batch=0.011, mean-batch-size=1000.00


Self Play:  58%|█████▊    | 1166/2000 [03:32<01:39,  8.41it/s]

Evaluation time: 0.012 seconds, size=838, eval-per-second=71647.81, total-batches=4000, mean-eval-per-second=80901.51, mean-time-per-batch=0.012, mean-batch-size=991.71


Self Play:  76%|███████▌  | 1523/2000 [04:14<00:41, 11.49it/s]

Evaluation time: 0.009 seconds, size=479, eval-per-second=55107.98, total-batches=5000, mean-eval-per-second=72128.20, mean-time-per-batch=0.013, mean-batch-size=923.56


Self Play:  90%|█████████ | 1802/2000 [04:41<00:19, 10.30it/s]

Evaluation time: 0.011 seconds, size=197, eval-per-second=18646.40, total-batches=6000, mean-eval-per-second=61336.16, mean-time-per-batch=0.013, mean-batch-size=823.07


Self Play:  97%|█████████▋| 1946/2000 [04:55<00:04, 11.82it/s]

Evaluation time: 0.002 seconds, size=55, eval-per-second=24696.15, total-batches=7000, mean-eval-per-second=56396.47, mean-time-per-batch=0.013, mean-batch-size=722.30


Self Play: 100%|█████████▉| 1990/2000 [04:59<00:01,  6.62it/s]

Evaluation time: 0.002 seconds, size=8, eval-per-second=4446.65, total-batches=8000, mean-eval-per-second=54864.01, mean-time-per-batch=0.012, mean-batch-size=635.22


Self Play: 100%|██████████| 2000/2000 [05:01<00:00,  6.64it/s]


Writing 2000 trajectories...
Training model for gen 9...
num decayed parameter tensors: 19, with 200,064 parameters
num non-decayed parameter tensors: 11, with 586 parameters
using fused AdamW: False
step 0: losses: train:2.5923, train_policy_loss:1.9270, train_value_loss:0.6654, val:2.5938, val_policy_loss:1.9267, val_value_loss:0.6671
saving best checkpoint to /Users/rodo/src/rgi3-sync/models/smoketest-e2e-v7/gen-9/best.pt
iter 0/32/10000: loss 2.5962, policy_loss:1.9283, value_loss:0.6679, time 2.38s, iter_time: 0.00ms
step 31: losses: train:2.5863, train_policy_loss:1.9214, train_value_loss:0.6649, val:2.5872, val_policy_loss:1.9210, val_value_loss:0.6662
saving best checkpoint to /Users/rodo/src/rgi3-sync/models/smoketest-e2e-v7/gen-9/best.pt
saving checkpoint to /Users/rodo/src/rgi3-sync/models/smoketest-e2e-v7/gen-9
step 63: losses: train:2.5790, train_policy_loss:1.9144, train_value_loss:0.6646, val:2.5816, val_policy_loss:1.9149, val_value_loss:0.6667
saving best checkpoint to

Self Play:   3%|▎         | 68/2000 [00:44<05:59,  5.38it/s]  

Evaluation time: 0.013 seconds, size=1000, eval-per-second=74454.22, total-batches=1000, mean-eval-per-second=100950.28, mean-time-per-batch=0.010, mean-batch-size=1000.00


Self Play:  18%|█▊        | 366/2000 [01:39<06:45,  4.03it/s]

Evaluation time: 0.018 seconds, size=1000, eval-per-second=56513.30, total-batches=2000, mean-eval-per-second=90951.35, mean-time-per-batch=0.011, mean-batch-size=1000.00


Self Play:  40%|███▉      | 791/2000 [02:34<01:54, 10.53it/s]

Evaluation time: 0.011 seconds, size=1000, eval-per-second=93809.22, total-batches=3000, mean-eval-per-second=86845.31, mean-time-per-batch=0.012, mean-batch-size=1000.00


Self Play:  62%|██████▏   | 1237/2000 [03:30<02:10,  5.85it/s]

Evaluation time: 0.011 seconds, size=764, eval-per-second=67778.85, total-batches=4000, mean-eval-per-second=80116.92, mean-time-per-batch=0.012, mean-batch-size=983.19


Self Play:  81%|████████▏ | 1626/2000 [04:10<00:47,  7.94it/s]

Evaluation time: 0.093 seconds, size=375, eval-per-second=4036.56, total-batches=5000, mean-eval-per-second=67026.20, mean-time-per-batch=0.013, mean-batch-size=899.22


Self Play:  94%|█████████▎| 1870/2000 [04:40<00:10, 12.55it/s]

Evaluation time: 0.002 seconds, size=133, eval-per-second=55672.90, total-batches=6000, mean-eval-per-second=53744.30, mean-time-per-batch=0.015, mean-batch-size=788.07


Self Play:  99%|█████████▊| 1971/2000 [04:46<00:02, 13.97it/s]

Evaluation time: 0.003 seconds, size=31, eval-per-second=11376.62, total-batches=7000, mean-eval-per-second=51863.67, mean-time-per-batch=0.013, mean-batch-size=685.59


Self Play: 100%|█████████▉| 1998/2000 [04:49<00:00,  7.14it/s]

Evaluation time: 0.001 seconds, size=3, eval-per-second=2454.72, total-batches=8000, mean-eval-per-second=50789.09, mean-time-per-batch=0.012, mean-batch-size=601.32


Self Play: 100%|██████████| 2000/2000 [04:49<00:00,  6.90it/s]


Writing 2000 trajectories...
Training model for gen 10...
num decayed parameter tensors: 19, with 200,064 parameters
num non-decayed parameter tensors: 11, with 586 parameters
using fused AdamW: False
step 0: losses: train:2.5752, train_policy_loss:1.9118, train_value_loss:0.6634, val:2.5719, val_policy_loss:1.9120, val_value_loss:0.6599
saving best checkpoint to /Users/rodo/src/rgi3-sync/models/smoketest-e2e-v7/gen-10/best.pt
iter 0/36/10000: loss 2.5862, policy_loss:1.9097, value_loss:0.6764, time 4.09s, iter_time: 0.00ms
step 35: losses: train:2.5693, train_policy_loss:1.9072, train_value_loss:0.6621, val:2.5663, val_policy_loss:1.9073, val_value_loss:0.6589
saving best checkpoint to /Users/rodo/src/rgi3-sync/models/smoketest-e2e-v7/gen-10/best.pt
saving checkpoint to /Users/rodo/src/rgi3-sync/models/smoketest-e2e-v7/gen-10
step 71: losses: train:2.5638, train_policy_loss:1.9024, train_value_loss:0.6613, val:2.5625, val_policy_loss:1.9027, val_value_loss:0.6598
saving best checkpoin

In [6]:
STOP!

SyntaxError: invalid syntax (2635050600.py, line 1)

In [None]:
# current_model = await experiment_runner.run_generation_step_async(generation_id, current_model)

experiment_config.experiment_name='smoketest-e2e-v3-hack'   # Use sliding window.
experiment_config.parent_experiment_name='smoketest-e2e-v3'
experiment_config.num_generations=41
experiment_config.num_games_per_gen=1

experiment_runner = ExperimentRunner(experiment_config, experiment_base_dir, training_args=tuned_params)
experiment_runner.progress_bar = False
await experiment_runner.play_generation_async(current_model, gen_id=experiment_config.num_generations, write_dataset=False)

# Tune Model (initial)


In [None]:
reload_local_modules(verbose=False)

state_0 = game.initial_state()
NUM_GENERATIONS = 5
LEARNING_RATE = 0.1

# Parameters which will never be used for tuning.
fixed_params = dict(
    model_name='c4-smoketest',
    model_version='0.1',
    num_players = game.num_players(state_0),
    vocab_size = action_vocab.vocab_size,
    dataset_paths = tuple(experiment_runner.get_trajectory_paths(experiment_config.num_generations)),


    eval_iters = 200,
    log_interval = 1000,
    eval_interval = 10_000,

    device = device,
)

initial_params = dict(
    n_layer=2,
    n_head=2,
    n_embd=8,  # tiny model

    n_max_context=n_max_context,
    batch_size = 32,
    gradient_accumulation_steps = 1,

    max_iters=100,
    max_epochs=1_000_000, # Make max_epoch high, rely on max_iters to stop.
        
    learning_rate = LEARNING_RATE,    
    decay_lr = True,  # whether to decay the learning rate
    lr_decay_iters = 100,  # make equal to max_iters usually
    min_lr = LEARNING_RATE / 10,  # learning_rate / 10 usually
    warmup_iters = 0,  # not super necessary potentially

    weight_decay = 1e-1,
    beta1 = 0.9,
    beta2 = 0.95,
    grad_clip = 1.0,  # clip gradients at this value, or disable if == 0.0

    dtype = "float16",

    dropout = 0.0,
    bias = False,  # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster
    last_file = None,   # Used in tuning key only.
)

tune_options = dict(
    n_layer = [1, 2, 3, 4, 5, 6, 8, 10, 12, 16, 32],
    # n_head = [1, 2, 4, 8, 16, 32],   # Needs to be calcualted to ensure n_embed % n_head == 0
    n_embd = [8, 16, 32, 64, 128, 256, 512, 1024, 2048],

    n_max_context = [initial_params['n_max_context']],
    batch_size = [16, 32, 64, 128, 256, 512, 1024],
    gradient_accumulation_steps = [1],  # TODO: We only support 1 for now. This fails is we don't have an exact multiple of the batch size per epoch.

    max_iters = [100, 300, 1_000, 3_000, 5_000, 10_000, 30_000, 100_000, 300_000],
    max_epochs = [1_000_000], # Make max_epoch high, rely on max_iters to stop.
 
    learning_rate = [0.001, 0.002, 0.005, 0.01, 0.02, 0.05, 0.1, 0.2, 0.5, 1.0],
    decay_lr = [False, True],

    # TODO: What is a sensible range here?
    beta1 = [0.90, 0.95, 0.99],
    beta2 = [0.95, 0.98, 0.99],

    weight_decay = [0.01, 0.05, 0.1, 0.2],
    grad_clip = [0,0, 1.0],  # clip gradients at this value, or disable if == 0.0

    dtype = ["bfloat16", "float16"],
    dropout = [0.0, 0.01, 0.02, 0.05, 0.1],
    bias = [True, False],    
)

_n_head_options = [1, 2, 4, 8, 16, 32]
computed_tune_options = dict(
    min_lr = lambda opt: [opt['learning_rate'] / 10],
    lr_decay_iters = lambda opt: [opt['max_iters']],
    warmup_iters = lambda opt: [x for x in [0, 100, 500, 1000] if x < opt['lr_decay_iters']] if opt['decay_lr'] else [0],
    n_head = lambda opt: [n for n in _n_head_options if opt['n_embd'] % n == 0],
    last_file = lambda opt: [str(opt['dataset_paths'][-1])],
)

TUNER_VERSION = "0.0.6-smoketest"

from rgi.rgizero.models.tuner import Tuner

tuner = Tuner(
    fixed_params=fixed_params.copy(),
    initial_params=initial_params.copy(),
    tune_options=tune_options.copy(), 
    computed_tune_options=computed_tune_options.copy(),
    cache_version=TUNER_VERSION,
    target_improvement_per_minute=1.00)
tuner.autotune_smart()


In [None]:
tuner = Tuner(
    fixed_params=fixed_params.copy(),
    initial_params=initial_params.copy(),
    tune_options=tune_options.copy(), 
    computed_tune_options=computed_tune_options.copy(),
    cache_version=TUNER_VERSION,
    target_improvement_per_minute=0.1)
tuner.autotune_smart()


In [None]:
tuner = Tuner(
    fixed_params=fixed_params.copy(),
    initial_params=initial_params.copy(),
    tune_options=tune_options.copy(), 
    computed_tune_options=computed_tune_options.copy(),
    cache_version=TUNER_VERSION,
    target_improvement_per_minute=0.01)
tuner.autotune_smart()

In [None]:
tuner = Tuner(
    fixed_params=fixed_params.copy(),
    initial_params=initial_params.copy(),
    tune_options=tune_options.copy(), 
    computed_tune_options=computed_tune_options.copy(),
    cache_version=TUNER_VERSION,
    target_improvement_per_minute=0.001)
tuner.autotune_smart()

In [None]:
tuner = Tuner(
    fixed_params=fixed_params.copy(),
    initial_params=initial_params.copy(),
    tune_options=tune_options.copy(), 
    computed_tune_options=computed_tune_options.copy(),
    cache_version=TUNER_VERSION,
    target_improvement_per_minute=0.0001)
tuner.autotune_smart()

# Sanity check models


In [None]:
raise NotImplementedError("Skip this...")

In [None]:
reload_local_modules(verbose=False)

tuner = Tuner(
    fixed_params=fixed_params.copy(),
    initial_params=initial_params.copy(),
    tune_options=tune_options.copy(), 
    computed_tune_options=computed_tune_options.copy(),
    cache_version=TUNER_VERSION,
    target_improvement_per_minute=0.001)

tuner_result = tuner.autotune_smart()
# print(f'tuner_result={tuner_result}')

best_params = tuner.best_params.copy()
## Recalculating with best_params = {'batch_size': 512, 'beta1': 0.9, 'beta2': 0.99, 'bias': False, 'decay_lr': True, 'dropout': 0.0, 'dtype': 'float16', 'grad_clip': 1.0, 'gradient_accumulation_steps': 1, 'last_file': '/Users/rodo/src/rgi3-sync/experiments/smoketest-e2e-v2/data/gen-20', 'learning_rate': 0.01, 'lr_decay_iters': 5000, 'max_epochs': 1000000, 'max_iters': 30000, 'min_lr': 0.001, 'n_embd': 64, 'n_head': 8, 'n_layer': 4, 'n_max_context': 44, 'warmup_iters': 1000, 'weight_decay': 0.2, 'model_name': 'c4-smoketest', 'model_version': '0.1', 'num_players': 2, 'vocab_size': 8, 'dataset_paths': (PosixPath('/Users/rodo/src/rgi3-sync/experiments/smoketest-e2e/data/gen-1'), PosixPath('/Users/rodo/src/rgi3-sync/experiments/smoketest-e2e/data/gen-2'), PosixPath('/Users/rodo/src/rgi3-sync/experiments/smoketest-e2e/data/gen-3'), PosixPath('/Users/rodo/src/rgi3-sync/experiments/smoketest-e2e-v2/data/gen-4'), PosixPath('/Users/rodo/src/rgi3-sync/experiments/smoketest-e2e-v2/data/gen-5'), PosixPath('/Users/rodo/src/rgi3-sync/experiments/smoketest-e2e-v2/data/gen-6'), PosixPath('/Users/rodo/src/rgi3-sync/experiments/smoketest-e2e-v2/data/gen-7'), PosixPath('/Users/rodo/src/rgi3-sync/experiments/smoketest-e2e-v2/data/gen-8'), PosixPath('/Users/rodo/src/rgi3-sync/experiments/smoketest-e2e-v2/data/gen-9'), PosixPath('/Users/rodo/src/rgi3-sync/experiments/smoketest-e2e-v2/data/gen-10'), PosixPath('/Users/rodo/src/rgi3-sync/experiments/smoketest-e2e-v2/data/gen-11'), PosixPath('/Users/rodo/src/rgi3-sync/experiments/smoketest-e2e-v2/data/gen-12'), PosixPath('/Users/rodo/src/rgi3-sync/experiments/smoketest-e2e-v2/data/gen-13'), PosixPath('/Users/rodo/src/rgi3-sync/experiments/smoketest-e2e-v2/data/gen-14'), PosixPath('/Users/rodo/src/rgi3-sync/experiments/smoketest-e2e-v2/data/gen-15'), PosixPath('/Users/rodo/src/rgi3-sync/experiments/smoketest-e2e-v2/data/gen-16'), PosixPath('/Users/rodo/src/rgi3-sync/experiments/smoketest-e2e-v2/data/gen-17'), PosixPath('/Users/rodo/src/rgi3-sync/experiments/smoketest-e2e-v2/data/gen-18'), PosixPath('/Users/rodo/src/rgi3-sync/experiments/smoketest-e2e-v2/data/gen-19'), PosixPath('/Users/rodo/src/rgi3-sync/experiments/smoketest-e2e-v2/data/gen-20')), 'eval_iters': 200, 'log_interval': 1000, 'eval_interval': 10000, 'device': 'mps'}
## {'train': 2.2821048521995544, 'train_policy_loss': 1.746874241232872, 'train_value_loss': 0.5352306108176709, 'val': 2.324920549112208, 'val_policy_loss': 1.7480337900273941, 'val_value_loss': 0.5768867538255804, 'elapsed': 1651.836928844452, 'param_hash': '03189f0f4cb118a1ec142e488fe3fac12e97c118f60661796f8a5a64fa871d44'}

best_params['max_iters'] = 30_000 # {'train': 2.2821048521995544, 'train_policy_loss': 1.746874241232872, 'train_value_loss': 0.5352306108176709, 'val': 2.324920549112208, 'val_policy_loss': 1.7480337900273941, 'val_value_loss': 0.5768867538255804, 'elapsed': 1651.836928844452, 'param_hash': '03189f0f4cb118a1ec142e488fe3fac12e97c118f60661796f8a5a64fa871d44'}
# best_params['max_iters'] = 10_000 #{'train': 2.2972692108154296, 'train_policy_loss': 1.7478227978944778, 'train_value_loss': 0.5494464221596718, 'val': 2.3170768583522126, 'val_policy_loss': 1.7486604136579178, 'val_value_loss': 0.5684164271635168, 'elapsed': 529.6186480522156, 'param_hash': 'c81f68b9b00650b83a229a76b93d77d51f7bd1af43148f6c1af75e8f562f2670'}
# best_params['max_iters'] = 5000 # {'train': 2.3059648656845093, 'train_policy_loss': 1.749609624147415, 'train_value_loss': 0.5563552376627922, 'val': 2.319357395172119, 'val_policy_loss': 1.7505432332263273, 'val_value_loss': 0.5688141549334806, 'elapsed': 267.94705629348755, 'param_hash': 'a1194e1f289d35c51e5a0f01692109358f947530b50dd54364f01baefdb6bedf'}
# best_params['max_iters'] = 3000 # {'train': 2.3192384707927705, 'train_policy_loss': 1.752758464217186, 'train_value_loss': 0.5664800041913987, 'val': 2.326388120651245, 'val_policy_loss': 1.7533490412375505, 'val_value_loss': 0.5730390969444724, 'elapsed': 172.16889691352844, 'param_hash': '9e0204edd0b23fd54025b0c4de66c870a268b95197e0e47c4460ec03875b3dcb'}

# best_params['learning_rate'] = 0.01 # {'train': 2.2821048521995544, 'train_policy_loss': 1.746874241232872, 'train_value_loss': 0.5352306108176709, 'val': 2.324920549112208, 'val_policy_loss': 1.7480337900273941, 'val_value_loss': 0.5768867538255804, 'elapsed': 1651.836928844452, 'param_hash': '03189f0f4cb118a1ec142e488fe3fac12e97c118f60661796f8a5a64fa871d44'}
best_params['learning_rate'] = 0.0005 # {'train': 2.2302739822864535, 'train_policy_loss': 1.7535288149118424, 'train_value_loss': 0.47674516707658765, 'val': 2.4826266765594482, 'val_policy_loss': 1.755559900227715, 'val_value_loss': 0.7270667658132666, 'elapsed': 1617.1719007492065, 'param_hash': '4364fb3cfa7a3b4d33fe30f70ec9957a0a14bc7d9a195b5a2de2247d2a72a6d1'}

print(f'\n## Recalculating with best_params = {best_params}')
best_params = tuner._recalculate_tunable_params(best_params)
best_model = tuner.get_model_for_params(best_params)
print(tuner.train_and_compute_loss(best_params, reload_model=True)[2])



In [None]:
dataset_paths = experiment_runner.get_trajectory_paths(NUM_GENERATIONS)
print_dataset_stats(dataset_paths, n_max_context, action_vocab, model=best_model, game=game)

In [None]:
best_model

In [None]:
# Inspect training data
td_array = [TrajectoryDataset(DATA_DIR, f"gen-{generation_id}", block_size=n_max_context) for generation_id in range(1, NUM_GENERATIONS+1)]

In [None]:
# [td for td in td_array]
unrolled = [(generation+1, d) for generation, td in enumerate(td_array) for d in td]

# gen, d = unrolled[0], 
# d.action[:2]
# d.value[0]

dd = defaultdict(lambda: defaultdict(lambda: torch.tensor([0., 0.])))

for gen, d in unrolled:
    for g in ['*', gen]:    
        # dd[tuple(tuple(d.action[:0].tolist()))][g] += d.value[0]
        dd[tuple(tuple(d.action[:1].tolist()))][g] += d.value[0]
        dd[tuple(tuple(d.action[:2].tolist()))][g] += d.value[0]
        dd[tuple(tuple(d.action[:3].tolist()))][g] += d.value[0]
        dd[tuple(tuple(d.action[:4].tolist()))][g] += d.value[0]
        dd[tuple(tuple(d.action[:5].tolist()))][g] += d.value[0]
        dd[tuple(tuple(d.action[:6].tolist()))][g] += d.value[0]
        dd[tuple(tuple(d.action[:7].tolist()))][g] += d.value[0]
        dd[tuple(tuple(d.action[:8].tolist()))][g] += d.value[0]
        dd[tuple(tuple(d.action[:9].tolist()))][g] += d.value[0]
        dd[tuple(tuple(d.action[:10].tolist()))][g] += d.value[0]

print(f"len(dd) = {len(dd)}")


In [None]:
def eval_prefix(model, game, prefix):
    serial_evaluator = ActionHistoryTransformerEvaluator(model, device=device, block_size=n_max_context, vocab=action_vocab)
    state = game.initial_state()
    for action in prefix:
        state = game.next_state(state, action)
    legal_actions = game.legal_actions(state)
    result = serial_evaluator.evaluate(game, state, legal_actions)
    return result


In [None]:
## Someting is borked? Player1 win percent should be much higher??
def compare_model_vs_data(model, game, dd):    
    list(dd.items())[10][1]['*'].sum() > 100
    top_k = sorted(dd.items(), key=lambda kv: kv[1]['*'].sum(), reverse=True)[:20]
    top_k_keys = sorted(k for k, v in top_k)
    
    prefix_list = top_k_keys

    # prefix_list = [
    #     (0,), 
    #     (0,1), (0,2), (0,3), (0,4), (0,5), (0,6), (0,7),
    #     (0,1,1), (0,1,2), (0,1,3), (0,1,4), (0,1,5), (0,1,6), (0,1,7),
    #     (0,4,1), (0,4,2), (0,4,3), (0,4,4), (0,4,5), (0,4,6), (0,4,7),
    # ]

    for prefix in prefix_list:
        print(f"\nprefix={prefix}")
        for gen, counts in dd[prefix].items():
            if gen == '*':
                print(f"gen={gen}: {counts}, win_pct={100*counts[0]/sum(counts):.2f}%, sum={sum(counts)}")
        # # assert prefix[0] == 0
        actions = prefix[1:]
        eval_result = eval_prefix(model, game, actions)
        # print(f'legal_policy={eval_result.legal_policy}')
        # print(f'player_values={eval_result.player_values}')
        print(f'player_probs={(eval_result.player_values+1)/2}')

compare_model_vs_data(current_model, game, dd)


In [None]:
# Copy model
model_0 = create_random_model(model_config, action_vocab_size=action_vocab.vocab_size, num_players=game.num_players(state_0), seed=42, device=device)
if RUN_GENERATIONS:
    model_1 = load_model(1)


In [None]:
print("\n\n### Model 0")
print(model_0.action_embedding.weight)
compare_model_vs_data(model_0, game, dd)

In [None]:
if RUN_GENERATIONS:
    print("\n\n### Model 1")
    print(model_1.action_embedding.weight)
    compare_model_vs_data(model_1, game, dd)

## Run tournament to calcualte ELO


In [None]:
import asyncio
import numpy as np
from contextlib import asynccontextmanager
from rgi.rgizero.tournament import Tournament
from rgi.rgizero.players.alphazero import AlphazeroPlayer
from rgi.rgizero.models.action_history_transformer import ActionHistoryTransformerEvaluator, AsyncNetworkEvaluator

@asynccontextmanager
async def create_player_factory(model, simulations, game, device, block_size, action_vocab, max_batch_size):
    """
    Creates a shared evaluator and returns a factory function that produces 
    new AlphazeroPlayer instances using that shared evaluator.
    """
    # 1. Setup the shared evaluator
    serial_evaluator = ActionHistoryTransformerEvaluator(
        model, 
        device=device, 
        block_size=n_max_context, 
        vocab=action_vocab
    )
    async_evaluator = AsyncNetworkEvaluator(
        base_evaluator=serial_evaluator, 
        max_batch_size=max_batch_size, 
        verbose=False
    )
    
    # 2. Start the evaluator background task
    await async_evaluator.start()
    
    try:
        # 3. Define the factory. This is called by Tournament for every game.
        # It creates a NEW player instance but uses the SHARED async_evaluator.
        def player_factory():
            # Create a fresh RNG for each game/player instance
            rng = np.random.default_rng(np.random.randint(0, 2**31))
            return AlphazeroPlayer(
                game, 
                async_evaluator, 
                rng=rng, 
                add_noise=True, 
                simulations=simulations
            )
            
        yield player_factory
        
    finally:
        # 4. Cleanup
        await async_evaluator.stop()

async def run_tournament_async():
    # Use async with to manage the lifecycle of the evaluators
    async with (
        # create_player_factory(model_dict[0], 100, game, device, block_size, action_vocab, 10) as factory_gen0_100,
        # create_player_factory(model_dict[1], 100, game, device, block_size, action_vocab, 10) as factory_gen1_100,
        # create_player_factory(model_dict[2], 100, game, device, block_size, action_vocab, 10) as factory_gen2_100,
        # create_player_factory(model_dict[3], 100, game, device, block_size, action_vocab, 10) as factory_gen3_100,
        # create_player_factory(model_dict[4], 100, game, device, block_size, action_vocab, 10) as factory_gen4_100,
        # create_player_factory(model_dict[5], 100, game, device, block_size, action_vocab, 10) as factory_gen5_100,
        # create_player_factory(model_dict[10], 100, game, device, block_size, action_vocab, 10) as factory_gen6_100,
        # create_player_factory(model_dict[15], 100, game, device, block_size, action_vocab, 10) as factory_gen7_100,
        # create_player_factory(model_dict[20], 100, game, device, block_size, action_vocab, 10) as factory_gen8_100,

        create_player_factory(model_dict[0], 200, game, device, block_size, action_vocab, 10) as factory_gen0_200,
        #create_player_factory(model_dict[1], 200, game, device, block_size, action_vocab, 10) as factory_gen1_200,
        #create_player_factory(model_dict[2], 200, game, device, block_size, action_vocab, 10) as factory_gen2_200,
        #create_player_factory(model_dict[3], 200, game, device, block_size, action_vocab, 10) as factory_gen3_200,
        #create_player_factory(model_dict[4], 200, game, device, block_size, action_vocab, 10) as factory_gen4_200,
        create_player_factory(model_dict[5], 200, game, device, block_size, action_vocab, 10) as factory_gen5_200,
        #create_player_factory(model_dict[10], 200, game, device, block_size, action_vocab, 10) as factory_gen10_200,
        #create_player_factory(model_dict[15], 200, game, device, block_size, action_vocab, 10) as factory_gen15_200,
        create_player_factory(model_dict[20], 200, game, device, block_size, action_vocab, 10) as factory_gen20_200,
        ):
        
        # The dictionary now maps names to FACTORIES (Callables), not Player instances
        player_factories = {
            # "factory_gen0_100": factory_gen0_100,
            # "factory_gen1_100": factory_gen1_100,
            # "factory_gen2_100": factory_gen2_100,
            # "factory_gen3_100": factory_gen3_100,
            # "factory_gen4_100": factory_gen4_100,
            # "factory_gen5_100": factory_gen5_100,
            # "factory_gen6_100": factory_gen6_100,
            # "factory_gen7_100": factory_gen7_100,

            "factory_gen0_200": factory_gen0_200,
            #"factory_gen1_200": factory_gen1_200,
            #"factory_gen2_200": factory_gen2_200,
            #"factory_gen3_200": factory_gen3_200,
            #"factory_gen4_200": factory_gen4_200,
            "factory_gen5_200": factory_gen5_200,
            #"factory_gen10_200": factory_gen10_200,
            #"factory_gen15_200": factory_gen15_200,
            "factory_gen20_200": factory_gen20_200,
        }
        
        tournament = Tournament(game, player_factories, initial_elo=1000)
        
        print("Running tournament...")
        # await tournament.run(num_games=1_000, concurrent_games=2000)
        await tournament.run(num_games=100, concurrent_games=2000)
        tournament.print_standings()

# RUN_TOURNAMENT = True
if RUN_TOURNAMENT:
    await run_tournament_async()

# Running tournament...
# Tournament Progress: 100%|██████████| 10000/10000 [1:25:59<00:00,  1.94it/s]

# Tournament Standings:
# Rank  Player               ELO        Games    W-L-D          
# -----------------------------------------------------------------
# 1     factory_gen6_200     1140.5     1247     827-419-1      
# 2     factory_gen2_200     1100.1     1251     693-554-4      
# 3     factory_gen5_100     1074.4     1251     598-652-1      
# 4     factory_gen3_200     1029.1     1252     674-573-5      
# 5     factory_gen4_200     1027.0     1248     711-536-1      
# 6     factory_gen0_200     1020.0     1254     444-810-0      
# 7     factory_gen5_200     990.2      1248     742-502-4      
# 8     factory_gen7_100     987.5      1250     650-597-3      
# 9     factory_gen7_200     979.2      1248     768-476-4      
# 10    factory_gen2_100     974.0      1249     522-723-4      
# 11    factory_gen6_100     966.6      1248     684-564-0      
# 12    factory_gen4_100     964.2      1251     557-693-1      
# 13    factory_gen1_100     962.5      1252     547-705-0      
# 14    factory_gen3_100     947.0      1251     528-723-0      
# 15    factory_gen1_200     941.1      1252     630-620-2      
# 16    factory_gen0_100     896.5      1248     410-838-0     


## 20 generations.
# Running tournament...
# Tournament Progress: 100%|██████████| 1000/1000 [08:35<00:00,  1.94it/s]

# Tournament Standings:
# Rank  Player               ELO        Games    W-L-D          
# -----------------------------------------------------------------
# 1     factory_gen10_200    1114.2     333      212-120-1      
# 2     factory_gen2_200     1032.6     333      190-141-2      
# 3     factory_gen1_200     1003.9     334      159-175-0      
# 4     factory_gen20_200    1000.9     335      171-164-0      
# 5     factory_gen5_200     974.6      331      183-146-2      
# 6     factory_gen0_200     873.8      334      82-251-1  

# Tune Model (continued)


In [None]:
tuner = Tuner(
    fixed_params=fixed_params.copy(),
    initial_params=initial_params.copy(),
    tune_options=tune_options.copy(), 
    computed_tune_options=computed_tune_options.copy(),
    cache_version=TUNER_VERSION,
    target_improvement_per_minute=0.01)
tuner.autotune_smart()

# Using initial model as baseline.
# ## Initial Model, loss=2.1298508644104004 elapsed=171.78943705558777s
# ## Searching generation 0 with 22 candidates, including ['bias: False -> True', 'learning_rate: 0.005 -> 0.002', 'learning_rate: 0.005 -> 0.002', 'dtype: bfloat16 -> float16', 'weight_decay: 0.1 -> 0.2']
# ## improved: False, loss=2.1332 elapsed=178.64s, mutation bias: False -> True
# ## improved: False, loss=2.1395 elapsed=172.03s, mutation learning_rate: 0.005 -> 0.002
# ## improved: False, loss=2.1395 elapsed=172.03s, mutation learning_rate: 0.005 -> 0.002


In [None]:
reload_local_modules(verbose=False)

tuner = Tuner(
    fixed_params=fixed_params.copy(),
    initial_params=initial_params.copy(),
    tune_options=tune_options.copy(), 
    computed_tune_options=computed_tune_options.copy(),
    cache_version=TUNER_VERSION,
    target_improvement_per_minute=0.001)
tuner.autotune_smart()
from rgi.rgizero.models.action_history_transformer import ActionHistoryTransformer, ActionHistoryTransformerEvaluator
from rgi.rgizero.models.transformer import TransformerConfig

tiny_config: TransformerConfig = TransformerConfig(n_max_context=100, n_layer=2, n_head=2, n_embd=8)
tiny_model = ActionHistoryTransformer(config=tiny_config, action_vocab_size=action_vocab.vocab_size, num_players=game.num_players(state_0))
tiny_model.to(device)
tiny_evaluator = ActionHistoryTransformerEvaluator(tiny_model, device=device, block_size=5, vocab=action_vocab)


In [None]:
tuner = Tuner(
    fixed_params=fixed_params.copy(),
    initial_params=initial_params.copy(),
    tune_options=tune_options.copy(), 
    computed_tune_options=computed_tune_options.copy(),
    cache_version=TUNER_VERSION,
    target_improvement_per_minute=0.0001)
tuner.autotune_smart()


In [None]:
reload_local_modules(verbose=False)

tuner = Tuner(
    fixed_params=fixed_params.copy(),
    tune_options=tune_options.copy(), 
    initial_params=initial_params.copy(),
    computed_tune_options=computed_tune_options.copy(),
    cache_version=TUNER_VERSION,
    target_improvement_per_minute=0.0)
tuner.autotune_smart()

## Debug best model

In [None]:
reload_local_modules(verbose=False)

best_model = tuner.load_best_model()
compare_model_vs_data(best_model, game, dd)


In [None]:
from pprint import pprint
print(f'tuner.best_loss={tuner.best_loss}')
print(f'tuner.best_loss_elapsed={int(tuner.best_loss_elapsed)//60}m{tuner.best_loss_elapsed%60:.0f}s')
pprint(tuner.best_params)
# best_params = tuner.initial_params


## Print tuner stats


In [None]:
reload_local_modules(verbose=False)
tuner = Tuner(
    fixed_params=fixed_params.copy(),
    tune_options=tune_options.copy(), 
    initial_params=initial_params.copy(),
    computed_tune_options=computed_tune_options.copy(),
    cache_version=TUNER_VERSION,
    target_improvement_per_minute=0.001)
# print stats based on cached results.
tuner_stats = tuner.print_hparam_stats()

In [None]:
tuner_stats

In [None]:
# [(k,v['mean_val_delta']) for (k,v) in sorted(tuner_stats.items(), key=lambda x: x[1]['mean_val_delta'], reverse=True)]

for x in  sorted([(v['mean_val_delta'], k, v['mean_val_1'], v['mean_val_2']) for (k,v) in tuner_stats.items() if not np.isnan(v['mean_val_delta'])], reverse=True): print(x)
# sorted([(v['mean_val_delta'], k) for (k,v) in tuner_stats.items() if not np.isnan(v['mean_val_delta'])], reverse=True)


In [None]:
for x in sorted([(v['std_val_delta'], k, v['mean_val_1'], v['mean_val_2']) for (k,v) in tuner_stats.items() if not np.isnan(v['std_val_delta'])], reverse=True): print(x)


## Debug Convergence

Synthetic sanity-check: train on a toy 2-step game where the first action strongly determines the winner. This verifies the value head and training loop can learn simple patterns.


In [None]:
raise NotImplementedError("xxx STOP HERE xxx")

In [None]:
state_0 = game.initial_state()
all_actions_0 = game.all_actions()

print(all_actions_0)


In [None]:
import random

def play_random_game_with_fake_reward(game, max_actions) -> dict:
    state = game.initial_state()
    action_history = []
    legal_policies = []
    legal_action_idx_list = []

    all_actions = game.all_actions()
    all_action_idx_map = {action: idx for idx, action in enumerate(all_actions)}

    num_actions = 0
    while not game.is_terminal(state) and num_actions < max_actions:
        current_player = game.current_player_id(state)
        legal_actions = game.legal_actions(state)
        action_idx = random.randrange(len(legal_actions))
        action = legal_actions[action_idx]

        action_history.append(action)
        legal_policies.append(np.ones(len(legal_actions))/len(legal_actions))
        legal_action_idx = np.array([all_action_idx_map[action] for action in legal_actions])
        legal_action_idx_list.append(legal_action_idx)

        state = game.next_state(state, action)
        num_actions += 1

    # Determine outcome
    fake_reward = np.mean(action_history) / len(legal_actions)
    rewards = np.array([fake_reward, 1.0-fake_reward])
    if fake_reward >= 0.5:
        winner = 1
    else:
        winner = 2

    return {
        "winner": winner,
        "rewards": rewards,
        "action_history": action_history,
        "legal_policies": legal_policies,
        "final_state": state,
        "legal_action_idx": legal_action_idx_list,
    }

In [None]:
play_random_game_with_fake_reward(game, max_actions=2)

In [None]:
results = [play_random_game_with_fake_reward(game, max_actions=2) for _ in range(100_000)]
print_game_stats(results)


In [None]:
fake_gen_name = "fake-0"
trajectory_path = write_trajectory_dataset(results, action_vocab, fake_gen_name)


In [None]:
# fake_model_config = model_config_dict[MODEL_SIZE]
fake_model_config = model_config_dict["large"]
fake_model = create_random_model(fake_model_config, action_vocab_size=action_vocab.vocab_size, num_players=game.num_players(state_0), seed=42, device=device)

training_splits = [f'gen-{fake_gen_name}']
fake_model, fake_trainer = train_model(fake_model, training_splits, train_config)
save_model(fake_model, fake_trainer, fake_gen_name)

## model_size=tiny
# num decayed parameter tensors: 11, with 1,968 parameters
# num non-decayed parameter tensors: 7, with 50 parameters
# using fused AdamW: False
# step 0: train loss 2.7817, val loss 2.7816
# iter 0/49/488: loss 2.7821, time 2537.56ms
# iter 100/147/488: loss 2.6890, time 53.61ms
# iter 200/245/488: loss 2.6342, time 63.05ms
# iter 300/343/488: loss 2.6187, time 55.31ms
# iter 400/441/488: loss 2.6147, time 61.11ms

## model_size=large
# num decayed parameter tensors: 35, with 1,579,776 parameters
# num non-decayed parameter tensors: 19, with 2,186 parameters
# using fused AdamW: False
# step 0: train loss 2.8087, val loss 2.8088
# iter 0/49/488: loss 2.8099, time 11225.20ms
# iter 100/147/488: loss 2.6065, time 596.91ms
# iter 200/245/488: loss 2.6075, time 618.00ms
# iter 300/343/488: loss 2.6080, time 613.63ms
# iter 400/441/488: loss 2.6051, time 616.39ms

In [None]:
# for rerun in range(10):
#     print(f"Re-running training for {fake_gen_name} {rerun+1} of 10")
#     fake_model, fake_trainer = train_model(fake_model, training_splits, train_config)
#     save_model(fake_model, fake_trainer, fake_gen_name)

In [None]:
# [td for td in td_array]
fake_td_array = [TrajectoryDataset(DATA_DIR, split, block_size=n_max_context) for split in training_splits]
fake_unrolled = [(generation+1, d) for generation, td in enumerate(fake_td_array) for d in td]

# gen, d = unrolled[0], 
# d.action[:2]
# d.value[0]

# Inspect training data
fake_dd = defaultdict(lambda: defaultdict(lambda: torch.tensor([0., 0.])))

for gen, d in fake_unrolled:
    for g in ['*', gen]:    
        fake_dd[tuple(tuple(d.action[:0].tolist()))][g] += d.value[0]
        fake_dd[tuple(tuple(d.action[:1].tolist()))][g] += d.value[0]
        fake_dd[tuple(tuple(d.action[:2].tolist()))][g] += d.value[0]
        # fake_dd[tuple(tuple(d.action[:3].tolist()))][g] += d.value[0]

print(f"len(fake_dd) = {len(fake_dd)}")


In [None]:
fake_model = load_model(fake_gen_name)
compare_model_vs_data(fake_model, game, dd)


In [None]:
fake_model, fake_trainer = train_model(fake_model, training_splits, train_config)
