## Alpha Zero详解

- http://joshvarty.github.io/AlphaZero/

考虑一个简单的游戏环境，`Connect2`。游戏规则是在一个$1 \times 4$的棋盘上，目标是在两个相邻位置方式相同的元素。整个博弈树展开为如下形式:

<img src="../images/11-connect2.png" width="100%">

很明显，这时一个先手必赢的游戏。

## Connect 2游戏环境代码实现

用1表示自己这方落子的表示，-1表示对方落子的表示，空的棋盘表示为0，因此初始状态是[0, 0, 0, 0]。赢得游戏的一方获得的奖励是+1，输掉游戏的一方获得奖励是-1。平局获得的奖励是0

Note: There is a very important detail here: The player who is playing always sees their own pieces as 1. In our implementation, we simply multiply the board by -1 every time we transition between players. Below, observe that state “toggles” as players take turns placing pieces:

In [1]:

import numpy as np

class Connect2Game(object):
    def __init__(self):
        """初始化函数"""
        self.columns = 4 # 四列
        self.win = 2 # 当两个连子即为赢
    
    def get_init_board(self):
        """获取初始化的棋盘"""
        return np.zeros((self.columns), dtype=np.int)
    
    def get_board_size(self):
        """获取棋盘大小"""
        return self.columns
    
    def get_action_size(self):
        """获取动作空间大小"""
        return self.columns
    
    def get_next_state(self, board, player, action):
        b = np.copy(board)
        b[action] = player
        
        # 返回新的game，但是需要转换玩家视角，player需要乘一个负号
        return (b, -player)
    
    def has_legal_moves(self, board):
        """判断是否还有合法的动作空间"""
        for index in range(self.columns):
            if board[index] == 0:
                return True
        return False
    
    def get_valid_moves(self, board):
        """获取有效的动作"""
        # 初始化时，所有的动作都是有效的
        valid_moves = [0] * self.get_action_size()
        b = np.copy(board)
        
        for index in range(self.columns):
            if board[index] == 0:
                valid_moves[index] = 1
                
        return valid_moves
    
    def is_win(self, board, player):
        """判断是否赢"""
        count = 0
        for index in range(self.columns):
            if board[index] == player:
                count += 1
            else:
                count += 0
            
            if count == self.win:
                return True
            
        return False
    
    def get_reward_for_player(self, board, player):
        """如果没有结束返回None，玩家1胜返回1，玩家2胜返回-1"""
        if self.is_win(board, player):
            return 1
        if self.is_win(board, player):
            return -1
        if self.has_legal_moves(board):
            return None
        
        return 0
    
    def get_canonical_board(self, board, player):
        return player * board

## Value Network

值网络的输入就是当前棋盘的状态表示，需要预测一个标量的输出，比如我方胜利的状态输入给`value network`的话，期望`value network`的输出是`1`，反之为`-1`，平局为`0`。

### 训练

为了训练`value network`，我们需要记录在`self-play`过程中的每一步，例如:

1. 初始化状态`[0, 0, 0, 0]`, 玩家`1`在`1`位置落子。
2. 状态进行转移，玩家身份互换，对于玩家`2`来说，状态为`[-1, 0, 0, 0]`, 玩家`2`选择在`3`位置落子。
3. 状态再次进行转移，玩家身份互换，对于玩家`1`来说，状态为`[1, 0, -1, 0]`, 玩家`1`选择在`2`位置落子。
4. 游戏结束。

游戏结束之后，回溯整个状态，将玩家`1`奖励设置为`1`，玩家`2`奖励设置为`-1`，记录表格如下所示:

```python
([ 0  0  0  0],  1)   # Player 1 plays in the first position
([-1  0  0  0], -1)   # Player 2 plays in the third position
([ 1  0 -1  0],  1)   # Player 1 plays in the second position
```

之后采集大量的样本，训练value network。

## Policy Network

policy network与value network一样，接收棋盘状态的输入，但是需要给出每个动作的选择概率。策略网络对于每个动作的选择概率之后要勇于MCTS的指导搜索，并将policy network的这个指导称之为priors。

