In [33]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

torch.manual_seed(0)
from tqdm.notebook import trange

import random
import math
import chess

In [41]:

class ChessGame:
    def __init__(self):
        self.action_size = 4096

    def get_initial_state(self):
        return chess.Board()

    def get_next_state(self, state, action, player):
        next_state = state.copy()
        next_state.push(action)
        return next_state

    def get_valid_moves(self, state):
        valid_moves = np.zeros(self.action_size)
        for move in state.legal_moves:
            valid_moves[self.get_action_index(move)] = 1
        return valid_moves

    def get_value_and_terminated(self, state, action):
        next_state = state.copy()
        next_state.push(action)
        if next_state.is_checkmate():
            if state.turn:
                return -1, True  # Black wins
            else:
                return 1, True  # White wins
        elif next_state.is_stalemate() or next_state.is_insufficient_material():
            return 0, True  # Draw
        else:
            return 0, False
    def is_game_over(self,state):
        if state.is_checkmate() or state.is_insufficient_material() or state.is_stalemate():
            return False
        else:
            return True

    def get_opponent(self, player):
        return not player

    def get_opponent_value(self, value):
        return -value

    def change_perspective(self, state, player):
        if player:
            return state
        else:
            return state.mirror()

    def get_encoded_state(self, state):
        encoded_state = np.zeros((12, 8, 8), dtype=np.float32)
        for square, piece in state.piece_map().items():
            piece_index = self.get_piece_index(piece)
            rank = chess.square_rank(square)
            file = chess.square_file(square)
            encoded_state[piece_index][rank][file] = 1
        return encoded_state

    def get_piece_index(self, piece):
        piece_type = piece.piece_type
        color = piece.color
        if piece_type == chess.PAWN:
            return 0 if color else 6
        elif piece_type == chess.KNIGHT:
            return 1 if color else 7
        elif piece_type == chess.BISHOP:
            return 2 if color else 8
        elif piece_type == chess.ROOK:
            return 3 if color else 9
        elif piece_type == chess.QUEEN:
            return 4 if color else 10
        elif piece_type == chess.KING:
            return 5 if color else 11

    def get_action_index(self, move):
        from_square = move.from_square
        to_square = move.to_square
        promotion = move.promotion
        action_index = from_square * 64 + to_square
        if promotion is not None:
            action_index += promotion.piece_type
        return action_index


