# Phase 1: Import Dependence Modules

In [None]:
import os
import copy
import shutil
import numpy as np
import torch
import torch.nn as nn
from typing import Tuple
from datetime import datetime
from dataclasses import dataclass


## Setup Device

In [None]:
if torch.cuda.is_available():
    DEVICE = torch.device('cuda')
    !nvidia-smi
    print()
else:
    DEVICE = torch.device('cpu')
print('CPU')
!cat /proc/cpuinfo | grep 'processor\|model\ name'
print()
!python3 --version
print('torch version   :', torch.__version__)
print('use device      :', DEVICE)


CPU
processor	: 0
model name	: Intel(R) Xeon(R) CPU @ 2.20GHz
processor	: 1
model name	: Intel(R) Xeon(R) CPU @ 2.20GHz

Python 3.10.12
torch version   : 2.2.1+cu121
use device      : cpu


## Environment

In [None]:
class TicTacToe:
    BOARD_HEIGHT = 3
    BOARD_WIDTH = 3
    LINES = [[0, 1, 2], [3, 4, 5], [6, 7, 8], [0, 3, 6], [1, 4, 7], [2, 5, 8], [0, 4, 8], [2, 4, 6]]

    def __init__(self):
        self.board: np.ndarray = np.zeros(0)
        self.action_history = []
        self.current_player: float = 0.
        self.winner: float = 0.

    def reset(self) -> None:
        self.board = np.zeros(self.BOARD_HEIGHT * self.BOARD_WIDTH, dtype=np.single)
        self.action_history.clear()
        self.current_player = 1.
        self.winner = 0.

    def act(self, action: int) -> None:
        if self.board[action] == 0.:
            self.board[action] = self.current_player
            self.action_history.append(action)
            self._update_winner()
            self.current_player *= -1.
        else:
            raise ValueError('invalid action id.')

    def get_legal_actions(self) -> list:
        return list(np.where(self.board == 0.)[0])

    def is_terminal(self) -> bool:
        return self.winner != 0. or not np.any(self.board == 0.)

    def get_eval_score(self) -> float:
        return self.winner

    def get_features(self) -> np.ndarray:
        """
        4 channels
            0. own position
            1. opponent position
            2. Nought turn
            3. Cross turn
        """
        features = []
        for channel in range(4):
            if channel == 0:
                features.append(np.where(self.board == self.current_player, 1., 0.))
            elif channel == 1:
                features.append(np.where(self.board == self.current_player * -1., 1., 0.))
            elif channel == 2:
                features.append(np.ones_like(self.board) if self.current_player == 1. else np.zeros_like(self.board))
            elif channel == 3:
                features.append(np.ones_like(self.board) if self.current_player == -1. else np.zeros_like(self.board))
        return np.stack(features, dtype=np.single).reshape((-1, self.BOARD_HEIGHT, self.BOARD_WIDTH))

    @staticmethod
    def get_num_input_channels() -> int:
        return 4

    def get_input_channel_height(self) -> int:
        return self.BOARD_HEIGHT

    def get_input_channel_width(self) -> int:
        return self.BOARD_WIDTH

    def get_policy_size(self) -> int:
        return self.BOARD_HEIGHT * self.BOARD_WIDTH

    def to_string(self) -> str:
        result = np.empty(self.board.shape, dtype=str)
        result[np.where(self.board == 0.)] = ' '
        result[np.where(self.board == 1.)] = 'O'
        result[np.where(self.board == -1.)] = 'X'
        return str(result.reshape((self.BOARD_HEIGHT, self.BOARD_WIDTH)))

    def _update_winner(self) -> None:
        for line in self.LINES:
            line_values = self.board[line]
            if np.all(line_values != 0.) and np.all(line_values == self.current_player):
                self.winner = self.current_player


