<a href="https://colab.research.google.com/github/pashakhomchenko/AlphaZero/blob/master/AlphaZero.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

this notebook is my follow along and exploration of this AlphaZero [tutorial](https://www.youtube.com/watch?v=wuSQpLinRB4). let's dive in

# Tic Tac Toe


## Game Setup

In [1]:
import numpy as np
import math

import torch
import torch.nn as nn
import torch.nn.functional as F

torch.manual_seed(0)

from tqdm.notebook import trange
import random

In [2]:
class TicTacToe:
  def __init__(self):
    self.row_count = 3
    self.column_count = 3
    self.action_size = self.row_count * self.column_count

  def __repr__(self):
    return "TicTacToe"

  def get_initial_state(self):
    return np.zeros([self.row_count, self.column_count])

  def get_next_state(self, state, action, player):
    # action == 0 means top left corner, action == 9 - bottom right
    row = action // self.column_count
    column = action % self.column_count
    state[row, column] = player
    return state

  def get_valid_moves(self, state):
    # state.reshape(-1) - flattens the array
    return (state.reshape(-1) == 0).astype(np.uint8)

  def check_win(self, state, action):
    if action == None:
      return False

    row = action // self.column_count
    column = action % self.column_count
    player = state[row, column]
    return (
        np.sum(state[row, :]) == player * self.column_count
        or
        np.sum(state[:, column]) == player * self.row_count
        or
        np.sum(np.diag(state)) == player * self.row_count
        or
        # flipping the state to get the other diagonal
        np.sum(np.diag(np.flip(state))) == player * self.row_count
    )

  def get_value_and_terminated(self, state, action):
    if self.check_win(state, action):
      # win, reward is 1
      return 1, True
    if np.sum(self.get_valid_moves(state)) == 0:
      # draw, reward is 0
      return 0, True
    # continue the game
    return 0, False

  def get_opponent(self, player):
    return -player

  def get_opponent_value(self, value):
    return -value

  def change_perspective(self, state, player):
    return state * player

  def get_encoded_state(self, state):
    encoded_state = np.stack(
        (state == -1, state == 0, state == 1)
    ).astype(np.float32)

    # Handle batched state
    if len(state.shape) == 3:
      encoded_state = np.swapaxes(encoded_state, 0, 1)

    return encoded_state

## Monte Carlo Tree Search

In [3]:
class Node:
  def __init__(self, game, args, state, parent=None, action_taken=None):
    self.game = game
    self.args = args
    self.state = state
    self.parent = parent
    self.action_taken = action_taken

    self.children = []
    self.expandable_moves = game.get_valid_moves(state)

    self.visit_count = 0
    self.value_sum = 0

  def is_fully_expanded(self):
    # no moves and at least one child to select
    # if there are no moves and children, no child can be created - all options explored, cannot expand, continue selection
    # if there are no moves and no children, the game is over - is_terminated will be true, stop selection
    # if there are moves then we found a leaf node and can expand it, stop selection
    return np.sum(self.expandable_moves) == 0 and len(self.children) > 0

  def select(self):
    best_child = None
    best_ucb = -np.inf

    for child in self.children:
      ucb = self.get_ucb(child)
      if ucb > best_ucb:
        best_child = child
        best_ucb = ucb

    return best_child

  def get_ucb(self, child):
    # rescale to [0,1] range
    # Take inverse of q because the child is the opponent from perspective
    # of the parent, so we are looking for the worst q
    # (parent player 1, child player 2)
    q_value = 1 - (child.value_sum / child.visit_count + 1 ) / 2
    return q_value + self.args['C'] * math.sqrt(math.log(self.visit_count) / child.visit_count)

  def expand(self):
    # select random move
    action = np.random.choice(np.where(self.expandable_moves == 1)[0])
    # no more expandable
    self.expandable_moves[action] = 0

    # create the new state that child will take
    child_state = self.state.copy()
    child_state = self.game.get_next_state(child_state, action, 1)
    child_state = self.game.change_perspective(child_state, -1)

    # create child
    child = Node(self.game, self.args, child_state, self, action)
    self.children.append(child)
    return child

  def simulate(self):
    value, is_terminated = self.game.get_value_and_terminated(self.state, self.action_taken)
    value = self.game.get_opponent_value(value)

    if is_terminated:
      return value

    # random playing until the game finishes
    rollout_state = self.state.copy()
    rollout_player = 1
    while True:
      valid_moves = self.game.get_valid_moves(rollout_state)
      action = np.random.choice(np.where(valid_moves == 1)[0])
      rollout_state = self.game.get_next_state(rollout_state, action, rollout_player)
      value, is_terminated = self.game.get_value_and_terminated(rollout_state, action)
      if is_terminated:
        if rollout_player == -1:
          value = self.game.get_opponent_value(value)
        return value

      rollout_player = self.game.get_opponent(rollout_player)

  def backpropogate(self, value):
    # update yourself
    self.value_sum += value
    self.visit_count += 1

    # update parent
    value = self.game.get_opponent_value(value)
    if self.parent is not None:
      self.parent.backpropogate(value)


class MCTS:
  def __init__(self, game, args):
    self.game = game
    self.args = args

  def search(self, state):
    root = Node(self.game, self.args, state)

    for search in range(self.args['num_searches']):
      node = root

      # Selection
      while node.is_fully_expanded():
        node = node.select()

      value, is_terminated = self.game.get_value_and_terminated(node.state, node.action_taken)
      # value above is value of the opponent
      value = self.game.get_opponent_value(value)

      if not is_terminated:
        # Expansion
        node = node.expand()
        # Simulation
        value = node.simulate()

      # Backpropogation
      node.backpropogate(value)

    action_probs = np.zeros(self.game.action_size)
    for child in root.children:
      action_probs[child.action_taken] = child.visit_count
    action_probs /= np.sum(action_probs)
    return action_probs



## Game with MCTS

In [4]:
tictactoe = TicTacToe()
player = 1

args = {
    'C': 1.41,
    'num_searches': 1000
}
mcts = MCTS(tictactoe, args)

state = tictactoe.get_initial_state()

# while True:
#   print(state)

#   if player == 1:
#     valid_moves = tictactoe.get_valid_moves(state)
#     print("valid moves", [i for i in range(tictactoe.action_size) if valid_moves[i] == 1])
#     action = int(input(f"{player}:"))

#     if valid_moves[action] == 0:
#       print("action not valid")
#       continue
#   else:
#     neutral_state = tictactoe.change_perspective(state, player)
#     mcts_probs = mcts.search(neutral_state)
#     action = np.argmax(mcts_probs)

#   state = tictactoe.get_next_state(state, action, player)
#   value, is_terminated = tictactoe.get_value_and_terminated(state, action)

#   if is_terminated:
#     print(state)
#     if value == 1:
#       print(player, "won")
#     else:
#       print("draw")
#     break

#   # switch to the next player
#   player = tictactoe.get_opponent(player)

too easy

##AlphaMCTS

In [5]:
from torch.nn.modules.activation import Softmax
class ResNet(nn.Module):
  def __init__(self, game, num_resBlocks, num_hidden, device):
    super().__init__()
    self.device = device

    self.startBlock = nn.Sequential(
        nn.Conv2d(3, num_hidden, kernel_size=3, padding=1),
        nn.BatchNorm2d(num_hidden),
        nn.ReLU()
    )
    self.backBone = nn.ModuleList(
        [ResBlock(num_hidden) for i in range(num_resBlocks)]
    )
    self.policyHead = nn.Sequential(
        nn.Conv2d(num_hidden, 32, kernel_size=3, padding=1),
        nn.BatchNorm2d(32),
        nn.ReLU(),
        nn.Flatten(),
        nn.Linear(32 * game.row_count * game.column_count, game.action_size), # policy output
        # why not add softmax in here??
    )
    self.valueHead = nn.Sequential(
        nn.Conv2d(num_hidden, 3, kernel_size=3, padding=1),
        nn.BatchNorm2d(3),
        nn.ReLU(),
        nn.Flatten(),
        nn.Linear(3 * game.row_count * game.column_count, 1), # value output
        nn.Tanh()
    )

    self.to(device)

  def forward(self, x):
    x = self.startBlock(x)
    for resBlock in self.backBone:
      x = resBlock(x)
    policy = self.policyHead(x)
    value = self.valueHead(x)
    return policy, value

class ResBlock(nn.Module):
  def __init__(self, num_hidden):
    super().__init__()
    self.conv1 = nn.Conv2d(num_hidden, num_hidden, kernel_size=3, padding=1)
    self.bn1 = nn.BatchNorm2d(num_hidden)
    self.conv2 = nn.Conv2d(num_hidden, num_hidden, kernel_size=3, padding=1)
    self.bn2 = nn.BatchNorm2d(num_hidden)

  def forward(self, x):
    residual = x
    x = F.relu(self.bn1(self.conv1(x)))
    x = self.bn2(self.conv2(x))
    x += residual # skip connection, allows to mask out conv
    x = F.relu(x)
    return x


In [6]:
class AlphaNode:
  def __init__(self, game, args, state, parent=None, action_taken=None, prior=0, visit_count=0):
    self.game = game
    self.args = args
    self.state = state
    self.parent = parent
    self.action_taken = action_taken
    self.prior = prior # prob that is given by parent to this node

    self.children = []

    self.visit_count = visit_count
    self.value_sum = 0

  def is_fully_expanded(self):
    # now we exapnd in all direction immediately as we have policy
    return len(self.children) > 0

  def select(self):
    best_child = None
    best_ucb = -np.inf

    for child in self.children:
      ucb = self.get_ucb(child)
      if ucb > best_ucb:
        best_child = child
        best_ucb = ucb

    return best_child

  def get_ucb(self, child):
    # rescale to [0,1] range
    # Take inverse of q because the child is the opponent from perspective
    # of the parent, so we are looking for the worst q
    # (parent player 1, child player 2)
    if child.visit_count == 0:
        q_value = 0
    else:
        q_value = 1 - ((child.value_sum / child.visit_count) + 1) / 2
    return q_value + self.args['C'] * (math.sqrt(self.visit_count) / (child.visit_count + 1)) * child.prior

  def expand(self, policy):
    for action, prob in enumerate(policy):
      if prob > 0:
        # create the new state that child will take
        child_state = self.state.copy()
        child_state = self.game.get_next_state(child_state, action, 1)
        child_state = self.game.change_perspective(child_state, -1)

        # create child
        child = AlphaNode(self.game, self.args, child_state, self, action, prob)
        self.children.append(child)


  def backpropogate(self, value):
    # update yourself
    self.value_sum += value
    self.visit_count += 1

    # update parent
    value = self.game.get_opponent_value(value)
    if self.parent is not None:
      self.parent.backpropogate(value)

class AlphaMCTS:
  def __init__(self, game, args, model):
    self.game = game
    self.args = args
    self.model = model

  @torch.no_grad()
  def search(self, state):
    # Add visit count for the first UCB calculation
    root = AlphaNode(self.game, self.args, state, visit_count=1)

    policy, _ = self.model(
            torch.tensor(self.game.get_encoded_state(state), device=self.model.device).unsqueeze(0)
    )
    policy = torch.softmax(policy, axis=1).squeeze(0).cpu().numpy()

    # Add dirichlet noise to the starting policy to explore more at the start
    policy = (1 - self.args['dirichlet_epsilon']) * policy + self.args['dirichlet_epsilon'] \
            * np.random.dirichlet([self.args['dirichlet_alpha']] * self.game.action_size)

    valid_moves = self.game.get_valid_moves(state)
    policy *= valid_moves
    policy /= np.sum(policy) # turn into probs
    root.expand(policy)


    for search in range(self.args['num_searches']):
      node = root

      # Selection
      while node.is_fully_expanded():
        node = node.select()

      value, is_terminated = self.game.get_value_and_terminated(node.state, node.action_taken)
      # value above is value of the opponent
      value = self.game.get_opponent_value(value)

      if not is_terminated:
        policy, value = self.model(
            # batch dim, as we don't have batches, it's just a singleton dim
            torch.tensor(self.game.get_encoded_state(node.state), device=self.model.device).unsqueeze(0)
        )
        policy = torch.softmax(policy, axis=1).squeeze(0).cpu().numpy() # get rid of batch dim

        # mask out illiegal moves
        valid_moves = self.game.get_valid_moves(node.state)
        policy *= valid_moves
        policy /= np.sum(policy)

        value = value.item()

        # Expansion
        node.expand(policy)

        # Simulation - no longer needed
        # value = node.simulate()

      # Backpropogation
      node.backpropogate(value)

    action_probs = np.zeros(self.game.action_size)
    for child in root.children:
      action_probs[child.action_taken] = child.visit_count
    action_probs /= np.sum(action_probs)
    return action_probs

In [7]:
tictactoe = TicTacToe()
player = 1

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

args = {
    'C': 2,
    'num_searches': 1000
}
model = ResNet(tictactoe, 4, 64, device)
model.eval()

mcts = AlphaMCTS(tictactoe, args, model)

state = tictactoe.get_initial_state()

# while True:
#   print(state)

#   if player == 1:
#     valid_moves = tictactoe.get_valid_moves(state)
#     print("valid moves", [i for i in range(tictactoe.action_size) if valid_moves[i] == 1])
#     action = int(input(f"{player}:"))

#     if valid_moves[action] == 0:
#       print("action not valid")
#       continue
#   else:
#     neutral_state = tictactoe.change_perspective(state, player)
#     mcts_probs = mcts.search(neutral_state)
#     action = np.argmax(mcts_probs)

#   state = tictactoe.get_next_state(state, action, player)
#   value, is_terminated = tictactoe.get_value_and_terminated(state, action)

#   if is_terminated:
#     print(state)
#     if value == 1:
#       print(player, "won")
#     else:
#       print("draw")
#     break

#   # switch to the next player
#   player = tictactoe.get_opponent(player)

##AlphaZero

In [8]:
class AlphaZero:
  def __init__(self, model, optimizer, game, args):
    self.model = model
    self.optimizer = optimizer
    self.game = game
    self.args = args
    self.mcts = AlphaMCTS(game, args, model)

  def selfPlay(self):
    memory = [] # training data for the model for a single game
    player = 1
    state = self.game.get_initial_state()

    while True:
      neutral_state = self.game.change_perspective(state, player)
      action_probs = self.mcts.search(neutral_state)

      memory.append((neutral_state, action_probs, player))

      # explore / exploit control
      temperature_action_probs = action_probs ** (1 / self.args['temperature'])
      temperature_action_probs /= np.sum(temperature_action_probs)
      action = np.random.choice(self.game.action_size, p=temperature_action_probs)
      state = self.game.get_next_state(state, action, player)

      value, is_terminated = self.game.get_value_and_terminated(state, action)

      if is_terminated:
        returnMemory = []
        for hist_neutral_state, hist_action_probs, hist_player in memory:
          hist_outcome = value if hist_player == player else self.game.get_opponent_value(value)
          returnMemory.append((self.game.get_encoded_state(hist_neutral_state),
                           hist_action_probs,
                           hist_outcome))
        return returnMemory

      player = self.game.get_opponent(player)

  def train(self, memory):
    random.shuffle(memory)
    for batchIdx in range(0, len(memory), self.args['batch_size']):
      sample = memory[batchIdx:min(len(memory) - 1, batchIdx + self.args['batch_size'])]
      state, policy_targets, value_targets = zip(*sample)

      state, policy_targets, value_targets = np.array(state), np.array(policy_targets), np.array(value_targets).reshape(-1, 1)

      state = torch.tensor(state, dtype=torch.float32, device=self.model.device)
      policy_targets = torch.tensor(policy_targets, dtype=torch.float32, device=self.model.device)
      value_targets = torch.tensor(value_targets, dtype=torch.float32, device=self.model.device)

      out_policy, out_value = self.model(state)

      # how is this the right loss, if policy_targets are also partially generated by the model
      policy_loss = F.cross_entropy(out_policy, policy_targets)
      value_loss = F.mse_loss(out_value, value_targets)
      loss = policy_loss + value_loss

      self.optimizer.zero_grad()
      loss.backward()
      self.optimizer.step()

  def learn(self):
    for iteration in range(self.args['num_iterations']):
      memory = [] # training data for the model

      self.model.eval()
      for selfPlay_iteration in trange(self.args['num_selfPlay_iterations']):
        memory += self.selfPlay()

      self.model.train()
      for epoch in trange(self.args['num_epochs']):
        self.train(memory)

      torch.save(self.model.state_dict(), f"model_{iteration}_{self.game}.pt")
      torch.save(self.optimizer.state_dict(), f"optimizer_{iteration}_{self.game}.pt")

In [9]:
tictactoe = TicTacToe()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = ResNet(tictactoe, 4, 64, device)

optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=0.0001)

