In [1]:
import random

import numpy as np
import torch

from alpha_zero.MCTS import MCTS, PureMCTS
from alpha_zero.Coach import Coach
from alpha_zero.Arena import Arena
from alpha_zero.othello.OthelloGame import OthelloGame
from alpha_zero.tictactoe.TicTacToeGame import TicTacToeGame
from alpha_zero.othello.pytorch.NNet import NNetWrapper as nn
from alpha_zero.utils import dotdict

from alpha_zero.tictactoe import TicTacToePlayers
from alpha_zero.othello.OthelloPlayers import GreedyOthelloPlayer

# for auto-reloading external modules
# see http://stackoverflow.com/questions/1907993/autoreload-of-modules-in-ipython
%load_ext autoreload
%autoreload 2

  from .autonotebook import tqdm as notebook_tqdm


# Monte Carlo Tree Search (MCTS)

Implement `getActionProb`, `select`, `simulate`, `backup`, `search`, `ucb_select` of class `PureMCTS` in `alpha_zero/MCTS.py`

The following code will check correctness of your implementation, but it might be slightly different with you implement them.

In [2]:
from pathlib import Path

random.seed(42)
np.random.seed(42)
torch.manual_seed(42)

args = dotdict({'numMCTSSims': 20, 'cpuct': 1.0})

tictactoe_game = TicTacToeGame(n = 3)
board = tictactoe_game.getInitBoard()
s = tictactoe_game.stringRepresentation(board)

target = Path("./data")

def check_select(path, target_actions):
    for idx, (path, action) in enumerate(path):
        print(path, action)
        if action != target_actions[idx]:
            raise ValueError("path is wrong")
    print("select checking: pass")

def check_simulate(reward, target_reward):
    if reward != target_reward:
        raise ValueError("reward is wrong")
    print("simulate checking: pass")

def check_backup(Ns, Nsa, Qsa, target_Ns, target_Nsa, target_Qsa):
    if Ns != target_Ns:
        raise ValueError("Ns is wrong")
    if Nsa != target_Nsa:
        raise ValueError("Nsa is wrong")
    if Qsa != target_Qsa:
        raise ValueError("Qsa is wrong")
    print("backup checking: pass")

for i in range(0, 5):
    mcts = PureMCTS(tictactoe_game, args)
    mcts.load_tree(target / f"tree_{i}.npy")

    print("=" * 20 + f" checking {i}-th tree preset " + "=" * 20)

    path, leaf = mcts.select(board)
    leaf_s = tictactoe_game.stringRepresentation(leaf)

    mcts.expand(leaf, leaf_s)

    reward = mcts.simulate(leaf, leaf_s)
    mcts.backup(path, reward)

    target_actions = np.load(target / f"tree_select_{i}.npy", allow_pickle=True)
    target_reward = np.load(target / f"tree_simulate_{i}.npy", allow_pickle=True)
    target_Ns, target_Nsa, target_Qsa = np.load(target / f"tree_backup_{i}.npy", allow_pickle=True)

    check_select(path, target_actions)
    check_simulate(reward, target_reward)
    check_backup(mcts.Ns, mcts.Nsa, mcts.Qsa, target_Ns, target_Nsa, target_Qsa)

[[0 0 0]
 [0 0 0]
 [0 0 0]] 7
[[ 0  0  0]
 [ 0  0  0]
 [ 0 -1  0]] 4
[[ 0  0  0]
 [ 0 -1  0]
 [ 0  1  0]] 1
[[ 0 -1  0]
 [ 0  1  0]
 [ 0 -1  0]] 2
select checking: pass
simulate checking: pass
backup checking: pass
[[0 0 0]
 [0 0 0]
 [0 0 0]] 1
[[ 0 -1  0]
 [ 0  0  0]
 [ 0  0  0]] 3
[[ 0  1  0]
 [-1  0  0]
 [ 0  0  0]] 2
select checking: pass
simulate checking: pass
backup checking: pass
[[0 0 0]
 [0 0 0]
 [0 0 0]] 4
[[ 0  0  0]
 [ 0 -1  0]
 [ 0  0  0]] 5
