# Step-by-step run of alphazero self-play & training.


In [None]:

import os
from pathlib import Path

import numpy as np
import torch
import torch.nn.functional as F

# Game and players
from rgi.rgizero.games.count21 import Count21Game
from rgi.rgizero.games.connect4 import Connect4Game
from rgi.rgizero.players.alphazero import AlphazeroPlayer
from rgi.rgizero.players.random_player import RandomPlayer


from rgi.rgizero.common import TOKENS

import notebook_utils
from notebook_utils import reload_local_modules

print("✅ Imports successful")

device = notebook_utils.detect_device(require_accelerator=True)

# Allow asyncio to work with jupyter notebook
import nest_asyncio
nest_asyncio.apply()

# Increase numpy print width
np.set_printoptions(linewidth=300)

DATA_DIR = Path.cwd().parent / "data" / "rgizero-e2e"
os.makedirs(DATA_DIR, exist_ok=True)

## Step 1: Set up history-wrapped game

In [None]:
from rgi.rgizero.games.history_wrapper import HistoryTrackingGame

# Wrap our game with history tracking
# base_game, max_game_length = Count21Game(), 21
# base_game, max_game_length = Connect4Game(), 7*6
base_game, max_game_length = Connect4Game(connect_length=5), 7*6  # Make it harder to connect! This helps test variable policy and longer games.

game = HistoryTrackingGame(base_game)
state_0 = game.initial_state()
block_size = max_game_length + 2

print("✅ Using HistoryTrackingGame from module")
print(f"Game: {base_game.__class__.__name__}, Players: {game.num_players(state_0)}, Actions: {list(game.all_actions())}")

## Step 2: Confirm we can self-play a game with a Random Evaluator.

In [None]:
reload_local_modules(verbose=False)

from rgi.rgizero.players.alphazero import AlphazeroPlayer, play_game, NetworkEvaluatorResult, NetworkEvaluator
from typing import override, Any

class RandomEvaluator(NetworkEvaluator):
    def __init__(self, seed: int = 42):
        self.rng = np.random.default_rng(seed)

    @override
    def evaluate(self, game, state, legal_actions: list[Any]):
        policy = self.rng.random(len(legal_actions))
        values = self.rng.random(game.num_players(state))
        return NetworkEvaluatorResult(policy, values)

evaluator = RandomEvaluator()
player = AlphazeroPlayer(game, evaluator)
game_result = play_game(game, [player, player])

game_result

## Step 3: Confirm we can read & write to trajectory_dataset

In [None]:
from rgi.rgizero.data.trajectory_dataset import Vocab

all_actions = game.all_actions()
num_players = game.num_players(state_0)

vocab = Vocab(itos=(TOKENS.START_OF_GAME,) + all_actions)
print(f"Vocab: {vocab}")

In [None]:
from rgi.rgizero.data.trajectory_dataset import TrajectoryDatasetBuilder, TrajectoryDataset, build_trajectory_loader
reload_local_modules(verbose=False)

def add_trajectory(game_result, vocab, td_builder):
    action_history = game_result['action_history']
    trajectory_length = len(action_history)
    legal_policies = game_result['legal_policies']
    legal_action_idx = game_result['legal_action_idx']
    rewards = game_result['rewards']

    # Translation key for converting legal_action_ids to vocab_action_idx.
    action_idx_to_vocab_idx = vocab.encode(all_actions)

    fixed_width_policies = np.zeros((trajectory_length, vocab.vocab_size))
    for i in range(trajectory_length):
        vocab_action_idx = action_idx_to_vocab_idx[legal_action_idx[i]]
        fixed_width_policies[i, vocab_action_idx] = legal_policies[i]

    encoded_action_history = vocab.encode(action_history)
    tiled_rewards = np.tile(rewards, (trajectory_length, 1))  # shape (num_players,) -> (num_moves, num_players)
    
    td_builder.add_trajectory(actions=encoded_action_history, fixed_width_policies=fixed_width_policies, values=tiled_rewards)

td_builder = TrajectoryDatasetBuilder(vocab)
add_trajectory(game_result, vocab, td_builder)

td_builder.save(DATA_DIR, 'train_1000')

## Load dataset

In [None]:
td = TrajectoryDataset(DATA_DIR, 'train_1000', 5)
td[0]


In [None]:
loader = build_trajectory_loader(
    DATA_DIR, 'train_1000', block_size=5, batch_size=1,
    device_is_cuda=False, workers=4)

for batch in loader:
    print(batch)
    break

## Create model (random weights) and play a single game.

In [None]:
reload_local_modules(verbose=False)

from rgi.rgizero.models.action_history_transformer import ActionHistoryTransformer, ActionHistoryTransformerEvaluator
from rgi.rgizero.models.transformer import TransformerConfig

tiny_config: TransformerConfig = TransformerConfig(n_max_context=100, n_layer=2, n_head=2, n_embd=8)
tiny_model = ActionHistoryTransformer(config=tiny_config, action_vocab_size=vocab.vocab_size, num_players=num_players)
tiny_model.to(device)
tiny_evaluator = ActionHistoryTransformerEvaluator(tiny_model, device=device, block_size=5, vocab=vocab)


In [None]:
legal_actions = game.legal_actions(state_0)
tiny_evaluator.evaluate(game, state_0, legal_actions)

In [None]:
tiny_player = AlphazeroPlayer(game, tiny_evaluator)
tiny_game_result = play_game(game, [tiny_player, tiny_player])

tiny_game_result

In [None]:
a_row = torch.rand(1, 2)
b_row = torch.rand(1, 2)
print(f'a_row: {a_row}')
print(f'b_row: {b_row}')
print(f'loss: {F.cross_entropy(a_row, b_row)}')

tile_shape = (3, 1)
a_tiled = torch.tile(a_row, tile_shape)
b_tiled = torch.tile(b_row, tile_shape)
print(f'a_tiled: {a_tiled}')
print(f'b_tiled: {b_tiled}')
print(f'loss: {F.cross_entropy(a_tiled, b_tiled)}')

print(f'loss: {F.cross_entropy(a_tiled, b_tiled, reduction="sum")}')