args = {
    'C': 2,
    'num_searches': 60,
    'num_iterations': 3,
    'num_selfPlay_iterations': 500,
    'num_epochs': 4,
    'batch_size': 64,
    'temperature': 1.25, # explore / exploit control
    'dirichlet_epsilon': 0.25, # noise params
    'dirichlet_alpha': 0.3
}

alphaZero = AlphaZero(model, optimizer, tictactoe, args)
alphaZero.learn()

  0%|          | 0/500 [00:00<?, ?it/s]

ERROR:root:Internal Python error in the inspect module.
Below is the traceback from this internal error.



Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/IPython/core/interactiveshell.py", line 3553, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-9-290e10f74a49>", line 22, in <cell line: 22>
    alphaZero.learn()
  File "<ipython-input-8-9eb5ec8f7279>", line 68, in learn
    memory += self.selfPlay()
  File "<ipython-input-8-9eb5ec8f7279>", line 16, in selfPlay
    action_probs = self.mcts.search(neutral_state)
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "<ipython-input-6-09e4800247fc>", line 76, in search
    policy, _ = self.model(
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "<ipython-input-5-5e997dc1b1bd>", line 35, in forward
    x = self.startBlock(x)
  File "/usr/local/lib/python3.10/dist-packages/torch

TypeError: ignored

In [10]:
import matplotlib.pyplot as plt

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

tictactoe = TicTacToe()

state = tictactoe.get_initial_state()
state = tictactoe.get_next_state(state, 2, -1)
state = tictactoe.get_next_state(state, 4, -1)
state = tictactoe.get_next_state(state, 6, 1)
state = tictactoe.get_next_state(state, 8, 1)

encoded_state = tictactoe.get_encoded_state(state) # 3 planes

tensor_state = torch.tensor(encoded_state, device=device).unsqueeze(0)

model = ResNet(tictactoe, 4, 64, device)
model.load_state_dict(torch.load("model_2.pt", map_location=device))
model.eval()

policy, value = model(tensor_state)
value = value.item()
policy = torch.softmax(policy, axis=1).squeeze(0).detach().cpu().numpy()

print(value)
print(state)

plt.bar(range(tictactoe.action_size), policy)
plt.show()

FileNotFoundError: ignored

#Connect Four

##Regular AlphaZero

In [11]:
class ConnectFour:
  def __init__(self):
    self.row_count = 6
    self.column_count = 7
    self.action_size = self.column_count
    self.in_a_row = 4

  def __repr__(self):
    return "ConnectFour"

  def get_initial_state(self):
    return np.zeros([self.row_count, self.column_count])

  def get_next_state(self, state, action, player):
    # action is a column that we are playing
    row = np.max(np.where(state[:, action] == 0))
    state[row, action] = player
    return state

  def get_valid_moves(self, state):
    # just check the top row
    return (state[0] == 0).astype(np.uint8)

  def check_win(self, state, action):
    if action == None:
      return False

    row = np.max(np.where(state[:, action] != 0))
    column = action
    player = state[row, column]

    def count(offset_row, offset_column):
      for i in range(1, self.in_a_row):
        r = row + offset_row * i
        c = action + offset_column * i
        if (
            r < 0
            or r >= self.row_count
            or c < 0
            or c >= self.column_count
            or state[r, c] != player
        ):
          return i - 1
      return self.in_a_row - 1

    return (
        count(1, 0) >= self.in_a_row - 1 # vertical
        or (count(0, 1) + count(0, -1)) >= self.in_a_row - 1 # horizontal
        or (count(1, 1) + count(-1, -1)) >= self.in_a_row - 1 # top left diagonal
        or (count(1, -1) + count(-1, 1)) >= self.in_a_row - 1 # top right diagonal
    )

  def get_value_and_terminated(self, state, action):
    if self.check_win(state, action):
      # win, reward is 1
      return 1, True
    if np.sum(self.get_valid_moves(state)) == 0:
      # draw, reward is 0
      return 0, True
    # continue the game
    return 0, False

  def get_opponent(self, player):
    return -player

  def get_opponent_value(self, value):
    return -value

  def change_perspective(self, state, player):
    return state * player

  def get_encoded_state(self, state):
    encoded_state = np.stack(
        (state == -1, state == 0, state == 1)
    ).astype(np.float32)

    # Handle batched state
    if len(state.shape) == 3:
      encoded_state = np.swapaxes(encoded_state, 0, 1)

    return encoded_state

This is way too slow with a larger neural net, so we will need to parallelize

In [12]:
game = ConnectFour()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = ResNet(game, 9, 128, device)

optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=0.0001)

