In [2]:
import chess
import chess.pgn
import os

In [3]:
def read_games_from_file(file_path):
    games = []
    with open(file_path, 'r') as pgn_file:
        while True:
            try:
                game = chess.pgn.read_game(pgn_file)
                if game is None:
                    break
                games.append(game)
            except ValueError as e:
                print(f"Pominięto partię z powodu błędu: {e}")
    return games

def is_checkmate(game):
    board = game.board()
    for move in game.mainline_moves():
        board.push(move)
    return board.is_checkmate()

def filter_checkmate_games(games, pgn_output_path):
    checkmate_games = []
    for cnt, game in enumerate(games):
        try:
            if is_checkmate(game):
                checkmate_games.append(game)
        except Exception as e:
            print(f"Pominięto partię numer {cnt} z powodu błędu: {e}")

    with open(pgn_output_path, "a") as output_file:
        for game in checkmate_games:
            output_file.write(str(game))
            output_file.write("\n\n")


data_path = "../data/raw/"
files = os.listdir(data_path)
for file in files:
    games = read_games_from_file(data_path + file)
    filter_checkmate_games(games, data_path + "mates.pgn")

In [4]:
data_path = "../data/raw/"
games = read_games_from_file(data_path+"mates.pgn")

invalid = 0
for cnt, game in enumerate(games):
    data = []
    board = chess.Board()
    cor = True

    for move in game.mainline_moves():
        legal_moves = list(board.legal_moves)
        legals_str = [
            [chess.square_name(move_.from_square),chess.square_name(move_.to_square)]
             for move_ in legal_moves if move_.promotion is None or move_.promotion == chess.QUEEN]
        if move in legal_moves:
            board_fen = board.fen()
            *fen_data, _, __ = board_fen.split(" ")
            fen_data.append(move.uci())
            fen_data.append(str(legals_str))
            data.append(",".join(fen_data)+"\n")
            board.push(move)
        else:
            cor = False
            invalid += 1
            break
    if cor:
        with open(data_path+"fen_data64.txt", "a") as f:
            f.writelines(data)
    if cnt % 1000 == 0:
        print(f"{cnt}/{len(games)}")
print("invalid: ", invalid)

0/51366
1000/51366
2000/51366
3000/51366
4000/51366
5000/51366
6000/51366
7000/51366
8000/51366
9000/51366
10000/51366
11000/51366
12000/51366
13000/51366
14000/51366
15000/51366
16000/51366
17000/51366
18000/51366
19000/51366
20000/51366
21000/51366
22000/51366
23000/51366
24000/51366
25000/51366
26000/51366
27000/51366
28000/51366
29000/51366
30000/51366
31000/51366
32000/51366
33000/51366
34000/51366
35000/51366
36000/51366
37000/51366
38000/51366
39000/51366
40000/51366
41000/51366
42000/51366
43000/51366
44000/51366
45000/51366
46000/51366
47000/51366
48000/51366
49000/51366
50000/51366
51000/51366
invalid:  244


In [5]:

data_path = "../data/raw/"
with open(data_path+"fen_data64.txt", "r") as f:
    fens = f.readlines()
max_move_cnt = 0

for cnt, fen in enumerate(fens):
    if "k" not in fen or "K" not in fen:
        print (fen)
        print (cnt)
    list_idx = fen.index("[")
    list_str = fen[list_idx:-1]
    len_ = list_str.count("[")
    max_move_cnt = max(max_move_cnt, len_ - 1)
    if cnt % 1000000 == 0:
        print(cnt)


0
1000000
2000000
3000000
4000000
5000000
6000000
7000000


In [6]:
max_move_cnt

79

In [7]:
import json
import csv
import ast
prep_tokenized_path = "../data/prep/tokenized64.csv"
to_tokenize_path = "../data/raw/fen_data64.txt"
vocab_src_path = "../src/lstm64/vocab_src.json"
vocab_tar_path = "../src/lstm64/vocab_tar.json"
cnt = 0

with open(vocab_tar_path, "r") as f:
    vocab_tar = json.load(f)

with open(vocab_src_path, "r") as f:
    vocab_src = json.load(f)

with open(to_tokenize_path, "r") as f:
    lines = f.readlines()

nums = [8, 7, 6, 5, 4, 3, 2]

with open(prep_tokenized_path, "w") as file:
    writer = csv.writer(file)
    writer.writerow(['sequence', 'target', 'legal'])
    for idx, line in enumerate(lines):
        seq_token = [vocab_src['SOS']]
        tar_token = [vocab_tar['SOS']]
        fen, turn, cas, en_pass, tar, *legals_str = line.split(",")
        
        for num in nums:
            fen = fen.replace(str(num), num * "1")
        seq_token.extend([vocab_src[char] for char in fen])

        turn_str = "True" if turn == "w" else "False"
        seq_token.append(vocab_src[turn_str])

        if cas != "-":
            cas = cas.replace("K", "Ki")
            cas = cas.replace("Q", "Qu")
            cas = cas.replace("k", "ki")
            cas = cas.replace("q", "qu")
            elems = [cas[i:i+2] for i in range(0, len(cas), 2)]
            seq_token.extend([vocab_src[cr] for cr in elems])

        if en_pass != "-":
            en = en_pass[0] + "_enpas"
            seq_token.append(vocab_src[en])
            seq_token.append(vocab_src["EOS"])
    
        from_, to = tar[0:2], tar[2:4]
        tar_token.append(vocab_tar[from_])
        tar_token.append(vocab_tar[to])
        if len(tar) == 5:
            tar_token.append(vocab_tar[tar[4]])
        tar_token.append(vocab_tar["EOS"])
        
        legals_str = ",".join(legals_str)
        legals_str = legals_str[:-1]
        # legals_str = legals_str.replace("'", '"')
        legals = ast.literal_eval(legals_str)
        legals = [[vocab_tar[from_], vocab_tar[to]] for from_, to in legals]
        while len(legals) != max_move_cnt:
            legals.append([-1, -1])
        if cnt % 100000 == 0:
            print(cnt)

        cnt += 1
        writer.writerow([seq_token, tar_token, legals])

0
100000
200000
300000
400000
500000
600000
700000
800000
900000
1000000
1100000
1200000
1300000
1400000
1500000
1600000
1700000
1800000
1900000
2000000
2100000
2200000
2300000
2400000
2500000
2600000
2700000
2800000
2900000
3000000
3100000
3200000
3300000
3400000
3500000
3600000
3700000
3800000
3900000
4000000
4100000
4200000
4300000
4400000
4500000
4600000
4700000
4800000
4900000
5000000
5100000
5200000
5300000
5400000
5500000
5600000
5700000
5800000
5900000
6000000
6100000
6200000
6300000
6400000
6500000
6600000
6700000
6800000
6900000
7000000
7100000
7200000
7300000
7400000
7500000
7600000
7700000
