# Step-by-step run of alphazero self-play & training.


In [None]:
import os
import time
from pathlib import Path
import asyncio

import numpy as np
import torch

# Game and players
from rgi.rgizero.experiment import ExperimentRunner, ExperimentConfig
from rgi.rgizero.data.trajectory_dataset import Vocab, print_dataset_stats, TrajectoryDataset
from rgi.rgizero.evaluators import ActionHistoryTransformerEvaluator, AsyncNetworkEvaluator
from rgi.rgizero.models.tuner import create_random_model

import notebook_utils
from notebook_utils import reload_local_modules

device = notebook_utils.detect_device()

## Disable for debugger stability?
# # Allow asyncio to work with jupyter notebook
# import nest_asyncio
# nest_asyncio.apply()

# Increase numpy print width
np.set_printoptions(linewidth=300)

%load_ext line_profiler

In [None]:
RUN_GENERATIONS = True


# Create Experiment Config
experiment_config = ExperimentConfig(
    experiment_name='smoketest-e2e-v2',
    parent_experiment_name='smoketest-e2e',
    game_name='connect4',
    num_generations=20,
    num_games_per_gen=10_000,
    num_simulations=50,
    model_size="tiny",
    train_batch_size=10,
    max_training_epochs=2,
    seed=42
)

# Tuned params from connect4 with 23k training games.
tuned_params = {
    'batch_size': 256,
    'beta1': 0.9,
    'beta2': 0.95,
    'bias': True,
    'decay_lr': True,
    'dropout': 0.0,
    'dtype': 'float16',
    'grad_clip': 1.0,
    'gradient_accumulation_steps': 1,
    'learning_rate': 0.002,
    'lr_decay_iters': 1000,
    'max_epochs': 1000000,
    'max_iters': 1000,
    'min_lr': 0.0002,
    'n_embd': 256,
    'n_head': 2,
    'n_layer': 3,
    'n_max_context': 44,
    'warmup_iters': 100,  # TODO: Tuner says this should be 500? That seems high...
    'weight_decay': 0.1,
    # 'model_name': 'c4-smoketest',
    # 'model_version': '0.1',
    # 'num_players': 2,
    # 'vocab_size': 8,
    # 'dataset_paths': (
    #     PosixPath('/Users/rodo/src/rgi3-sync/experiments/smoketest-e2e/data/gen-1'),
    #     PosixPath('/Users/rodo/src/rgi3-sync/experiments/smoketest-e2e/data/gen-2'),
    #     PosixPath('/Users/rodo/src/rgi3-sync/experiments/smoketest-e2e/data/gen-3'),
    #     PosixPath('/Users/rodo/src/rgi3-sync/experiments/smoketest-e2e-v2/data/gen-4'),
    #     PosixPath('/Users/rodo/src/rgi3-sync/experiments/smoketest-e2e-v2/data/gen-5')
    #     ),
    'eval_iters': 200,
    'log_interval': 1000,
    'eval_interval': 10000,
    # 'device': 'mps',
    }

## Step 1: Set up game and experiment runner


In [None]:
from rgi.rgizero.data.trajectory_dataset import Vocab
from rgi.rgizero.common import TOKENS

# Initialize Experiment Runner
experiment_base_dir = Path.cwd().parent / 'experiments'
experiment_runner = ExperimentRunner(experiment_config, experiment_base_dir, training_args=tuned_params)
game = experiment_runner.game
action_vocab = experiment_runner.action_vocab
n_max_context = experiment_runner.n_max_context

DATA_DIR = experiment_runner.data_dir
MODEL_DIR = experiment_runner.models_dir

print('âœ… Runner initialized')
print(f'Game: {experiment_runner.config.game_name}, Players: {experiment_runner.num_players}, Actions: {list(game.base_game.all_actions())}')
print('Data dir: ', DATA_DIR)
print('Model dir: ', MODEL_DIR)


# Debug stuff

In [None]:
# Initialize (creates Random Gen 0 if needed)
model_0 = experiment_runner.initialize()
current_model = model_0


In [None]:
RUN_GENERATIONS = False
results_dict = {}
trajectory_paths_dict = {}
model_dict = {0: model_0}

