## RL POLICY NETWORK USING MCTS SELF PLAY

In [1]:
import chess
import pygame
import random
import numpy as np
import tensorflow as tf
from IPython.display import Audio
from collections import defaultdict

pygame 2.5.2 (SDL 2.28.3, Python 3.11.7)
Hello from the pygame community. https://www.pygame.org/contribute.html


UTILITY FUNCTIONS

In [2]:
def parse_fen(fen):
    #print(fen)
    #flipping the board for perspective
    file_to_num = {'a': 1, 'b': 2, 'c': 3, 'd':4, 'e':5, 'f':6, 'g':7, 'h':8}
    enc = np.zeros([8,8,18]).astype(int)
    fen_elem = fen.split(' ')
    fen_pos = fen_elem[0]
    if fen_elem[1] == 'w':
        player = 1
        enc_dict = {"R":0, "N":1, "B":2, "Q":3, "K":4, "P":5, "r":6, "n":7, "b":8, "q":9, "k":10, "p":11}
    else:
        player = 0
        enc_dict = {"r":0, "n":1, "b":2, "q":3, "k":4, "p":5, "R":6, "N":7, "B":8, "Q":9, "K":10, "P":11}
    enc[:,:,12] = player
    castle = fen_elem[2]
    if player:
        if 'Q' not in castle:
            enc[:,:,13] = 1
        if 'K' not in castle:
            enc[:,:,14] = 1
        if 'q' not in castle:
            enc[:,:,15] = 1
        if 'k' not in castle:
            enc[:,:,16] = 1
    else:
        if 'Q' not in castle:
            enc[:,:,15] = 1
        if 'K' not in castle:
            enc[:,:,16] = 1
        if 'q' not in castle:
            enc[:,:,13] = 1
        if 'k' not in castle:
            enc[:,:,14] = 1
    enc[:,:,17] = int(fen_elem[-1])
    ranks = fen_pos.split('/')
    for i, rank in enumerate(ranks):
        j = 0
        k = 0
        while k < len(rank):
            if rank[k].isdigit():
                j += int(rank[k])
                k += 1
                continue
            enc[(7 - i)*player + (1 - player)*i, j*player + (1 - player)*(7 - j), enc_dict[rank[k]]] = 1
            j += 1
            k += 1
    return enc

In [3]:
def parse_actions(move, fen):
    if fen.split(' ')[1] == 'w':
        player = 1
    else:
        player = 0
    enc = np.zeros([8,8,73]).astype(int)
    file_to_num = {'a': 1, 'b': 2, 'c': 3, 'd':4, 'e':5, 'f':6, 'g':7, 'h':8}
    init_file = file_to_num[move[0]]
    init_rank = int(move[1])
    final_file = file_to_num[move[2]]
    final_rank = int(move[3])
    under_promo = move[-1] if move[-1] in ['n', 'N', 'b', 'B', 'r', 'R'] else None
    file_diff = (init_file - final_file)*player + (final_file - init_file)*(1 - player)
    rank_diff = (init_rank - final_rank)*player + (final_rank - init_rank)*(1 - player)
    if under_promo is not None: #underpromotions
        if file_diff == 0: #no capture
            if under_promo in ['r', 'R']:
                idx = 64
            elif under_promo in ['b', 'B']:
                idx = 65
            else:
                idx = 66
        elif file_diff > 0: # left capture
            if under_promo in ['r', 'R']:
                idx = 67
            elif under_promo in ['b', 'B']:
                idx = 68
            else:
                idx = 69
        else: # right capture
            if under_promo in ['r', 'R']:
                idx = 70
            elif under_promo in ['b', 'B']:
                idx = 71
            else:
                idx = 72
    elif file_diff == 0: #NS direction
        if rank_diff < 0: #upward
            idx = -rank_diff - 1
        else: #downward
            idx = 6 + rank_diff
    elif rank_diff == 0: #EW direction
        if file_diff < 0: #right
            idx = 13 - file_diff
        else: #left
            idx = 20 + file_diff
    elif abs(file_diff) == abs(rank_diff): #diagonal moves
        if rank_diff < 0 and file_diff < 0: #NE
            idx = 27 - rank_diff
        elif rank_diff > 0 and file_diff > 0: #SW
            idx = 34 + rank_diff
        elif rank_diff < 0 and file_diff > 0: #NW
            idx = 41 + file_diff
        else: #SE
            idx = 48 + rank_diff
    elif file_diff == 1 and rank_diff == -2:
        idx = 56
    elif file_diff == -1 and rank_diff == -2:
        idx = 57
    elif file_diff == 2 and rank_diff == 1:
        idx = 58
    elif file_diff == 2 and rank_diff == -1:
        idx = 59
    elif file_diff == 1 and rank_diff == 2:
        idx = 60
    elif file_diff == -1 and rank_diff == 2:
        idx = 61
    elif file_diff == -2 and rank_diff == 1:
        idx = 62
    elif file_diff == -2 and rank_diff == -1:
        idx = 63
    enc[(init_rank - 1)*player + (8 - init_rank)*(1 - player), (init_file - 1)*player + (8 - init_file)*(1 - player), idx] = 1
    return enc.flatten(), (init_file - 1)*player + (8 - init_file)*(1 - player), (init_rank - 1)*player + (8 - init_rank)*(1 - player), idx