[[ 0  0  0]
 [ 0  1 -1]
 [ 0  0  0]] 7
select checking: pass
simulate checking: pass
backup checking: pass
[[0 0 0]
 [0 0 0]
 [0 0 0]] 4
[[ 0  0  0]
 [ 0 -1  0]
 [ 0  0  0]] 7
[[ 0  0  0]
 [ 0  1  0]
 [ 0 -1  0]] 1
select checking: pass
simulate checking: pass
backup checking: pass
[[0 0 0]
 [0 0 0]
 [0 0 0]] 4
[[ 0  0  0]
 [ 0 -1  0]
 [ 0  0  0]] 5
[[ 0  0  0]
 [ 0  1 -1]
 [ 0  0  0]] 7
[[ 0  0  0]
 [ 0 -1  1]
 [ 0 -1  0]] 1
select checking: pass
simulate checking: pass
backup checking: pass


We expect your implemented MCTS will beat random strategy with a high rate

There are two hyperparameter your can adjust:
* `numMCTSSims`: indicating that how many times rollouts the algorithm will do for picking action
* `cpuct`: parameter to balance exploration and exploitation in UCB selection. Higher cpuct results in more exploration.

In [3]:
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)

total_matches = 100

args = dotdict({'numMCTSSims': 50, 'cpuct': 1.0})
mcts = PureMCTS(tictactoe_game, args)

random_player = TicTacToePlayers.RandomPlayer(tictactoe_game).play
mcts_player = lambda x: np.argmax(mcts.getActionProb(x, temp=0))
arena = Arena(mcts_player, random_player, tictactoe_game, display=TicTacToeGame.display)

win, lose, tie = arena.playGames(total_matches, verbose=False)

print(f"vs random win: {win}, tie: {tie}, lose: {lose}")
if (win + tie) > total_matches * 0.95:
    print("Implamentation of MCTS is totally correct")
else:
    raise Exception("Implamentation of MCTS might be wrong")

Arena.playGames (1): 100%|██████████| 50/50 [00:00<00:00, 104.93it/s]
Arena.playGames (2): 100%|██████████| 50/50 [00:00<00:00, 96.44it/s]

vs random win: 92, tie: 6, lose: 2
Implamentation of MCTS is totally correct





# AlphaZero

Implement `simulate`, `ucb_select` of class `MCTS` in `alpha_zero/MCTS.py`
Implement `executeEpisode` of class `Coach` in `alpha_zero/Coach.py`

There are 5 more hyperparameter your can adjust:
* `numIters`: number of iteration to train nnet
* `numEps`: number of complete self-play games to simulate during a new iteration.
* `tempThreshold`: first `tempThreshold` steps in `executeEpisode` will use temp as 1
* `updateThreshold`: win rate threshould to accept the new network or not
* `maxlenOfQueue`: size of history pool

Usually there is no need to adjust these hyperparameter.

In [4]:
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)

training_args = dotdict({
    'numIters': 10,
    'numEps': 100,
    'tempThreshold': 15,
    'updateThreshold': 0.6,
    'maxlenOfQueue': 200000,
    'numMCTSSims': 25,
    'arenaCompare': 40,
    'cpuct': 1,

    'checkpoint': './othello_6/',
    'load_model': False,
    'numItersForTrainExamplesHistory': 20,
})

othello_game = OthelloGame(n = 6)

nnet = nn(othello_game)

c = Coach(othello_game, nnet, training_args)

c.learn()

Self Play: 100%|██████████| 100/100 [00:53<00:00,  1.88it/s]


Checkpoint Directory does not exist! Making directory ./othello_6/
EPOCH ::: 1


Training Net: 100%|██████████| 406/406 [00:01<00:00, 322.87it/s, Loss_pi=3.14e+00, Loss_v=9.18e-01]


EPOCH ::: 2


Training Net: 100%|██████████| 406/406 [00:01<00:00, 373.48it/s, Loss_pi=2.81e+00, Loss_v=8.18e-01]


EPOCH ::: 3