current_model = model_dict[0]
if RUN_GENERATIONS:
    for generation_id in range(1, experiment_config.num_generations+1):
        current_model = await experiment_runner.run_generation_step_async(generation_id, current_model)
        dataset_paths = experiment_runner.get_trajectory_paths(generation_id)
        
        # print stats for visibility
        print_dataset_stats(dataset_paths, n_max_context, action_vocab, model=current_model, game=game)
        
        model_dict[generation_id] = current_model


# Tune Model (initial)


In [None]:
reload_local_modules(verbose=False)

state_0 = game.initial_state()
NUM_GENERATIONS = 5
LEARNING_RATE = 0.1

# Parameters which will never be used for tuning.
fixed_params = dict(
    model_name='c4-smoketest',
    model_version='0.1',
    num_players = game.num_players(state_0),
    vocab_size = action_vocab.vocab_size,
    dataset_paths = tuple(experiment_runner.get_trajectory_paths(experiment_config.num_generations)),


    eval_iters = 200,
    log_interval = 1000,
    eval_interval = 10_000,

    device = device,
)

initial_params = dict(
    n_layer=2,
    n_head=2,
    n_embd=8,  # tiny model

    n_max_context=n_max_context,
    batch_size = 32,
    gradient_accumulation_steps = 1,

    max_iters=100,
    max_epochs=1_000_000, # Make max_epoch high, rely on max_iters to stop.
        
    learning_rate = LEARNING_RATE,    
    decay_lr = True,  # whether to decay the learning rate
    lr_decay_iters = 100,  # make equal to max_iters usually
    min_lr = LEARNING_RATE / 10,  # learning_rate / 10 usually
    warmup_iters = 0,  # not super necessary potentially

    weight_decay = 1e-1,
    beta1 = 0.9,
    beta2 = 0.95,
    grad_clip = 1.0,  # clip gradients at this value, or disable if == 0.0

    dtype = "float16",

    dropout = 0.0,
    bias = False,  # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster
    last_file = None,   # Used in tuning key only.
)

tune_options = dict(
    n_layer = [1, 2, 3, 4, 5, 6, 8, 10, 12, 16, 32],
    # n_head = [1, 2, 4, 8, 16, 32],   # Needs to be calcualted to ensure n_embed % n_head == 0
    n_embd = [8, 16, 32, 64, 128, 256, 512, 1024, 2048],

    n_max_context = [initial_params['n_max_context']],
    batch_size = [16, 32, 64, 128, 256, 512, 1024],
    gradient_accumulation_steps = [1],  # TODO: We only support 1 for now. This fails is we don't have an exact multiple of the batch size per epoch.

    max_iters = [100, 300, 1_000, 3_000, 5_000, 10_000, 30_000, 100_000, 300_000],
    max_epochs = [1_000_000], # Make max_epoch high, rely on max_iters to stop.
 
    learning_rate = [0.001, 0.002, 0.005, 0.01, 0.02, 0.05, 0.1, 0.2, 0.5, 1.0],
    decay_lr = [False, True],

    # TODO: What is a sensible range here?
    beta1 = [0.90, 0.95, 0.99],
    beta2 = [0.95, 0.98, 0.99],

    weight_decay = [0.01, 0.05, 0.1, 0.2],
    grad_clip = [0,0, 1.0],  # clip gradients at this value, or disable if == 0.0

    dtype = ["bfloat16", "float16"],
    dropout = [0.0, 0.01, 0.02, 0.05, 0.1],
    bias = [True, False],    
)

_n_head_options = [1, 2, 4, 8, 16, 32]
computed_tune_options = dict(
    min_lr = lambda opt: [opt['learning_rate'] / 10],
    lr_decay_iters = lambda opt: [opt['max_iters']],
    warmup_iters = lambda opt: [x for x in [0, 100, 500, 1000] if x < opt['lr_decay_iters']] if opt['decay_lr'] else [0],
    n_head = lambda opt: [n for n in _n_head_options if opt['n_embd'] % n == 0],
    last_file = lambda opt: [str(opt['dataset_paths'][-1])],
)

TUNER_VERSION = "0.0.4-smoketest"

from rgi.rgizero.models.tuner import Tuner