In [4]:
def decode_action(out_policy, fen):
    policy_dict = dict({})
    board = chess.Board(fen)
    legal_moves = list(board.legal_moves)
    for move in legal_moves:
        _, x, y, z = parse_actions(str(move), fen) #x = file, y = rank
        policy_dict[move] = out_policy[y, x, z]
    return policy_dict 

In [5]:
def beep():
    beep_sound = Audio(filename = "beep-01.wav", autoplay = True)
    return beep_sound

NETWORK BLOCKS

In [6]:
@tf.keras.utils.register_keras_serializable()
class ConvBlock(tf.keras.layers.Layer):
    def __init__(self, filters, kernel_size, **kwargs):
        super(ConvBlock, self).__init__(**kwargs)
        self.conv = tf.keras.layers.Conv2D(filters, kernel_size, padding = 'same')
        self.batch_norm = tf.keras.layers.BatchNormalization()
        self.relu = tf.keras.layers.ReLU()

    def call(self, inputs):
        x = self.conv(inputs)
        x = self.batch_norm(x)
        x = self.relu(x)
        return x

@tf.keras.utils.register_keras_serializable()
class ResBlock(tf.keras.layers.Layer):
    def __init__(self, filters, kernel_size, **kwargs):
        super(ResBlock, self).__init__(**kwargs)
        self.conv1 = tf.keras.layers.Conv2D(filters, kernel_size, padding = 'same')
        self.batch_norm1 = tf.keras.layers.BatchNormalization()
        self.relu = tf.keras.layers.ReLU()
        self.conv2 = tf.keras.layers.Conv2D(filters, kernel_size, padding = 'same')
        self.batch_norm2 = tf.keras.layers.BatchNormalization()

    def call(self, inputs):
        out = self.conv1(inputs)
        out = self.batch_norm1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.batch_norm2(out)
        return self.relu(out + inputs)
    
@tf.keras.utils.register_keras_serializable()
class PolicyHead(tf.keras.layers.Layer):
    def __init__(self, filters, kernel_size, num_move_planes, **kwargs):
        super(PolicyHead, self).__init__(**kwargs)
        self.conv = tf.keras.layers.Conv2D(filters, kernel_size, padding = 'same')
        self.batch_norm = tf.keras.layers.BatchNormalization()
        self.relu = tf.keras.layers.ReLU()
        self.flatten = tf.keras.layers.Flatten()
        self.dense = tf.keras.layers.Dense(num_move_planes * 8 * 8, activation = 'softmax')

    def call(self, inputs):
        x = self.conv(inputs)
        x = self.batch_norm(x)
        x = self.relu(x)
        x = self.flatten(x)
        x = self.dense(x)
        return x

