In [17]:
# rgi/players/alphazero/mcts.py

import math
import numpy as np
import random
from typing import Dict, Generic, TypeVar, Tuple
from rgi.core.base import Game, TGameState, TAction

TGame = TypeVar("TGame", bound=Game)
T = TypeVar("T")

class MCTSNode(Generic[TGameState, TAction]):
    def __init__(self, game_state: TGameState, parent: "MCTSNode" = None, prior: float = 0):
        self.game_state = game_state
        self.parent = parent
        self.prior = prior
        self.children: Dict[TAction, MCTSNode] = {}
        self.visit_count = 0
        self.value_sum = 0.0

    def ucb_score(self, exploration_weight: float) -> float:
        if self.visit_count == 0:
            return float("inf")
        return (self.value_sum / self.visit_count) + exploration_weight * self.prior * math.sqrt(self.parent.visit_count) / (1 + self.visit_count)

class MCTS(Generic[TGameState, TAction]):
    def __init__(self, game: Game[TGameState, TAction], model: "AlphaZeroModel", exploration_weight: float = 1.0, dirichlet_alpha: float = 0.3):
        self.game = game
        self.model = model
        self.exploration_weight = exploration_weight
        self.dirichlet_alpha = dirichlet_alpha

    def search(
        self,
        initial_state: TGameState,
        num_simulations: int,
        temperature: float = 1.0
    ) -> Tuple[Dict[TAction, float], float]:
        """Returns action probabilities and predicted value."""
        root = MCTSNode(initial_state)

        for _ in range(num_simulations):
            node = root
            search_path = [node]

            # Selection
            while node.children:
                node = max(node.children.values(), key=lambda n: n.ucb_score(self.exploration_weight))
                search_path.append(node)

            # Expansion & Evaluation
            if not self.game.is_terminal(node.game_state):
                policy, value = self.model.predict(self.game.to_nn_input(node.game_state))
                valid_actions = self.game.legal_actions(node.game_state)
                
                if node.parent is None:  # Root node
                    noise = np.random.dirichlet([self.dirichlet_alpha] * len(valid_actions))
                    policy = 0.75 * policy + 0.25 * noise

                for action in valid_actions:
                    child_state = self.game.next_state(node.game_state, action)
                    node.children[action] = MCTSNode(
                        child_state,
                        parent=node,
                        prior=policy[action]
                    )
            else:
                value = self.game.reward(node.game_state, self.game.current_player_id(node.game_state))

            # Backpropagation
            self._backpropagate(search_path, value)

        # Return action probabilities
        total_visits = sum(child.visit_count for child in root.children.values())
        action_probs = {
            action: (child.visit_count / total_visits) ** (1 / temperature)
            for action, child in root.children.items()
        }
        return action_probs, root.value_sum / root.visit_count if root.visit_count else 0

    def _backpropagate(self, search_path: list[MCTSNode], value: float):
        for node in reversed(search_path):
            node.visit_count += 1
            node.value_sum += value
            value = -value  # Alternate player perspective

In [18]:
# rgi/players/alphazero/model.py
from abc import ABC, abstractmethod
import numpy as np

class AlphaZeroModel(ABC):
    @abstractmethod
    def predict(self, game_state: np.ndarray) -> tuple[np.ndarray, float]:
        """Return (action probabilities, value) for the given state."""
        pass

    @abstractmethod
    def train(self, states: list[np.ndarray], action_probs: list[np.ndarray], values: list[float]):
        """Update model parameters using training data."""
        pass

In [22]:
# rgi/players/alphazero/pytorch/model.py

import torch
import torch.nn as nn
import torch.nn.functional as F

class PyTorchAlphaZeroNet(nn.Module):
    def __init__(self, input_shape: tuple, num_actions: int):
        super().__init__()
        self.conv1 = nn.Conv2d(input_shape[0], 64, 3, padding=1)
        self.conv2 = nn.Conv2d(64, 64, 3, padding=1)
        self.fc = nn.Linear(64 * input_shape[1] * input_shape[2], 256)
        self.policy_head = nn.Linear(256, num_actions)
        self.value_head = nn.Linear(256, 1)

    def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc(x))
        policy = F.softmax(self.policy_head(x), dim=1)
        value = torch.tanh(self.value_head(x))
        return policy, value

