In [None]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import math
from tqdm.notebook import trange
import random
import pickle
if torch.cuda.is_available():
  device = "cuda"
else:
  device = "cpu"
device

'cpu'

In [None]:
class Gomoku():
    """
    A class representing the gomoku game, played on a 15x15 board.
    The objective is to place five of your marks in a horizontal, vertical,
    or diagonal row. Default board size is 15x15.
    """

    def __init__(self, rows=15, columns=15):
      self.action_size = rows * columns
      self.columns = columns
      self.rows = rows
      self.in_a_row = 5 # number of consecutive marks required to win

    def __repr__(self):
      return "Gomoku"

    def get_initial_state(self):
      return np.zeros((self.rows, self.columns))

    def get_next_state(self, state, action, player):
      row = action // self.columns
      column = action % self.columns
      state[row, column] = player
      return state

    def get_valid_moves(self, state):
      return ((state == 0).astype(np.uint8))

    def check_win(self, state, action):
      if action == None:
        return False

      row = action // self.columns
      column = action % self.columns
      player = state[row, column]

      def count(offset_row, offset_column):
        for i in range(1, self.in_a_row):
          r = row + offset_row * i
          c = column + offset_column * i
          if (
            r < 0
            or r >= self.rows
            or c < 0
            or c >= self.columns
            or state[r][c] != player
            ):
              return i - 1
        return self.in_a_row - 1

      return (
        (count(1, 0) + count(-1, 0)) >= self.in_a_row - 1 # vertical
        or (count(0, 1) + count(0, -1)) >= self.in_a_row - 1 # horizontal
        or (count(1, 1) + count(-1, -1)) >= self.in_a_row - 1 # top left diagonal
        or (count(1, -1) + count(-1, 1)) >= self.in_a_row - 1 # top right diagonal
      )

    def get_value_and_terminated(self, state, action):
        if self.check_win(state, action):
            return 1, True
        if np.sum(self.get_valid_moves(state)) == 0:
            return 0, True
        return 0, False

    def get_opponent(self, player):
        return -player

    def get_opponent_value(self, value):
      return -value

    def change_perspective(self, state, player):
      return state * player

    def get_encoded_state(self, state):
      encoded_state = np.stack(
          (state == -1, state == 0, state == 1)
      ).astype(np.float32)

      if len(state.shape) == 3:
        encoded_state = np.swapaxes(encoded_state, 0, 1)

      return encoded_state

In [None]:
class ConnectFour():
    """
    Class representing ConnectFour game, 6x7 board.
    """

    def __init__(self, rows=6, columns=7):
      self.columns = 7
      self.rows = 6
      self.action_size = self.columns
      self.in_a_row = 4 # number of consecutive marks required to win

    def __repr__(self):
      return "ConnectFour"

    def get_initial_state(self):
      return np.zeros((self.rows, self.columns))

    def get_next_state(self, state, action, player):
      row = np.max(np.where(state[:, action] == 0))
      state[row, action] = player
      return state

    def get_valid_moves(self, state):
      return ((state[0] == 0).astype(np.uint8))

    def check_win(self, state, action):
      if action == None:
          return False

      row = np.min(np.where(state[:, action] != 0))
      column = action
      player = state[row][column]

      def count(offset_row, offset_column):
        for i in range(1, self.in_a_row):
          r = row + offset_row * i
          c = action + offset_column * i
          if (
            r < 0
            or r >= self.rows
            or c < 0
            or c >= self.columns
            or state[r][c] != player
            ):
              return i - 1
        return self.in_a_row - 1

      return (
        count(1, 0) >= self.in_a_row - 1 # vertical
        or (count(0, 1) + count(0, -1)) >= self.in_a_row - 1 # horizontal
        or (count(1, 1) + count(-1, -1)) >= self.in_a_row - 1 # top left diagonal
        or (count(1, -1) + count(-1, 1)) >= self.in_a_row - 1 # top right diagonal
      )

    def get_value_and_terminated(self, state, action):
        if self.check_win(state, action):
            return 1, True
        if np.sum(self.get_valid_moves(state)) == 0:
            return 0, True
        return 0, False

    def get_opponent(self, player):
        return -player

    def get_opponent_value(self, value):
      return -value

    def change_perspective(self, state, player):
      return state * player

    def get_encoded_state(self, state):
      encoded_state = np.stack(
          (state == -1, state == 0, state == 1)
      ).astype(np.float32)

      if len(state.shape) == 3:
        encoded_state = np.swapaxes(encoded_state, 0, 1)


      return encoded_state