@tf.keras.utils.register_keras_serializable()
class ValueHead(tf.keras.layers.Layer):
    def __init__(self, filters, kernel_size, **kwargs):
        super(ValueHead, self).__init__(**kwargs)
        self.conv = tf.keras.layers.Conv2D(filters, kernel_size, padding = 'same')
        self.batch_norm = tf.keras.layers.BatchNormalization()
        self.relu = tf.keras.layers.ReLU()
        self.flatten = tf.keras.layers.Flatten()
        self.dense1 = tf.keras.layers.Dense(64, activation = 'relu')
        self.dense2 = tf.keras.layers.Dense(1, activation = 'tanh')

    def call(self, inputs):
        x = self.conv(inputs)
        x = self.batch_norm(x)
        x = self.relu(x)
        x = self.flatten(x)
        x = self.dense1(x)
        x = self.dense2(x)
        return x

NODE/GAME STATE CLASS

In [7]:
class node:
    def __init__(self, fen, P, W = 0, N = 0):
        """
        W = total value score for the action
        N = Number of times action taken
        P = Prior probability of the action from the SL network
        """
        self.fen = fen
        self.W = W
        self.N = N
        self.P = P
        if N == 0:
            self.Q = 0
        else:
            self.Q = W/N
        self.child = []
        self.actions = []

In [8]:
def game_end(fen):
    b = chess.Board(fen)
    leg_moves = list(b.legal_moves)
    if len(leg_moves) == 0 or b.halfmove_clock == 100:
        return 1
    return 0

In [9]:
def game_res(fen):
    b = chess.Board(fen)
    if b.is_checkmate(): #win for either side
        return 1
    return 0 #draw

In [10]:
def simul(state, policy_model, value_model, side_to_play):
    """
    Function to perform single simulation of MCTS (selection of best move + expansion of leaf node)
    """
    state_turn = chess.Board(state.fen).turn
    #print(state.fen, len(state.child))
    if state.N == 0: #checking for leaf node
        b = chess.Board(state.fen)
        leg_moves = b.legal_moves
        state.actions = list(leg_moves)
        out_policy = policy_model.predict(np.array([parse_fen(state.fen)]), verbose = 0)[0].reshape((8, 8, 73))
        policy_dict = decode_action(out_policy, state.fen)
        for move in leg_moves:
            move_prob = policy_dict[move]
            b.push(move)
            state.child.append(node(b.fen(), move_prob))
            move_val = game_res(b.fen()) if game_end(b.fen()) else value_model.predict(np.array([parse_fen(b.fen())]), verbose = 0)[0][0]
            state.W += (int(state_turn) == side_to_play)*move_val - (1 - (int(state_turn) == side_to_play))*move_val
            b.pop()
    else:
        ucb_scores = []
        for child in state.child:
            #EVAL THE BEST MOVE USING UCB
            U = 2*child.P*np.sqrt(state.N)/(1 + child.N) #choosing c = 2 for the UCB, exploration parameter
            #print(U, child.Q, child.Q + U)
            ucb_scores.append(child.Q + U)
        #print(ucb_scores), np.argmax(ucb_scores))
        if len(ucb_scores):
            max_idx = np.argmax(ucb_scores)
            best_child_W = state.child[max_idx].W
            best_move_val = simul(state.child[max_idx], policy_model, value_model, side_to_play) - best_child_W  
            state.W += (int(state_turn) == side_to_play)*best_move_val - (1 - (int(state_turn) == side_to_play))*best_move_val
    state.N += 1
    state.Q = state.W/state.N
    return state.W

In [11]:
def MCTS(root, policy_model, value_model, side_to_play, num_simul = 10): #root is the current game state/ node
    for i in range(num_simul):
        #print(i)
        simul(root, policy_model, value_model, side_to_play)
    # ucb_scores = []
    # for child in root.child:
    #     U = 1*child.P*np.sqrt(node.N)/(1 + child.N) #choosing c = 1 for the UCB
    #     ucb_scores.append(child.Q + U)
    # best_idx = np.argmax(ucb_scores)
    # prob = [child.N/root.N for child in root.child]
    # samp_idx = random.choices(np.arange(len(root.child)), weights = prob, k = 1)[0]
    best_idx = np.argmax([child.N for child in root.child])
    #return samp_idx
    return best_idx

