IMPORTING LIBRARIES

In [1]:
import chess
import random
import numpy as np
from sklearn.model_selection import train_test_split
import tensorflow as tf
from IPython.display import Audio

UTILITY FUNCTIONS

In [2]:
def beep():
    beep_sound = Audio(filename = "beep-01.wav", autoplay = True)
    return beep_sound

ENCODING THE BOARD

In [3]:
def parse_fen(fen):
    file_to_num = {'a': 1, 'b': 2, 'c': 3, 'd':4, 'e':5, 'f':6, 'g':7, 'h':8}
    enc = np.zeros([8,8,18]).astype(int)
    fen_elem = fen.split(' ')
    fen_pos = fen_elem[0]
    if fen_elem[1] == 'w':
        player = 1
        enc_dict = {"R":0, "N":1, "B":2, "Q":3, "K":4, "P":5, "r":6, "n":7, "b":8, "q":9, "k":10, "p":11}
    else:
        player = 0
        enc_dict = {"r":0, "n":1, "b":2, "q":3, "k":4, "p":5, "R":6, "N":7, "B":8, "Q":9, "K":10, "P":11}
    enc[:,:,12] = player
    castle = fen_elem[2]
    if player:
        if 'Q' not in castle:
            enc[:,:,13] = 1
        if 'K' not in castle:
            enc[:,:,14] = 1
        if 'q' not in castle:
            enc[:,:,15] = 1
        if 'k' not in castle:
            enc[:,:,16] = 1
    else:
        if 'Q' not in castle:
            enc[:,:,15] = 1
        if 'K' not in castle:
            enc[:,:,16] = 1
        if 'q' not in castle:
            enc[:,:,13] = 1
        if 'k' not in castle:
            enc[:,:,14] = 1
    enc[:,:,17] = int(fen_elem[-1])
    ranks = fen_pos.split('/')
    for i, rank in enumerate(ranks):
        j = 0
        k = 0
        while k < len(rank):
            if rank[k].isdigit():
                j += int(rank[k])
                k += 1
                continue
            enc[(7 - i)*player + (1 - player)*i, j*player + (1 - player)*(7 - j), enc_dict[rank[k]]] = 1
            j += 1
            k += 1
    return enc

ENCODING THE ACTION

In [4]:
def parse_actions(move, fen):
    if fen.split(' ')[1] == 'w':
        player = 1
    else:
        player = 0
    enc = np.zeros([8,8,73]).astype(int)
    file_to_num = {'a': 1, 'b': 2, 'c': 3, 'd':4, 'e':5, 'f':6, 'g':7, 'h':8}
    init_file = file_to_num[move[0]]
    init_rank = int(move[1])
    final_file = file_to_num[move[2]]
    final_rank = int(move[3])
    under_promo = move[-1] if move[-1] in ['n', 'N', 'b', 'B', 'r', 'R'] else None
    file_diff = (init_file - final_file)*player + (final_file - init_file)*(1 - player)
    rank_diff = (init_rank - final_rank)*player + (final_rank - init_rank)*(1 - player)
    if under_promo is not None: #underpromotions
        if file_diff == 0: #no capture
            if under_promo in ['r', 'R']:
                idx = 64
            elif under_promo in ['b', 'B']:
                idx = 65
            else:
                idx = 66
        elif file_diff > 0: # left capture
            if under_promo in ['r', 'R']:
                idx = 67
            elif under_promo in ['b', 'B']:
                idx = 68
            else:
                idx = 69
        else: # right capture
            if under_promo in ['r', 'R']:
                idx = 70
            elif under_promo in ['b', 'B']:
                idx = 71
            else:
                idx = 72
    elif file_diff == 0: #NS direction
        if rank_diff < 0: #upward
            idx = -rank_diff - 1
        else: #downward
            idx = 6 + rank_diff
    elif rank_diff == 0: #EW direction
        if file_diff < 0: #right
            idx = 13 - file_diff
        else: #left
            idx = 20 + file_diff
    elif abs(file_diff) == abs(rank_diff): #diagonal moves
        if rank_diff < 0 and file_diff < 0: #NE
            idx = 27 - rank_diff
        elif rank_diff > 0 and file_diff > 0: #SW
            idx = 34 + rank_diff
        elif rank_diff < 0 and file_diff > 0: #NW
            idx = 41 + file_diff
        else: #SE
            idx = 48 + rank_diff
    elif file_diff == 1 and rank_diff == -2:
        idx = 56
    elif file_diff == -1 and rank_diff == -2:
        idx = 57
    elif file_diff == 2 and rank_diff == 1:
        idx = 58
    elif file_diff == 2 and rank_diff == -1:
        idx = 59
    elif file_diff == 1 and rank_diff == 2:
        idx = 60
    elif file_diff == -1 and rank_diff == 2:
        idx = 61
    elif file_diff == -2 and rank_diff == 1:
        idx = 62
    elif file_diff == -2 and rank_diff == -1:
        idx = 63
    enc[(init_rank - 1)*player + (8 - init_rank)*(1 - player), (init_file - 1)*player + (8 - init_file)*(1 - player), idx] = 1
    return enc.flatten(), (init_file - 1)*player + (8 - init_file)*(1 - player), (init_rank - 1)*player + (8 - init_rank)*(1 - player), idx