In [None]:
class ConnectFour(TicTacToe):
    BOARD_HEIGHT = 6
    BOARD_WIDTH = 7

    def __init__(self):
        super().__init__()
        self.board_view: np.ndarray = np.zeros(0)

    def reset(self) -> None:
        super().reset()
        self.board_view = self.board.reshape(ConnectFour.BOARD_HEIGHT, ConnectFour.BOARD_WIDTH)

    def act(self, action: int) -> None:
        action_row = np.argmax(self.board_view == 0., axis=0)[action]
        if self.board_view[action_row, action] == 0.:
            self.board[action_row * self.BOARD_WIDTH + action] = self.current_player
            self.board_view[action_row, action] = self.current_player
            self.action_history.append(action)
            num_connection = self._check_connected(action_row, action)
            if num_connection >= 4:
                self.winner = self.current_player
            self.current_player *= -1.
        else:
            raise ValueError('invalid action id.')

    def get_legal_actions(self) -> list:
        non_zero_mask = np.any(self.board_view == 0., axis=0)
        return list(np.where(non_zero_mask)[0])

    def get_policy_size(self) -> int:
        return self.BOARD_WIDTH

    def to_string(self) -> str:
        result = np.empty(self.board_view.shape, dtype=str)
        result[np.where(self.board_view == 1.)] = 'O'
        result[np.where(self.board_view == -1.)] = 'X'
        result[np.where(self.board_view == 0.)] = ' '
        return str(np.flip(result, 0))

    def _check_connected(self, action_row: int, action_column: int) -> int:
        row_array = self.board_view[action_row, :]
        column_array = self.board_view[:, action_column]
        diagonal_array = self.board_view.diagonal(offset=action_column - action_row)
        flipped = np.fliplr(self.board_view)
        flipped_diagonal_array = flipped.diagonal(offset=self.BOARD_WIDTH - 1 - action_column - action_row)
        max_connections = []
        action_player = self.board_view[action_row, action_column]
        for arr in [row_array, column_array, diagonal_array, flipped_diagonal_array]:
            split_index = np.where(arr != action_player)[0]
            split_index = np.insert(split_index, 0, -1)
            split_index = np.insert(split_index, len(split_index), len(arr))
            max_connections.append(np.max(np.diff(split_index)))
        return max(max_connections) - 1


### Test Environment

In [None]:
seeds = [100, 200, 300]
for seed in seeds:
    np.random.seed(seed)
    env = TicTacToe()
    env.reset()
    while not env.is_terminal():
        legal_actions = env.get_legal_actions()
        action_id = np.random.choice(legal_actions)
        print('player:', env.current_player, '(O)' if env.current_player == 1.0 else '(X)')
        print('board:')
        print(env.to_string())
        print('features:')
        print(env.get_features())
        env.act(action_id)
        print('-----------------------')
        print('act action id:', action_id)

    print('score:', env.get_eval_score())
    print(env.to_string())
    print('========================')


player: 1.0 (O)
board:
[[' ' ' ' ' ']
 [' ' ' ' ' ']
 [' ' ' ' ' ']]
features:
[[[0. 0. 0.]
  [0. 0. 0.]
  [0. 0. 0.]]

 [[0. 0. 0.]
  [0. 0. 0.]
  [0. 0. 0.]]

 [[1. 1. 1.]
  [1. 1. 1.]
  [1. 1. 1.]]

 [[0. 0. 0.]
  [0. 0. 0.]
  [0. 0. 0.]]]
-----------------------
act action id: 8
player: -1.0 (X)
board:
[[' ' ' ' ' ']
 [' ' ' ' ' ']
 [' ' ' ' 'O']]
features:
[[[0. 0. 0.]
  [0. 0. 0.]
  [0. 0. 0.]]

 [[0. 0. 0.]
  [0. 0. 0.]
  [0. 0. 1.]]

 [[0. 0. 0.]
  [0. 0. 0.]
  [0. 0. 0.]]

 [[1. 1. 1.]
  [1. 1. 1.]
  [1. 1. 1.]]]
-----------------------
act action id: 0
player: 1.0 (O)
board:
[['X' ' ' ' ']
 [' ' ' ' ' ']
 [' ' ' ' 'O']]
features:
[[[0. 0. 0.]
  [0. 0. 0.]
  [0. 0. 1.]]

 [[1. 0. 0.]
  [0. 0. 0.]
  [0. 0. 0.]]

 [[1. 1. 1.]
  [1. 1. 1.]
  [1. 1. 1.]]

 [[0. 0. 0.]
  [0. 0. 0.]
  [0. 0. 0.]]]
