In [1]:
import torch
from torch.functional import F
from torch import nn
from torch.utils.tensorboard import SummaryWriter
from agents.model_agent import MancalaModel

import re
from datetime import datetime
import random
import numpy as np

from game.mancalaenv import MancalaEnv
from torch import optim
import torch.distributions as dist
import os

from agents.agent import Agent
from game.mancala import Mancala
from copy import deepcopy
from game.play import play
from agents.model_agent import ModelAgent

import statistics as stats



In [2]:
time_tag = re.sub(r'[^\d]', '-', str(datetime.now().time()))
log_dir = "runs/" + time_tag
writer = SummaryWriter(log_dir)
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
seed = 1234
n_holes = 7
lr = 0.01
max_game_length = 1000
reward_discount = 0.99
epochs = 200000
print_interval = 2000
writer_interval = 2000
solved_win_rate = 0.99
eps = np.finfo(np.float32).eps.item()

torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)

In [3]:
from torchsummary import summary
model = MancalaModel(n_inputs=n_holes*2, n_outputs=n_holes)
env = MancalaEnv()
optimizer = optim.Adam(model.parameters(), lr=lr)

summary(model, (n_holes*2,), batch_size=1, device='cpu')

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Linear-1                   [1, 256]           3,840
              ReLU-2                   [1, 256]               0
            Linear-3                   [1, 256]          65,792
              ReLU-4                   [1, 256]               0
            Linear-5                     [1, 7]           1,799
            Linear-6                     [1, 1]             257
Total params: 71,688
Trainable params: 71,688
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.01
Params size (MB): 0.27
Estimated Total Size (MB): 0.28
----------------------------------------------------------------


In [4]:
def get_random_move(side, env):
    return np.random.choice(env.get_valid_moves(side), 1, replace=False)[0]

In [5]:
def calculate_loss(rewards, log_probabilities, values):
    discounted_rewards = []
    accumulated_rewards = 0
    for current_reward in rewards[::-1]:
        accumulated_rewards = reward_discount * accumulated_rewards + current_reward
        discounted_rewards.append(accumulated_rewards)

    discounted_rewards = torch.tensor(discounted_rewards[::-1]).float().to(device)
    normalized_rewards = (discounted_rewards - discounted_rewards.mean()) / (discounted_rewards.std() + eps)

    policy_loss = []
    value_loss = []
    for reward, log_probability, value in zip(normalized_rewards, log_probabilities, values):
        policy_loss.append((reward - value) * -log_probability)
        reward = reward.unsqueeze(0)
        value_loss.append(F.smooth_l1_loss(value, reward))

    return torch.stack(policy_loss).sum() + torch.stack(value_loss).sum()

In [6]:

def record(total_loss, env, epoch, model_side):
    writer.add_scalar('loss', total_loss, epoch)
    match_record_dir = log_dir + '/_' + str(epoch) + '.txt'
    torch.save(model.state_dict(), log_dir + '_' + str(epoch) + '.msd')
    with open(match_record_dir, 'w') as file:
        file.write(f'model side is: {model_side}' + os.linesep)
        for side, move, score, board in env.move_history:
            file.write(f'{side} moved {move}, score: {score}' + os.linesep)
            file.write(board + os.linesep)

In [7]:
class RandomAgent(Agent):
    def get_move(self, env, side):
        game = Mancala(n_holes, n_holes, board=deepcopy(env.board))
        return np.random.choice(game.get_valid_moves(side), 1, replace=False)[0]

def evaluation(model, epoch):
    wins = 0
    n_games = 100
    model.eval()
    model.to('cpu')
    random_agent = RandomAgent()
    model_agent = ModelAgent(model=model)
    for _ in range(n_games):
        winner = play(model_agent, random_agent)
        if winner == model_agent:
            wins += 1

    writer.add_scalar(f'wins against random / {n_games} games', wins, epoch)
    print(f'wins against random agent: {wins} / {n_games}, {wins/n_games*100:.2f}% @ epoch={epoch}')
    return wins/n_games

In [8]:
def select_action(env, side, model):
    x = ModelAgent.get_model_input(env, side).to(device)
    distribution, value = model.train().to(device)(x)
    outputs = dist.Categorical(distribution)
    action = outputs.sample()
    return outputs.log_prob(action), action.item() + 1, value

In [9]:

def train_one_game(model: nn.Module, epoch):
    game_finished = False
    env.reset()
    rewards = []
    values = []
    log_probabilities = []
    game_step = 0
    win_rate = -1
    next_player = random.choice(['north', 'south'])
    while not game_finished:
        if next_player == 'north':
            # model move
            log_prob, action, value = select_action(env, 'north', model)
            next_player, reward, done = env.step('north', action)
            rewards.append(reward)
            log_probabilities.append(log_prob)
            values.append(value)
        else:
            # opponent move
            with torch.no_grad():
                next_player, _, done = env.step('south', get_random_move('south', env))

        game_step += 1
        if done or game_step > max_game_length:
            game_finished = True

    optimizer.zero_grad()
    total_loss = calculate_loss(rewards, log_probabilities, values)
    total_loss.backward()
    optimizer.step()

    if epoch % writer_interval == 0:
        record(total_loss, env, epoch, 'north')
        win_rate = evaluation(model, epoch)

    return total_loss.detach(), win_rate

In [10]:
import time, datetime

def get_time_elapsed(start):
    return str(datetime.timedelta(seconds=int(time.time() - start)))

start_time = time.time()
for i in range(0, epochs):
    loss, win_rate = train_one_game(model, i)
    time_elapsed = int(time.time() - start_time)
    if i % print_interval == 0:
        print(f'i={i:8d} loss={loss:6f} elapsed={get_time_elapsed(start_time)}')
    if win_rate >= solved_win_rate:
        print(f'Game solved with win rate: {win_rate} >= {solved_win_rate}')
        break
print(f'Finished Training, total time take: {get_time_elapsed(start_time)}')

wins against random agent: 0 / 100, 0.00% @ epoch=0
i=       0 loss=8.066968 elapsed=0:00:01
wins against random agent: 0 / 100, 0.00% @ epoch=2000
i=    2000 loss=0.011399 elapsed=0:00:14
wins against random agent: 0 / 100, 0.00% @ epoch=4000
i=    4000 loss=0.118097 elapsed=0:00:28
wins against random agent: 0 / 100, 0.00% @ epoch=6000
i=    6000 loss=0.054873 elapsed=0:00:41
wins against random agent: 0 / 100, 0.00% @ epoch=8000
i=    8000 loss=0.272164 elapsed=0:00:56
wins against random agent: 0 / 100, 0.00% @ epoch=10000
i=   10000 loss=3.568330 elapsed=0:01:15
wins against random agent: 0 / 100, 0.00% @ epoch=12000
i=   12000 loss=0.001453 elapsed=0:01:28
wins against random agent: 0 / 100, 0.00% @ epoch=14000
i=   14000 loss=0.211788 elapsed=0:01:41
wins against random agent: 0 / 100, 0.00% @ epoch=16000
i=   16000 loss=0.012473 elapsed=0:01:57
wins against random agent: 0 / 100, 0.00% @ epoch=18000
i=   18000 loss=0.673434 elapsed=0:02:13
wins against random agent: 0 / 100, 0.

KeyboardInterrupt: 

In [None]:
torch.save(model.state_dict(), 'final_model.msd')

