In [4]:
import numpy as np
import time
import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader, random_split
import torch.optim as optim
import matplotlib.pyplot as plt
import sys, os

project_root = os.path.abspath(os.path.join(os.getcwd(), '..', '..'))
sys.path.append(project_root)

from Resources.Model import Model_v27
from Resources.Game import *


##### global parameters

In [11]:
gamma = 0.98

def scale(i):
    return 0.1 * i**2

# games saved in batches to reduce i/o stream
# each batch is a input file and a label file containing [batch_size] individual games
batch_target = 200

##### local counters

In [12]:
white_wins = 0
black_wins = 0
draws = 0

batch_count = 0         # number of batches locally done

In [13]:
# keep generating new batches of data until stopped
while True:

    batch_white_wins = 0
    batch_black_wins = 0
    batch_draws = 0

    # load newest model
    model = Model_v27()
    model_saves = os.listdir('../Monte Carlo/Model Saves MC v27_2')
    if len(model_saves) > 0:
        newest_model = max(int(i[6:-8]) for i in model_saves)
        model.load_state_dict(torch.load('../Monte Carlo/Model Saves MC v27_2/model_{}_batches'.format(newest_model)))
    else:
        time.sleep(0.05)

    # print('loaded model ', newest_model)

    with open('/Users/Philip/Desktop/Projects/RL Chess/MCTS/Game Saves v27_2/stats', 'rb') as f:
        stats = torch.load(f)
    stats = stats.int()
    global_white_wins   = stats[1]
    global_black_wins   = stats[2]
    global_draws        = stats[3]

    if (global_white_wins + global_black_wins + global_draws) == 0:
        percentage_decisive = 0.5
    else:
        percentage_decisive = ((global_white_wins + global_black_wins) / (global_white_wins + global_black_wins + global_draws)).item()
    
    batch_size = int(batch_target // percentage_decisive) # so that on average we have [batch_target] decisive games

    meta_games = []
    meta_boards_white = []; meta_boards_black = []
    i = 0

    meta_active = []

    for game_ind in range(batch_size):
        meta_games.append( Game() )
        meta_boards_white.append([]); meta_boards_black.append([])
        meta_active.append(True)

    model.eval()

    while True in meta_active:

        i += 1
        full_board_batch = []
        meta_board_batch_sizes = [] # save batch sizes to split model output afterwards
        meta_moves = []

        # go through games, collect positions for value evaluations
        for g, game in enumerate(meta_games):

            if not meta_active[g]:
                meta_board_batch_sizes.append(0)
                meta_moves.append([])
                continue

            meta_moves.append(game.PossibleMoves())
            game_ini = game.copy()
            board_batch = [board_to_tensor(game.pieces)]
            mate = False

            for move in meta_moves[-1]:
                game.PlayMove(move)
                board_batch.append(board_to_tensor(game.pieces))
                game.FlipBoard()
                if game.is_over():
                    mate = True
                    break
                game = game_ini.copy()

            meta_games[g] = game.copy()

            if mate:
                meta_active[g] = False
                meta_board_batch_sizes.append(0)
                game.FlipBoard()
                if i % 2 == 1:
                    meta_boards_white[g].append(board_to_bool_tensor(game.pieces))
                if i % 2 == 0:
                    meta_boards_black[g].append(board_to_bool_tensor(game.pieces))

            if not mate:
                full_board_batch = full_board_batch + board_batch
                meta_board_batch_sizes.append(len(board_batch))

        if len(full_board_batch) == 0:
            break

        # get values of all positions
        full_board_batch = torch.stack(full_board_batch)
        out = model(full_board_batch).detach()
        meta_values = torch.split(out, meta_board_batch_sizes)

        # make moves for all games
        for g, game in enumerate(meta_games):

            if not meta_active[g]:
                continue
            values = meta_values[g]
            values_diff = [scale(i)*(values[j] - values[0]) for j in range(1, len(values))]
            move_prob = torch.softmax(torch.Tensor(values_diff), dim=0).numpy()
            chosen_i = np.random.choice(range(len(meta_moves[g])), p=move_prob)
            chosen_move = meta_moves[g][chosen_i]
            game.PlayMove(chosen_move)
            if i % 2 == 1:
                meta_boards_white[g].append(board_to_bool_tensor(game.pieces))
            if i % 2 == 0:
                meta_boards_black[g].append(board_to_bool_tensor(game.pieces))
            game.FlipBoard()

    meta_inputs = []
    meta_labels = []

    for g, game in enumerate(meta_games):
        
        winner = game.get_winner()
        if winner == 'draw':
            draws += 1; batch_draws += 1
            continue
            reward_white = 0;   reward_black = 0

        elif winner == 'white':
            white_wins += 1; batch_white_wins += 1
            reward_white = 1;   reward_black = -1

        elif winner == 'black':
            black_wins += 1; batch_black_wins += 1
            reward_white = -1;  reward_black = 1

        labels_white = [reward_white * gamma**(len(meta_boards_white[g]) - 1 - i) for i in range(len(meta_boards_white[g]))]
        labels_black = [reward_black * gamma**(len(meta_boards_black[g]) - 1 - i) for i in range(len(meta_boards_black[g]))]

        meta_inputs = meta_inputs + meta_boards_white[g] + meta_boards_black[g]
        meta_labels = meta_labels + labels_white + labels_black

    if batch_white_wins + batch_black_wins == 0:
        print('no decisive games in the whole batch -> skip to next batch (batch size too small?)')
        continue

    inputs_tens = torch.stack(meta_inputs)
    labels_tens = torch.Tensor(meta_labels)

    with open('/Users/Philip/Desktop/Projects/RL Chess/MCTS/Game Saves v27_2/stats', 'rb') as f:
        stats = torch.load(f)
    stats = stats.int()
    stats[0] += 1;                  stats[1] += batch_white_wins
    stats[2] += batch_black_wins;   stats[3] += batch_draws
    torch.save(stats, '/Users/Philip/Desktop/Projects/RL Chess/MCTS/Game Saves v27_2/stats')

    new_batch_index = stats[0]
    torch.save(inputs_tens, '/Users/Philip/Desktop/Projects/RL Chess/MCTS/Game Saves v27_2/inputs_{}'.format(new_batch_index))
    torch.save(labels_tens, '/Users/Philip/Desktop/Projects/RL Chess/MCTS/Game Saves v27_2/labels_{}'.format(new_batch_index))

    print('local batches: {} --  w: {}, b: {}, d: {}'.format(batch_count, white_wins, black_wins, draws))

    batch_count += 1
            
    if batch_count % 5 == 0:
        print(' -- global batches = {} --  w: {}, b: {}, d: {} (total: {})'.format(
            new_batch_index, stats[1], stats[2], stats[3], stats[1] + stats[2] + stats[3]))

local batches: 0 --  w: 172, b: 153, d: 75
local batches: 1 --  w: 263, b: 264, d: 119
local batches: 2 --  w: 298, b: 315, d: 284
local batches: 3 --  w: 355, b: 363, d: 447
local batches: 4 --  w: 406, b: 423, d: 644
 -- global batches = 13 --  w: 1120, b: 1125, d: 1612 (total: 3857)
local batches: 5 --  w: 482, b: 489, d: 845
local batches: 6 --  w: 547, b: 561, d: 1072
local batches: 7 --  w: 626, b: 645, d: 1295
local batches: 8 --  w: 709, b: 748, d: 1506
local batches: 9 --  w: 791, b: 854, d: 1723
 -- global batches = 28 --  w: 2298, b: 2302, d: 4852 (total: 9452)
local batches: 10 --  w: 883, b: 963, d: 1932
local batches: 11 --  w: 969, b: 1060, d: 2164
local batches: 12 --  w: 1056, b: 1147, d: 2406
local batches: 13 --  w: 1148, b: 1219, d: 2665
local batches: 14 --  w: 1217, b: 1310, d: 2933
 -- global batches = 43 --  w: 3636, b: 3652, d: 8422 (total: 15710)
local batches: 15 --  w: 1304, b: 1399, d: 3188
local batches: 16 --  w: 1411, b: 1515, d: 3401
local batches: 17 -

KeyboardInterrupt: 