In [3]:
import numpy as np
import time
import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader, random_split
import torch.optim as optim
import matplotlib.pyplot as plt
import sys, os

project_root = os.path.abspath(os.path.join(os.getcwd(), '..', '..'))
sys.path.append(project_root)

from Resources.Model import Model_v8
from Resources.Game import *


##### global parameters

In [5]:
gamma = 0.98
value_diff_scale = 50
value_diff_scale_early = 1
# games saved in batches to reduce i/o stream
# each batch is a input file and a label file containing [batch_size] individual games
batch_target = 100

##### local counters

In [6]:
white_wins = 0
black_wins = 0
draws = 0

batch_count = 0         # number of batches locally done

In [8]:
# keep generating new batches of data until stopped
while True:

    batch_white_wins = 0
    batch_black_wins = 0
    batch_draws = 0

    # load newest model
    model = Model_v8()
    model_saves = os.listdir('../Monte Carlo/Model Saves MC v8_3')
    if len(model_saves) > 0:
        newest_model = max(int(i[6:-8]) for i in model_saves)
        model.load_state_dict(torch.load('../Monte Carlo/Model Saves MC v8_3/model_{}_batches'.format(newest_model)))
    else:
        time.sleep(0.05)

    # print('loaded model ', newest_model)

    with open('/Users/Philip/Desktop/Projects/RL Chess/MCTS/Game Saves v8_3/stats', 'rb') as f:
        stats = torch.load(f)
    stats = stats.int()
    global_white_wins   = stats[1]
    global_black_wins   = stats[2]
    global_draws        = stats[3]

    percentage_decisive = (global_white_wins + global_black_wins) / (global_white_wins + global_black_wins + global_draws)
    batch_size = int(batch_target // percentage_decisive.item()) # so that on average we have [batch_target] decisive games

    meta_games = []
    meta_boards_white = []; meta_boards_black = []
    i = 0

    meta_active = []

    for game_ind in range(batch_size):
        meta_games.append( Game() )
        meta_boards_white.append([]); meta_boards_black.append([])
        meta_active.append(True)

    model.eval()

    # print('start with {} games'.format(batch_size))

    while True in meta_active:

        # if i % 20 == 0:
        #     print('i = {}, with {} active games '.format(i, meta_active.count(True)))

        i += 1
        full_board_batch = []
        meta_board_batch_sizes = [] # save batch sizes to split model output afterwards
        meta_moves = []

        # go through games, collect positions for value evaluations
        for g, game in enumerate(meta_games):

            if not meta_active[g]:
                meta_board_batch_sizes.append(0)
                meta_moves.append([])
                continue

            meta_moves.append(game.PossibleMoves())
            game_ini = game.copy()
            board_batch = [board_to_tensor(game.pieces)]
            mate = False

            for move in meta_moves[-1]:
                game.PlayMove(move)
                board_batch.append(board_to_tensor(game.pieces))
                game.FlipBoard()
                if game.is_over():
                    mate = True
                    break
                game = game_ini.copy()

            meta_games[g] = game.copy()

            if mate:
                meta_active[g] = False
                meta_board_batch_sizes.append(0)
                game.FlipBoard()
                if i % 2 == 1:
                    meta_boards_white[g].append(board_to_bool_tensor(game.pieces))
                if i % 2 == 0:
                    meta_boards_black[g].append(board_to_bool_tensor(game.pieces))

            if not mate:
                full_board_batch = full_board_batch + board_batch
                meta_board_batch_sizes.append(len(board_batch))

        if len(full_board_batch) == 0:
            break

        # get values of all positions
        full_board_batch = torch.stack(full_board_batch)
        out = model(full_board_batch).detach()
        meta_values = torch.split(out, meta_board_batch_sizes)

        # make moves for all games
        for g, game in enumerate(meta_games):

            if not meta_active[g]:
                continue
            values = meta_values[g]
            if i < 7:
                scale = value_diff_scale_early
            else:
                scale = value_diff_scale
            values_diff = [scale*(values[i] - values[0]) for i in range(1, len(values))]
            move_prob = torch.softmax(torch.Tensor(values_diff), dim=0).numpy()
            chosen_i = np.random.choice(range(len(meta_moves[g])), p=move_prob)
            chosen_move = meta_moves[g][chosen_i]
            game.PlayMove(chosen_move)
            if i % 2 == 1:
                meta_boards_white[g].append(board_to_bool_tensor(game.pieces))
            if i % 2 == 0:
                meta_boards_black[g].append(board_to_bool_tensor(game.pieces))
            game.FlipBoard()

    # print('games done, start evaluating')

    meta_inputs = []
    meta_labels = []

    for g, game in enumerate(meta_games):
        
        winner = game.get_winner()
        if winner == 'draw':
            draws += 1; batch_draws += 1
            continue
            reward_white = 0;   reward_black = 0

        elif winner == 'white':
            white_wins += 1; batch_white_wins += 1
            reward_white = 1;   reward_black = -1

        elif winner == 'black':
            black_wins += 1; batch_black_wins += 1
            reward_white = -1;  reward_black = 1

        labels_white = [reward_white * gamma**(len(meta_boards_white[g]) - 1 - i) for i in range(len(meta_boards_white[g]))]
        labels_black = [reward_black * gamma**(len(meta_boards_black[g]) - 1 - i) for i in range(len(meta_boards_black[g]))]

        meta_inputs = meta_inputs + meta_boards_white[g] + meta_boards_black[g]
        meta_labels = meta_labels + labels_white + labels_black

    if batch_white_wins + batch_black_wins == 0:
        print('no decisive games in the whole batch -> skip to next batch (batch size too small?)')
        continue

    inputs_tens = torch.stack(meta_inputs)
    labels_tens = torch.Tensor(meta_labels)

    # print('evaluation done, save batch')

    with open('/Users/Philip/Desktop/Projects/RL Chess/MCTS/Game Saves v8_3/stats', 'rb') as f:
        stats = torch.load(f)
    stats = stats.int()
    stats[0] += 1;                  stats[1] += batch_white_wins
    stats[2] += batch_black_wins;   stats[3] += batch_draws
    torch.save(stats, '/Users/Philip/Desktop/Projects/RL Chess/MCTS/Game Saves v8_3/stats')

    # print('batch index = ', stats[0])

    new_batch_index = stats[0]
    torch.save(inputs_tens, '/Users/Philip/Desktop/Projects/RL Chess/MCTS/Game Saves v8_3/inputs_{}'.format(new_batch_index))
    torch.save(labels_tens, '/Users/Philip/Desktop/Projects/RL Chess/MCTS/Game Saves v8_3/labels_{}'.format(new_batch_index))

    print('local batches: {} --  w: {}, b: {}, d: {}'.format(batch_count, white_wins, black_wins, draws))

    batch_count += 1

    if batch_count % 5 == 0:
        print(' -- global batches = {} --  w: {}, b: {}, d: {} (total: {})'.format(
            new_batch_index, stats[1], stats[2], stats[3], stats[1] + stats[2] + stats[3]))

local batches: 1 --  w: 99, b: 94, d: 307
local batches: 2 --  w: 148, b: 145, d: 457
local batches: 3 --  w: 192, b: 192, d: 616
local batches: 4 --  w: 246, b: 238, d: 766
 -- global batches = 1232 --  w: 61478, b: 61698, d: 185734 (total: 308910)
local batches: 5 --  w: 298, b: 306, d: 896
local batches: 6 --  w: 353, b: 363, d: 1034
local batches: 7 --  w: 415, b: 409, d: 1176
local batches: 8 --  w: 472, b: 468, d: 1310
local batches: 9 --  w: 538, b: 509, d: 1453
 -- global batches = 1257 --  w: 62833, b: 63045, d: 189282 (total: 315160)
local batches: 10 --  w: 591, b: 560, d: 1599
local batches: 11 --  w: 639, b: 614, d: 1747
local batches: 12 --  w: 692, b: 660, d: 1898
local batches: 13 --  w: 768, b: 701, d: 2031
local batches: 14 --  w: 826, b: 759, d: 2165
 -- global batches = 1282 --  w: 64263, b: 64274, d: 192873 (total: 321410)
local batches: 15 --  w: 880, b: 811, d: 2309
local batches: 16 --  w: 930, b: 858, d: 2461
local batches: 17 --  w: 963, b: 910, d: 2625
local 

KeyboardInterrupt: 