In [59]:
import numpy as np
import time
import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader, random_split
import torch.optim as optim
import matplotlib.pyplot as plt
import sys, os

project_root = os.path.abspath(os.path.join(os.getcwd(), '..', '..'))
sys.path.append(project_root)

from Resources.Model import Model_v7
from Resources.Game import *


##### global parameters

In [60]:
gamma = 0.98
value_diff_scale = 250
value_diff_scale_early = 5
# games saved in batches to reduce i/o stream
# each batch is a input file and a label file containing [batch_size] individual games
batch_size = 20

##### local counters

In [61]:
white_wins = 0
black_wins = 0
draws = 0

game_count = 0          # counting decisive games
batch_count = 0         # number of batches locally done

In [62]:
first_load = True
initialize_batch = True

batch_white_wins = 0
batch_black_wins = 0
batch_draws = 0

while True:

    # load newest model initially and for every new batch
    if initialize_batch or first_load:
        model = Model_v7()
        model_saves = os.listdir('../Monte Carlo/Model Saves MC v7 Parallel')
        if len(model_saves) > 0:
            newest_model = max(int(i[6:-8]) for i in model_saves)
            model.load_state_dict(torch.load('../Monte Carlo/Model Saves MC v7 Parallel/model_{}_batches'.format(newest_model)))

        first_load = False

    game = Game()
    i = 0
    boards_white = [];  boards_black = []

    model.eval()

    while not game.is_over():
        
        i += 1
        moves = game.PossibleMoves()

        game_ini = game.copy()
        board_batch = [board_to_tensor(game.pieces)]

        mate = False

        for move in moves:
            game.PlayMove(move)
            board_batch.append(board_to_tensor(game.pieces))
            game.FlipBoard()
            if game.is_over():
                mate = True
                chosen_move = move
                game = game_ini.copy()
                break
            game = game_ini.copy()

        if not mate:
            board_tensor = torch.stack(board_batch)
            values = model(board_tensor)
            if i < 7:
                scale = value_diff_scale_early
            else:
                scale = value_diff_scale
            values_diff = [scale*(values[i] - values[0]) for i in range(1, len(values))]
            move_prob = torch.softmax(torch.Tensor(values_diff), dim=0).numpy()
            chosen_i = np.random.choice(range(len(moves)), p=move_prob)
            chosen_move = moves[chosen_i]
            
        game.PlayMove(chosen_move)

        if i % 2 == 1:
            boards_white.append(board_to_bool_tensor(game.pieces))
        if i % 2 == 0:
            boards_black.append(board_to_bool_tensor(game.pieces))

        game.FlipBoard()
    
    winner = game.get_winner()

    if winner == 'draw':
        draws += 1; batch_draws += 1
        continue

    elif winner == 'white':
        white_wins += 1; batch_white_wins += 1
        reward_white = 1;   reward_black = -1

    elif winner == 'black':
        black_wins += 1; batch_black_wins += 1
        reward_white = -1;  reward_black = 1

    labels_white = [reward_white * gamma**(len(boards_white) - 1 - i) for i in range(len(boards_white))]
    labels_black = [reward_black * gamma**(len(boards_black) - 1 - i) for i in range(len(boards_black))]

    inputs_tens = torch.stack(boards_white + boards_black)
    labels_tens = torch.Tensor(labels_white + labels_black)

    if initialize_batch:
        batch_inputs = inputs_tens.clone()
        batch_labels = labels_tens.clone()
        initialize_batch = False
    else:
        batch_inputs = torch.cat((batch_inputs, inputs_tens))
        batch_labels = torch.cat((batch_labels, labels_tens))

    game_count += 1

    if game_count % batch_size == 0:

        batch_count += 1

        with open('/Users/Philip/Desktop/Projects/RL Chess/MCTS/Game Saves v7/stats', 'rb') as f:
            stats = torch.load(f)
        stats = stats.int()
        stats[0] += 1 # batch index
        stats[1] += batch_white_wins
        stats[2] += batch_black_wins
        stats[3] += batch_draws
        torch.save(stats, '/Users/Philip/Desktop/Projects/RL Chess/MCTS/Game Saves v7/stats')

        # print('update index ', stats[0])

        new_batch_index = stats[0]

        torch.save(batch_inputs, '/Users/Philip/Desktop/Projects/RL Chess/MCTS/Game Saves v7/inputs_{}'.format(new_batch_index))
        torch.save(batch_labels, '/Users/Philip/Desktop/Projects/RL Chess/MCTS/Game Saves v7/labels_{}'.format(new_batch_index))

        print('local batches: {} --  w: {}, b: {}, d: {}'.format(batch_count, white_wins, black_wins, draws))

        if batch_count % 5 == 0:
            print(' -- global batches = {} --  w: {}, b: {}, d: {} (total: {})'.format(
                new_batch_index, stats[1], stats[2], stats[3], stats[1] + stats[2] + stats[3]))

        initialize_batch = True

        batch_white_wins = 0
        batch_black_wins = 0
        batch_draws = 0

local batches: 1 --  w: 9, b: 11, d: 11
local batches: 2 --  w: 16, b: 24, d: 17
local batches: 3 --  w: 24, b: 36, d: 26
local batches: 4 --  w: 34, b: 46, d: 29
local batches: 5 --  w: 43, b: 57, d: 34
 -- global batches = 20 --  w: 202, b: 198, d: 127 (total: 527)
local batches: 6 --  w: 51, b: 69, d: 38
local batches: 7 --  w: 63, b: 77, d: 39
local batches: 8 --  w: 73, b: 87, d: 44
local batches: 9 --  w: 82, b: 98, d: 54
local batches: 10 --  w: 88, b: 112, d: 58
 -- global batches = 42 --  w: 430, b: 410, d: 271 (total: 1111)
local batches: 11 --  w: 97, b: 123, d: 63
local batches: 12 --  w: 103, b: 137, d: 77
local batches: 13 --  w: 115, b: 145, d: 105
local batches: 14 --  w: 125, b: 155, d: 149
local batches: 15 --  w: 138, b: 162, d: 197
 -- global batches = 69 --  w: 739, b: 641, d: 895 (total: 2275)
local batches: 16 --  w: 148, b: 172, d: 247
local batches: 17 --  w: 160, b: 180, d: 295
local batches: 18 --  w: 172, b: 188, d: 317
local batches: 19 --  w: 184, b: 196, 

KeyboardInterrupt: 