In [None]:
class Node:
    """
    Represents a node in a Monte Carlo Tree Search (MCTS) used for decision-making
    in AlphaZero. Each node corresponds to a specific state of the game and contains information about its children, parent, visit counts and

    Attributes:
    - game: The game object, providing methods to compute game states and transitions.
    - args: A dictionary of arguments/hyperparameters
    - state: The current state of the game associated with this node.
    - parent: The parent node of this node. None if it is the root.
    - action_taken: The action that led to this node's state from the parent node.
    - prior: The prior probability of selecting this node.
    - children: A list of child nodes representing possible next states.
    - visit_count: The number of times this node has been visited during MCTS simulations.
    - value_sum: The cumulative value from simulations passing through this node.
    """

    def __init__(self, game, args, state, parent=None, action_taken=None, prior=0, visit_count=0):
        self.game = game
        self.args = args
        self.state = state
        self.parent = parent
        self.action_taken = action_taken
        self.prior = prior

        self.children = []

        self.visit_count = visit_count
        self.value_sum = 0

    def is_fully_expanded(self):
        return len(self.children) > 0

    def select(self):
        best_child = None
        best_ucb = -np.inf

        for child in self.children:
            ucb = self.get_ucb(child)
            if ucb > best_ucb:
                best_child = child
                best_ucb = ucb

        return best_child

    def get_ucb(self, child):
        if child.visit_count == 0:
            q_value = 0
        else:
            q_value = 1 - ((child.value_sum / child.visit_count) + 1) / 2
        return q_value + self.args['C'] * (math.sqrt(self.visit_count) / (child.visit_count + 1)) * child.prior

    def expand(self, policy):
        for action, prob in enumerate(policy):
            if prob > 0:
                child_state = self.state.copy()
                child_state = self.game.get_next_state(child_state, action, 1)
                child_state = self.game.change_perspective(child_state, player=-1)

                child = Node(self.game, self.args, child_state, self, action, prob)
                self.children.append(child)

    def backpropagate(self, value):
        self.value_sum += value
        self.visit_count += 1

        if self.parent is not None:
            value = self.game.get_opponent_value(value)
            self.parent.backpropagate(value)