tuner = Tuner(
    fixed_params=fixed_params.copy(),
    initial_params=initial_params.copy(),
    tune_options=tune_options.copy(), 
    computed_tune_options=computed_tune_options.copy(),
    cache_version=TUNER_VERSION,
    target_improvement_per_minute=0.001)
_ = tuner.autotune_smart()


In [None]:
reload_local_modules(verbose=False)

tuner = Tuner(
    fixed_params=fixed_params.copy(),
    initial_params=initial_params.copy(),
    tune_options=tune_options.copy(), 
    computed_tune_options=computed_tune_options.copy(),
    cache_version=TUNER_VERSION,
    target_improvement_per_minute=0.001)

tuner_result = tuner.autotune_smart()
# print(f'tuner_result={tuner_result}')

best_params = tuner.best_params.copy()
## Recalculating with best_params = {'batch_size': 512, 'beta1': 0.9, 'beta2': 0.99, 'bias': False, 'decay_lr': True, 'dropout': 0.0, 'dtype': 'float16', 'grad_clip': 1.0, 'gradient_accumulation_steps': 1, 'last_file': '/Users/rodo/src/rgi3-sync/experiments/smoketest-e2e-v2/data/gen-20', 'learning_rate': 0.01, 'lr_decay_iters': 5000, 'max_epochs': 1000000, 'max_iters': 30000, 'min_lr': 0.001, 'n_embd': 64, 'n_head': 8, 'n_layer': 4, 'n_max_context': 44, 'warmup_iters': 1000, 'weight_decay': 0.2, 'model_name': 'c4-smoketest', 'model_version': '0.1', 'num_players': 2, 'vocab_size': 8, 'dataset_paths': (PosixPath('/Users/rodo/src/rgi3-sync/experiments/smoketest-e2e/data/gen-1'), PosixPath('/Users/rodo/src/rgi3-sync/experiments/smoketest-e2e/data/gen-2'), PosixPath('/Users/rodo/src/rgi3-sync/experiments/smoketest-e2e/data/gen-3'), PosixPath('/Users/rodo/src/rgi3-sync/experiments/smoketest-e2e-v2/data/gen-4'), PosixPath('/Users/rodo/src/rgi3-sync/experiments/smoketest-e2e-v2/data/gen-5'), PosixPath('/Users/rodo/src/rgi3-sync/experiments/smoketest-e2e-v2/data/gen-6'), PosixPath('/Users/rodo/src/rgi3-sync/experiments/smoketest-e2e-v2/data/gen-7'), PosixPath('/Users/rodo/src/rgi3-sync/experiments/smoketest-e2e-v2/data/gen-8'), PosixPath('/Users/rodo/src/rgi3-sync/experiments/smoketest-e2e-v2/data/gen-9'), PosixPath('/Users/rodo/src/rgi3-sync/experiments/smoketest-e2e-v2/data/gen-10'), PosixPath('/Users/rodo/src/rgi3-sync/experiments/smoketest-e2e-v2/data/gen-11'), PosixPath('/Users/rodo/src/rgi3-sync/experiments/smoketest-e2e-v2/data/gen-12'), PosixPath('/Users/rodo/src/rgi3-sync/experiments/smoketest-e2e-v2/data/gen-13'), PosixPath('/Users/rodo/src/rgi3-sync/experiments/smoketest-e2e-v2/data/gen-14'), PosixPath('/Users/rodo/src/rgi3-sync/experiments/smoketest-e2e-v2/data/gen-15'), PosixPath('/Users/rodo/src/rgi3-sync/experiments/smoketest-e2e-v2/data/gen-16'), PosixPath('/Users/rodo/src/rgi3-sync/experiments/smoketest-e2e-v2/data/gen-17'), PosixPath('/Users/rodo/src/rgi3-sync/experiments/smoketest-e2e-v2/data/gen-18'), PosixPath('/Users/rodo/src/rgi3-sync/experiments/smoketest-e2e-v2/data/gen-19'), PosixPath('/Users/rodo/src/rgi3-sync/experiments/smoketest-e2e-v2/data/gen-20')), 'eval_iters': 200, 'log_interval': 1000, 'eval_interval': 10000, 'device': 'mps'}
## {'train': 2.2821048521995544, 'train_policy_loss': 1.746874241232872, 'train_value_loss': 0.5352306108176709, 'val': 2.324920549112208, 'val_policy_loss': 1.7480337900273941, 'val_value_loss': 0.5768867538255804, 'elapsed': 1651.836928844452, 'param_hash': '03189f0f4cb118a1ec142e488fe3fac12e97c118f60661796f8a5a64fa871d44'}

