# AlphaZero Self-Play Training

Train agents through self-play, optionally starting from a BC checkpoint.

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/christianwissmann85/ai-cardgame/blob/master/notebooks/05_alphazero_training.ipynb)

In [None]:
# Colab setup (uncomment if needed)
# !curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y
# import os; os.environ['PATH'] = f"{os.environ['HOME']}/.cargo/bin:{os.environ['PATH']}"
# !pip install git+https://github.com/christianwissmann85/ai-cardgame.git

In [None]:
# Setup: Change to repo root directory (required for data files)
import os
from pathlib import Path


def find_repo_root():
    path = Path.cwd()
    while path != path.parent:
        if (path / 'data' / 'cards').exists():
            return path
        path = path.parent
    return None

repo_root = find_repo_root()
if repo_root:
    os.chdir(repo_root)
    print(f"Working directory: {os.getcwd()}")
else:
    print("Warning: Could not find repo root.")

In [None]:
import random
from collections import deque
from dataclasses import dataclass

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm.auto import tqdm

## 1. Network Architecture

Same architecture as BC - shared backbone with policy and value heads.

In [None]:
class ResidualBlock(nn.Module):
    def __init__(self, hidden_dim: int):
        super().__init__()
        self.fc1 = nn.Linear(hidden_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.ln1 = nn.LayerNorm(hidden_dim)
        self.ln2 = nn.LayerNorm(hidden_dim)

    def forward(self, x):
        residual = x
        x = self.ln1(x)
        x = F.relu(self.fc1(x))
        x = self.ln2(x)
        x = self.fc2(x)
        return F.relu(x + residual)


class AlphaZeroNetwork(nn.Module):
    def __init__(self, obs_dim=326, action_dim=256, hidden_dim=256, num_blocks=4):
        super().__init__()
        self.input_proj = nn.Linear(obs_dim, hidden_dim)
        self.blocks = nn.ModuleList([ResidualBlock(hidden_dim) for _ in range(num_blocks)])
        self.policy_head = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, action_dim),
        )
        self.value_head = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Linear(hidden_dim // 2, 1),
            nn.Tanh(),
        )

    def forward(self, obs, mask):
        x = F.relu(self.input_proj(obs))
        for block in self.blocks:
            x = block(x)
        logits = self.policy_head(x)
        logits = logits.masked_fill(~mask, float('-inf'))
        value = self.value_head(x)
        return logits, value

## 2. MCTS Implementation

In [None]:
@dataclass
class MCTSConfig:
    num_simulations: int = 50
    c_puct: float = 1.0
    dirichlet_alpha: float = 0.3
    dirichlet_epsilon: float = 0.25
    temperature: float = 1.0


class MCTSNode:
    def __init__(self, prior: float):
        self.prior = prior
        self.visit_count = 0
        self.value_sum = 0.0
        self.children: dict[int, MCTSNode] = {}

    @property
    def value(self) -> float:
        if self.visit_count == 0:
            return 0.0
        return self.value_sum / self.visit_count

    def ucb_score(self, parent_visits: int, c_puct: float) -> float:
        exploration = c_puct * self.prior * np.sqrt(parent_visits) / (1 + self.visit_count)
        return self.value + exploration