Training Net: 100%|██████████| 406/406 [00:01<00:00, 360.99it/s, Loss_pi=2.58e+00, Loss_v=7.78e-01]


EPOCH ::: 4


Training Net: 100%|██████████| 406/406 [00:01<00:00, 327.89it/s, Loss_pi=2.36e+00, Loss_v=7.35e-01]


EPOCH ::: 5


Training Net: 100%|██████████| 406/406 [00:01<00:00, 374.75it/s, Loss_pi=2.21e+00, Loss_v=6.80e-01]


EPOCH ::: 6


Training Net: 100%|██████████| 406/406 [00:01<00:00, 359.03it/s, Loss_pi=2.11e+00, Loss_v=6.40e-01]


EPOCH ::: 7


Training Net: 100%|██████████| 406/406 [00:01<00:00, 355.96it/s, Loss_pi=2.03e+00, Loss_v=5.99e-01]


EPOCH ::: 8


Training Net: 100%|██████████| 406/406 [00:01<00:00, 384.08it/s, Loss_pi=1.95e+00, Loss_v=5.65e-01]


EPOCH ::: 9


Training Net: 100%|██████████| 406/406 [00:01<00:00, 359.62it/s, Loss_pi=1.93e+00, Loss_v=5.27e-01]


EPOCH ::: 10


Training Net: 100%|██████████| 406/406 [00:01<00:00, 402.41it/s, Loss_pi=1.89e+00, Loss_v=5.06e-01]
Arena.playGames (1): 100%|██████████| 20/20 [00:11<00:00,  1.80it/s]
Arena.playGames (2): 100%|██████████| 20/20 [00:11<00:00,  1.73it/s]


Checkpoint Directory exists! 
Checkpoint Directory exists! 


Self Play: 100%|██████████| 100/100 [00:53<00:00,  1.88it/s]


Checkpoint Directory exists! 
EPOCH ::: 1


Training Net: 100%|██████████| 815/815 [00:02<00:00, 323.93it/s, Loss_pi=1.85e+00, Loss_v=6.24e-01]


EPOCH ::: 2


Training Net: 100%|██████████| 815/815 [00:02<00:00, 338.97it/s, Loss_pi=1.79e+00, Loss_v=5.78e-01]


EPOCH ::: 3


Training Net: 100%|██████████| 815/815 [00:02<00:00, 363.44it/s, Loss_pi=1.73e+00, Loss_v=5.52e-01]


EPOCH ::: 4


Training Net: 100%|██████████| 815/815 [00:02<00:00, 355.88it/s, Loss_pi=1.70e+00, Loss_v=5.30e-01]


EPOCH ::: 5


Training Net: 100%|██████████| 815/815 [00:02<00:00, 355.85it/s, Loss_pi=1.67e+00, Loss_v=5.02e-01]


EPOCH ::: 6


Training Net: 100%|██████████| 815/815 [00:02<00:00, 355.12it/s, Loss_pi=1.64e+00, Loss_v=4.88e-01]


EPOCH ::: 7


Training Net: 100%|██████████| 815/815 [00:02<00:00, 352.60it/s, Loss_pi=1.62e+00, Loss_v=4.65e-01]


EPOCH ::: 8


Training Net: 100%|██████████| 815/815 [00:02<00:00, 370.09it/s, Loss_pi=1.60e+00, Loss_v=4.50e-01]


EPOCH ::: 9


Training Net: 100%|██████████| 815/815 [00:02<00:00, 385.49it/s, Loss_pi=1.59e+00, Loss_v=4.36e-01]


EPOCH ::: 10


Training Net: 100%|██████████| 815/815 [00:02<00:00, 365.53it/s, Loss_pi=1.58e+00, Loss_v=4.29e-01]
Arena.playGames (1): 100%|██████████| 20/20 [00:11<00:00,  1.69it/s]
Arena.playGames (2): 100%|██████████| 20/20 [00:11<00:00,  1.79it/s]


Checkpoint Directory exists! 
Checkpoint Directory exists! 


Self Play: 100%|██████████| 100/100 [00:55<00:00,  1.82it/s]