但是网络对于非法的动作，通常不会输出概率为0，因此我们需要将非法的动作mask掉，之后再将其re-normalizing一下，来保障最终的概率和为1。

### 训练

这里并不直接训练一个网络来决策出如何走，而是让`policy`来学习`MCTS`的输出，因此在记录的时候，我们需要记录状态，和每个状态下`MCTS`的输出(`probabilities produced by the MCTS`):

```python
([ 0  0  0  0], [0.1, 0.4, 0.4, 0.1])
([-1  0  0  0], [0.0, 0.3, 0.3, 0.3])
([ 1  0 -1  0], [0.0, 0.8, 0.0, 0.2])
```

`policy network`和`MCTS`是迭代进行的，`MCTS`之后需要用到`policy network`来指导搜索。

## Value Network和Policy Network代码实现

In [2]:
args = {
    'batch_size': 64,
    'num_simulations': 100,                         # Number of Monte Carlo simulations for each move
    'numIters': 500,                                # Total number of training iterations
    'numEps': 100,                                  # Number of full games (episodes) to run during each iteration
    'numItersForTrainExamplesHistory': 20,
    'epochs': 2,                                    # Number of epochs of training per iteration
    'checkpoint_path': 'latest.pth'                 # location to save latest set of weights
}

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class Connect2Model(nn.Module):

    def __init__(self, board_size, action_size, device):

        super(Connect2Model, self).__init__()

        self.device = device
        self.size = board_size
        self.action_size = action_size

        self.fc1 = nn.Linear(in_features=self.size, out_features=16)
        self.fc2 = nn.Linear(in_features=16, out_features=16)

        # Two heads on our network
        self.action_head = nn.Linear(in_features=16, out_features=self.action_size)
        self.value_head = nn.Linear(in_features=16, out_features=1)

        self.to(device)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))

        action_logits = self.action_head(x)
        value_logit = self.value_head(x)

        return F.softmax(action_logits, dim=1), torch.tanh(value_logit)

    def predict(self, board):
        board = torch.FloatTensor(board.astype(np.float32)).to(self.device)
        board = board.view(1, self.size)
        self.eval()
        with torch.no_grad():
            pi, v = self.forward(board)

        return pi.data.cpu().numpy()[0], v.data.cpu().numpy()[0]

## Monte Carlo Tree Search

### Node

每个节点代表的就是一个可到达的棋盘状态，节点上的边就代表玩家在当前节点状态采取的动作，节点存储的信息如下所示:

```python
class Node:
  def __init__(self, prior, to_play):
      self.prior = prior      # The prior probability of selecting this state from its parent
      self.to_play = to_play  # The player whose turn it is. (1 or -1)

      self.children = {}      # A lookup of legal child positions
      self.visit_count = 0    # Number of times this state was visited during MCTS. "Good" are visited more then "bad" states.
      self.value_sum = 0      # The total value of this state from all visits
      self.state = None       # The board state as this node

  def value(self):
       # Average value for a node
      return self.value_sum / self.visit_count  
```

节点中的`prior`就是父亲节点选择到这个当前节点的概率，比如父亲节点输出的动作概率为`[0.8, 0, 0, 0.2]`, 如果选择第一个节点`0.8`概率的这个动作，就会转移到当前节点的话，那么当前节点的`prior`概率就是`0.8`。

### MCTS

定义好了节点之后，就可以通过mcts来构建博弈树了，整个博弈树的展开是多个simulations构成的，每个simulation会添加一个节点到博弈树中，由以下三部分组成:

1. Select

2. Expand

3. Backup

simulation次数越多，得到的模型就会越好。核心思路如下:

