In [None]:
import keras
import chess
import numpy
import tensorflow
import time

In [None]:
!pip install mcts

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting mcts
  Downloading mcts-1.0.4-py3-none-any.whl (4.2 kB)
Installing collected packages: mcts
Successfully installed mcts-1.0.4


In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
model = keras.models.load_model('/content/drive/My Drive/models/model_test.h5')

In [None]:
squares_index = {
  'a': 0,
  'b': 1,
  'c': 2,
  'd': 3,
  'e': 4,
  'f': 5,
  'g': 6,
  'h': 7
}

def square_to_index(square):
  letter = chess.square_name(square)
  return 8 - int(letter[1]), squares_index[letter[0]]

def split_dims(board):
  board3d = numpy.zeros((14, 8, 8), dtype=numpy.int8)

  for piece in chess.PIECE_TYPES:
    for square in board.pieces(piece, chess.WHITE):
      idx = numpy.unravel_index(square, (8, 8))
      board3d[piece - 1][7 - idx[0]][idx[1]] = 1
    for square in board.pieces(piece, chess.BLACK):
      idx = numpy.unravel_index(square, (8, 8))
      board3d[piece + 5][7 - idx[0]][idx[1]] = 1


  aux = board.turn
  board.turn = chess.WHITE
  for move in board.legal_moves:
      i, j = square_to_index(move.to_square)
      board3d[12][i][j] = 1
  board.turn = chess.BLACK
  for move in board.legal_moves:
      i, j = square_to_index(move.to_square)
      board3d[13][i][j] = 1
  board.turn = aux

  return board3d
def eval(board):
  board3d = split_dims(board)
  board3d = numpy.expand_dims(board3d, 0)
  return model(board3d)

In [None]:
board = chess.Board()
print(tensorflow.keras.backend.get_value(eval(board)))

[[0.50175804]]


In [None]:
from __future__ import division

import time
import math
import random
import statistics
from random import randrange


def randomPolicy(state):  
    if not state.isTerminal():
        i=0
        max = randrange(1,5)
        while not state.isTerminal() and i <max:      
            action = random.choice(state.getPossibleActions())
            state = state.takeAction(action)
            i+=1
        if state.isTerminal():
            reward = state.getReward()
        else:
            reward = tensorflow.keras.backend.get_value(eval(state.board))[0][0]
    else:
        reward = state.getReward()
    return reward


class treeNode():
    def __init__(self, state, parent):
        self.state = state
        if state.isTerminal():
          if state.getReward() == 1:
            self.sideToMove = False
          if state.getReward() == 0:
            self.sideToMove = True
        else:
          self.sideToMove = state.board.turn
        self.isTerminal = state.isTerminal()
        self.isFullyExpanded = self.isTerminal
        self.parent = parent
        self.numVisits = 0
        self.totalReward = 0
        self.children = {}


class mcts():
    def __init__(self, timeLimit=None, iterationLimit=None, explorationConstant=1 / math.sqrt(2),
                 rolloutPolicy=randomPolicy):
        if timeLimit != None:
            if iterationLimit != None:
                raise ValueError("Cannot have both a time limit and an iteration limit")
            # time taken for each MCTS search in milliseconds
            self.timeLimit = timeLimit
            self.limitType = 'time'
        else:
            if iterationLimit == None:
                raise ValueError("Must have either a time limit or an iteration limit")
            # number of iterations of the search
            if iterationLimit < 1:
                raise ValueError("Iteration limit must be greater than one")
            self.searchLimit = iterationLimit
            self.limitType = 'iterations'
        self.explorationConstant = explorationConstant
        self.rollout = rolloutPolicy

    def search(self, initialState):
        self.root = treeNode(initialState, None)

        if self.limitType == 'time':
            timeLimit = time.time() + self.timeLimit / 1000
            i=0
            while time.time() < timeLimit:
                self.executeRound()
                bestChild = self.getBestChild(self.root, 0) 
                print(i, ": ", self.getAction(self.root, bestChild))
                i = i+1
        else:
            for i in range(self.searchLimit):
                self.executeRound()
                bestChild = self.getBestChild(self.root, 0)            
                print(i, ": ", self.getAction(self.root, bestChild))                      
        bestChild = self.getBestChild(self.root, 0)
        return self.getAction(self.root, bestChild)

    def executeRound(self):
        node = self.selectNode(self.root)
        reward = self.rollout(node.state)
        self.backpropogate(node, reward)

    def selectNode(self, node):
        while not node.isTerminal:
            if node.isFullyExpanded:
                node = self.getBestChild(node, self.explorationConstant)
            else:
                return self.expand(node)
        return node

    def expand(self, node):
        actions = node.state.getPossibleActions()
        for action in actions:
            if action not in node.children.keys():
                newNode = treeNode(node.state.takeAction(action), node)
                node.children[action] = newNode
                if len(actions) == len(node.children):
                    node.isFullyExpanded = True
                return newNode

        raise Exception("Should never reach here")

    def backpropogate(self, node, reward):
        while node is not None:
            if node.sideToMove == False:
              node.totalReward += reward
              #print(reward)         
            else:
              node.totalReward +=  1 - reward
              #print(1 - reward)         
            node.numVisits += 1     
            node = node.parent
    def getBestChild(self, node, explorationValue):
        bestValue = float("-inf")
        bestNodes = []
        for child in node.children.values():
            nodeValue = child.totalReward / child.numVisits + explorationValue * math.sqrt(
                2 * math.log(node.numVisits) / child.numVisits)
            if nodeValue > bestValue:
                bestValue = nodeValue
                bestNodes = [child]
            elif nodeValue == bestValue:
                bestNodes.append(child)
        return random.choice(bestNodes)

    def getAction(self, root, bestChild):
        for action, node in root.children.items():
            if node is bestChild:
                return action


In [None]:
import chess
import copy

class State:
    def __init__(self, board):
        self.board = board
    

    def getCurrentPlayer(self):
        return self.currentPlayer

    def getPossibleActions(self):
        moves = list(self.board.legal_moves)
        return moves

    def takeAction(self, move):
        new_state = copy.deepcopy(self)
        new_state.board.push(move)
        #print(new_state.board)
        return new_state
    def isTerminal(self):
        if self.board.is_stalemate() or self.board.is_checkmate() or self.board.can_claim_fifty_moves():
            return True
        else:
            return False

    def getReward(self):
        if self.board.result() == '1-0':
              return 1      
        if self.board.result() == '0-1':          
              return 0           
        else:
            return 0.5

In [None]:
board = chess.Board('4r1k1/pp1p1p1p/6p1/3p4/3b1PPq/4nB1P/P3PK2/7R w - - 8 27')

initialState = State(board)
searcher = mcts(iterationLimit=100000, explorationConstant= math.sqrt(2))
action = searcher.search(initialState=initialState)
print(action)