In [None]:
class AlphaZeroParallel:
    """
    Runs AlphaZero algorithm in parallel.

    Attributes:
    - model: The neural network model used to predict policy and value outputs.
    - optimizer: The optimizer used for updating model weights.
    - policy_loss_fn: Loss function for the policy head of the model.
    - value_loss_fn: Loss function for the value head of the model.
    - game: An instance of the game class defining game mechanics and rules.
    - args: A dictionary of arguments/hyperparameters for training and MCTS.
    - verbose: A boolean for showning model accuracy.
    - mcts: An instance of the MCTSParallel class used for simulating moves.
    """

    def __init__(self, model, optimizer, policy_loss_fn, value_loss_fn, game, args, verbose=True):
        self.model = model
        self.optimizer = optimizer
        self.policy_loss_fn = policy_loss_fn
        self.value_loss_fn = value_loss_fn
        self.game = game
        self.args = args
        self.verbose = verbose
        self.mcts = MCTSParallel(game, args, model)

    def get_canonical_boards(self,state, action_probs, player):
        """
        state: np.ndarray shape of (rows, columns)
        action_probs: np.ndarray shape of (action_size)
        player: int

        returns:
        list of tuples (state, action_probs, player)

        Does data augmentation by flipping the state and action_probs.
        """
        game_name = repr(game)

        # gomoku allows for more configurations of the board than connect four.
        if game_name == "Gomoku":
            memory = []
            current_state = state.astype(np.int8)
            current_probs = action_probs.astype(np.float32)

            memory.extend([
                [current_state, current_probs, player],
                [np.flip(current_state, axis=0),
                np.flip(current_probs.reshape(15, 15), axis=0).reshape(225),
                player]
            ])

            for i in range(3):
                current_state = np.rot90(current_state)
                current_probs = current_probs.reshape(15, 15)
                current_probs = np.rot90(current_probs)
                current_probs = current_probs.reshape(225)

                # Create flipped versions
                memory.extend([
                    [current_state, current_probs, player],
                    [np.flip(current_state, axis=0),
                    np.flip(current_probs.reshape(15, 15), axis=0).reshape(225),
                    player]
                ])

        elif game_name == "ConnectFour":
            memory = [
              [state.astype(np.int8), action_probs.astype(np.float32), player],
              [np.flip(state.astype(np.int8), axis=1),
              np.flip(action_probs.astype(np.float32), axis=0),
              player]
            ]

        return memory

    def selfPlay(self):
        """
        returns:
        list of tuples (state, action_probs, player)

        Generates training data by simulating games using Monte Carlo Tree Search (MCTS).
        """

        return_memory = []
        player = 1
        spGames = (SPG(self.game) for _ in range(self.args['num_parallel_games']))
        active_games = list(spGames)

        while active_games:

            #monitor progress
            print(f"{len(active_games)} parallel games left")

            states = np.stack([spg.state for spg in active_games])
            neutral_states = self.game.change_perspective(states, player)

            self.mcts.search(neutral_states, active_games)

            for i in range(len(active_games))[::-1]:
                spg = active_games[i]

                action_probs = np.zeros(self.game.action_size)
                for child in spg.root.children:
                    action_probs[child.action_taken] = child.visit_count
                action_probs /= np.sum(action_probs)

                spg.memory.extend((self.get_canonical_boards(spg.root.state, action_probs, player)))
                #spg.memory.append((spg.root.state, action_probs, player)) (adding to memory without get_canonical_boards)

                temperature_action_probs = (action_probs ** (1 / self.args['temperature']))
                temperature_action_probs /= np.sum(temperature_action_probs)
                action = np.random.choice(self.game.action_size, p=temperature_action_probs)

                spg.state = self.game.get_next_state(spg.state, action, player)
                value, is_terminal = self.game.get_value_and_terminated(spg.state, action)

                if is_terminal:
                    for hist_neutral_state, hist_action_probs, hist_player in spg.memory:
                        hist_outcome = value if hist_player == player else self.game.get_opponent_value(value)
                        return_memory.append((
                            self.game.get_encoded_state(hist_neutral_state),
                            hist_action_probs,
                            hist_outcome
                        ))
                    del active_games[i]

            player = self.game.get_opponent(player)

        return return_memory

    def train(self, memory):
        """
        memory: list of tuples (state, action_probs, value)

        training loop for neural network
        """

        policy_losses = []
        value_losses = []
        random.shuffle(memory)
        for batchIdx in range(0, len(memory), self.args['batch_size']):
            sample = memory[batchIdx:batchIdx+self.args['batch_size']]
            state, policy_targets, value_targets = zip(*sample)
            state, policy_targets, value_targets = np.array(state), np.array(policy_targets), np.array(value_targets).reshape(-1, 1)

            state = torch.tensor(state, dtype=torch.float32, device=self.model.device)
            policy_targets = torch.tensor(policy_targets, dtype=torch.float32, device=self.model.device)
            value_targets = torch.tensor(value_targets, dtype=torch.float32, device=self.model.device)

            out_policy, out_value = self.model(state)

            policy_loss = self.policy_loss_fn(out_policy, policy_targets)
            value_loss = self.value_loss_fn(out_value, value_targets)
            loss = args['policy_value_bias'] * policy_loss + (1 - args['policy_value_bias']) * value_loss

            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

            policy_losses.append(policy_loss.item())
            value_losses.append(value_loss.item())

            # Visualize loss
            if self.verbose and batchIdx % 50 == 0:
                print(f"batch {batchIdx // self.args['batch_size']}: policy_loss mean: {np.mean(policy_losses):.3f} | value loss mean: {np.mean(value_losses):.3f}")

    def learn(self):
        """
        Iterates between selfplay and model training
        """

        for iteration in range(self.args['num_iterations']):
            memory = []

            # generate games
            self.model.eval()
            for selfPlay_iteration in range(self.args['num_selfPlay_iterations'] // self.args['num_parallel_games']):
                memory.extend(self.selfPlay())

            if memory != []:
              with open(f"iteration_{iteration}_memory.plk", "wb") as file:
                pickle.dump(memory, file)

            # train the model
            self.model.train()
            for epoch in range(self.args['num_epochs']):
                self.train(memory)

            torch.save(self.model.state_dict(), f"model_{iteration}_{self.game}.pt")
            torch.save(self.optimizer.state_dict(), f"optimizer_{iteration}_{self.game}.pt")

            torch.cuda.empty_cache() if torch.cuda.is_available() else None


class SPG:
    """
    'SelfPlayGame' class for managing game states and trees during selfplay.
    """

    def __init__(self, game):
        self.state = game.get_initial_state()
        self.memory = []
        self.root = None
        self.node = None

In [None]:
class MCTSParallel:
    """
    Performs Monte Carlo Tree Search (MCTS) in parallel.

    Attributes:
    - game: An instance of the game class defining game mechanics and rules.
    - args: A dictionary of arguments/hyperparameters for training and MCTS.
    - model: The neural network model used to predict policy and value outputs.
    """
    def __init__(self, game, args, model):
        self.game = game
        self.args = args
        self.model = model

    @torch.no_grad()
    def search(self, states, spGames):
        """
        states: np.ndarray shape of (parallelgames, rows, columns)
        spGames: list of SPG objects

        Performs MCTS simulations in parallel.
        """

        # get predictions (policy) and add noise
        policy, _ = self.model(
            torch.tensor(self.game.get_encoded_state(states), device=self.model.device)
        )
        policy = torch.softmax(policy, axis=1).cpu().numpy()
        policy = (1 - self.args['dirichlet_epsilon']) * policy + self.args['dirichlet_epsilon'] \
            * np.random.dirichlet([self.args['dirichlet_alpha']] * self.game.action_size, size=policy.shape[0])

        # normalize and mask invalid moves.
        game_name = repr(game)
        if game_name == "Gomoku":
          valid_moves = (states == 0).reshape(states.shape[0], -1).astype(np.uint8)
        elif game_name == "ConnectFour":
          valid_moves = (states[:, 0, :] == 0).astype(np.uint8)

        policy *= valid_moves
        policy /= np.sum(policy, axis=1, keepdims=True)

        # initialize root nodes
        for i, spg in enumerate(spGames):
            spg.root = Node(self.game, self.args, states[i], visit_count=1)
            spg.root.expand(policy[i])

        for search in range(self.args['num_mstc_searches']):
            for spg in spGames:
                spg.node = None
                node = spg.root

                while node.is_fully_expanded():
                    node = node.select()

                value, is_terminal = self.game.get_value_and_terminated(node.state, node.action_taken)
                value = self.game.get_opponent_value(value)

                if is_terminal:
                    node.backpropagate(value)

                else:
                    spg.node = node

            expandable_spGames = [mappingIdx for mappingIdx in range(len(spGames)) if spGames[mappingIdx].node is not None]

            if expandable_spGames:
                states = np.stack([spGames[mappingIdx].node.state for mappingIdx in expandable_spGames])

                policy, value = self.model(
                    torch.tensor(self.game.get_encoded_state(states), device=self.model.device)
                )
                policy = torch.softmax(policy, axis=1).cpu().numpy()

                if game_name == "Gomoku":
                    valid_moves = (states == 0).reshape(states.shape[0], -1).astype(np.uint8)
                elif game_name == "ConnectFour":
                    valid_moves = (states[:, 0, :] == 0).astype(np.uint8)

                policy *= valid_moves
                policy /= np.sum(policy, axis=1, keepdims=True)

                value = value.cpu().numpy()

            for i, mappingIdx in enumerate(expandable_spGames):
                node = spGames[mappingIdx].node
                node.expand(policy[i])
                node.backpropagate(value[i])

In [None]:
class ResNet(nn.Module):
  """
  Input: game state in shape (3, rows, columns). One dimension for player 1 pieces,
         one for player 2 pieces, and one for empty spaces.
  Output: policy in shape (action_size) and value in shape (1)
  """

  def __init__(self, game, num_resBlocks, num_hidden, device):
    super().__init__()

    self.device = device
    self.startBlock = nn.Sequential(
        nn.Conv2d(3, num_hidden, kernel_size=3, padding=1),
        nn.BatchNorm2d(num_hidden),
        nn.ReLU()
    )

    # Stack of residual blocks
    self.backBone = nn.ModuleList(
        [ResBlock(num_hidden) for i in range(num_resBlocks)]
    )

    # Policy head outputs action probabilities
    self.policyHead = nn.Sequential(
        nn.Conv2d(num_hidden, 32, kernel_size=3, padding=1),
        nn.BatchNorm2d(32),
        nn.ReLU(),
        nn.Flatten(),
        nn.Linear(32 * game.rows * game.columns, game.action_size)
    )

    # Value head estimates game state value between [-1,1]
    self.valueHead = nn.Sequential(
        nn.Conv2d(num_hidden, 3, kernel_size=3, padding=1),
        nn.BatchNorm2d(3),
        nn.ReLU(),
        nn.Flatten(),
        nn.Linear(3 * game.rows * game.columns, 1),
        nn.Tanh()
    )

    self.to(device)

  def forward(self, x):
    x = self.startBlock(x)
    for resBlock in self.backBone:
      x = resBlock(x)
    policy = self.policyHead(x)
    value = self.valueHead(x)

    return policy, value


class ResBlock(nn.Module):
  """
  Residual block with two convolutional layers and skip connection
  """

  def __init__(self, num_hidden):
     super().__init__()
     self.conv1 = nn.Conv2d(num_hidden, num_hidden, kernel_size=3, padding=1)
     self.bn1 = nn.BatchNorm2d(num_hidden)
     self.conv2 = nn.Conv2d(num_hidden, num_hidden, kernel_size=3, padding=1)
     self.bn2 = nn.BatchNorm2d(num_hidden)
     self.relu = nn.ReLU()

  def forward(self, x):
    residual = x
    x = self.relu(self.bn1(self.conv1(x)))
    x = self.bn2(self.conv2(x))
    x += residual
    x = self.relu(x)

    return x

In [None]:
game = Gomoku()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = ResNet(game, num_resBlocks=9, num_hidden=128, device=device)
#model.load_state_dict(torch.load("model_3_ConnectFour.pt", map_location=device))

optim = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=0.0001)
#optim.load_state_dict(torch.load("optimizer_3_ConnectFour.pt", map_location=device))