```python
def run(self, model, state, to_play):

    root = Node(0, to_play)

    # EXPAND root
    action_probs, value = model.predict(state)
    valid_moves = self.game.get_valid_moves(state)
    action_probs = action_probs * valid_moves  # mask invalid moves
    action_probs /= np.sum(action_probs)
    root.expand(state, to_play, action_probs)

    for _ in range(self.args['num_simulations']):
        node = root
        search_path = [node]

        # SELECT
        while node.expanded():
            action, node = node.select_child()
            search_path.append(node)

        parent = search_path[-2]
        state = parent.state
        # Now we're at a leaf node and we would like to expand
        # Players always play from their own perspective
        next_state, _ = self.game.get_next_state(state, player=1, action=action)
        # Get the board from the perspective of the other player
        next_state = self.game.get_canonical_board(next_state, player=-1)

        # The value of the new state from the perspective of the other player
        value = self.game.get_reward_for_player(next_state, player=1)
        if value is None:
            # If the game has not ended:
            # EXPAND
            action_probs, value = model.predict(next_state)
            valid_moves = self.game.get_valid_moves(next_state)
            action_probs = action_probs * valid_moves  # mask invalid moves
            action_probs /= np.sum(action_probs)
            node.expand(next_state, parent.to_play * -1, action_probs)

        self.backpropagate(search_path, value, parent.to_play * -1)

    return root
```

接收当前的棋盘状态`state`，当前的玩家是哪一方，`mcts`的`simulation`次数是多少。之后就是创建跟节点，然后扩展根节点，当然在这里，扩展根节点并没有用到`policy network`的`prior`的知识。

扩展完之后，我们就有孩子节点，然后去选择孩子节点，选择孩子节点通过`UCB`的算法来选择。选择到了孩子节点之后，`node`就会被更新，并且返回了这个被选择的动作，之后，退出`while`循环，因为这个孩子节点并没有被`expand`，所以没有孩子节点，并且这个孩子节点被记录到`search_path`中。

之后将这个选择的动作给到`get_next_state`中，进行了状态转移，之后玩家身份也需要转换，状态也需要转换。

然后获取奖励，如果游戏没有结束，那么获取到的奖励将会是`None`，然后再次`expand`，获取到三个子节点，但是他们的`value`都是`None`，之后再回溯叠加奖励。

### UCB

UCB需要三个东西:

1. 孩子节点的先验概率。

2. 孩子节点的value，其实也就是对手的value。

3. 访问次数

In [4]:
import math

def ucb_score(parent, child):
    """
    The score for an action that would transition between the parent and child.
    """
    prior_score = child.prior * math.sqrt(parent.visit_count) / (child.visit_count + 1)
    if child.visit_count > 0:
        # The value of the child is from the perspective of the opposing player
        value_score = -child.value()
    else:
        value_score = 0

    return value_score + prior_score

## MCTS代码实现

In [5]:
class Node:
    def __init__(self, prior, to_play):
        self.visit_count = 0
        self.to_play = to_play
        self.prior = prior
        self.value_sum = 0
        self.children = {}
        self.state = None

    def expanded(self):
        return len(self.children) > 0

    def value(self):
        if self.visit_count == 0:
            return 0
        return self.value_sum / self.visit_count

    def select_action(self, temperature):
        """
        Select action according to the visit count distribution and the temperature.
        """
        visit_counts = np.array([child.visit_count for child in self.children.values()])
        actions = [action for action in self.children.keys()]
        if temperature == 0:
            action = actions[np.argmax(visit_counts)]
        elif temperature == float("inf"):
            action = np.random.choice(actions)
        else:
            # See paper appendix Data Generation
            visit_count_distribution = visit_counts ** (1 / temperature)
            visit_count_distribution = visit_count_distribution / sum(visit_count_distribution)
            action = np.random.choice(actions, p=visit_count_distribution)

        return action

    def select_child(self):
        """
        Select the child with the highest UCB score.
        """
        best_score = -np.inf
        best_action = -1
        best_child = None

        for action, child in self.children.items():
            score = ucb_score(self, child)
            if score > best_score:
                best_score = score
                best_action = action
                best_child = child

        return best_action, best_child

    def expand(self, state, to_play, action_probs):
        """
        We expand a node and keep track of the prior policy probability given by neural network
        """
        self.to_play = to_play
        self.state = state
        for a, prob in enumerate(action_probs):
            if prob != 0:
                self.children[a] = Node(prior=prob, to_play=self.to_play * -1)

    def __repr__(self):
        """
        Debugger pretty print node info
        """
        prior = "{0:.2f}".format(self.prior)
        return "{} Prior: {} Count: {} Value: {}".format(self.state.__str__(), prior, self.visit_count, self.value())