DECODING THE TENSOR OUTPUT

In [5]:
def decode_action(out_policy, fen):
    board = chess.Board(fen)
    legal_moves = list(board.legal_moves)
    legal_moves_str = np.array([str(move) for move in legal_moves])
    prob = np.zeros(len(legal_moves))
    for i, move in enumerate(legal_moves_str):
        _, x, y, z = parse_actions(move, fen) #x = file, y = rank
        prob[i] = out_policy[y, x, z]
    prob /= np.sum(prob)
    idx = prob.argsort()[::-1]
    prob = prob[idx]
    legal_moves_str = legal_moves_str[idx]
    policy_dict = {move: p for move, p in zip(legal_moves_str, prob)}
    best_move = legal_moves_str[0]
    return policy_dict, best_move

COLLECTING DATA FROM GRANDMASTER GAMES FOR TRAINING

In [6]:
# fen = []
# action = []
# all_fen = []
# pre_sum = []
# game_nos = []
with open('move.txt', 'r') as fm:
    action = [line.strip() for line in fm.readlines()]
with open('all_fen.txt', 'r') as faf:
    all_fen = [line.strip() for line in faf.readlines()]
with open('fen.txt', 'r') as ff:
    fen = [line.strip() for line in ff.readlines()]
with open('value_fen.txt', 'r') as vf:
    value_fen = [line.strip() for line in vf.readlines()]
with open('values.txt', 'r') as v:
    values = [int(line.strip()) for line in v.readlines()]
with open('game_nos.txt', 'r') as fg:
    game_nos = [int(line.strip()) for line in fg.readlines()]
with open('pre_sum.txt', 'r') as ps:
    pre_sum = [int(line.strip()) for line in ps.readlines()]
# with open('puzzle_fen.txt', 'r') as fpf:
#     fen += [line.strip() for line in fpf.readlines()]
# with open('puzzle_action.txt', 'r') as fpa:
#     action += [line.strip() for line in fpa.readlines()]

TRAIN-TEST SPLIT

In [7]:
ind_train = np.arange(len(fen))
value_ind_train = np.arange(len(value_fen))
random.seed(0)
random.shuffle(ind_train)
random.seed(42)
random.shuffle(value_ind_train)

DATA GENERATOR WITH HISTORY INCLUDED

In [8]:
def policy_data_generator(X, Y, indices, game_nos, pre_sum, all_fen, batch_size, is_validation = False, history = 3):
    # Here X and Y are the original data (before splitting)
    # total stack height = history + 1
    while True:
        for start in range(0, len(X), batch_size):
            batch_indices = indices[start: start + batch_size]
            if history > 0:
                batch_X = []
                for i in batch_indices:
                    board = chess.Board(X[i])
                    n_moves = board.fullmove_number
                    hash_id = pre_sum[game_nos[i]] + n_moves - 1
                    stack_height = 18*(history + 1)
                    stack_input = np.empty((8, 8, stack_height))
                    for j in range(history, -1, -1):
                        stack_input[:, :, 17*(history - j):17*(history + 1 - j)] = parse_fen(all_fen[hash_id - j])
                    batch_X.append(stack_input)
            else:
                batch_X = [parse_fen(X[i]) for i in batch_indices]               
            batch_Y = [parse_actions(Y[i], X[i])[0] for i in batch_indices]
            yield np.array(batch_X), np.array(batch_Y)