In [None]:
class MCTS:
    def __init__(self, network, config: MCTSConfig, device):
        self.network = network
        self.config = config
        self.device = device

    def search(self, game) -> tuple[np.ndarray, float]:
        """Run MCTS and return policy distribution and root value."""
        root = MCTSNode(prior=0.0)

        # Expand root
        obs = np.array(game.observe(), dtype=np.float32)
        mask = np.array(game.action_mask(), dtype=np.float32)
        policy, value = self._evaluate(obs, mask)

        # Add Dirichlet noise at root for exploration
        legal_actions = np.where(mask > 0.5)[0]
        noise = np.random.dirichlet([self.config.dirichlet_alpha] * len(legal_actions))

        for i, action in enumerate(legal_actions):
            noisy_prior = (
                (1 - self.config.dirichlet_epsilon) * policy[action] +
                self.config.dirichlet_epsilon * noise[i]
            )
            root.children[action] = MCTSNode(prior=noisy_prior)

        # Run simulations
        for _ in range(self.config.num_simulations):
            self._simulate(game.fork(), root)

        # Build policy from visit counts
        policy_out = np.zeros(256, dtype=np.float32)
        for action, child in root.children.items():
            policy_out[action] = child.visit_count

        # Apply temperature
        if self.config.temperature > 0:
            policy_out = policy_out ** (1 / self.config.temperature)

        policy_out = policy_out / policy_out.sum()

        return policy_out, root.value

    def _simulate(self, game, node: MCTSNode) -> float:
        """Run one simulation from node."""
        if game.is_done():
            return game.get_reward(0)  # Return from P0 perspective

        # Select action with highest UCB
        best_action = None
        best_score = -float('inf')

        for action, child in node.children.items():
            score = child.ucb_score(node.visit_count, self.config.c_puct)
            if score > best_score:
                best_score = score
                best_action = action

        if best_action is None:
            return 0.0

        child = node.children[best_action]
        game.step(best_action)

        # Expand if needed
        if child.visit_count == 0 and not game.is_done():
            obs = np.array(game.observe(), dtype=np.float32)
            mask = np.array(game.action_mask(), dtype=np.float32)
            policy, value = self._evaluate(obs, mask)

            legal_actions = np.where(mask > 0.5)[0]
            for action in legal_actions:
                child.children[action] = MCTSNode(prior=policy[action])

            # Value from current player's perspective
            if game.current_player() == 0:
                value = value
            else:
                value = -value
        else:
            value = -self._simulate(game, child)  # Negamax

        # Backpropagate
        child.visit_count += 1
        child.value_sum += value

        return -value

    def _evaluate(self, obs: np.ndarray, mask: np.ndarray) -> tuple[np.ndarray, float]:
        """Neural network evaluation."""
        with torch.no_grad():
            obs_t = torch.from_numpy(obs).unsqueeze(0).to(self.device)
            mask_t = torch.from_numpy(mask).bool().unsqueeze(0).to(self.device)
            logits, value = self.network(obs_t, mask_t)
            policy = F.softmax(logits, dim=-1).squeeze(0).cpu().numpy()
            value = value.item()
        return policy, value

## 3. Self-Play Data Collection

In [None]:
@dataclass
class Experience:
    obs: np.ndarray
    mask: np.ndarray
    policy: np.ndarray
    value: float  # Filled in after game ends


class ReplayBuffer:
    def __init__(self, capacity: int = 100000):
        self.buffer = deque(maxlen=capacity)

    def add(self, experiences: list[Experience]):
        self.buffer.extend(experiences)

    def sample(self, batch_size: int) -> dict:
        samples = random.sample(list(self.buffer), min(batch_size, len(self.buffer)))
        return {
            'obs': torch.stack([torch.from_numpy(e.obs) for e in samples]),
            'mask': torch.stack([torch.from_numpy(e.mask) for e in samples]),
            'policy': torch.stack([torch.from_numpy(e.policy) for e in samples]),
            'value': torch.tensor([e.value for e in samples]),
        }

    def __len__(self):
        return len(self.buffer)

In [None]:
def self_play_game(game, mcts: MCTS) -> list[Experience]:
    """Play one self-play game and collect experiences."""
    experiences = []
    player_experiences = [[], []]  # Separate by player

    while not game.is_done():
        current_player = game.current_player()

        obs = np.array(game.observe(), dtype=np.float32)
        mask = np.array(game.action_mask(), dtype=np.float32)

        # Run MCTS
        policy, _ = mcts.search(game)

        # Store experience (value filled in later)
        exp = Experience(obs=obs, mask=mask, policy=policy, value=0.0)
        player_experiences[current_player].append(exp)

        # Sample action from policy
        action = np.random.choice(256, p=policy)
        game.step(action)

    # Assign values based on game outcome
    p0_reward = game.get_reward(0)

    for exp in player_experiences[0]:
        exp.value = p0_reward
        experiences.append(exp)

    for exp in player_experiences[1]:
        exp.value = -p0_reward  # Opponent's value is negated
        experiences.append(exp)

    return experiences

## 4. Training

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Create network
network = AlphaZeroNetwork().to(device)

# Optional: Load BC checkpoint for warm start
bc_checkpoint = Path("bc_model.pt")
if bc_checkpoint.exists():
    print(f"Found BC checkpoint: {bc_checkpoint}")
    try:
        ckpt = torch.load(bc_checkpoint, map_location=device, weights_only=False)
        network.load_state_dict(ckpt['model_state_dict'])
        print("Successfully loaded BC checkpoint for warm start")
    except Exception as e:
        print(f"Could not load BC checkpoint (architecture mismatch?): {e}")
        print("Starting from scratch instead")
else:
    print("No BC checkpoint found, starting from scratch")

optimizer = torch.optim.AdamW(network.parameters(), lr=1e-4)
replay_buffer = ReplayBuffer(capacity=50000)