Checkpoint Directory exists! 
EPOCH ::: 1


Training Net: 100%|██████████| 1224/1224 [00:03<00:00, 349.58it/s, Loss_pi=1.60e+00, Loss_v=5.15e-01]


EPOCH ::: 2


Training Net: 100%|██████████| 1224/1224 [00:03<00:00, 332.89it/s, Loss_pi=1.58e+00, Loss_v=4.94e-01]


EPOCH ::: 3


Training Net: 100%|██████████| 1224/1224 [00:03<00:00, 373.39it/s, Loss_pi=1.55e+00, Loss_v=4.76e-01]


EPOCH ::: 4


Training Net: 100%|██████████| 1224/1224 [00:03<00:00, 341.77it/s, Loss_pi=1.54e+00, Loss_v=4.62e-01]


EPOCH ::: 5


Training Net: 100%|██████████| 1224/1224 [00:03<00:00, 339.05it/s, Loss_pi=1.53e+00, Loss_v=4.55e-01]


EPOCH ::: 6


Training Net: 100%|██████████| 1224/1224 [00:03<00:00, 330.21it/s, Loss_pi=1.52e+00, Loss_v=4.40e-01]


EPOCH ::: 7


Training Net: 100%|██████████| 1224/1224 [00:03<00:00, 333.26it/s, Loss_pi=1.51e+00, Loss_v=4.27e-01]


EPOCH ::: 8


Training Net: 100%|██████████| 1224/1224 [00:03<00:00, 364.33it/s, Loss_pi=1.50e+00, Loss_v=4.23e-01]


EPOCH ::: 9


Training Net: 100%|██████████| 1224/1224 [00:03<00:00, 345.42it/s, Loss_pi=1.49e+00, Loss_v=4.11e-01]


EPOCH ::: 10


Training Net: 100%|██████████| 1224/1224 [00:03<00:00, 361.84it/s, Loss_pi=1.48e+00, Loss_v=4.01e-01]
Arena.playGames (1): 100%|██████████| 20/20 [00:11<00:00,  1.81it/s]
Arena.playGames (2): 100%|██████████| 20/20 [00:11<00:00,  1.68it/s]


Checkpoint Directory exists! 
Checkpoint Directory exists! 


Self Play: 100%|██████████| 100/100 [00:54<00:00,  1.84it/s]


Checkpoint Directory exists! 
EPOCH ::: 1


Training Net: 100%|██████████| 1633/1633 [00:05<00:00, 325.54it/s, Loss_pi=1.50e+00, Loss_v=4.91e-01]


EPOCH ::: 2


Training Net: 100%|██████████| 1633/1633 [00:04<00:00, 339.29it/s, Loss_pi=1.49e+00, Loss_v=4.71e-01]


EPOCH ::: 3


Training Net: 100%|██████████| 1633/1633 [00:04<00:00, 345.14it/s, Loss_pi=1.48e+00, Loss_v=4.58e-01]


EPOCH ::: 4


Training Net: 100%|██████████| 1633/1633 [00:04<00:00, 335.90it/s, Loss_pi=1.47e+00, Loss_v=4.44e-01]


EPOCH ::: 5


Training Net: 100%|██████████| 1633/1633 [00:04<00:00, 341.13it/s, Loss_pi=1.46e+00, Loss_v=4.39e-01]


EPOCH ::: 6


Training Net: 100%|██████████| 1633/1633 [00:04<00:00, 357.66it/s, Loss_pi=1.45e+00, Loss_v=4.30e-01]


EPOCH ::: 7


Training Net: 100%|██████████| 1633/1633 [00:05<00:00, 317.03it/s, Loss_pi=1.45e+00, Loss_v=4.18e-01]


EPOCH ::: 8


Training Net: 100%|██████████| 1633/1633 [00:04<00:00, 355.17it/s, Loss_pi=1.44e+00, Loss_v=4.17e-01]


EPOCH ::: 9


Training Net: 100%|██████████| 1633/1633 [00:04<00:00, 355.50it/s, Loss_pi=1.44e+00, Loss_v=4.08e-01]