# best_params['max_iters'] = 30_000 # {'train': 2.2821048521995544, 'train_policy_loss': 1.746874241232872, 'train_value_loss': 0.5352306108176709, 'val': 2.324920549112208, 'val_policy_loss': 1.7480337900273941, 'val_value_loss': 0.5768867538255804, 'elapsed': 1651.836928844452, 'param_hash': '03189f0f4cb118a1ec142e488fe3fac12e97c118f60661796f8a5a64fa871d44'}
# best_params['max_iters'] = 10_000 #{'train': 2.2972692108154296, 'train_policy_loss': 1.7478227978944778, 'train_value_loss': 0.5494464221596718, 'val': 2.3170768583522126, 'val_policy_loss': 1.7486604136579178, 'val_value_loss': 0.5684164271635168, 'elapsed': 529.6186480522156, 'param_hash': 'c81f68b9b00650b83a229a76b93d77d51f7bd1af43148f6c1af75e8f562f2670'}
# best_params['max_iters'] = 5000 # {'train': 2.3059648656845093, 'train_policy_loss': 1.749609624147415, 'train_value_loss': 0.5563552376627922, 'val': 2.319357395172119, 'val_policy_loss': 1.7505432332263273, 'val_value_loss': 0.5688141549334806, 'elapsed': 267.94705629348755, 'param_hash': 'a1194e1f289d35c51e5a0f01692109358f947530b50dd54364f01baefdb6bedf'}
# best_params['max_iters'] = 3000 # {'train': 2.3192384707927705, 'train_policy_loss': 1.752758464217186, 'train_value_loss': 0.5664800041913987, 'val': 2.326388120651245, 'val_policy_loss': 1.7533490412375505, 'val_value_loss': 0.5730390969444724, 'elapsed': 172.16889691352844, 'param_hash': '9e0204edd0b23fd54025b0c4de66c870a268b95197e0e47c4460ec03875b3dcb'}
# best_params['max_iters'] = 1000
best_params['max_iters'] = 3000

# best_params['learning_rate'] = 0.01 # {'train': 2.2821048521995544, 'train_policy_loss': 1.746874241232872, 'train_value_loss': 0.5352306108176709, 'val': 2.324920549112208, 'val_policy_loss': 1.7480337900273941, 'val_value_loss': 0.5768867538255804, 'elapsed': 1651.836928844452, 'param_hash': '03189f0f4cb118a1ec142e488fe3fac12e97c118f60661796f8a5a64fa871d44'}
# best_params['learning_rate'] = 0.0005
best_params['learning_rate'] = 0.0001

print(f'## Recalculating with best_params = {best_params}')
best_params = tuner._recalculate_tunable_params(best_params)
best_model = tuner.get_model_for_params(best_params)
print(tuner.train_and_compute_loss(best_params, reload_model=True)[2])



In [None]:
reload_local_modules(verbose=False)

generation_id = 20
dataset_paths = experiment_runner.get_trajectory_paths(generation_id)
print_dataset_stats(dataset_paths, n_max_context, action_vocab, model=best_model, game=game)

# 1000 iter training
transform_config_fields: {'n_max_context', 'n_layer', 'dropout', 'n_head', 'bias', 'n_embd'}
train_config_fields: {'decay_lr', 'wandb_log', 'warmup_iters', 'dtype', 'compile', 'batch_size', 'model_version', 'lr_decay_iters', 'eval_iters', 'log_interval', 'model_name', 'device', 'gradient_accumulation_steps', 'weight_decay', 'eval_interval', 'grad_clip', 'always_save_checkpoint', 'beta1', 'min_lr', 'eval_only', 'max_epochs', 'max_iters', 'learning_rate', 'beta2'}
Using forked data for gen 1 from /Users/rodo/src/rgi3-sync/experiments/smoketest-e2e/data/gen-1
Using forked data for gen 2 from /Users/rodo/src/rgi3-sync/experiments/smoketest-e2e/data/gen-2
Using forked data for gen 3 from /Users/rodo/src/rgi3-sync/experiments/smoketest-e2e/data/gen-3
Dataset Stats:
  Trajectories: 173000
  Total actions: 2754358
  Avg trajectory length: 15.92