In [None]:
# Training config
@dataclass
class TrainConfig:
    num_iterations: int = 10  # Outer loop iterations
    games_per_iteration: int = 5  # Self-play games per iteration
    train_steps_per_iteration: int = 50  # Gradient steps per iteration
    batch_size: int = 64
    mcts_simulations: int = 25  # Reduced for speed

config = TrainConfig()
mcts_config = MCTSConfig(num_simulations=config.mcts_simulations)

In [None]:
from essence_wars._core import PyGame


def train_step(network, optimizer, batch, device):
    """One training step."""
    obs = batch['obs'].to(device)
    mask = batch['mask'].bool().to(device)
    target_policy = batch['policy'].to(device)
    target_value = batch['value'].to(device)

    logits, value = network(obs, mask)

    # Policy loss (cross-entropy with MCTS policy)
    log_probs = F.log_softmax(logits, dim=-1)
    policy_loss = -(target_policy * log_probs).sum(dim=-1).mean()

    # Value loss
    value_loss = F.mse_loss(value.squeeze(-1), target_value)

    loss = policy_loss + value_loss

    optimizer.zero_grad()
    loss.backward()
    torch.nn.utils.clip_grad_norm_(network.parameters(), 1.0)
    optimizer.step()

    return loss.item(), policy_loss.item(), value_loss.item()

In [None]:
# Training loop
history = {'loss': [], 'policy_loss': [], 'value_loss': [], 'buffer_size': []}

for iteration in range(config.num_iterations):
    print(f"\n=== Iteration {iteration + 1}/{config.num_iterations} ===")

    # Self-play
    network.eval()
    mcts = MCTS(network, mcts_config, device)

    for g in tqdm(range(config.games_per_iteration), desc="Self-play"):
        game = PyGame()
        game.reset(seed=iteration * 1000 + g)
        experiences = self_play_game(game, mcts)
        replay_buffer.add(experiences)

    print(f"Buffer size: {len(replay_buffer)}")

    # Training
    if len(replay_buffer) < config.batch_size:
        print("Not enough data, skipping training")
        continue

    network.train()
    losses, pol_losses, val_losses = [], [], []

    for _ in tqdm(range(config.train_steps_per_iteration), desc="Training"):
        batch = replay_buffer.sample(config.batch_size)
        loss, pol_loss, val_loss = train_step(network, optimizer, batch, device)
        losses.append(loss)
        pol_losses.append(pol_loss)
        val_losses.append(val_loss)

    history['loss'].append(np.mean(losses))
    history['policy_loss'].append(np.mean(pol_losses))
    history['value_loss'].append(np.mean(val_losses))
    history['buffer_size'].append(len(replay_buffer))

    print(f"Loss: {history['loss'][-1]:.4f} "
          f"(policy: {history['policy_loss'][-1]:.4f}, "
          f"value: {history['value_loss'][-1]:.4f})")

## 5. Evaluate

In [None]:
from essence_wars.benchmark import EssenceWarsBenchmark


# Create agent
class AlphaZeroAgent:
    def __init__(self, network, device, name="AlphaZero"):
        self.network = network
        self.device = device
        self._name = name

    @property
    def name(self):
        return self._name

    def select_action(self, obs, mask):
        self.network.eval()
        with torch.no_grad():
            obs_t = torch.from_numpy(obs).float().unsqueeze(0).to(self.device)
            mask_t = torch.from_numpy(mask).bool().unsqueeze(0).to(self.device)
            logits, _ = self.network(obs_t, mask_t)
            return logits.argmax(dim=-1).item()

    def reset(self):
        pass

agent = AlphaZeroAgent(network, device)

# Benchmark
benchmark = EssenceWarsBenchmark(games_per_opponent=20, verbose=True)
results = benchmark.evaluate(agent, baselines=["random", "greedy"])
print("\n" + results.summary())

## 6. Save Checkpoint

In [None]:
checkpoint = {
    'network_state_dict': network.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'config': {
        'hidden_dim': 256,
        'num_blocks': 4,
    },
    'history': history,
    'iteration': config.num_iterations,
}

torch.save(checkpoint, "alphazero_model.pt")
print("Saved checkpoint to alphazero_model.pt")

## Tips for Better Training

1. **More iterations**: Run for 100-1000 iterations for stronger play
2. **More simulations**: Use 100-400 MCTS simulations for better policies
3. **Parallel self-play**: Use multiple processes for data generation
4. **Temperature schedule**: Start high (1.0) and decay to 0.1 over training
5. **Evaluation checkpoints**: Save and evaluate every N iterations

For production training, see `python/scripts/train_alphazero.py`.