EPOCH ::: 10


Training Net: 100%|██████████| 1633/1633 [00:04<00:00, 372.20it/s, Loss_pi=1.44e+00, Loss_v=3.96e-01]
Arena.playGames (1): 100%|██████████| 20/20 [00:11<00:00,  1.81it/s]
Arena.playGames (2): 100%|██████████| 20/20 [00:11<00:00,  1.81it/s]


Checkpoint Directory exists! 
Checkpoint Directory exists! 


Self Play: 100%|██████████| 100/100 [00:52<00:00,  1.91it/s]


Checkpoint Directory exists! 
EPOCH ::: 1


Training Net: 100%|██████████| 2041/2041 [00:06<00:00, 319.10it/s, Loss_pi=1.44e+00, Loss_v=4.59e-01]


EPOCH ::: 2


Training Net: 100%|██████████| 2041/2041 [00:06<00:00, 328.72it/s, Loss_pi=1.44e+00, Loss_v=4.51e-01]


EPOCH ::: 3


Training Net: 100%|██████████| 2041/2041 [00:06<00:00, 329.85it/s, Loss_pi=1.43e+00, Loss_v=4.40e-01]


EPOCH ::: 4


Training Net: 100%|██████████| 2041/2041 [00:06<00:00, 310.78it/s, Loss_pi=1.42e+00, Loss_v=4.28e-01]


EPOCH ::: 5


Training Net: 100%|██████████| 2041/2041 [00:06<00:00, 327.25it/s, Loss_pi=1.42e+00, Loss_v=4.22e-01]


EPOCH ::: 6


Training Net: 100%|██████████| 2041/2041 [00:06<00:00, 318.20it/s, Loss_pi=1.41e+00, Loss_v=4.14e-01]


EPOCH ::: 7


Training Net: 100%|██████████| 2041/2041 [00:06<00:00, 327.76it/s, Loss_pi=1.41e+00, Loss_v=4.11e-01]


EPOCH ::: 8


Training Net: 100%|██████████| 2041/2041 [00:06<00:00, 313.72it/s, Loss_pi=1.41e+00, Loss_v=4.01e-01]


EPOCH ::: 9


Training Net: 100%|██████████| 2041/2041 [00:06<00:00, 334.13it/s, Loss_pi=1.40e+00, Loss_v=3.96e-01]


EPOCH ::: 10


Training Net: 100%|██████████| 2041/2041 [00:05<00:00, 350.28it/s, Loss_pi=1.39e+00, Loss_v=3.91e-01]
Arena.playGames (1): 100%|██████████| 20/20 [00:11<00:00,  1.73it/s]
Arena.playGames (2): 100%|██████████| 20/20 [00:12<00:00,  1.66it/s]


Checkpoint Directory exists! 
Checkpoint Directory exists! 


Self Play: 100%|██████████| 100/100 [00:53<00:00,  1.89it/s]


Checkpoint Directory exists! 
EPOCH ::: 1


Training Net: 100%|██████████| 2448/2448 [00:06<00:00, 364.74it/s, Loss_pi=1.41e+00, Loss_v=4.54e-01]


EPOCH ::: 2


Training Net: 100%|██████████| 2448/2448 [00:07<00:00, 345.82it/s, Loss_pi=1.40e+00, Loss_v=4.43e-01]


EPOCH ::: 3


Training Net: 100%|██████████| 2448/2448 [00:07<00:00, 345.92it/s, Loss_pi=1.39e+00, Loss_v=4.37e-01]


EPOCH ::: 4


Training Net: 100%|██████████| 2448/2448 [00:07<00:00, 343.90it/s, Loss_pi=1.39e+00, Loss_v=4.31e-01]


EPOCH ::: 5


Training Net: 100%|██████████| 2448/2448 [00:06<00:00, 364.14it/s, Loss_pi=1.38e+00, Loss_v=4.25e-01]


EPOCH ::: 6


Training Net: 100%|██████████| 2448/2448 [00:06<00:00, 357.36it/s, Loss_pi=1.38e+00, Loss_v=4.22e-01]


EPOCH ::: 7