args = {
    'C': 2,
    'num_searches': 600,
    'num_iterations': 8,
    'num_selfPlay_iterations': 500,
    'num_epochs': 4,
    'batch_size': 128,
    'temperature': 1.25, # explore / exploit control
    'dirichlet_epsilon': 0.25, # noise params
    'dirichlet_alpha': 0.3
}

alphaZero = AlphaZero(model, optimizer, game, args)
alphaZero.learn()

  0%|          | 0/500 [00:00<?, ?it/s]

KeyboardInterrupt: ignored

In [13]:
game = ConnectFour()
player = 1

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

args = {
    'C': 2,
    'num_searches': 20,
    'dirichlet_epsilon': 0.0, # noise params
    'dirichlet_alpha': 0.3
}
model = ResNet(game, 9, 128, device)
model.eval()

mcts = AlphaMCTS(game, args, model)

state = game.get_initial_state()

while True:
  print(state)

  if player == 1:
    valid_moves = game.get_valid_moves(state)
    print("valid moves", [i for i in range(game.action_size) if valid_moves[i] == 1])
    action = int(input(f"{player}:"))

    if valid_moves[action] == 0:
      print("action not valid")
      continue
  else:
    neutral_state = game.change_perspective(state, player)
    mcts_probs = mcts.search(neutral_state)
    action = np.argmax(mcts_probs)

  state = game.get_next_state(state, action, player)
  value, is_terminated = game.get_value_and_terminated(state, action)

  if is_terminated:
    print(state)
    if value == 1:
      print(player, "won")
    else:
      print("draw")
    break

  # switch to the next player
  player = game.get_opponent(player)

