In [None]:
# default_exp generate_training_set

# Generate Training Set
> Reads a game from a file opened in text mode.

In [None]:
#hide
from nbdev.showdoc import *
%load_ext autoreload
%autoreload 2

In [None]:
#exports
import os
import chess.pgn
import numpy as np
from mediocre_chess_ai.state import State


def get_dataset(num_samples):
    # Initialize variables for storing board states and results.
    X,Y = [], []
    gn = 0
    
    # Value from each match depending on outcome.
    values = {'1/2-1/2': 0, '0-1': -1, '1-0': 1}
    
    for fn in os.listdir("data"):
        # pgn files in this data folder
        pgn = open(os.path.join("data", fn))
        while 1:
            game = chess.pgn.read_game(pgn)
            if game is None:
                break  
            # Gets the result from the game.
            res = game.headers["Result"]
            if res not in values: 
                continue   
            value = values[res]
            
            # Plays the next move .
            board = game.board()
            for i,move in enumerate(game.mainline_moves()):
                board.push(move)
                ser = State(board).serialize()
                X.append(ser)
                Y.append(value)
            print("parsing game %d, got %d examples" % (gn, len(X)))
            
            # Cancel if we got enough examples.
            if num_samples is not None and len(X) > num_samples:
                return X,Y
            gn +=1
    # Convert to numpy arrays
    X = np.array(X)
    Y = np.array(Y)
    return X,Y

In [None]:
#hide
X, Y = get_dataset(1e3)
np.savez("processed/dataset_1k.npz", X, Y)

parsing game 0, got 45 examples
parsing game 1, got 100 examples
parsing game 2, got 147 examples
parsing game 3, got 193 examples
parsing game 4, got 294 examples
parsing game 5, got 430 examples
parsing game 6, got 574 examples
parsing game 7, got 780 examples
parsing game 8, got 917 examples
parsing game 9, got 987 examples
parsing game 10, got 1140 examples


In [None]:
#hide
from nbdev.export import notebook2script; notebook2script()

Converted 00_state.ipynb.
Converted 01_generate_training_set.ipynb.
Converted 02_train.ipynb.
Converted 03_play.ipynb.
Converted index.ipynb.