In [12]:
def update(scrn, board):
    '''
    updates the screen basis the board class
    '''
    for i in range(64):
        piece = board.piece_at(i)
        if piece == None:
            pass
        else:
            scrn.blit(pieces[str(piece)],((i%8)*100,700-(i//8)*100))
    for i in range(7):
        i=i+1
        pygame.draw.line(scrn,WHITE,(0,i*100),(800,i*100))
        pygame.draw.line(scrn,WHITE,(i*100,0),(i*100,800))
    pygame.display.flip()
    
# def main_two_agent(move, BOARD = chess.Board()):
#     '''
#     for agent vs agent game
#     '''
#     #name window
#     pygame.display.set_caption('Self play')
    
#     #variable to be used later
#     status = True
#     scrn.blit(board_img, (0, 0))
#     update(scrn, BOARD)
#     while (status):
#         BOARD.push(move)
#         scrn.blit(board_img, (0, 0))
#         update(scrn, BOARD)
#     # deactivates the pygame library
#         if BOARD.outcome() != None:
#             print(BOARD.outcome())
#             status = False
#             print(BOARD)
#     pygame.quit()
    
# main_one_agent(agent, model, agent_color=False)

LOADING THE MODELS

In [13]:
value_model = tf.keras.models.load_model('value_model_res11.h5')
policy_model = tf.keras.models.load_model('C:/Users/DELL/Downloads/policy_model_res11.h5')

In [14]:
def self_play(root, policy_model, value_model): #root is the starting position
    game_pos = []
    fen_hash = defaultdict(int)
    moves = []
    status = True
    while status and not game_end(root.fen):
    #while not game_end(root.fen):
        #print(root.fen)
        fen_hash[root.fen.split(' ')[0]] += 1
        b = chess.Board(root.fen)
        scrn.blit(board_img, (0, 0))
        update(scrn, b)
        for event in pygame.event.get():
            if event.type == pygame.QUIT:
                status = False
        if fen_hash[root.fen.split(' ')[0]] >= 3:
            break
        side_to_play = int(b.turn)
        best_idx = MCTS(root, policy_model, value_model, side_to_play)
        game_pos.append(root.fen) 
        moves.append(b.san(root.actions[best_idx]))
        print(moves[-1])
        root = root.child[best_idx]
    pygame.quit()
    return moves

RUNNING THE SELF PLAY MODULE

In [None]:
root = node(chess.STARTING_FEN, 1)
#initialise display
X = 800
Y = 800
scrn = pygame.display.set_mode((X, Y))
pygame.init()

#basic colours
WHITE = (255, 255, 255)
GREY = (128, 128, 128)
YELLOW = (204, 204, 0)
BLUE = (50, 255, 255)
BLACK = (0, 0, 0)

#load piece images
pieces = {'p': pygame.image.load('pieces/b_pawn.png'),
          'n': pygame.image.load('pieces/b_knight.png'),
          'b': pygame.image.load('pieces/b_bishop.png'),
          'r': pygame.image.load('pieces/b_rook.png'),
          'q': pygame.image.load('pieces/b_queen.png'),
          'k': pygame.image.load('pieces/b_king.png'),
          'P': pygame.image.load('pieces/w_pawn.png'),
          'N': pygame.image.load('pieces/w_knight.png'),
          'B': pygame.image.load('pieces/w_bishop.png'),
          'R': pygame.image.load('pieces/w_rook.png'),
          'Q': pygame.image.load('pieces/w_queen.png'),
          'K': pygame.image.load('pieces/w_king.png')
          }
#load the chess board
board_img = pygame.image.load('board.png')
#name window
pygame.display.set_caption('Self play')

#variable to be used later
# scrn.blit(board_img, (0, 0))
# update(scrn, chess.Board())
moves = self_play(root, policy_model, value_model)

e4
g6
Nc3
Bg7
d4
a6
f4
b5
Nf3
Bb7
e5
d6
Bd3
Nd7
O-O
Nb6
a4
b4
Na2
a5
Qe2
Nh6
b3
O-O
Bb5
Nf5
g4
Nxd4
Nxd4
dxe5
fxe5
Qxd4+
Be3
Qxe3+
Qxe3
Bxe5
Qxe5
Nd7
Qxe7
Nc5
Qf6
Be4
Qh4


In [None]:
# import cProfile
# root = node(chess.STARTING_FEN, 1)
# cProfile.run('simul(root, policy_model, value_model, 1)')

In [None]:
beep()