-----------------------
act action id: 4
player: -1.0 (X)
board:
[['X' ' ' ' ']
 [' ' 'O' ' ']
 [' ' ' ' 'O']]
features:
[[[1. 0. 0.]
  [0. 0. 0.]
  [0. 0. 0.]]

 [[0. 0. 0.]
  [0. 1. 0.]


## MCTS

In [None]:
class Node:
    def __init__(self, player: float, action: int, policy: float = 0):
        self.player: float = player
        self.action: int = action
        self.policy: float = policy
        self.visit_count: int = 0
        self.mean_value: float = 0.
        self.children: list[Node] = []

    def is_leaf(self) -> bool:
        return len(self.children) == 0

    def add_child(self, player: float, action: int, policy: float) -> None:
        child_node = Node(player, action, policy)
        self.children.append(child_node)

    def update(self, value: float) -> None:
        self.visit_count += 1
        self.mean_value += (value - self.mean_value) / self.visit_count


class MCTS:
    DIRICHLET_NOISE_ALPHA = 0.3  # Usually (1 / sqrt(number of actions))
    DIRICHLET_NOISE_EPSILON = 0.25
    PUCT_C1 = 1.25
    PUCT_C2 = 19652
    TEMPERATURE = 1.0

    def __init__(self, env: TicTacToe, network: nn.Module):
        self.env: TicTacToe = env
        self.network: nn.Module = network
        self.root: Node | None = None

    def simulate(self, num_simulations: int, use_dirichlet_noise: bool = True) -> None:
        self.root = Node(self.env.current_player, 0)
        node_path = [self.root]
        self.expand_and_evaluate(node_path)
        if use_dirichlet_noise:
            self.add_exploration_noise(self.root)

        for _ in range(num_simulations):
            # Selection
            node_path = self.selection()
            # Expansion and evaluation
            value = self.expand_and_evaluate(node_path)
            # Backup
            self.backup(node_path, value)

    def selection(self) -> list[Node]:
        node_path = [self.root]
        node = self.root
        while not node.is_leaf():
            node = self.select_child(node)
            node_path.append(node)
        return node_path

    def expand_and_evaluate(self, node_path) -> float:
        # go to state
        env_transition = copy.deepcopy(self.env)
        for child in node_path[1:]:
            env_transition.act(child.action)

        if env_transition.is_terminal():
            return env_transition.get_eval_score()

        features = env_transition.get_features()
        features_tensor = torch.from_numpy(features).unsqueeze(0).to(DEVICE)
        _, policy, value = self.network(features_tensor)
        policy = policy.squeeze(0)
        player = env_transition.current_player
        leaf_node = node_path[-1]
        for action in env_transition.get_legal_actions():
            leaf_node.add_child(player, action, policy[action].item())
        return value.squeeze(0).item()

    @staticmethod
    def backup(node_path: list[Node], value: float) -> None:
        for node in node_path:
            node.update(value * node.player)

    def decide_action(self, use_softmax: bool = True) -> int:
        candidate_actions = []
        action_weights = []
        for child in self.root.children:
            candidate_actions.append(child.action)
            action_weights.append(child.visit_count ** (1 / self.TEMPERATURE))
        if np.sum(action_weights) == 0:
            action_weights = np.ones_like(action_weights, dtype=np.single)
        action_weights /= np.sum(action_weights)
        if use_softmax:
            selected_action = np.random.choice(candidate_actions, p=action_weights)
        else:
            selected_action = candidate_actions[np.argmax(action_weights)]
        return selected_action

    def add_exploration_noise(self, parent: Node) -> None:
        dirichlet_noise = np.random.gamma(self.DIRICHLET_NOISE_ALPHA, 1, len(parent.children))
        dirichlet_noise /= np.sum(dirichlet_noise)
        for child, noise in zip(parent.children, dirichlet_noise):
            child.policy = child.policy * (1 - self.DIRICHLET_NOISE_EPSILON) + noise * self.DIRICHLET_NOISE_EPSILON

    def select_child(self, parent: Node) -> Node:
        # ============== TODO ==============
        # hint: select the child with the highest PUCT score
        # hint: self.PUCT_C1 and self.PUCT_C2 is PUCT constant
        best_child = np.random.choice(parent.children)
        return best_child

    def get_normalize_child_visits(self) -> np.ndarray:
        child_visits = np.zeros(self.env.get_policy_size(), dtype=np.single)
        for child in self.root.children:
            child_visits[child.action] = child.visit_count / self.root.visit_count
        return child_visits


## Replay Buffer

In [None]:
@dataclass
class Game:
    action_history: list[int]
    terminal_value: float
    child_visits: list[np.ndarray]


class ReplayBuffer:
    def __init__(self, buffer_num_games: int):
        self.buffer: list[Game] = []
        self.buffer_num_games: int = buffer_num_games

    def save_game(self, game: Game) -> None:
        if len(self.buffer) >= self.buffer_num_games:
            self.buffer.pop(0)
        self.buffer.append(game)

    def sample_batch(self, env: TicTacToe, batch_size: int) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
        features = []
        target_value = []
        target_policy = []
        select_weights = np.array([len(game.action_history) for game in self.buffer])
        select_weights = select_weights / np.sum(select_weights)
        sample_games_index = np.random.choice(len(self.buffer), size=batch_size, p=select_weights, replace=True)
        unique, counts = np.unique(sample_games_index, return_counts=True)
        for game_index, count in zip(unique, counts):
            env.reset()
            game = self.buffer[game_index]
            env_steps = np.random.randint(len(game.action_history), size=count)
            env_steps.sort()
            for step in env_steps:
                for action in game.action_history[len(env.action_history): step]:
                    env.act(action)
                features.append(env.get_features())
                target_value.append(game.terminal_value)
                target_policy.append(game.child_visits[step])

        features = np.stack(features, dtype=np.single)
        target_policy = np.stack(target_policy, dtype=np.single)
        target_value = np.stack(target_value, dtype=np.single)
        return features, target_policy, target_value


## AlphaZero Network

In [None]:
class AlphaZeroNetwork(nn.Module):
    def __init__(self,
                 input_channels: int,
                 input_channel_height: int,
                 input_channel_width: int,
                 action_size: int,
                 hidden_channels: int = 16,
                 ):
        super().__init__()
        self.conv1 = nn.Conv2d(input_channels, hidden_channels, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(hidden_channels)
        self.conv2 = nn.Conv2d(hidden_channels, 1, kernel_size=1)
        self.bn2 = nn.BatchNorm2d(1)
        self.flat = nn.Flatten()
        self.policy_head = nn.Linear(input_channel_height * input_channel_width, action_size)
        self.value_head = nn.Linear(input_channel_height * input_channel_width, 1)

    def forward(self, x):
        x = nn.functional.relu(self.bn1(self.conv1(x)))
        x = nn.functional.relu(self.bn2(self.conv2(x)))
        x = self.flat(x)
        p = self.policy_head(x)
        policy_logit = nn.functional.log_softmax(p, dim=1)
        policy = nn.functional.softmax(p, dim=1)
        value = torch.tanh(self.value_head(x))
        return policy_logit, policy, value


## Logger

In [None]:
def now_time() -> str:
    return datetime.now().strftime("[%Y-%m-%d %H:%M:%S] ")


class Logger:
    def __init__(self, env: TicTacToe, device: torch.device | str = 'cpu'):
        self.folder_name: str = f'{type(env).__name__}'
        self.model_folder_name: str = f'{type(env).__name__}/model'
        self.device = device
        if os.path.exists(self.model_folder_name):
            shutil.rmtree(self.model_folder_name)
        os.makedirs(self.model_folder_name)


    def write_log(self, write_message: str, timestamp: bool = True) -> None:
        with open(f'{self.folder_name}/training_log.txt', 'a') as f:
            if timestamp:
                write_message = now_time() + write_message
            f.write(write_message + '\n')
            print(write_message)

    def save_network(self, network: nn.Module, iteration: int) -> None:
        torch.jit.script(network).save(f'{self.model_folder_name}/weight_iter_{iteration}.pt')

    def load_network(self, iteration: int) -> nn.Module:
        return torch.jit.load(f'{self.model_folder_name}/weight_iter_{iteration}.pt',
                              map_location=torch.device(self.device))


# Phase 2: AlphaZero Algorithm

## Self-Play

In [None]:
def self_play(env: TicTacToe,
              network: nn.Module,
              logger: Logger,
              sp_num_games_per_iteration: int,
              sp_mcts_simulation: int,
              display_step: int = 10) -> list[Game]:
    network.eval()
    games = []
    with torch.no_grad():
        for i in range(1, sp_num_games_per_iteration + 1):
            search_statistics = []
            env.reset()
            while not env.is_terminal():
                mcts = MCTS(copy.deepcopy(env), network)
                mcts.simulate(sp_mcts_simulation)
                action = mcts.decide_action()
                env.act(action)
                search_statistics.append(mcts.get_normalize_child_visits())

            assert len(env.action_history) == len(search_statistics)
            game = Game(env.action_history[:], env.get_eval_score(), search_statistics)
            games.append(game)

            if i % display_step == 0:
                logger.write_log(f'sp games: {i} / {sp_num_games_per_iteration}')
    return games


## Optimization

In [None]:
def optimization(env: TicTacToe,
                 network: nn.Module,
                 logger: Logger,
                 replay_buffer: ReplayBuffer,
                 op_batch_size: int,
                 op_training_steps: int,
                 display_step: int = 10,
                 learning_rate: float = 0.02,
                 momentum: float = 0.9,
                 weight_decay: float = 1e-4) -> None:
    network.train()
    optimizer = torch.optim.SGD(network.parameters(),
                                lr=learning_rate,
                                momentum=momentum,
                                weight_decay=weight_decay)
    for i in range(1, op_training_steps + 1):
        optimizer.zero_grad()
        features, target_policy, target_value = replay_buffer.sample_batch(env, op_batch_size)
        features_tensor = torch.from_numpy(features).to(DEVICE)
        target_policy_tensor = torch.from_numpy(target_policy).to(DEVICE)
        target_value_tensor = torch.from_numpy(target_value).unsqueeze(-1).to(DEVICE)

        policy_logit, _, value = network(features_tensor)
        loss_policy = nn.functional.cross_entropy(policy_logit, target_policy_tensor)
        loss_value = nn.functional.mse_loss(value, target_value_tensor)
        loss = loss_policy + loss_value
        loss.backward()
        optimizer.step()

        if i % display_step == 0:
            logger.write_log(f'op training steps: {i} / {op_training_steps}')
            logger.write_log(f'\tloss_policy: {loss_policy.item():.4f}', timestamp=False)
            logger.write_log(f'\tloss_value : {loss_value.item():.4f}', timestamp=False)


## Evaluation

In [None]:
def evaluation(env: TicTacToe,
               logger: Logger,
               iteration_a: int,
               iteration_b: int,
               num_games: int,
               mcts_simulation: int):
    network_a = logger.load_network(iteration_a)
    network_b = logger.load_network(iteration_b)
    network_a.eval()
    network_b.eval()
    winner_count = [{'iter': iteration_a, 'winP1': 0, 'winP2': 0, 'draw': 0, 'lossP1': 0, 'lossP2': 0, 'scores': []},
                    {'iter': iteration_b, 'winP1': 0, 'winP2': 0, 'draw': 0, 'lossP1': 0, 'lossP2': 0, 'scores': []}]
    with torch.no_grad():
        for i in range(num_games):
            env.reset()
            while not env.is_terminal():
                network = network_a if env.current_player == 1. else network_b
                mcts = MCTS(copy.deepcopy(env), network)
                mcts.simulate(mcts_simulation, use_dirichlet_noise=False)
                action = mcts.decide_action(use_softmax=True)
                env.act(action)
            eval_score = env.get_eval_score()
            winner_count[0]['scores'].append(eval_score)
            winner_count[1]['scores'].append(-eval_score)
            if eval_score == 1.:
                winner_count[0]['winP1'] += 1
                winner_count[1]['lossP2'] += 1
            elif eval_score == -1.:
                winner_count[0]['lossP1'] += 1
                winner_count[1]['winP2'] += 1
            else:
                winner_count[0]['draw'] += 1
                winner_count[1]['draw'] += 1
            # Swap player
            network_a, network_b = network_b, network_a
            winner_count[0], winner_count[1] = winner_count[1], winner_count[0]

    sorted(winner_count, key=lambda item: item['iter'])
    own_info = winner_count.pop()
    opponent_info = winner_count.pop()
    title_name = f'iteration {own_info.pop("iter")} vs. {opponent_info.pop("iter")}'
    mean_score = sum(own_info.pop('scores')) / num_games
    logger.write_log(f'\t{title_name}: {own_info}', timestamp=False)
    logger.write_log(f'\t{title_name} win rate: {(mean_score + 1) / 2:.2%}', timestamp=False)


## Train Model (Main Function)

In [None]:
# hyperparameters
max_iteration = 50
sp_num_games_per_iteration = 10
sp_mcts_simulation = 8
buffer_num_iteration = 5
op_batch_size = 32
op_training_steps = 50
eval_interval_iteration = 5
eval_num_games = 100
network_num_hidden_channels = 8

random_seed = 600
np.random.seed(random_seed)
torch.manual_seed(random_seed)

env = TicTacToe()
# env = ConnectFour()
network = AlphaZeroNetwork(input_channels=env.get_num_input_channels(),
                           input_channel_height=env.get_input_channel_height(),
                           input_channel_width=env.get_input_channel_width(),
                           action_size=env.get_policy_size(),
                           hidden_channels=network_num_hidden_channels
                           ).to(DEVICE)
replay_buffer = ReplayBuffer(buffer_num_iteration * sp_num_games_per_iteration)
logger = Logger(env, device=DEVICE)
logger.save_network(network, 0)

for i in range(1, max_iteration + 1):
    logger.write_log(f'iteration {i}:')

    # Part 1: Self-Play
    latest_network = logger.load_network(i - 1)
    games = self_play(env, latest_network, logger,
                      sp_num_games_per_iteration=sp_num_games_per_iteration,
                      sp_mcts_simulation=sp_mcts_simulation,
                      display_step=10)
    for game in games:
        replay_buffer.save_game(game)

    # Part 2: Training
    optimization(env, network, logger, replay_buffer,
                 op_batch_size=op_batch_size,
                 op_training_steps=op_training_steps,
                 display_step=10)
    logger.save_network(network, i)

    # Evaluation
    if i % eval_interval_iteration == 0:
        logger.write_log('evaluation:')
        evaluation(env, logger,
                   iteration_a=0,
                   iteration_b=i,
                   num_games=eval_num_games,
                   mcts_simulation=sp_mcts_simulation)
        if i - eval_interval_iteration > 0:
            evaluation(env, logger,
                    iteration_a=i - eval_interval_iteration,
                    iteration_b=i,
                    num_games=eval_num_games,
                    mcts_simulation=sp_mcts_simulation)


[2024-04-18 10:29:08] iteration 1:
[2024-04-18 10:29:09] sp games: 10 / 10
[2024-04-18 10:29:09] op training steps: 10 / 50
	loss_policy: 2.1167
	loss_value : 0.5453
[2024-04-18 10:29:09] op training steps: 20 / 50
	loss_policy: 2.1662
	loss_value : 0.5785
[2024-04-18 10:29:09] op training steps: 30 / 50
	loss_policy: 2.0788
	loss_value : 0.3294
[2024-04-18 10:29:09] op training steps: 40 / 50
	loss_policy: 2.1420
	loss_value : 0.4288
[2024-04-18 10:29:09] op training steps: 50 / 50
	loss_policy: 2.0470
	loss_value : 0.2275
[2024-04-18 10:29:09] iteration 2:
[2024-04-18 10:29:10] sp games: 10 / 10
[2024-04-18 10:29:10] op training steps: 10 / 50
	loss_policy: 2.0597
	loss_value : 0.5637
[2024-04-18 10:29:10] op training steps: 20 / 50
	loss_policy: 2.0099
	loss_value : 0.4261
[2024-04-18 10:29:10] op training steps: 30 / 50
	loss_policy: 1.9436
	loss_value : 0.5477
[2024-04-18 10:29:10] op training steps: 40 / 50
	loss_policy: 1.9575
	loss_value : 0.4956
[2024-04-18 10:29:10] op traini