In [42]:
class ResNetChess(nn.Module):
    def __init__(self, game, num_resBlocks, num_hidden, device):
        super().__init__()

        self.device = device
        self.startBlock = nn.Sequential(
            nn.Conv2d(12, num_hidden, kernel_size=3, padding=1),
            nn.BatchNorm2d(num_hidden),
            nn.ReLU()
        )

        self.backBone = nn.ModuleList(
            [ResBlock(num_hidden) for _ in range(num_resBlocks)]
        )

        self.policyHead = nn.Sequential(
            nn.Conv2d(num_hidden, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(32 * 8 * 8, game.action_size)
        )

        self.valueHead = nn.Sequential(
            nn.Conv2d(num_hidden, 1, kernel_size=1),
            nn.BatchNorm2d(1),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(8 * 8, num_hidden),
            nn.ReLU(),
            nn.Linear(num_hidden, 1),
            nn.Tanh()
        )

    def forward(self, x):
        out = self.startBlock(x)
        for block in self.backBone:
            out = block(out)
        policy = self.policyHead(out)
        value = self.valueHead(out)
        return F.softmax(policy, dim=1), value

class ResBlock(nn.Module):
    def __init__(self, num_hidden):
        super().__init__()

        self.block = nn.Sequential(
            nn.Conv2d(num_hidden, num_hidden, kernel_size=3, padding=1),
            nn.BatchNorm2d(num_hidden),
            nn.ReLU(),
            nn.Conv2d(num_hidden, num_hidden, kernel_size=3, padding=1),
            nn.BatchNorm2d(num_hidden),
        )

    def forward(self, x):
        return x + self.block(x)


In [43]:
import math
import numpy as np

class Node:
    def __init__(self, state, action=None, parent=None):
        self.state = state
        self.action = action
        self.parent = parent
        self.children = []
        self.value = 0
        self.visit_count = 0

    def is_fully_expanded(self):
        return len(self.children) == len(self.state.legal_moves)

    def add_child(self, child):
        self.children.append(child)

    def update(self, value):
        self.value += value
        self.visit_count += 1

    def get_ucb_score(self, exploration_constant):
        if self.visit_count == 0:
            return math.inf
        return (self.value / self.visit_count) + exploration_constant * math.sqrt(
            math.log(self.parent.visit_count) / self.visit_count
        )

    def select_best_child(self, exploration_constant):
        best_score = float("-inf")
        best_child = None

        for child in self.children:
            score = child.get_ucb_score(exploration_constant)
            if score > best_score:
                best_score = score
                best_child = child

        return best_child

class MCTS:
    def __init__(self, game, model,args):
        self.game = game
        self.model = model
        self.args=args

    def search(self, state, player):
        root = Node(state)
        for _ in range(self.args["num_searches"]):
            leaf = self._traverse(root)
            value = self._simulate(leaf.state)
            self._backpropagate(leaf, value)
        best_child = root.select_best_child(0)
        return best_child.action

    def _traverse(self, node):
        while node.is_fully_expanded():
            node = node.select_best_child(self.args["exploration_constant"])
        if not node.state.is_game_over():
            action = self._select_unexplored_action(node)
            next_state = self.game.get_next_state(node.state, action)
            child_node = Node(next_state, action, node)
            node.add_child(child_node)
            return child_node
        return node

    def _select_unexplored_action(self, node):
        valid_moves = self.game.get_valid_moves(node.state)
        explored_actions = [child.action for child in node.children]
        unexplored_actions = np.setdiff1d(valid_moves, explored_actions)
        return np.random.choice(unexplored_actions)

    def _simulate(self, state):
        current_player = state.turn
        while not state.is_game_over():
            action = self._select_random_action(state)
            state = self.game.get_next_state(state, action)
            current_player = self.game.get_opponent(current_player)
        value = self.game.get_opponent_value(self.game.get_value_and_terminated(state, action)[0])
        return value

    def _select_random_action(self, state):
        valid_moves = self.game.get_valid_moves(state)
        valid_actions = np.where(valid_moves == 1)[0]
        return np.random.choice(valid_actions)

    def _backpropagate(self, node, value):
        while node is not None:
            node.update(value)
            value = self.game.get_opponent_value(value)
            node = node.parent


In [44]:
class AlphaZeroChess:
    def __init__(self, model, optimizer, game, args):
        self.model = model
        self.optimizer = optimizer
        self.game = game
        self.args = args
        self.mcts = MCTS(game, args, model)

    def self_play(self):
        memory = []
        state = self.game.get_initial_state()
        player=chess.WHITE
        while not self.game.is_game_over(state):
            neutral_state = self.game.change_perspective(state, player)
            action_probs = self.mcts.search(neutral_state)

            memory.append((neutral_state, action_probs, player))

            temperature_action_probs = action_probs ** (1 / self.args['temperature'])
            action = np.random.choice(self.game.get_action_size(), p=temperature_action_probs)

            state = self.game.get_next_state(state, action)
            player = self.game.get_opponent(player)

        return memory

    def train(self, memory):
        random.shuffle(memory)
        for batch_idx in range(0, len(memory), self.args['batch_size']):
            batch = memory[batch_idx:batch_idx + self.args['batch_size']]
            states, policy_targets, value_targets = zip(*batch)

            states = np.array(states)
            policy_targets = np.array(policy_targets)
            value_targets = np.array(value_targets).reshape(-1, 1)

            states = torch.tensor(states, dtype=torch.float32, device=self.model.device)
            policy_targets = torch.tensor(policy_targets, dtype=torch.float32, device=self.model.device)
            value_targets = torch.tensor(value_targets, dtype=torch.float32, device=self.model.device)

            self.optimizer.zero_grad()
            out_policy, out_value = self.model(states)
            policy_loss = F.cross_entropy(out_policy, policy_targets)
            value_loss = F.mse_loss(out_value, value_targets)
            loss = policy_loss + value_loss
            loss.backward()
            self.optimizer.step()

    def learn(self):
        for iteration in range(self.args['num_iterations']):
            memory = []

            self.model.eval()
            for self_play_iter in trange(self.args['num_self_play_iterations']):
                memory += self.self_play()

            self.model.train()
            for epoch in trange(self.args['num_epochs']):
                self.train(memory)

            torch.save(self.model.state_dict(), f"model_{iteration}.pt")
            torch.save(self.optimizer.state_dict(), f"optimizer_{iteration}.pt")


In [45]:
tictactoe = ChessGame()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = ResNetChess(tictactoe, 4, 64, device)

optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=0.0001)

args = {
    'C': 2,
    'num_searches': 60,
    'num_iterations': 3,
    'num_self_play_iterations': 500,
    'num_epochs': 4,
    'batch_size': 64,
    'temperature': 1.25,
    'exploration_constant':1.25
}

alphaZero = AlphaZeroChess(model, optimizer, tictactoe, args)
alphaZero.learn()

  0%|          | 0/500 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/500 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/500 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]