In [None]:
from google.colab import drive
drive.mount("/content/drive", force_remount=True)

Mounted at /content/drive


In [None]:
!pip install chess zstandard

Collecting chess
  Downloading chess-1.11.1.tar.gz (156 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/156.5 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m156.5/156.5 kB[0m [31m4.7 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: chess
  Building wheel for chess (setup.py) ... [?25l[?25hdone
  Created wheel for chess: filename=chess-1.11.1-py3-none-any.whl size=148497 sha256=6fc159e574561e368fb2e419e10c0863d362d5b3ab7970fec6c06ebdb80334ae
  Stored in directory: /root/.cache/pip/wheels/f0/3f/76/8783033e8524d407e1bebaf72fdd3f3eba27e0c030e92bbd87
Successfully built chess
Installing collected packages: chess
Successfully installed chess-1.11.1


In [None]:

# Chess position reading and conversion to tensor
import chess
import re
import time
import math


chr_to_num = {"k": 0, "q": 1, "r": 2, "b": 3, "n": 4, "p": 5, "P": 7, "N": 8, "B": 9, "R": 10, "Q": 11, "K": 12}

def square_to_int(sq):
    return (ord(sq[0]) - 97) * 8 + int(sq[1]) - 1

def squareint_to_square(sqint):
    return (sqint // 8, sqint % 8)

def int_to_bin(anint, pad=4):
    return [int(_) for _ in "0" * (pad - len(bin(anint)[2:])) + bin(anint)[2:]]

def fast_board_to_boardmap(board):
    # Slower than piece_map() when there are less pieces on the board, but faster (~2x) in most cases.
    boards = [[0.5 for _ in range(8)] for _ in range(8)]
    for square in board.pieces(chess.PAWN, chess.WHITE):
        idx = squareint_to_square(square)
        boards[idx[0]][idx[1]] = round((chr_to_num["P"]) / 12, 4)
    for square in board.pieces(chess.PAWN, chess.BLACK):
        idx = squareint_to_square(square)
        boards[idx[0]][idx[1]] = round((chr_to_num["p"]) / 12, 4)
    for square in board.pieces(chess.KNIGHT, chess.WHITE):
        idx = squareint_to_square(square)
        boards[idx[0]][idx[1]] = round((chr_to_num["N"]) / 12, 4)
    for square in board.pieces(chess.KNIGHT, chess.BLACK):
        idx = squareint_to_square(square)
        boards[idx[0]][idx[1]] = round((chr_to_num["n"]) / 12, 4)
    for square in board.pieces(chess.BISHOP, chess.WHITE):
        idx = squareint_to_square(square)
        boards[idx[0]][idx[1]] = round((chr_to_num["B"]) / 12, 4)
    for square in board.pieces(chess.BISHOP, chess.BLACK):
        idx = squareint_to_square(square)
        boards[idx[0]][idx[1]] = round((chr_to_num["b"]) / 12, 4)
    for square in board.pieces(chess.ROOK, chess.WHITE):
        idx = squareint_to_square(square)
        boards[idx[0]][idx[1]] = round((chr_to_num["R"]) / 12, 4)
    for square in board.pieces(chess.ROOK, chess.BLACK):
        idx = squareint_to_square(square)
        boards[idx[0]][idx[1]] = round((chr_to_num["r"]) / 12, 4)
    for square in board.pieces(chess.QUEEN, chess.WHITE):
        idx = squareint_to_square(square)
        boards[idx[0]][idx[1]] = round((chr_to_num["Q"]) / 12, 4)
    for square in board.pieces(chess.QUEEN, chess.BLACK):
        idx = squareint_to_square(square)
        boards[idx[0]][idx[1]] = round((chr_to_num["q"]) / 12, 4)
    for square in board.pieces(chess.KING, chess.WHITE):
        idx = squareint_to_square(square)
        boards[idx[0]][idx[1]] = round((chr_to_num["K"]) / 12, 4)
    for square in board.pieces(chess.KING, chess.BLACK):
        idx = squareint_to_square(square)
        boards[idx[0]][idx[1]] = round((chr_to_num["k"]) / 12, 4)
    return [boards]

def fast_board_to_feature(board):
    whosemove = [int(board.turn)]
    enpassqnum = board.ep_square
    can_enpassant = [0]
    if (enpassqnum is not None) and (board.has_legal_en_passant()):
        enpassqnum = (enpassqnum % 8) * 8 + (enpassqnum // 8)
        enpassqnum = int_to_bin(enpassqnum, pad=6)
        can_enpassant = [1]
    else:
        enpassqnum = int_to_bin(0, pad=6)
    castling_rights = [int(board.has_kingside_castling_rights(chess.WHITE)), int(board.has_queenside_castling_rights(chess.WHITE)), int(board.has_kingside_castling_rights(chess.BLACK)), int(board.has_queenside_castling_rights(chess.BLACK))]
    return whosemove + can_enpassant + enpassqnum + castling_rights

def cp_to_win_prob(cp):
    return 0.4 * (1 - math.exp(-cp / 200)) / (1 + math.exp(-cp / 200)) + 0.5

def mate_to_win_prob(mate):
    if mate < 0:
        return min(0.1, 0.1 + (abs(mate) - 21)/200)
    else:
        return max(0.9, 0.9 + (21 - mate)/200)


In [None]:
# Lichess database decompression

import numpy as np
import io
import zstandard as zstd
import time
import json
import pickle
import matplotlib.pyplot as plt

WINNING = 0
ADVANTAGE = 1
DRAW = 2
SKIP = -1

your_filename = "/content/drive/MyDrive/parrot/lichess_db_eval.jsonl.zst"
dctx = zstd.ZstdDecompressor()
l = 0
num_datasets = 0
num_sub_datasets = 0
wlist, alist, dlist = [[], [], []], [[], [], []], [[], [], []]
wi, ai, di = 0, 0, 0
st = time.time()
i = 0
flag = None
with open(your_filename, 'rb') as compressed:
    with dctx.stream_reader(compressed) as reader:
        text_stream = io.TextIOWrapper(reader, encoding='utf-8')
        for line in text_stream:
            line = json.loads(line)
            try:
                cp = line["evals"][0]["pvs"][0]["cp"]
                score = cp_to_win_prob(cp)
                mate = None
                if 0.4 < score < 0.6:
                    # Draw: -1 ~< evaluation ~< 1
                    if di < 150:
                        flag = DRAW
                    else:
                        flag = SKIP
                elif 0.2 < score < 0.8:
                    # Advantage: -4 ~< evaluation ~< 4
                    if ai < 250:
                        flag = ADVANTAGE
                    else:
                        flag = SKIP
                else:
                    # Completely winning position
                    if wi < 112:
                        flag = WINNING
                    else:
                        flag = SKIP
            except:
                mate = line["evals"][0]["pvs"][0]["mate"]
                score = mate_to_win_prob(mate)
                if wi < 112:
                    flag = WINNING
                else:
                    flag = SKIP
            if flag == SKIP:
                l += 2
                continue
            else:
                fen = line["fen"]
                move = line["evals"][0]["pvs"][0]["line"][:4]
                board = chess.Board(fen)
                next_board = board.copy()
                next_board.push(chess.Move.from_uci(move))
                if flag == WINNING:
                    wlist[0].append(fast_board_to_boardmap(board))
                    wlist[1].append(fast_board_to_feature(board))
                    wlist[2].append(score)
                elif flag == ADVANTAGE:
                    alist[0].append(fast_board_to_boardmap(board))
                    alist[1].append(fast_board_to_feature(board))
                    alist[2].append(score)
                elif flag == DRAW:
                    dlist[0].append(fast_board_to_boardmap(board))
                    dlist[1].append(fast_board_to_feature(board))
                    dlist[2].append(score)
                l += 1
                if len(wlist[2]) == 65536:
                    pickle.dump(wlist, open(f"/content/drive/MyDrive/parrot/evaluation_database/winning_data_{wi}.chess", "wb"))
                    print("win list,", wi)
                    wlist = [[], [], []]
                    wi += 1
                if len(alist[2]) == 65536:
                    pickle.dump(alist, open(f"/content/drive/MyDrive/parrot/evaluation_database/adv_data_{ai}.chess", "wb"))
                    print("adv list,", ai)
                    alist = [[], [], []]
                    ai += 1
                if len(dlist[2]) == 65536:
                    pickle.dump(dlist, open(f"/content/drive/MyDrive/parrot/evaluation_database/draw_data_{di}.chess", "wb"))
                    print("draw list,", di)
                    dlist = [[], [], []]
                    di += 1

                if l % 262144 == 262133:
                    print(l, time.time() - st)
                if mate is not None:
                    if mate < 0:
                        next_score = max(0, score - 0.0025)
                    else:
                        next_score = min(1, score + 0.0025)
                else:
                    next_score = score
                if flag == WINNING:
                    wlist[0].append(fast_board_to_boardmap(next_board))
                    wlist[1].append(fast_board_to_feature(next_board))
                    wlist[2].append(next_score)
                elif flag == ADVANTAGE:
                    alist[0].append(fast_board_to_boardmap(next_board))
                    alist[1].append(fast_board_to_feature(next_board))
                    alist[2].append(next_score)
                elif flag == DRAW:
                    dlist[0].append(fast_board_to_boardmap(next_board))
                    dlist[1].append(fast_board_to_feature(next_board))
                    dlist[2].append(next_score)
                l += 1
                if len(wlist[2]) == 65536:
                    pickle.dump(wlist, open(f"/content/drive/MyDrive/parrot/evaluation_database/winning_data_{wi}.chess", "wb"))
                    print("win list,", wi)
                    wlist = [[], [], []]
                    wi += 1
                if len(alist[2]) == 65536:
                    pickle.dump(alist, open(f"/content/drive/MyDrive/parrot/evaluation_database/adv_data_{ai}.chess", "wb"))
                    print("adv list,", ai)
                    alist = [[], [], []]
                    ai += 1
                if len(dlist[2]) == 65536:
                    pickle.dump(dlist, open(f"/content/drive/MyDrive/parrot/evaluation_database/draw_data_{di}.chess", "wb"))
                    print("draw list,", di)
                    dlist = [[], [], []]
                    di += 1
                if l % 262144 == 262133:
                    print(l, time.time() - st)


draw list, 0
win list, 0
draw list, 1
262133 41.32717561721802
draw list, 2
win list, 1
adv list, 0
draw list, 3
524277 88.90053248405457
draw list, 4
win list, 2
draw list, 5
786421 133.9401478767395
win list, 3
draw list, 6
adv list, 1
draw list, 7
win list, 4
1048565 181.11942672729492
draw list, 8
draw list, 9
win list, 5
adv list, 2
1310709 228.16621470451355
draw list, 10
win list, 6
draw list, 11
1572853 272.3724091053009
draw list, 12
adv list, 3
win list, 7
draw list, 13
1834997 318.7781729698181
draw list, 14
win list, 8
draw list, 15
adv list, 4
2097141 366.37301087379456
draw list, 16
win list, 9
draw list, 17
draw list, 18
win list, 10
2359285 413.26839232444763
adv list, 5
draw list, 19
win list, 11
draw list, 20
2621429 456.8270924091339
draw list, 21
adv list, 6
win list, 12
draw list, 22
2883573 502.4304356575012
win list, 13
draw list, 23
adv list, 7
draw list, 24
3145717 546.6686346530914
win list, 14
draw list, 25
draw list, 26
3407861 591.8338649272919
adv list, 8


In [None]:
# Data shuffling

import pickle
import matplotlib.pyplot as plt
import random
import time

def read_data(phase, num):
    data_list = pickle.load(open(f"/content/drive/MyDrive/parrot/evaluation_database/{phase}_data_{num}.chess", "rb"))
    return data_list

l = 0
s = time.time()
P = 2
if P == 0:
    st = "draw"
    num = 600
    rmax = 149
elif P == 1:
    st = "winning"
    num = 450
    rmax = 111
elif P == 2:
    st = "adv"
    num = 1000
    rmax = 249

while l < num:
    r0, r1, r2, r3 = random.sample(range(0, rmax + 1), 4)
    if len(set([r0, r1, r2, r3])) != len([r0, r1, r2, r3]):
        print("A collision has happened!")
        continue

    p0, f0, e0 = read_data(st, r0)
    p1, f1, e1 = read_data(st, r1)
    p2, f2, e2 = read_data(st, r2)
    p3, f3, e3 = read_data(st, r3)
    p = list(p0) + list(p1) + list(p2) + list(p3)
    f = list(f0) + list(f1) + list(f2) + list(f3)
    e = list(e0) + list(e1) + list(e2) + list(e3)
    zipped = list(zip(p, f, e))
    random.shuffle(zipped)
    p, f, e = zip(*zipped)
    p0, f0, e0 = p[:65536], f[:65536], e[:65536]
    p1, f1, e1 = p[65536:131072], f[65536:131072], e[65536:131072]
    p2, f2, e2 = p[131072:196608], f[131072:196608], e[131072:196608]
    p3, f3, e3 = p[196608:262144], f[196608:262144], e[196608:262144]
    pickle.dump([p0, f0, e0], open(f"/content/drive/MyDrive/parrot/evaluation_database/{st}_data_{r0}.chess", "wb"))
    pickle.dump([p1, f1, e1], open(f"/content/drive/MyDrive/parrot/evaluation_database/{st}_data_{r1}.chess", "wb"))
    pickle.dump([p2, f2, e2], open(f"/content/drive/MyDrive/parrot/evaluation_database/{st}_data_{r2}.chess", "wb"))
    pickle.dump([p3, f3, e3], open(f"/content/drive/MyDrive/parrot/evaluation_database/{st}_data_{r3}.chess", "wb"))
    l += 1
    print(l, time.time() - s, r0, r1, r2, r3)

1 28.787046670913696 22 163 10 178
2 52.41506099700928 78 195 20 215
3 73.28612875938416 234 151 149 233
4 95.52377009391785 169 101 205 47
5 117.0937168598175 49 176 146 3
6 138.3934829235077 153 146 23 69
7 158.54436874389648 208 248 58 171
8 177.00164532661438 47 205 24 107
9 198.3479974269867 200 246 66 94
10 218.92748928070068 131 91 95 180
11 237.42337369918823 234 71 205 95
12 259.4865918159485 1 33 191 112
13 279.36227893829346 148 222 182 164
14 300.44947838783264 217 223 13 137
15 319.6152687072754 58 70 65 207
16 339.8308570384979 45 227 155 63
17 362.7792887687683 185 38 139 212
18 383.85117077827454 163 208 0 105
19 405.1830303668976 42 51 191 36
20 433.5633752346039 54 92 118 183
21 454.2021653652191 113 156 6 64
22 474.3661222457886 212 183 227 49
23 496.40686202049255 108 249 186 38
24 515.3590886592865 190 246 149 218
25 537.083794593811 182 140 215 50
26 558.223292350769 103 44 29 117
27 581.1365685462952 234 172 55 120
28 602.1902742385864 212 165 221 11
29 625.33138

In [None]:
# I had to convert absolute evaluations to relative evaluations (white pov), to balance the data. Should have done this earlier!

import os
import sys
import pickle
import numpy as np

l = os.listdir("/content/drive/MyDrive/parrot/puzzle_database")
os.chdir("/content/drive/MyDrive/parrot/puzzle_database")

TRANSFORM = {0.5: 0, 1.0: 1.0, 0.9167: 0.8333, 0.8333: 0.6667, 0.75: 0.5, 0.6667: 0.3333, 0.5833: 0.1667, 0.0: -1.0, 0.0833: -0.8333, 0.1667: -0.6667, 0.25: -0.5, 0.3333: -0.3333, 0.4167: -0.1667}

for filename in l:
    try:
        data_list = pickle.load(open(filename, "rb"))
    except:
        print(filename, "caused an error!")
        continue
    if len(data_list) == 2:
        print(filename, "has already been updated.")
        continue
    board_list, score_list = [], []
    for entry in range(4096):
        turn = data_list[1][entry][0]
        sgn = 1 if turn == 1 else -1
        board = data_list[0][entry]
        for row in range(8):
            for col in range(8):
                board[0][row][col] = TRANSFORM[board[0][row][col]] * sgn
        if turn == 1:
            score_list.append(data_list[2][entry])
        elif turn == 0:
            board = np.rot90(board, 2).tolist()
            score_list.append(1 - data_list[2][entry])
        board_list.append(board)
    pickle.dump([board_list, score_list], open(filename, "wb"))
    print(filename)



data_4_mid_21.chess
data_2_mid_11.chess
data_0_mid_81.chess
data_3_mid_56.chess
data_0_mid_28.chess
data_1_mid_47.chess
data_0_mid_46.chess
data_4_mid_49.chess
data_2_mid_16.chess
data_3_mid_9.chess
data_4_mid_44.chess
data_0_mid_12.chess
data_4_mid_42.chess
data_1_mid_6.chess
data_0_mid_52.chess
data_2_mid_69.chess
data_0_mid_56.chess
data_0_mid_68.chess
data_0_mid_64.chess
data_1_mid_9.chess
data_0_mid_54.chess
data_3_mid_37.chess
data_2_mid_55.chess
data_1_mid_25.chess
data_0_mid_42.chess
data_2_mid_46.chess
data_3_mid_30.chess
data_4_mid_39.chess
data_0_mid_19.chess
data_2_mid_23.chess
data_0_mid_36.chess
data_4_mid_12.chess
data_1_mid_35.chess
data_2_mid_63.chess
data_0_mid_67.chess
data_4_mid_5.chess
data_3_mid_67.chess
data_0_mid_79.chess
data_4_mid_70.chess
data_1_mid_23.chess
data_4_mid_67.chess
data_1_mid_63.chess
data_2_mid_66.chess
data_3_mid_59.chess
data_0_mid_53.chess
data_3_mid_44.chess
data_4_mid_33.chess
data_4_mid_51.chess
data_2_mid_70.chess
data_4_mid_50.chess
data