value_loss_fn = nn.MSELoss()
policy_loss_fn = nn.CrossEntropyLoss()

# hyperparameters
args = {
    'C': 2,                          # UCB exploration constant
    'num_mstc_searches': 1,        # Number of MCTS simulations per move
    'num_iterations': 4,             # Training iterations
    'num_selfPlay_iterations': 2,  # Self-play games per iteration
    'num_parallel_games': 2,       # Number of games to play in parallel
    'num_epochs': 4,                 # Training epochs per iteration
    'batch_size': 128,              # Training batch size
    'temperature': 1,                # Temperature for action selection
    'policy_value_bias': 0.5,        # Balance between policy and value loss
    'dirichlet_epsilon': 0.15,       # Exploration noise weight
    'dirichlet_alpha': 0.15         # Dirichlet distribution parameter
}

alphaZero = AlphaZeroParallel(model, optim, policy_loss_fn, value_loss_fn, game, args)
alphaZero.learn()

In [None]:
game = ConnectFour()
player = 1

args = {
    'C': 0.1,
    'num_mcts_searches': 50,
    'dirichlet_epsilon': 0.05,
    'dirichlet_alpha': 0.1
}

model1 = ResNet(game, 9, 128, device)
#model1.load_state_dict(torch.load("model_3_ConnectFour (1).pt", map_location=device))
model1.eval()

