In [1]:
import chess
import multiprocessing
import chess.pgn
import os
import sys
import numpy as np
from dotenv import load_dotenv, find_dotenv

sys.path.append("..")

from silvermind import states

In [2]:
load_dotenv(find_dotenv())

try:
    num_cpu_cores = os.environ['NUM_CPU_CORES']
except KeyError:
    num_cpu_cores = input("NUM_CPU_CORES")

num_cpu_cores = int(num_cpu_cores) - 4

__file__ = os.path.abspath('')

### Make basic dataset

In [3]:
def basic_dataset(return_dict, file_name, max_games=sys.maxsize, verbose=False):
        games = []
        with open(file_name) as pgn:
            while (game := chess.pgn.read_game(pgn)) is not None:
                if len(games) >= max_games:
                    break
                games.append(game)

        X = []
        y = []
        for i, game in enumerate(games):
            result = {"1-0":1, "0-1":0, "1/2-1/2":None, "*":None}[game.headers["Result"]]
            if result is None:
                continue
            board = game.board()
            X.append(states.serialize(board))
            y.append(result)
            for move in game.mainline_moves():
                board.push(move)
                X.append(states.serialize(board))
                y.append(result)
            if verbose and i % (max(1, max_games // 5)) == 0:
                print(f"batch {round(i/max_games, 3)*100}% complete")

        return_dict[file_name] = (X, y)


def balanced_dataset(return_dict, file_name, max_games=sys.maxsize, verbose=False):
        games = []
        with open(file_name) as pgn:
            while (game := chess.pgn.read_game(pgn)) is not None:
                if len(games) >= max_games:
                    break
                games.append(game)

        X = []
        y = []
        for i, game in enumerate(games):
            result = {"1-0":1, "0-1":0, "1/2-1/2":None, "*":None}[game.headers["Result"]]
            if result is None:
                continue
            board = game.board()
            for move_num, move in enumerate(game.mainline_moves()):
                if board.is_capture(move):
                    board.push(move)
                    continue
                board.push(move)
                if move_num < 5:
                    continue

                X.append(states.serialize(board))
                y.append(result)
            if verbose and i % (max(1, max_games // 5)) == 0:
                print(f"batch {round(i/max_games, 3)*100}% complete")

        return_dict[file_name] = (X, y)

def make_dataset(location=f"{__file__}/../pgns", worker_func=basic_dataset, verbose=False, max_games_per_batch=sys.maxsize):
    manager = multiprocessing.Manager()
    return_dict = manager.dict()
    workers = []
    
    file_names = [f"{location}/{file_title}" for file_title in os.listdir(location) if file_title.endswith(".pgn")]
    
    if verbose:
        print("Creating workers")
    for file_name in file_names:
        worker = multiprocessing.Process(
            target=worker_func,
            args=(return_dict, file_name, max_games_per_batch, verbose)
        )
        workers.append(worker)
    
    if verbose:
        print("Activating workers")
    num_workers_executed = 0
    while num_workers_executed < len(workers):
        workers_to_execute = min(num_cpu_cores, len(workers) - num_workers_executed)
        for i in range(workers_to_execute):
            workers[num_workers_executed + i].start()
        for i in range(workers_to_execute):
            workers[num_workers_executed + i].join()
        
        num_workers_executed += workers_to_execute
        
        if verbose:
            print(f"Dataset creation {round(min(1, num_workers_executed/len(workers)), 3) * 100}% complete...")
    
    y = []
    X = []
    for file_name in file_names:
        if file_name not in return_dict:
            continue
        res_X, res_y = return_dict[file_name]
        X.extend(res_X)
        y.extend(res_y)
    X = np.array(X)
    y = np.array(y)
    
    return X, y

In [4]:
X, y = make_dataset(max_games_per_batch=2500, worker_func=balanced_dataset, verbose=True)

print(X.shape, y.shape)

if not os.path.isdir(f"{__file__}/../dataset"):
    os.mkdir(f"{__file__}/../dataset")

np.savez(f"{__file__}/../dataset/balanced_game_data.npz", X, y)

Creating workers
Activating workers
batch 0.0% complete
batch 0.0% complete
batch 0.0% complete
batch 0.0% complete


Process Process-6:
Process Process-13:
Process Process-10:
Process Process-15:
Traceback (most recent call last):
Process Process-14:
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Process Process-17:
Process Process-12:
Process Process-18:
Process Process-3:
Process Process-11:
Process Process-2:
Process Process-4:
  File "/usr/lib/python3.9/multiprocessing/process.py", line 315, in _bootstrap
    self.run()
  File "/usr/lib/python3.9/multiprocessing/process.py", line 315, in _bootstrap
    self.run()
  File "/usr/lib/python3.9/multiprocessing/process.py", line 315, in _bootstrap
    self.run()
  File "/usr/lib/python3.9/multiprocessing/process.py", line 315, in _bootstrap
    self.run()
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
  File "/usr/lib/python3.9/multiprocessing/process.py", line 315, in _bootstrap
    self.run()
  File "/

KeyboardInterrupt: 