In [1]:
import pygame
from abc import ABC, abstractmethod
from collections import defaultdict, namedtuple
import math
import chess
import random

class MCTS:
    "Monte Carlo tree searcher. First rollout the tree then choose a move."

    def __init__(self, exploration_weight=1):
        self.Q = defaultdict(int)  # total reward of each node
        self.N = defaultdict(int)  # total visit count for each node
        self.children = dict()  # children of each node
        self.exploration_weight = exploration_weight

    def choose(self, node):
        print(node)
        "Choose the best successor of node. (Choose a move in the game)"
        if node.is_terminal()[0]:
            raise RuntimeError(f"choose called on terminal node {node}")

        if node not in self.children:
            return node.find_random_child()

        def score(n):
            if self.N[n] == 0:
                return float("-inf")  # avoid unseen moves
            return self.Q[n] / self.N[n]  # average reward

        return max(self.children[node], key=score)

    def do_rollout(self, node):
        "Make the tree one layer better. (Train for one iteration.)"
        path = self._select(node)
        leaf = path[-1]
        self._expand(leaf)
        reward = self._simulate(leaf)
        self._backpropagate(path, reward)

    def _select(self, node):
        "Find an unexplored descendent of `node`"
        path = []
        while True:
            path.append(node)
            if node not in self.children or not self.children[node]:
                # node is either unexplored or terminal
                return path
            unexplored = self.children[node] - self.children.keys()
            if unexplored:
                n = unexplored.pop()
                path.append(n)
                return path
            node = self._uct_select(node)  # descend a layer deeper

    def _expand(self, node):
        "Update the `children` dict with the children of `node`"
        if node in self.children:
            return  # already expanded
        self.children[node] = node.find_children()

    def _simulate(self, node):
        "Returns the reward for a random simulation (to completion) of `node`"
        invert_reward = True
        while True:
            if node.is_terminal():
                reward = node.reward()
                return 1 - reward if invert_reward else reward
            node = node.find_random_child()
            invert_reward = not invert_reward

    def _backpropagate(self, path, reward):
        "Send the reward back up to the ancestors of the leaf"
        for node in reversed(path):
            self.N[node] += 1
            self.Q[node] += reward
            reward = 1 - reward  # 1 for me is 0 for my enemy, and vice versa

    def _uct_select(self, node):
        "Select a child of node, balancing exploration & exploitation"

        # All children of node should already be expanded:
        assert all(n in self.children for n in self.children[node])

        log_N_vertex = math.log(self.N[node])

        def uct(n):
            "Upper confidence bound for trees"
            return self.Q[n] / self.N[n] + self.exploration_weight * math.sqrt(
                log_N_vertex / self.N[n]
            )

        return max(self.children[node], key=uct)
    
class Node(ABC):
    """
    A representation of a single board state.
    MCTS works by constructing a tree of these Nodes.
    Could be e.g. a chess or checkers board state.
    """

    @abstractmethod
    def find_children(self):
        "All possible successors of this board state"
        return set()

    @abstractmethod
    def find_random_child(self):
        "Random successor of this board state (for more efficient simulation)"
        return None

    @abstractmethod
    def is_terminal(self):
        "Returns True if the node has no children"
        return True

    @abstractmethod
    def reward(self):
        "Assumes `self` is terminal node. 1=win, 0=loss, .5=tie, etc"
        return 0

    @abstractmethod
    def __hash__(self):
        "Nodes must be hashable"
        return 123456789

    @abstractmethod
    def __eq__(node1, node2):
        "Nodes must be comparable"
        return True


pygame 2.6.0 (SDL 2.28.4, Python 3.10.12)
Hello from the pygame community. https://www.pygame.org/contribute.html


In [2]:
"""
A chess implementation, will be used along side RL
"""

_CB = namedtuple("ChessBoard", "board turn winner terminal")

class ChessGame(_CB, Node):
    def find_children(self):
        if self.terminal:
            return set()
        children = set()
        for move in self.board.legal_moves:
            new_board = self.board.copy(stack=False)
            new_board.push(move)
            terminal, winner = self._check_terminal(new_board)
            children.add(ChessGame(board=new_board, turn=new_board.turn, winner=winner, terminal=terminal))
        return children

    def find_random_child(self):
        if self.terminal:
            return None
        move = random.choice(list(self.board.legal_moves))
        new_board = self.board.copy(stack=False)
        new_board.push(move)
        terminal, winner = self._check_terminal(new_board)
        return ChessGame(board=new_board, turn=new_board.turn, winner=winner, terminal=terminal)

    def is_terminal(self):
        return self.terminal, self.winner

    def _check_terminal(self, board):
        if board.is_checkmate():
            winner = not board.turn  # Winner is the opposite of the current turn
            return True, winner
        if board.is_stalemate() or board.is_insufficient_material() or board.can_claim_draw():
            return True, None  # Draw
        return False, None

    def reward(self):
        if not self.terminal:
            raise RuntimeError("reward called on non-terminal node")
        if self.winner is None:
            return 0.5  # Draw
        elif self.winner == chess.WHITE:
            return 1.0
        else:
            return 0.0

    def __hash__(self):
        return self.board.transposition_key()

    def __eq__(self, other):
        return isinstance(other, ChessGame) and self.board.transposition_key() == other.board.transposition_key()

In [3]:
def play_chess():
    game = ChessGame(board=chess.Board(), turn=True, winner=None, terminal=False)
    print(game.board)
    print("")
    
    while True:
        if game.is_terminal()[0]:
            result = game.board.result()
            print(f"Game over: {result}")
            break
        
        tree = MCTS()
        num_rollouts = 1000
        for _ in tqdm(range(num_rollouts), desc="MCTS Rollouts"):
            tree.do_rollout(game)
        
        game = tree.choose(game)
        print(game.board)
        print("")
        
        if game.is_terminal()[0]:
            result = game.board.result()
            print(f"Game over: {result}")
            break


In [4]:
play_chess()

ChessGame(board=Board('rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1'), turn=True, winner=None, terminal=False)
ChessGame(board=Board('rnbqkbnr/pppppppp/8/8/8/1P6/P1PPPPPP/RNBQKBNR b KQkq - 0 1'), turn=False, winner=None, terminal=False)
ChessGame(board=Board('rnbqkbnr/pppppp1p/8/6p1/8/1P6/P1PPPPPP/RNBQKBNR w KQkq - 0 2'), turn=True, winner=None, terminal=False)
ChessGame(board=Board('rnbqkbnr/pppppp1p/8/6p1/8/NP6/P1PPPPPP/R1BQKBNR b KQkq - 1 2'), turn=False, winner=None, terminal=False)


KeyboardInterrupt: 