model2 = ResNet(game, 9, 128, device)
#model2.load_state_dict(torch.load("model_7_ConnectFour (1).pt", map_location=device))
model2.eval()

mcts1 = MCTS(game, args, model1)
mcts2 = MCTS(game, args, model2)

state = game.get_initial_state()


while True:
    print(state)

    if player == 1:
        neutral_state1 = game.change_perspective(state, player)
        mcts_probs1, pos_value1, value1 = mcts1.search(neutral_state1)
        action = np.argmax(mcts_probs1)

    else:
        neutral_state = game.change_perspective(state, player)
        mcts_probs, pos_value, value = mcts2.search(neutral_state)
        action = np.argmax(mcts_probs)

    state = game.get_next_state(state, action, player)

    value, is_terminal = game.get_value_and_terminated(state, action)

    if is_terminal:
        print(state)
        if value == 1:
            print(player, "won")
        else:
            print("draw")
        break

    player = game.get_opponent(player)

In [None]:
class TicTacToe():
    def __init__(self, rows=3, columns=3):
      self.action_size = rows * columns
      self.columns = columns
      self.rows = rows

    def __repr__(self):
      return "TicTacToe"

    def get_initial_state(self):
      return np.zeros((self.rows, self.columns))

    def get_next_state(self, state, action, player):
      row = action // self.columns
      column = action % self.columns
      state[row, column] = player
      return state

    def get_valid_moves(self, state):
      return ((state.reshape(-1) == 0).astype(np.uint8))

    def canonical_boards(self):
      pass

    def check_win(self, state, action):
      if action == None:
        return False

      row = action // self.columns
      column = action % self.columns
      player = state[row, column]

      return (
          np.sum(state[row, :]) == player * self.columns
          or np.sum(state[:, column]) == player * self.rows
          or np.sum(np.diag(state)) == player * self.rows
          or np.sum(np.diag(np.flip(state, axis=0))) == player * self.rows
      )

    def get_value_and_terminated(self, state, action):
        if self.check_win(state, action):
            return 1, True
        if np.sum(self.get_valid_moves(state)) == 0:
            return 0, True
        return 0, False

    def get_opponent(self, player):
        return -player

    def get_opponent_value(self, value):
      return -value

    def change_perspective(self, state, player):
      return state * player

    def get_encoded_state(self, state):
      encoded_state = np.stack(
          (state == -1, state == 0, state == 1)
      ).astype(np.float32)

      if len(state.shape) == 3:
        encoded_state = np.swapaxes(encoded_state, 0, 1)

      return encoded_state