In [9]:
def value_data_generator(X, Y, indices, batch_size):
    # Here X and Y are the original data (before splitting)
    # total stack height = history + 1
    while True:
        for start in range(0, len(X), batch_size):
            batch_indices = indices[start: start + batch_size]
            batch_X = [parse_fen(X[i]) for i in batch_indices] 
            batch_Y = [Y[i] for i in batch_indices]
            yield np.array(batch_X), np.array(batch_Y).astype(float)

POLICY NETWORK

In [10]:
@tf.keras.utils.register_keras_serializable()
class ConvBlock(tf.keras.layers.Layer):
    def __init__(self, filters, kernel_size, **kwargs):
        super(ConvBlock, self).__init__(**kwargs)
        self.conv = tf.keras.layers.Conv2D(filters, kernel_size, padding = 'same')
        self.batch_norm = tf.keras.layers.BatchNormalization()
        self.relu = tf.keras.layers.ReLU()

    def call(self, inputs):
        x = self.conv(inputs)
        x = self.batch_norm(x)
        x = self.relu(x)
        return x

@tf.keras.utils.register_keras_serializable()
class ResBlock(tf.keras.layers.Layer):
    def __init__(self, filters, kernel_size, **kwargs):
        super(ResBlock, self).__init__(**kwargs)
        self.conv1 = tf.keras.layers.Conv2D(filters, kernel_size, padding = 'same')
        self.batch_norm1 = tf.keras.layers.BatchNormalization()
        self.relu = tf.keras.layers.ReLU()
        self.conv2 = tf.keras.layers.Conv2D(filters, kernel_size, padding = 'same')
        self.batch_norm2 = tf.keras.layers.BatchNormalization()

    def call(self, inputs):
        out = self.conv1(inputs)
        out = self.batch_norm1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.batch_norm2(out)
        return self.relu(out + inputs)
    
@tf.keras.utils.register_keras_serializable()
class PolicyHead(tf.keras.layers.Layer):
    def __init__(self, filters, kernel_size, num_move_planes, **kwargs):
        super(PolicyHead, self).__init__(**kwargs)
        self.conv = tf.keras.layers.Conv2D(filters, kernel_size, padding = 'same')
        self.batch_norm = tf.keras.layers.BatchNormalization()
        self.relu = tf.keras.layers.ReLU()
        self.flatten = tf.keras.layers.Flatten()
        self.dense = tf.keras.layers.Dense(num_move_planes * 8 * 8, activation = 'softmax')

    def call(self, inputs):
        x = self.conv(inputs)
        x = self.batch_norm(x)
        x = self.relu(x)
        x = self.flatten(x)
        x = self.dense(x)
        return x

@tf.keras.utils.register_keras_serializable()
class ValueHead(tf.keras.layers.Layer):
    def __init__(self, filters, kernel_size, **kwargs):
        super(ValueHead, self).__init__(**kwargs)
        self.conv = tf.keras.layers.Conv2D(filters, kernel_size, padding = 'same')
        self.batch_norm = tf.keras.layers.BatchNormalization()
        self.relu = tf.keras.layers.ReLU()
        self.flatten = tf.keras.layers.Flatten()
        self.dense1 = tf.keras.layers.Dense(64, activation = 'relu')
        self.dense2 = tf.keras.layers.Dense(1, activation = 'tanh')

    def call(self, inputs):
        x = self.conv(inputs)
        x = self.batch_norm(x)
        x = self.relu(x)
        x = self.flatten(x)
        x = self.dense1(x)
        x = self.dense2(x)
        return x

In [11]:
# Initialize the policy model
history = 0
policy_model = tf.keras.Sequential([
    ConvBlock(filters = 256, kernel_size = (3, 3), input_shape = (8, 8, 18*(history + 1))),
    ResBlock(filters = 256, kernel_size = (3, 3)),
    ResBlock(filters = 256, kernel_size = (3, 3)),
    ResBlock(filters = 256, kernel_size = (3, 3)),
    ResBlock(filters = 256, kernel_size = (3, 3)),
    ResBlock(filters = 256, kernel_size = (3, 3)),
    ResBlock(filters = 256, kernel_size = (3, 3)),
    ResBlock(filters = 256, kernel_size = (3, 3)),
    ResBlock(filters = 256, kernel_size = (3, 3)),
    ResBlock(filters = 256, kernel_size = (3, 3)),
    ResBlock(filters = 256, kernel_size = (3, 3)),
    PolicyHead(filters = 128, kernel_size = (1, 1), num_move_planes = 73)
])