Prefix Stats:
actions=(): 173000 win=104006 loss=68957 draw=37 win1%=60.12 model-win1%=59.44
actions=(1,): 46966 win=26530 loss=20423 draw=13 win1%=56.49 model-win1%=55.45
actions=(1, 1): 13319 win=7577 loss=5739 draw=3 win1%=56.89 model-win1%=58.33
actions=(2,): 8666 win=4963 loss=3698 draw=5 win1%=57.27 model-win1%=58.16
actions=(3,): 16091 win=10062 loss=6024 draw=5 win1%=62.53 model-win1%=61.10
actions=(4,): 19604 win=14349 loss=5254 draw=1 win1%=73.19 model-win1%=72.11
actions=(5,): 18147 win=11351 loss=6793 draw=3 win1%=62.55 model-win1%=62.14
actions=(6,): 17995 win=10412 loss=7581 draw=2 win1%=57.86 model-win1%=57.84
actions=(7,): 45531 win=26339 loss=19184 draw=8 win1%=57.85 model-win1%=57.81
actions=(7, 1): 11328 win=7096 loss=4230 draw=2 win1%=62.64 model-win1%=61.39
actions=(7, 7): 8785 win=5147 loss=3637 draw=1 win1%=58.59 model-win1%=60.27

# 3000 iter training
transform_config_fields: {'n_max_context', 'n_layer', 'dropout', 'n_head', 'bias', 'n_embd'}
train_config_fields: {'decay_lr', 'wandb_log', 'warmup_iters', 'dtype', 'compile', 'batch_size', 'model_version', 'lr_decay_iters', 'eval_iters', 'log_interval', 'model_name', 'device', 'gradient_accumulation_steps', 'weight_decay', 'eval_interval', 'grad_clip', 'always_save_checkpoint', 'beta1', 'min_lr', 'eval_only', 'max_epochs', 'max_iters', 'learning_rate', 'beta2'}
Using forked data for gen 1 from /Users/rodo/src/rgi3-sync/experiments/smoketest-e2e/data/gen-1
Using forked data for gen 2 from /Users/rodo/src/rgi3-sync/experiments/smoketest-e2e/data/gen-2
Using forked data for gen 3 from /Users/rodo/src/rgi3-sync/experiments/smoketest-e2e/data/gen-3
Dataset Stats:
  Trajectories: 173000
  Total actions: 2754358
  Avg trajectory length: 15.92
Prefix Stats:
actions=(): 173000 win=104006 loss=68957 draw=37 win1%=60.12 model-win1%=59.36
actions=(1,): 46966 win=26530 loss=20423 draw=13 win1%=56.49 model-win1%=55.66
actions=(1, 1): 13319 win=7577 loss=5739 draw=3 win1%=56.89 model-win1%=58.03
actions=(2,): 8666 win=4963 loss=3698 draw=5 win1%=57.27 model-win1%=58.40
actions=(3,): 16091 win=10062 loss=6024 draw=5 win1%=62.53 model-win1%=61.81
actions=(4,): 19604 win=14349 loss=5254 draw=1 win1%=73.19 model-win1%=73.78
actions=(5,): 18147 win=11351 loss=6793 draw=3 win1%=62.55 model-win1%=61.63
actions=(6,): 17995 win=10412 loss=7581 draw=2 win1%=57.86 model-win1%=58.78
actions=(7,): 45531 win=26339 loss=19184 draw=8 win1%=57.85 model-win1%=57.19
actions=(7, 1): 11328 win=7096 loss=4230 draw=2 win1%=62.64 model-win1%=63.81
actions=(7, 7): 8785 win=5147 loss=3637 draw=1 win1%=58.59 model-win1%=59.49