In [None]:
class AlphaZero:
  def __init__(self, model, optimizer, policy_loss_fn, value_loss_fn, game, args):
    self.model = model
    self.optimizer = optimizer
    self.policy_loss_fn = policy_loss_fn
    self.value_loss_fn = value_loss_fn
    self.game = game
    self.args = args
    self.mcts = MCTS(game, args, model)

  def selfPlay(self):
    memory = []
    player = 1
    state = self.game.get_initial_state()

    while True:
       neutral_state = self.game.change_perspective(state, player)
       action_probs = self.mcts.search(neutral_state)

       memory.append((neutral_state, action_probs, player))

       temperature_action_probs = action_probs ** (1 / self.args['temperature'])
       action = np.random.choice(self.game.action_size, p=temperature_action_probs)

       state = self.game.get_next_state(state, action, player)

       value, is_terminal = self.game.get_value_and_terminated(state, action)

       if is_terminal:
        returnMemory = []
        for hist_neutral_state, hist_action_probs, hist_player in memory:
          hist_outcome = value if hist_player == player else self.game.get_opponent_value(value)
          returnMemory.append((
              self.game.get_encoded_state(hist_neutral_state),
              hist_action_probs,
              hist_outcome
          ))
        return returnMemory

        player = self.game.get_opponent(player)


  def train(self, memory):
    random.shuffle(memory)
    for batchIdx in range(0, len(memory), self.args['batch_size']):
      sample = memory[batchIdx:min(len(memory) - 1, batchIdx + self.args['batch_size'])]
      state, policy_targets, value_targets = zip(*sample)

      state, policy_targets, value_targets = np.array(state), np.array(policy_targets), np.array(value_targets).reshape(-1, 1)

      state = torch.tensor(state, dtype=torch.float32, device=self.model.device)
      policy_targets = torch.tensor(policy_targets, dtype=torch.float32, device=self.model.device)
      value_targets = torch.tensor(value_targets, dtype=torch.float32, device=self.model.device)

      out_policy, out_value = self.model(state)
      policy_loss = self.policy_loss_fn(out_policy, policy_targets)
      value_loss = self.value_loss_fn(out_value, value_targets)
      loss = policy_loss + value_loss

      self.optimizer.zero_grad()
      loss.backward()
      self.optimizer.step()

  def learn(self):
    for iteration in range(self.args['num_iterations']):
      memory = []

      self.model.eval()
      for selfPlay_iteration in range(self.args['num_selfPlay_iterations']):
        memory += self.selfPlay()

      self.model.train()
      for epoch in range(self.args['num_epochs']):
        self.train(memory)

      torch.save(self.model.state_dict(), f"model_{iteration}_{self.game}.pt")
      torch.save(self.optimizer.state_dict(), f"optimizer_{iteration}_{self.game}.pt")