Training Net: 100%|██████████| 2448/2448 [00:07<00:00, 319.93it/s, Loss_pi=1.38e+00, Loss_v=4.15e-01]


EPOCH ::: 8


Training Net: 100%|██████████| 2448/2448 [00:07<00:00, 349.18it/s, Loss_pi=1.37e+00, Loss_v=4.08e-01]


EPOCH ::: 9


Training Net: 100%|██████████| 2448/2448 [00:07<00:00, 335.14it/s, Loss_pi=1.37e+00, Loss_v=4.06e-01]


EPOCH ::: 10


Training Net: 100%|██████████| 2448/2448 [00:06<00:00, 364.67it/s, Loss_pi=1.37e+00, Loss_v=4.02e-01]
Arena.playGames (1): 100%|██████████| 20/20 [00:11<00:00,  1.76it/s]
Arena.playGames (2): 100%|██████████| 20/20 [00:12<00:00,  1.65it/s]


Checkpoint Directory exists! 
Checkpoint Directory exists! 


Self Play: 100%|██████████| 100/100 [00:53<00:00,  1.87it/s]


Checkpoint Directory exists! 
EPOCH ::: 1


Training Net: 100%|██████████| 2858/2858 [00:08<00:00, 327.34it/s, Loss_pi=1.38e+00, Loss_v=4.57e-01]


EPOCH ::: 2


Training Net: 100%|██████████| 2858/2858 [00:08<00:00, 326.02it/s, Loss_pi=1.37e+00, Loss_v=4.48e-01]


EPOCH ::: 3


Training Net: 100%|██████████| 2858/2858 [00:08<00:00, 332.48it/s, Loss_pi=1.37e+00, Loss_v=4.41e-01]


EPOCH ::: 4


Training Net: 100%|██████████| 2858/2858 [00:08<00:00, 336.16it/s, Loss_pi=1.36e+00, Loss_v=4.37e-01]


EPOCH ::: 5


Training Net: 100%|██████████| 2858/2858 [00:08<00:00, 340.43it/s, Loss_pi=1.36e+00, Loss_v=4.30e-01]


EPOCH ::: 6


Training Net: 100%|██████████| 2858/2858 [00:08<00:00, 341.58it/s, Loss_pi=1.36e+00, Loss_v=4.22e-01]


EPOCH ::: 7


Training Net: 100%|██████████| 2858/2858 [00:08<00:00, 343.16it/s, Loss_pi=1.36e+00, Loss_v=4.22e-01]


EPOCH ::: 8


Training Net: 100%|██████████| 2858/2858 [00:08<00:00, 339.80it/s, Loss_pi=1.35e+00, Loss_v=4.19e-01]


EPOCH ::: 9


Training Net: 100%|██████████| 2858/2858 [00:08<00:00, 342.31it/s, Loss_pi=1.35e+00, Loss_v=4.13e-01]


EPOCH ::: 10


Training Net: 100%|██████████| 2858/2858 [00:08<00:00, 332.88it/s, Loss_pi=1.35e+00, Loss_v=4.12e-01]
Arena.playGames (1): 100%|██████████| 20/20 [00:11<00:00,  1.77it/s]
Arena.playGames (2): 100%|██████████| 20/20 [00:11<00:00,  1.71it/s]
Self Play: 100%|██████████| 100/100 [00:52<00:00,  1.89it/s]


Checkpoint Directory exists! 
EPOCH ::: 1


Training Net: 100%|██████████| 3264/3264 [00:10<00:00, 323.29it/s, Loss_pi=1.39e+00, Loss_v=4.83e-01]


EPOCH ::: 2


Training Net: 100%|██████████| 3264/3264 [00:09<00:00, 334.23it/s, Loss_pi=1.37e+00, Loss_v=4.74e-01]


EPOCH ::: 3


Training Net: 100%|██████████| 3264/3264 [00:10<00:00, 321.25it/s, Loss_pi=1.37e+00, Loss_v=4.65e-01]


EPOCH ::: 4