class PyTorchAlphaZeroModel(AlphaZeroModel):
    def __init__(self, input_shape: tuple, num_actions: int, device: str = "cuda"):
        self.net = PyTorchAlphaZeroNet(input_shape, num_actions).to(device)
        self.optimizer = torch.optim.Adam(self.net.parameters())
        self.device = device

    def predict(self, game_state: np.ndarray) -> tuple[np.ndarray, float]:
        with torch.no_grad():
            state_tensor = torch.tensor(game_state, dtype=torch.float32, device=self.device).unsqueeze(0)
            policy, value = self.net(state_tensor)
            return policy.squeeze().cpu().numpy(), value.item()

    def train(self, states: list[np.ndarray], action_probs: list[np.ndarray], values: list[float]):
        states_tensor = torch.tensor(np.array(states), dtype=torch.float32, device=self.device)
        policies_tensor = torch.tensor(np.array(action_probs), dtype=torch.float32, device=self.device)
        values_tensor = torch.tensor(np.array(values), dtype=torch.float32, device=self.device)

        self.optimizer.zero_grad()
        pred_policies, pred_values = self.net(states_tensor)
        loss = self._compute_loss(pred_policies, pred_values, policies_tensor, values_tensor)
        loss.backward()
        self.optimizer.step()

    def _compute_loss(
        self,
        pred_policies: torch.Tensor,
        pred_values: torch.Tensor,
        target_policies: torch.Tensor,
        target_values: torch.Tensor
    ) -> torch.Tensor:
        policy_loss = F.cross_entropy(pred_policies, target_policies)
        value_loss = F.mse_loss(pred_values.squeeze(), target_values)
        return policy_loss + value_loss

In [20]:
# rgi/players/alphazero/pytorch/trainer.py

from collections import deque
import numpy as np
# from rgi.players.alphazero.mcts import MCTS
import random

class AlphaZeroTrainer:
    def __init__(
        self,
        game: Game,
        model: AlphaZeroModel,
        num_simulations: int = 800,
        buffer_size: int = 100000,
        batch_size: int = 32
    ):
        self.game = game
        self.model = model
        self.mcts = MCTS(game)
        self.buffer = deque(maxlen=buffer_size)
        self.batch_size = batch_size

    def self_play(self, num_games: int):
        for _ in range(num_games):
            game_state = self.game.initial_state()
            game_history = []

            while not self.game.is_terminal(game_state):
                action_probs, _ = self.mcts.search(game_state, num_simulations=self.num_simulations)
                action = np.random.choice(list(action_probs.keys()), p=list(action_probs.values()))
                game_history.append((game_state, action_probs))
                game_state = self.game.next_state(game_state, action)

            # Assign final reward to all states in the game
            reward = self.game.reward(game_state, self.game.current_player_id(game_state))
            for idx, (state, action_probs) in enumerate(game_history):
                self.buffer.append((state, action_probs, reward * ((-1) ** idx)))  # Alternate perspective

    def train(self, epochs: int):
        for _ in range(epochs):
            batch = random.sample(self.buffer, min(len(self.buffer), self.batch_size))
            states, action_probs, values = zip(*batch)
            self.model.train(states, action_probs, values)

# Run Count21 game


In [26]:
# Run Count21 game

from rgi.games.count21 import count21

game = count21.Count21Game()
model = PyTorchAlphaZeroModel(input_shape=(21,), num_actions=3)
trainer = AlphaZeroTrainer(game, model)

# Initial evaluation vs random
evaluate_vs_random(model, game, num_games=100)

# Training loop
trainer.self_play(num_games=100)
trainer.train(epochs=10)

# Post-training evaluation
evaluate_vs_random(model, game, num_games=100)

IndexError: tuple index out of range