In [None]:
class MCTS:
    """
    Monte Carlo Tree Search without parallelization for inference.

    Attributes:
    - game: An instance of the game class defining game mechanics and rules.
    - args: A dictionary of arguments/hyperparameters for training and MCTS.
    - model: The neural network model used to predict policy and value outputs.
    """

    def __init__(self, game, args, model):
        self.game = game
        self.args = args
        self.model = model

    @torch.no_grad()
    def search(self, state):
        root = Node(self.game, self.args, state, visit_count=1)

        policy, pos_value = self.model(
            torch.tensor(self.game.get_encoded_state(state), device=self.model.device).unsqueeze(0)
        )
        policy = torch.softmax(policy, axis=1).squeeze(0).cpu().numpy()
        policy = (1 - self.args['dirichlet_epsilon']) * policy + self.args['dirichlet_epsilon'] \
            * np.random.dirichlet([self.args['dirichlet_alpha']] * self.game.action_size)

        valid_moves = self.game.get_valid_moves(state).reshape(-1)
        policy *= valid_moves
        policy /= np.sum(policy)
        root.expand(policy)

        for search in range(self.args['num_mcts_searches']):
            node = root

            while node.is_fully_expanded():
                node = node.select()

            value, is_terminal = self.game.get_value_and_terminated(node.state, node.action_taken)
            value = self.game.get_opponent_value(value)

            if not is_terminal:
                policy, value = self.model(
                    torch.tensor(self.game.get_encoded_state(node.state), device=self.model.device).unsqueeze(0)
                )
                policy = torch.softmax(policy, axis=1).squeeze(0).cpu().numpy()
                valid_moves = self.game.get_valid_moves(node.state).reshape(-1)
                policy *= valid_moves
                policy /= np.sum(policy)

                value = value.item()

                node.expand(policy)

            node.backpropagate(value)


        action_probs = np.zeros(self.game.action_size)
        for child in root.children:
            action_probs[child.action_taken] = child.visit_count
        action_probs /= np.sum(action_probs)
        return action_probs, pos_value.item(), value