Training Net: 100%|██████████| 3264/3264 [00:09<00:00, 335.23it/s, Loss_pi=1.36e+00, Loss_v=4.61e-01]


EPOCH ::: 5


Training Net: 100%|██████████| 3264/3264 [00:09<00:00, 327.65it/s, Loss_pi=1.36e+00, Loss_v=4.51e-01]


EPOCH ::: 6


Training Net: 100%|██████████| 3264/3264 [00:09<00:00, 359.58it/s, Loss_pi=1.36e+00, Loss_v=4.48e-01]


EPOCH ::: 7


Training Net: 100%|██████████| 3264/3264 [00:09<00:00, 342.48it/s, Loss_pi=1.36e+00, Loss_v=4.42e-01]


EPOCH ::: 8


Training Net: 100%|██████████| 3264/3264 [00:09<00:00, 337.80it/s, Loss_pi=1.35e+00, Loss_v=4.38e-01]


EPOCH ::: 9


Training Net: 100%|██████████| 3264/3264 [00:09<00:00, 345.02it/s, Loss_pi=1.35e+00, Loss_v=4.34e-01]


EPOCH ::: 10


Training Net: 100%|██████████| 3264/3264 [00:09<00:00, 340.32it/s, Loss_pi=1.35e+00, Loss_v=4.29e-01]
Arena.playGames (1): 100%|██████████| 20/20 [00:11<00:00,  1.80it/s]
Arena.playGames (2): 100%|██████████| 20/20 [00:11<00:00,  1.69it/s]
Self Play: 100%|██████████| 100/100 [00:53<00:00,  1.85it/s]


Checkpoint Directory exists! 
EPOCH ::: 1


Training Net: 100%|██████████| 3674/3674 [00:10<00:00, 336.56it/s, Loss_pi=1.39e+00, Loss_v=5.14e-01]


EPOCH ::: 2


Training Net: 100%|██████████| 3674/3674 [00:11<00:00, 323.81it/s, Loss_pi=1.38e+00, Loss_v=5.02e-01]


EPOCH ::: 3


Training Net: 100%|██████████| 3674/3674 [00:11<00:00, 332.41it/s, Loss_pi=1.37e+00, Loss_v=4.95e-01]


EPOCH ::: 4


Training Net: 100%|██████████| 3674/3674 [00:11<00:00, 330.12it/s, Loss_pi=1.37e+00, Loss_v=4.85e-01]


EPOCH ::: 5


Training Net: 100%|██████████| 3674/3674 [00:10<00:00, 338.31it/s, Loss_pi=1.36e+00, Loss_v=4.76e-01]


EPOCH ::: 6


Training Net: 100%|██████████| 3674/3674 [00:11<00:00, 333.59it/s, Loss_pi=1.36e+00, Loss_v=4.68e-01]


EPOCH ::: 7


Training Net: 100%|██████████| 3674/3674 [00:11<00:00, 326.63it/s, Loss_pi=1.35e+00, Loss_v=4.67e-01]


EPOCH ::: 8


Training Net: 100%|██████████| 3674/3674 [00:11<00:00, 307.47it/s, Loss_pi=1.35e+00, Loss_v=4.62e-01]


EPOCH ::: 9


Training Net: 100%|██████████| 3674/3674 [00:11<00:00, 324.78it/s, Loss_pi=1.35e+00, Loss_v=4.56e-01]


EPOCH ::: 10


Training Net: 100%|██████████| 3674/3674 [00:11<00:00, 327.88it/s, Loss_pi=1.34e+00, Loss_v=4.54e-01]
Arena.playGames (1): 100%|██████████| 20/20 [00:11<00:00,  1.80it/s]
Arena.playGames (2): 100%|██████████| 20/20 [00:11<00:00,  1.74it/s]


Checkpoint Directory exists! 
Checkpoint Directory exists! 


Self Play: 100%|██████████| 100/100 [00:53<00:00,  1.87it/s]


Checkpoint Directory exists! 
EPOCH ::: 1


Training Net: 100%|██████████| 4084/4084 [00:11<00:00, 344.72it/s, Loss_pi=1.35e+00, Loss_v=4.73e-01]


