In [1]:
import numpy as np
import time
import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader, random_split
import torch.optim as optim
import matplotlib.pyplot as plt
import sys, os

project_root = os.path.abspath(os.path.join(os.getcwd(), '..', '..'))
sys.path.append(project_root)

from Resources.Model import Model_v8
from Resources.Game import *


##### global parameters

In [6]:
gamma = 0.98
value_diff_scale = 50
value_diff_scale_early = 1
# games saved in batches to reduce i/o stream
# each batch is a input file and a label file containing [batch_size] individual games
batch_size = 100

##### local counters

In [11]:
white_wins = 0
black_wins = 0
draws = 0

game_count = 0          # counting decisive games
batch_count = 0         # number of batches locally done

In [12]:
first_load = True
initialize_batch = True

batch_white_wins = 0
batch_black_wins = 0
batch_draws = 0

while True:

    # load newest model initially and for every new batch
    if initialize_batch or first_load:
        model = Model_v8()
        model_saves = os.listdir('../Monte Carlo/Model Saves MC v8_3')
        if len(model_saves) > 0:
            newest_model = max(int(i[6:-8]) for i in model_saves)
            model.load_state_dict(torch.load('../Monte Carlo/Model Saves MC v8_3/model_{}_batches'.format(newest_model)))

        first_load = False

    game = Game()
    i = 0
    boards_white = [];  boards_black = []

    model.eval()

    while not game.is_over():
        
        i += 1
        moves = game.PossibleMoves()

        game_ini = game.copy()
        board_batch = [board_to_tensor(game.pieces)]

        mate = False

        for move in moves:
            game.PlayMove(move)
            board_batch.append(board_to_tensor(game.pieces))
            game.FlipBoard()
            if game.is_over():
                mate = True
                chosen_move = move
                game = game_ini.copy()
                break
            game = game_ini.copy()

        if not mate:
            board_tensor = torch.stack(board_batch)
            values = model(board_tensor)
            if i < 7:
                scale = value_diff_scale_early
            else:
                scale = value_diff_scale
            values_diff = [scale*(values[i] - values[0]) for i in range(1, len(values))]
            move_prob = torch.softmax(torch.Tensor(values_diff), dim=0).numpy()
            chosen_i = np.random.choice(range(len(moves)), p=move_prob)
            chosen_move = moves[chosen_i]
            
        game.PlayMove(chosen_move)

        if i % 2 == 1:
            boards_white.append(board_to_bool_tensor(game.pieces))
        if i % 2 == 0:
            boards_black.append(board_to_bool_tensor(game.pieces))

        game.FlipBoard()
    
    winner = game.get_winner()

    if winner == 'draw':
        draws += 1; batch_draws += 1
        continue
        reward_white = 0;   reward_black = 0

    elif winner == 'white':
        white_wins += 1; batch_white_wins += 1
        reward_white = 1;   reward_black = -1

    elif winner == 'black':
        black_wins += 1; batch_black_wins += 1
        reward_white = -1;  reward_black = 1

    labels_white = [reward_white * gamma**(len(boards_white) - 1 - i) for i in range(len(boards_white))]
    labels_black = [reward_black * gamma**(len(boards_black) - 1 - i) for i in range(len(boards_black))]

    inputs_tens = torch.stack(boards_white + boards_black)
    labels_tens = torch.Tensor(labels_white + labels_black)

    if initialize_batch:
        batch_inputs = inputs_tens.clone()
        batch_labels = labels_tens.clone()
        initialize_batch = False
    else:
        batch_inputs = torch.cat((batch_inputs, inputs_tens))
        batch_labels = torch.cat((batch_labels, labels_tens))

    game_count += 1

    if game_count % batch_size == 0:

        batch_count += 1

        with open('/Users/Philip/Desktop/Projects/RL Chess/MCTS/Game Saves v8_3/stats', 'rb') as f:
            stats = torch.load(f)
        stats = stats.int()
        stats[0] += 1 # batch index
        stats[1] += batch_white_wins
        stats[2] += batch_black_wins
        stats[3] += batch_draws
        torch.save(stats, '/Users/Philip/Desktop/Projects/RL Chess/MCTS/Game Saves v8_3/stats')

        # print('update index ', stats[0])

        new_batch_index = stats[0]

        torch.save(batch_inputs, '/Users/Philip/Desktop/Projects/RL Chess/MCTS/Game Saves v8_3/inputs_{}'.format(new_batch_index))
        torch.save(batch_labels, '/Users/Philip/Desktop/Projects/RL Chess/MCTS/Game Saves v8_3/labels_{}'.format(new_batch_index))

        print('local batches: {} --  w: {}, b: {}, d: {}'.format(batch_count, white_wins, black_wins, draws))

        if batch_count % 5 == 0:
            print(' -- global batches = {} --  w: {}, b: {}, d: {} (total: {})'.format(
                new_batch_index, stats[1], stats[2], stats[3], stats[1] + stats[2] + stats[3]))

        initialize_batch = True

        batch_white_wins = 0
        batch_black_wins = 0
        batch_draws = 0

local batches: 1 --  w: 53, b: 47, d: 114
local batches: 2 --  w: 102, b: 98, d: 233
local batches: 3 --  w: 161, b: 139, d: 353
local batches: 4 --  w: 216, b: 184, d: 510
local batches: 5 --  w: 265, b: 235, d: 676
 -- global batches = 25 --  w: 1250, b: 1250, d: 3149 (total: 5649)
local batches: 6 --  w: 310, b: 290, d: 847
local batches: 7 --  w: 363, b: 337, d: 988
local batches: 8 --  w: 414, b: 386, d: 1149
local batches: 9 --  w: 461, b: 439, d: 1281
local batches: 10 --  w: 512, b: 488, d: 1408
 -- global batches = 47 --  w: 2369, b: 2331, d: 6546 (total: 11246)
local batches: 11 --  w: 558, b: 542, d: 1573
local batches: 12 --  w: 605, b: 595, d: 1760
local batches: 13 --  w: 644, b: 656, d: 1950
local batches: 14 --  w: 691, b: 709, d: 2129
local batches: 15 --  w: 748, b: 752, d: 2329
 -- global batches = 73 --  w: 3651, b: 3649, d: 11016 (total: 18316)
local batches: 16 --  w: 796, b: 804, d: 2496
local batches: 17 --  w: 847, b: 853, d: 2631
local batches: 18 --  w: 898, 

KeyboardInterrupt: 