In [None]:
!pip install chess zstandard
import numpy as np
import io
import zstandard as zstd
import time
import json
import pickle
import matplotlib.pyplot as plt

from google.colab import drive
drive.mount("/content/drive", force_remount=True)

Collecting chess
  Downloading chess-1.11.2.tar.gz (6.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.1/6.1 MB[0m [31m40.8 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: chess
  Building wheel for chess (setup.py) ... [?25l[?25hdone
  Created wheel for chess: filename=chess-1.11.2-py3-none-any.whl size=147775 sha256=a9940fa1ac564debc21d0ec0dcc45d4cc0cad89587c7f9dfbf0cb077738e3ae2
  Stored in directory: /root/.cache/pip/wheels/fb/5d/5c/59a62d8a695285e59ec9c1f66add6f8a9ac4152499a2be0113
Successfully built chess
Installing collected packages: chess
Successfully installed chess-1.11.2


In [None]:
# Chess position reading and conversion to tensor
import chess
import re
import time
import math


chr_to_num = {"k": 0, "q": 1, "r": 2, "b": 3, "n": 4, "p": 5, "P": 7, "N": 8, "B": 9, "R": 10, "Q": 11, "K": 12}

def square_to_int(sq):
    return (ord(sq[0]) - 97) * 8 + int(sq[1]) - 1

def squareint_to_square(sqint):
    return (sqint // 8, sqint % 8)

def int_to_bin(anint, pad=4):
    return [int(_) for _ in "0" * (pad - len(bin(anint)[2:])) + bin(anint)[2:]]

def fast_board_to_boardmap(board):
    # Slower than piece_map() when there are less pieces on the board, but faster (~2x) in most cases.
    boards = [[0.5 for _ in range(8)] for _ in range(8)]
    for square in board.pieces(chess.PAWN, chess.WHITE):
        idx = squareint_to_square(square)
        boards[idx[0]][idx[1]] = round((chr_to_num["P"]) / 12, 4)
    for square in board.pieces(chess.PAWN, chess.BLACK):
        idx = squareint_to_square(square)
        boards[idx[0]][idx[1]] = round((chr_to_num["p"]) / 12, 4)
    for square in board.pieces(chess.KNIGHT, chess.WHITE):
        idx = squareint_to_square(square)
        boards[idx[0]][idx[1]] = round((chr_to_num["N"]) / 12, 4)
    for square in board.pieces(chess.KNIGHT, chess.BLACK):
        idx = squareint_to_square(square)
        boards[idx[0]][idx[1]] = round((chr_to_num["n"]) / 12, 4)
    for square in board.pieces(chess.BISHOP, chess.WHITE):
        idx = squareint_to_square(square)
        boards[idx[0]][idx[1]] = round((chr_to_num["B"]) / 12, 4)
    for square in board.pieces(chess.BISHOP, chess.BLACK):
        idx = squareint_to_square(square)
        boards[idx[0]][idx[1]] = round((chr_to_num["b"]) / 12, 4)
    for square in board.pieces(chess.ROOK, chess.WHITE):
        idx = squareint_to_square(square)
        boards[idx[0]][idx[1]] = round((chr_to_num["R"]) / 12, 4)
    for square in board.pieces(chess.ROOK, chess.BLACK):
        idx = squareint_to_square(square)
        boards[idx[0]][idx[1]] = round((chr_to_num["r"]) / 12, 4)
    for square in board.pieces(chess.QUEEN, chess.WHITE):
        idx = squareint_to_square(square)
        boards[idx[0]][idx[1]] = round((chr_to_num["Q"]) / 12, 4)
    for square in board.pieces(chess.QUEEN, chess.BLACK):
        idx = squareint_to_square(square)
        boards[idx[0]][idx[1]] = round((chr_to_num["q"]) / 12, 4)
    for square in board.pieces(chess.KING, chess.WHITE):
        idx = squareint_to_square(square)
        boards[idx[0]][idx[1]] = round((chr_to_num["K"]) / 12, 4)
    for square in board.pieces(chess.KING, chess.BLACK):
        idx = squareint_to_square(square)
        boards[idx[0]][idx[1]] = round((chr_to_num["k"]) / 12, 4)
    return [boards]

def fast_board_to_feature(board):
    whosemove = [int(board.turn)]
    enpassqnum = board.ep_square
    can_enpassant = [0]
    if (enpassqnum is not None) and (board.has_legal_en_passant()):
        enpassqnum = (enpassqnum % 8) * 8 + (enpassqnum // 8)
        enpassqnum = int_to_bin(enpassqnum, pad=6)
        can_enpassant = [1]
    else:
        enpassqnum = int_to_bin(0, pad=6)
    castling_rights = [int(board.has_kingside_castling_rights(chess.WHITE)), int(board.has_queenside_castling_rights(chess.WHITE)), int(board.has_kingside_castling_rights(chess.BLACK)), int(board.has_queenside_castling_rights(chess.BLACK))]
    return whosemove + can_enpassant + enpassqnum + castling_rights

def cp_to_win_prob(cp):
    return 0.4 * (1 - math.exp(-cp / 200)) / (1 + math.exp(-cp / 200)) + 0.5

def mate_to_win_prob(mate):
    if mate < 0:
        return min(0.1, 0.1 + (abs(mate) - 21)/200)
    else:
        return max(0.9, 0.9 + (21 - mate)/200)


In [None]:
!wget https://github.com/official-stockfish/Stockfish/releases/latest/download/stockfish-ubuntu-x86-64-sse41-popcnt.tar
!tar -xvf stockfish-ubuntu-x86-64-sse41-popcnt.tar

--2025-04-15 14:25:16--  https://github.com/official-stockfish/Stockfish/releases/latest/download/stockfish-ubuntu-x86-64-sse41-popcnt.tar
Resolving github.com (github.com)... 140.82.114.4
Connecting to github.com (github.com)|140.82.114.4|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://github.com/official-stockfish/Stockfish/releases/download/sf_17.1/stockfish-ubuntu-x86-64-sse41-popcnt.tar [following]
--2025-04-15 14:25:16--  https://github.com/official-stockfish/Stockfish/releases/download/sf_17.1/stockfish-ubuntu-x86-64-sse41-popcnt.tar
Reusing existing connection to github.com:443.
HTTP request sent, awaiting response... 302 Found
Location: https://objects.githubusercontent.com/github-production-release-asset-2e65be/20976138/1f1e24b8-d3b0-49bf-92fd-ebd01964592a?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=releaseassetproduction%2F20250415%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20250415T142516Z&X-Amz-Expires=300&X-Amz-Signature=671f5d

In [None]:
import chess.engine

In [None]:
engine = chess.engine.SimpleEngine.popen_uci("/content/stockfish/stockfish-ubuntu-x86-64-sse41-popcnt")

In [None]:
import time
import os

l = 0
PERSON = 4
EASY = 1
MID = 2
HARD = 3
hard_list = [[], [], []]
mid_list = [[], [], []]
easy_list = [[], [], []]
mi = 35
hi = 14
ei = 19
s = time.time()
for i in range(186 * PERSON + 30, 186 * (PERSON + 1)):
    with open(f"/content/drive/MyDrive/parrot/puzzle_database/puzzles_{i}.csv", "r") as file:
        data = file.readlines()
    print("Opened dataset ", i)
    for line in data:
        l += 1
        if l % 250 == 0:
            print(l)
        lst = line.split(",")
        fen = lst[1]
        mvlst = lst[2].split(" ")
        if int(lst[3]) < 1200:
            flag = EASY
        elif int(lst[3]) > 2100:
            flag = HARD
        else:
            flag = MID
        board = chess.Board(fen)
        # Analyse puzzle position with stockfish for 0.1 seconds
        info = engine.analyse(board, chess.engine.Limit(time=0.1))["score"]
        score = info.relative
        color = 1 if info.turn else -1
        try:
            value = cp_to_win_prob(score.cp * color)
        except:
            value = mate_to_win_prob(score.moves * color)
        if flag == EASY:
            easy_list[0].append(fast_board_to_boardmap(board))
            easy_list[1].append(fast_board_to_feature(board))
            easy_list[2].append(value)
        elif flag == MID:
            mid_list[0].append(fast_board_to_boardmap(board))
            mid_list[1].append(fast_board_to_feature(board))
            mid_list[2].append(value)
        elif flag == HARD:
            hard_list[0].append(fast_board_to_boardmap(board))
            hard_list[1].append(fast_board_to_feature(board))
            hard_list[2].append(value)
        if len(easy_list[2]) == 4096:
            pickle.dump(easy_list, open(f"/content/drive/MyDrive/parrot/puzzle_database/data_{PERSON}_easy_{ei}.chess", "wb"))
            print("easy", ei, time.time() - s)
            easy_list = [[], [], []]
            ei += 1
        elif len(mid_list[2]) == 4096:
            pickle.dump(mid_list, open(f"/content/drive/MyDrive/parrot/puzzle_database/data_{PERSON}_mid_{mi}.chess", "wb"))
            print("mid", mi, time.time() - s)
            mid_list = [[], [], []]
            mi += 1
        elif len(hard_list[2]) == 4096:
            pickle.dump(hard_list, open(f"/content/drive/MyDrive/parrot/puzzle_database/data_{PERSON}_hard_{hi}.chess", "wb"))
            print("hard", hi, time.time() - s)
            hard_list = [[], [], []]
            hi += 1

        for mvstr in mvlst:
            board.push(chess.Move.from_uci(mvstr))
            info = engine.analyse(board, chess.engine.Limit(time=0.08))["score"]
            score = info.relative
            color = 1 if info.turn else -1
            try:
                value = cp_to_win_prob(score.cp * color)
            except:
                if score.moves == 0:
                    value = int(not info.turn)
                else:
                    value = mate_to_win_prob(score.moves * color)
            if flag == EASY:
                easy_list[0].append(fast_board_to_boardmap(board))
                easy_list[1].append(fast_board_to_feature(board))
                easy_list[2].append(value)
            elif flag == MID:
                mid_list[0].append(fast_board_to_boardmap(board))
                mid_list[1].append(fast_board_to_feature(board))
                mid_list[2].append(value)
            elif flag == HARD:
                hard_list[0].append(fast_board_to_boardmap(board))
                hard_list[1].append(fast_board_to_feature(board))
                hard_list[2].append(value)
            if len(easy_list[2]) == 4096:
                pickle.dump(easy_list, open(f"/content/drive/MyDrive/parrot/puzzle_database/data_{PERSON}_easy_{ei}.chess", "wb"))
                print("easy", ei, time.time() - s)
                easy_list = [[], [], []]
                ei += 1
            elif len(mid_list[2]) == 4096:
                pickle.dump(mid_list, open(f"/content/drive/MyDrive/parrot/puzzle_database/data_{PERSON}_mid_{mi}.chess", "wb"))
                print("mid", mi, time.time() - s)
                mid_list = [[], [], []]
                mi += 1
            elif len(hard_list[2]) == 4096:
                pickle.dump(hard_list, open(f"/content/drive/MyDrive/parrot/puzzle_database/data_{PERSON}_hard_{hi}.chess", "wb"))
                print("hard", hi, time.time() - s)
                hard_list = [[], [], []]
                hi += 1

Opened dataset  774
250
500
750
1000
1250
mid 35 607.9909334182739
1500
1750
2000
2250
2500
easy 19 1142.1784763336182
2750
mid 36 1190.7453954219818
3000
3250
hard 14 1472.1567256450653
3500
3750
4000
4250
mid 37 1824.7540807724
4500
4750
5000
Opened dataset  775
easy 20 2199.6508333683014
5250
5500
mid 38 2434.2299320697784
5750
6000
6250
6500
6750
hard 15 2951.870923280716
7000
mid 39 3058.8736221790314
7250
7500
easy 21 3276.299509048462
7750
8000
8250
8500
mid 40 3666.2021288871765
8750
9000
9250
9500
9750
10000
Opened dataset  776
mid 41 4263.6901977062225
easy 22 4341.537752389908
10250
hard 16 4406.585518121719
10500
10750
11000
11250
mid 42 4858.741576433182
11500
11750
12000
12250
12500
12750
easy 23 5425.012931585312
mid 43 5482.020093679428
13000
13250
13500
hard 17 5810.802311658859
13750
14000
14250
mid 44 6089.324741840363
14500
14750
15000
Opened dataset  777
15250
easy 24 6511.616156578064
15500
15750
mid 45 6699.945583343506
16000
16250
16500
16750
17000
hard 18 7261.

In [None]:
import matplotlib.pyplot as plt
import scipy.stats

def read_puzzle_data(person, difficulty, num):
  data_list = pickle.load(open(f"/content/drive/MyDrive/parrot/puzzle_database/data_{person}_{difficulty}_{num}.chess", "rb"))
  return data_list

In [None]:
# Shuffling data

import random
import time

P = 0
EASY_DISTRIBUTION = [51, 37, 42, 41, 42]
MID_DISTRIBUTION = [92, 66, 76, 74, 76]
HARD_DISTRIBUTION = [36, 27, 30, 29, 30]

DISTROS = [[51, 37, 42, 41, 42], [92, 66, 76, 74, 76], [36, 27, 30, 29, 30]]

if P in [0, 2]:
    num = 1000
else:
    num = 1600

if P == 0:
    st = "easy"
elif P == 1:
    st = "mid"
elif P == 2:
    st = "hard"

start = time.time()
for i in range(num):
    ri0 = random.randint(0, DISTROS[P][0] - 1)
    ri1 = random.randint(0, DISTROS[P][1] - 1)
    ri2 = random.randint(0, DISTROS[P][2] - 1)
    ri3 = random.randint(0, DISTROS[P][3] - 1)
    ri4 = random.randint(0, DISTROS[P][4] - 1)

    bl0, fl0, el0 = read_puzzle_data(0, st, ri0)
    bl1, fl1, el1 = read_puzzle_data(1, st, ri1)
    bl2, fl2, el2 = read_puzzle_data(2, st, ri2)
    bl3, fl3, el3 = read_puzzle_data(3, st, ri3)
    bl4, fl4, el4 = read_puzzle_data(4, st, ri4)

    bl = list(bl0) + list(bl1) + list(bl2) + list(bl3) + list(bl4)
    fl = list(fl0) + list(fl1) + list(fl2) + list(fl3) + list(fl4)
    el = list(el0) + list(el1) + list(el2) + list(el3) + list(el4)

    zipped = list(zip(bl, fl, el))
    random.shuffle(zipped)
    bl, fl, el = zip(*zipped)

    pickle.dump([bl[:4096], fl[:4096], el[:4096]], open(f"/content/drive/MyDrive/parrot/puzzle_database/data_{0}_{st}_{ri0}.chess", "wb"))
    pickle.dump([bl[4096:8192], fl[4096:8192], el[4096:8192]], open(f"/content/drive/MyDrive/parrot/puzzle_database/data_{1}_{st}_{ri1}.chess", "wb"))
    pickle.dump([bl[8192:12288], fl[8192:12288], el[8192:12288]], open(f"/content/drive/MyDrive/parrot/puzzle_database/data_{2}_{st}_{ri2}.chess", "wb"))
    pickle.dump([bl[12288:16384], fl[12288:16384], el[12288:16384]], open(f"/content/drive/MyDrive/parrot/puzzle_database/data_{3}_{st}_{ri3}.chess", "wb"))
    pickle.dump([bl[16384:20480], fl[16384:20480], el[16384:20480]], open(f"/content/drive/MyDrive/parrot/puzzle_database/data_{4}_{st}_{ri4}.chess", "wb"))

    print(i, time.time() - start)

0 9.041930675506592
1 14.648443937301636
2 22.721497058868408
3 29.45970058441162
4 37.96062183380127
5 45.532405614852905
6 53.19490456581116
7 61.557371854782104
8 69.41981434822083
9 73.88979649543762
10 80.04471182823181
11 88.11757469177246
12 96.61164665222168
13 101.2178738117218
14 107.2203528881073
15 111.44609880447388
16 117.4992024898529
17 125.21446967124939
18 127.6591112613678
19 136.79325246810913
20 144.82150411605835
21 148.65180921554565
22 154.38727498054504
23 158.71089959144592
24 164.69064044952393
25 170.7264142036438
26 177.10282397270203
27 183.66146850585938
28 191.9273567199707
29 201.607323884964
30 206.38585758209229
31 212.0861358642578
32 219.446439743042
33 222.04428482055664
34 228.68093132972717
35 235.21382641792297
36 239.67642307281494
37 244.4661614894867
38 249.93140316009521
39 254.9374074935913
40 261.05940413475037
41 263.44419980049133
42 266.45428347587585
43 267.5431983470917
44 273.7078626155853
45 276.8602650165558
46 277.9447522163391
47