[[0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0.]]
valid moves [0, 1, 2, 3, 4, 5, 6]


KeyboardInterrupt: ignored

## Parallel AlphaZero

In [14]:
class AlphaMCTSParellel:
  def __init__(self, game, args, model):
    self.game = game
    self.args = args
    self.model = model

  @torch.no_grad()
  def search(self, states, spGames):
    policy, _ = self.model(
            torch.tensor(self.game.get_encoded_state(states), device=self.model.device)
    )
    policy = torch.softmax(policy, axis=1).cpu().numpy()

    # Add dirichlet noise to the starting policy to explore more at the start
    policy = (1 - self.args['dirichlet_epsilon']) * policy + self.args['dirichlet_epsilon'] \
            * np.random.dirichlet([self.args['dirichlet_alpha']] * self.game.action_size, size=policy.shape[0])

    for i, spg in enumerate(spGames):
      spg_policy = policy[i]
      valid_moves = self.game.get_valid_moves(states[i])
      spg_policy *= valid_moves
      spg_policy /= np.sum(spg_policy) # turn into probs

      # Save the root to spg class
      spg.root = AlphaNode(self.game, self.args, states[i], visit_count=1)
      spg.root.expand(spg_policy)


    for search in range(self.args['num_searches']):
      for spg in spGames:
        node = spg.root

        # Selection
        while node.is_fully_expanded():
          node = node.select()

        value, is_terminated = self.game.get_value_and_terminated(node.state, node.action_taken)
        # value above is value of the opponent
        value = self.game.get_opponent_value(value)

        if is_terminated:
          node.backpropogate(value)
        else:
          spg.node = node

      expandable_spGames = [mappingIdx for mappingIdx in range(len(spGames)) if spGames[mappingIdx].node is not None]

      if len(expandable_spGames) > 0:
        states = np.stack([spGames[mappingIdx].node.state for mappingIdx in expandable_spGames])
        policy, value = self.model(
            torch.tensor(self.game.get_encoded_state(states), device=self.model.device)
        )
        policy = torch.softmax(policy, axis=1).cpu().numpy()

      for i, mappingIdx in enumerate(expandable_spGames):
        node = spGames[mappingIdx].node
        spg_policy, spg_value = policy[i], value[i]

        valid_moves = self.game.get_valid_moves(node.state)
        spg_policy *= valid_moves
        spg_policy /= np.sum(spg_policy)

        # Expansion
        node.expand(spg_policy)
        node.backpropogate(spg_value)