EPOCH ::: 2


Training Net: 100%|██████████| 4084/4084 [00:11<00:00, 358.47it/s, Loss_pi=1.34e+00, Loss_v=4.69e-01]


EPOCH ::: 3


Training Net: 100%|██████████| 4084/4084 [00:11<00:00, 347.28it/s, Loss_pi=1.34e+00, Loss_v=4.64e-01]


EPOCH ::: 4


Training Net: 100%|██████████| 4084/4084 [00:11<00:00, 367.08it/s, Loss_pi=1.34e+00, Loss_v=4.60e-01]


EPOCH ::: 5


Training Net: 100%|██████████| 4084/4084 [00:11<00:00, 348.43it/s, Loss_pi=1.33e+00, Loss_v=4.55e-01]


EPOCH ::: 6


Training Net: 100%|██████████| 4084/4084 [00:11<00:00, 370.42it/s, Loss_pi=1.33e+00, Loss_v=4.51e-01]


EPOCH ::: 7


Training Net: 100%|██████████| 4084/4084 [00:11<00:00, 362.68it/s, Loss_pi=1.32e+00, Loss_v=4.45e-01]


EPOCH ::: 8


Training Net: 100%|██████████| 4084/4084 [00:11<00:00, 364.13it/s, Loss_pi=1.32e+00, Loss_v=4.43e-01]


EPOCH ::: 9


Training Net: 100%|██████████| 4084/4084 [00:10<00:00, 372.24it/s, Loss_pi=1.32e+00, Loss_v=4.40e-01]


EPOCH ::: 10


Training Net: 100%|██████████| 4084/4084 [00:11<00:00, 356.94it/s, Loss_pi=1.32e+00, Loss_v=4.35e-01]
Arena.playGames (1): 100%|██████████| 20/20 [00:11<00:00,  1.79it/s]
Arena.playGames (2): 100%|██████████| 20/20 [00:11<00:00,  1.81it/s]


We expect your implemented MCTS will beat greedy strategy with a high rate

tie with MCTS algorithm, but with higher speed

In [5]:
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)

total_matches = 40

args = dotdict({'numMCTSSims': 50, 'cpuct': 1.0})
othello_game = OthelloGame(n = 6)
nnet = nn(othello_game)

nnet.load_checkpoint("othello_6", "best.pth.tar")
alpha_zero = MCTS(othello_game, nnet, args)
mcts = PureMCTS(othello_game, args)

greed_player = GreedyOthelloPlayer(othello_game).play
mcts_player = lambda x: np.argmax(mcts.getActionProb(x, temp=0))
your_alpha_zero_player = lambda x: np.argmax(alpha_zero.getActionProb(x, temp=0))

arena = Arena(your_alpha_zero_player, greed_player, othello_game, display=OthelloGame.display)

win, lose, tie = arena.playGames(total_matches, verbose=False)
print(f"vs greed win: {win}, tie: {tie}, lose: {lose}")
if win <= total_matches * 0.8:
    raise Exception("Implamentation of alphaZero might be wrong or the hyperparameters are not good enough")

arena = Arena(your_alpha_zero_player, mcts_player, othello_game, display=OthelloGame.display)

win, lose, tie = arena.playGames(total_matches, verbose=False)
print(f"vs mcts win: {win}, tie: {tie}, lose: {lose}")
if win >= total_matches * 0.4:
    print("Implamentation of MCTS is totally correct")
else:
    raise Exception("Implamentation of alphaZero might be wrong or the hyperparameters are not good enough")

Arena.playGames (1): 100%|██████████| 20/20 [00:12<00:00,  1.64it/s]
Arena.playGames (2): 100%|██████████| 20/20 [00:12<00:00,  1.66it/s]


vs greed win: 39, tie: 0, lose: 1


Arena.playGames (1): 100%|██████████| 20/20 [01:00<00:00,  3.04s/it]
Arena.playGames (2): 100%|██████████| 20/20 [01:02<00:00,  3.14s/it]

vs mcts win: 39, tie: 0, lose: 1
Implamentation of MCTS is totally correct