In [6]:
class MCTS:

    def __init__(self, game, model, args):
        self.game = game
        self.model = model
        self.args = args

    def run(self, model, state, to_play):

        root = Node(0, to_play)

        # EXPAND root
        action_probs, value = model.predict(state)
        valid_moves = self.game.get_valid_moves(state)
        action_probs = action_probs * valid_moves  # mask invalid moves
        action_probs /= np.sum(action_probs)
        root.expand(state, to_play, action_probs)

        for _ in range(self.args['num_simulations']):
            node = root
            search_path = [node]

            # SELECT
            while node.expanded():
                action, node = node.select_child()
                search_path.append(node)

            parent = search_path[-2]
            state = parent.state
            # Now we're at a leaf node and we would like to expand
            # Players always play from their own perspective
            next_state, _ = self.game.get_next_state(state, player=1, action=action)
            # Get the board from the perspective of the other player
            next_state = self.game.get_canonical_board(next_state, player=-1)

            # The value of the new state from the perspective of the other player
            value = self.game.get_reward_for_player(next_state, player=1)
            if value is None:
                # If the game has not ended:
                # EXPAND
                action_probs, value = model.predict(next_state)
                valid_moves = self.game.get_valid_moves(next_state)
                action_probs = action_probs * valid_moves  # mask invalid moves
                action_probs /= np.sum(action_probs)
                node.expand(next_state, parent.to_play * -1, action_probs)

            self.backpropagate(search_path, value, parent.to_play * -1)

        return root

    def backpropagate(self, search_path, value, to_play):
        """
        At the end of a simulation, we propagate the evaluation all the way up the tree
        to the root.
        """
        for node in reversed(search_path):
            node.value_sum += value if node.to_play == to_play else -value
            node.visit_count += 1

## 训练

In [7]:
import os
from random import shuffle
import torch.optim as optim