In [18]:
class AlphaZeroParellel:
  def __init__(self, model, optimizer, game, args):
    self.model = model
    self.optimizer = optimizer
    self.game = game
    self.args = args
    self.mcts = AlphaMCTSParellel(game, args, model)

  def selfPlay(self):
    return_memory = [] # training data for the model for a single game
    player = 1
    spGames = [SPG(self.game) for spg in range(self.args['num_parallel_games'])]

    while len(spGames) > 0:
      states = np.stack([spg.state for spg in spGames])
      neutral_states = self.game.change_perspective(states, player)

      self.mcts.search(neutral_states, spGames)

      for i in range(len(spGames))[::-1]:
        spg = spGames[i]

        action_probs = np.zeros(self.game.action_size)
        for child in spg.root.children:
          action_probs[child.action_taken] = child.visit_count
        action_probs /= np.sum(action_probs)

        spg.memory.append((spg.root.state, action_probs, player))

        # explore / exploit control
        temperature_action_probs = action_probs ** (1 / self.args['temperature'])
        temperature_action_probs /= np.sum(temperature_action_probs)
        action = np.random.choice(self.game.action_size, p=temperature_action_probs)

        spg.state = self.game.get_next_state(spg.state, action, player)

        value, is_terminated = self.game.get_value_and_terminated(spg.state, action)

        if is_terminated:
          for hist_neutral_state, hist_action_probs, hist_player in spg.memory:
            hist_outcome = value if hist_player == player else self.game.get_opponent_value(value)
            return_memory.append((self.game.get_encoded_state(hist_neutral_state),
                            hist_action_probs,
                            hist_outcome))
            del spGames[i]

      player = self.game.get_opponent(player)

    return return_memory

  def train(self, memory):
    random.shuffle(memory)
    for batchIdx in range(0, len(memory), self.args['batch_size']):
      sample = memory[batchIdx:min(len(memory) - 1, batchIdx + self.args['batch_size'])]
      state, policy_targets, value_targets = zip(*sample)

      state, policy_targets, value_targets = np.array(state), np.array(policy_targets), np.array(value_targets).reshape(-1, 1)

      state = torch.tensor(state, dtype=torch.float32, device=self.model.device)
      policy_targets = torch.tensor(policy_targets, dtype=torch.float32, device=self.model.device)
      value_targets = torch.tensor(value_targets, dtype=torch.float32, device=self.model.device)

      out_policy, out_value = self.model(state)

      # how is this the right loss, if policy_targets are also partially generated by the model
      policy_loss = F.cross_entropy(out_policy, policy_targets)
      value_loss = F.mse_loss(out_value, value_targets)
      loss = policy_loss + value_loss

      self.optimizer.zero_grad()
      loss.backward()
      self.optimizer.step()

  def learn(self):
    for iteration in range(self.args['num_iterations']):
      memory = [] # training data for the model

      self.model.eval()
      for selfPlay_iteration in trange(self.args['num_selfPlay_iterations'] // self.args['num_parallel_games']):
        memory += self.selfPlay()

      self.model.train()
      for epoch in trange(self.args['num_epochs']):
        self.train(memory)

      torch.save(self.model.state_dict(), f"model_{iteration}_{self.game}.pt")
      torch.save(self.optimizer.state_dict(), f"optimizer_{iteration}_{self.game}.pt")

In [16]:
class SPG: # Self Play Game
  def __init__(self, game):
    self.state = game.get_initial_state()
    self.memory = []
    self.root = None
    self.node = None

In [None]:
game = ConnectFour()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = ResNet(game, 9, 128, device)

optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=0.0001)

args = {
    'C': 2,
    'num_searches': 600,
    'num_iterations': 8,
    'num_selfPlay_iterations': 500,
    'num_parallel_games': 250,
    'num_epochs': 4,
    'batch_size': 128,
    'temperature': 1.25, # explore / exploit control
    'dirichlet_epsilon': 0.25, # noise params
    'dirichlet_alpha': 0.3
}

alphaZero = AlphaZeroParellel(model, optimizer, game, args)
alphaZero.learn()

  0%|          | 0/2 [00:00<?, ?it/s]