# Compile the model
policy_model.compile(optimizer = 'adam', loss = 'categorical_crossentropy', metrics = ['accuracy'])

In [12]:
# Initialize the value model
history = 0
value_model = tf.keras.Sequential([
    ConvBlock(filters = 256, kernel_size = (3, 3), input_shape = (8, 8, 18*(history + 1))),
    ResBlock(filters = 256, kernel_size = (3, 3)),
    ResBlock(filters = 256, kernel_size = (3, 3)),
    ResBlock(filters = 256, kernel_size = (3, 3)),
    ResBlock(filters = 256, kernel_size = (3, 3)),
    ResBlock(filters = 256, kernel_size = (3, 3)),
    ResBlock(filters = 256, kernel_size = (3, 3)),
    ResBlock(filters = 256, kernel_size = (3, 3)),
    ResBlock(filters = 256, kernel_size = (3, 3)),
    ResBlock(filters = 256, kernel_size = (3, 3)),
    ResBlock(filters = 256, kernel_size = (3, 3)),
    ValueHead(filters = 1, kernel_size = (1, 1))
])

# Compile the model
value_model.compile(optimizer = 'adam', loss = 'mse')

CALLBACK FUNCTION FOR SAVING POLICY MODEL

In [13]:
policy_checkpoint_filepath = 'policy_model_res11.h5'
policy_model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath = policy_checkpoint_filepath,
    save_weights_only = False,  # Save the entire model, including architecture and optimizer state
    save_freq = 10000,          # Save every 10000 batches
    verbose = 1
)

CALLBACK FUNCTION FOR SAVING VALUE MODEL

In [14]:
value_checkpoint_filepath = 'value_model_res12.h5'
value_model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath = value_checkpoint_filepath,
    save_weights_only = False,  # Save the entire model, including architecture and optimizer state
    save_freq = 1000,          # Save every 10000 batches
    verbose = 1
)

In [15]:
len(values)

193131

TRAINING THE VALUE NETWORK

In [16]:
batch_size = 32
steps_per_epoch = len(value_ind_train) // batch_size
value_train_data_generator = value_data_generator(value_fen, values, value_ind_train, batch_size)
value_model.fit(value_train_data_generator,
          epochs = 1,
          steps_per_epoch = steps_per_epoch,
          callbacks = [value_model_checkpoint_callback])
value_train_data_generator.close()

 283/6035 [>.............................] - ETA: 3:05:13 - loss: 0.6828

KeyboardInterrupt: 

In [None]:
value_model.save('value_model_res12.h5')

In [None]:
import os
os.system("shutdown /s /t 1") 

TRAINING THE POLICY NETWORK

In [None]:
batch_size = 32
steps_per_epoch = len(ind_train) // batch_size
train_data_generator = policy_data_generator(fen, action, ind_train, game_nos, pre_sum, all_fen, batch_size, history = 0)
policy_model.fit(train_data_generator,
          epochs = 1,
          steps_per_epoch = steps_per_epoch,
          callbacks = [policy_model_checkpoint_callback])
train_data_generator.close()

In [None]:
beep()

In [None]:
def play_best_move_hist(model, fen, history = 3):
    enc_fen = parse_fen(fen)
    stack_fen = np.tile(enc_fen, (1, 1, history + 1))
    out_policy = model.predict(np.array([stack_fen]))[0].reshape((8, 8, 73))
    return out_policy, decode_action(out_policy, fen)

In [None]:
samp_fen = 'rn2kb1r/pp3ppp/4p1qn/1p4B1/2B5/3P2QP/PPP2PP1/R3K2R w - - 1 0'
b = chess.Board(samp_fen)
b

In [None]:
out, p = play_best_move_hist(model, samp_fen, history = 0)
p[0]