class Trainer:

    def __init__(self, game, model, args):
        self.game = game
        self.model = model
        self.args = args
        self.mcts = MCTS(self.game, self.model, self.args)

    def exceute_episode(self):

        train_examples = []
        current_player = 1
        episode_step = 0
        state = self.game.get_init_board()

        while True:
            episode_step += 1

            canonical_board = self.game.get_canonical_board(state, current_player)

            self.mcts = MCTS(self.game, self.model, self.args)
            root = self.mcts.run(self.model, canonical_board, to_play=1)

            action_probs = [0 for _ in range(self.game.get_action_size())]
            for k, v in root.children.items():
                action_probs[k] = v.visit_count

            action_probs = action_probs / np.sum(action_probs)
            train_examples.append((canonical_board, current_player, action_probs))

            action = root.select_action(temperature=0)
            state, current_player = self.game.get_next_state(state, current_player, action)
            reward = self.game.get_reward_for_player(state, current_player)

            if reward is not None:
                ret = []
                for hist_state, hist_current_player, hist_action_probs in train_examples:
                    # [Board, currentPlayer, actionProbabilities, Reward]
                    ret.append((hist_state, hist_action_probs, reward * ((-1) ** (hist_current_player != current_player))))

                return ret

    def learn(self):
        for i in range(1, self.args['numIters'] + 1):

            print("{}/{}".format(i, self.args['numIters']))

            train_examples = []

            for eps in range(self.args['numEps']):
                iteration_train_examples = self.exceute_episode()
                train_examples.extend(iteration_train_examples)

            shuffle(train_examples)
            self.train(train_examples)
            filename = self.args['checkpoint_path']
            self.save_checkpoint(folder=".", filename=filename)

    def train(self, examples):
        optimizer = optim.Adam(self.model.parameters(), lr=5e-4)
        pi_losses = []
        v_losses = []

        for epoch in range(self.args['epochs']):
            self.model.train()

            batch_idx = 0

            while batch_idx < int(len(examples) / self.args['batch_size']):
                sample_ids = np.random.randint(len(examples), size=self.args['batch_size'])
                boards, pis, vs = list(zip(*[examples[i] for i in sample_ids]))
                boards = torch.FloatTensor(np.array(boards).astype(np.float64))
                target_pis = torch.FloatTensor(np.array(pis))
                target_vs = torch.FloatTensor(np.array(vs).astype(np.float64))

                # predict
                boards = boards.contiguous()#.cuda()
                target_pis = target_pis.contiguous()#.cuda()
                target_vs = target_vs.contiguous()#.cuda()

                # compute output
                out_pi, out_v = self.model(boards)
                l_pi = self.loss_pi(target_pis, out_pi)
                l_v = self.loss_v(target_vs, out_v)
                total_loss = l_pi + l_v

                pi_losses.append(float(l_pi))
                v_losses.append(float(l_v))

                optimizer.zero_grad()
                total_loss.backward()
                optimizer.step()

                batch_idx += 1

            print()
            print("Policy Loss", np.mean(pi_losses))
            print("Value Loss", np.mean(v_losses))
            print("Examples:")
            print(out_pi[0].detach())
            print(target_pis[0])

    def loss_pi(self, targets, outputs):
        loss = -(targets * torch.log(outputs)).sum(dim=1)
        return loss.mean()

    def loss_v(self, targets, outputs):
        loss = torch.sum((targets-outputs.view(-1))**2)/targets.size()[0]
        return loss

    def save_checkpoint(self, folder, filename):
        if not os.path.exists(folder):
            os.mkdir(folder)

        filepath = os.path.join(folder, filename)
        torch.save({
            'state_dict': self.model.state_dict(),
        }, filepath)

In [8]:

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')



game = Connect2Game()
board_size = game.get_board_size()
action_size = game.get_action_size()

model = Connect2Model(board_size, action_size, device)

trainer = Trainer(game, model, args)
trainer.learn()

1/500


Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  return np.zeros((self.columns), dtype=np.int)



Policy Loss 1.4615188439687092
Value Loss 1.0158655941486359
Examples:
tensor([0.2474, 0.1966, 0.2601, 0.2958])
tensor([0.0300, 0.0400, 0.8800, 0.0500])

Policy Loss 1.4592657387256622
Value Loss 0.9948037962118784
Examples:
tensor([0.2299, 0.2020, 0.2514, 0.3167])
tensor([0.5500, 0.4500, 0.0000, 0.0000])
2/500

Policy Loss 1.443178693453471
Value Loss 0.9317807555198669
Examples:
tensor([0.2467, 0.2011, 0.2628, 0.2894])
tensor([0.0300, 0.0300, 0.8900, 0.0500])

Policy Loss 1.4429944852987926
Value Loss 0.9395932952562968
Examples:
tensor([0.2458, 0.2198, 0.2627, 0.2717])
tensor([0., 1., 0., 0.])
3/500

Policy Loss 1.4357839822769165
Value Loss 0.8991268078486124
Examples:
tensor([0.2457, 0.2055, 0.2658, 0.2830])
tensor([0.0300, 0.0300, 0.9100, 0.0300])

Policy Loss 1.4346220592657726
Value Loss 0.9015381038188934
Examples:
tensor([0.2447, 0.2263, 0.2625, 0.2665])
tensor([0., 1., 0., 0.])
4/500

Policy Loss 1.4544566075007122
Value Loss 0.9462345341841379
Examples:
tensor([0.2327, 0.2

KeyboardInterrupt: 

## 测试

测试代码在:

1. https://github.com/JoshVarty/AlphaZeroSimple/blob/ecf72a468aba26b8b155ec6fb1b91697a2fbb7a9/tests.py