In [None]:
import random

import numpy as np
import torch

from alpha_zero.MCTS import MCTS, PureMCTS
from alpha_zero.Coach import Coach
from alpha_zero.Arena import Arena
from alpha_zero.othello.OthelloGame import OthelloGame
from alpha_zero.tictactoe.TicTacToeGame import TicTacToeGame
from alpha_zero.othello.pytorch.NNet import NNetWrapper as nn
from alpha_zero.utils import dotdict

from alpha_zero.tictactoe import TicTacToePlayers
from alpha_zero.othello.OthelloPlayers import GreedyOthelloPlayer

# for auto-reloading external modules
# see http://stackoverflow.com/questions/1907993/autoreload-of-modules-in-ipython
%load_ext autoreload
%autoreload 2

# Monte Carlo Tree Search (MCTS)

Implement `select`, `simulate`, `backup`, `ucb_select` of class `PureMCTS` in `alpha_zero/MCTS.py`

The following code will check correctness of your implementation, but it might be slightly different with you implement them.

In [None]:
from pathlib import Path

random.seed(42)
np.random.seed(42)
torch.manual_seed(42)

args = dotdict({'numMCTSSims': 20, 'cpuct':1.0})

tictactoe_game = TicTacToeGame(n = 3)
board = tictactoe_game.getInitBoard()
s = tictactoe_game.stringRepresentation(board)

target = Path("./data")

def check_select(path, target_actions):
    for idx, (path, action) in enumerate(path):
        print(path, action)
        if action != target_actions[idx]:
            raise ValueError("path is wrong")
    print("select checking: pass")

def check_simulate(reward, target_reward):
    if reward != target_reward:
        raise ValueError("reward is wrong")
    print("simulate checking: pass")

def check_backup(Ns, Nsa, Qsa, target_Ns, target_Nsa, target_Qsa):
    if Ns != target_Ns:
        raise ValueError("Ns is wrong")
    if Nsa != target_Nsa:
        raise ValueError("Nsa is wrong")
    if Qsa != target_Qsa:
        raise ValueError("Qsa is wrong")
    print("backup checking: pass")

for i in range(5):
    mcts = PureMCTS(tictactoe_game, args)
    mcts.load_tree(target / f"tree_{i}.npy")

    print("=" * 20 + f"checking {i}-th tree preset " + "=" * 20)

    path, leaf = mcts.select(board)
    leaf_s = tictactoe_game.stringRepresentation(leaf)

    mcts.expand(leaf, leaf_s)

    reward = mcts.simulate(leaf, leaf_s)
    mcts.backup(path, reward)

    target_actions = np.load(target / f"tree_select_{i}.npy", allow_pickle=True)
    target_reward = np.load(target / f"tree_simulate_{i}.npy", allow_pickle=True)
    target_Ns, target_Nsa, target_Qsa = np.load(target / f"tree_backup_{i}.npy", allow_pickle=True)

    check_select(path, target_actions)
    check_simulate(reward, target_reward)
    check_backup(mcts.Ns, mcts.Nsa, mcts.Qsa, target_Ns, target_Nsa, target_Qsa)

We expect your implemented MCTS will beat random strategy with a high rate

There are two hyperparameter your can adjust:
* `numMCTSSims`: indicating that how many times rollouts the algorithm will do for picking action
* `cpuct`: parameter to balance exploration and exploitation in UCB selection. Higher cpuct results in more exploration.

In [None]:
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)

total_matches = 100

args = dotdict({'numMCTSSims': 50, 'cpuct': 1.0})
mcts = PureMCTS(tictactoe_game, args)

random_player = TicTacToePlayers.RandomPlayer(tictactoe_game).play
mcts_player = lambda x: np.argmax(mcts.getActionProb(x, temp=0))
arena = Arena(mcts_player, random_player, tictactoe_game, display=TicTacToeGame.display)

win, lose, tie = arena.playGames(total_matches, verbose=False)

print(f"vs random win: {win}, tie: {tie}, lose: {lose}")
if (win + tie) > total_matches * 0.95:
    print("Implamentation of MCTS is totally correct")
else:
    raise Exception("Implamentation of MCTS might be wrong")

# AlphaZero

Implement `simulate`, `ucb_select` of class `MCTS` in `alpha_zero/MCTS.py`
Implement `executeEpisode` of class `Coach` in `alpha_zero/Coach.py`

There are 5 more hyperparameter your can adjust:
* `numIters`: number of iteration to train nnet
* `numEps`: number of complete self-play games to simulate during a new iteration.
* `tempThreshold`: first `tempThreshold` steps in `executeEpisode` will use temp as 1
* `updateThreshold`: win rate threshould to accept the new network or not
* `maxlenOfQueue`: size of history pool

Usually there is no need to adjust these hyperparameter.

In [None]:
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)

training_args = dotdict({
    'numIters': 10,
    'numEps': 100,
    'tempThreshold': 15,
    'updateThreshold': 0.6,
    'maxlenOfQueue': 200000,
    'numMCTSSims': 25,
    'arenaCompare': 40,
    'cpuct': 1,

    'checkpoint': './othello_6/',
    'load_model': False,
    'numItersForTrainExamplesHistory': 20,
})

othello_game = OthelloGame(n = 6)

nnet = nn(othello_game)

c = Coach(othello_game, nnet, training_args)

c.learn()

We expect your implemented MCTS will beat greedy strategy with a high rate

tie with MCTS algorithm, but with higher speed

In [None]:
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)

total_matches = 40

args = dotdict({'numMCTSSims': 50, 'cpuct': 1.0})
othello_game = OthelloGame(n = 6)
nnet = nn(othello_game)

nnet.load_checkpoint("othello_6", "best.pth.tar")
alpha_zero = MCTS(othello_game, nnet, args)
mcts = PureMCTS(othello_game, args)

greed_player = GreedyOthelloPlayer(othello_game).play
mcts_player = lambda x: np.argmax(mcts.getActionProb(x, temp=0))
your_alpha_zero_player = lambda x: np.argmax(alpha_zero.getActionProb(x, temp=0))

arena = Arena(your_alpha_zero_player, greed_player, othello_game, display=OthelloGame.display)

win, lose, tie = arena.playGames(total_matches, verbose=False)
print(f"vs greed win: {win}, tie: {tie}, lose: {lose}")
if win <= total_matches * 0.8:
    raise Exception("Implamentation of alphaZero might be wrong or the hyperparameters are not good enough")

arena = Arena(your_alpha_zero_player, mcts_player, othello_game, display=OthelloGame.display)

win, lose, tie = arena.playGames(total_matches, verbose=False)
print(f"vs mcts win: {win}, tie: {tie}, lose: {lose}")
if win >= total_matches * 0.4:
    print("Implamentation of MCTS is totally correct")
else:
    raise Exception("Implamentation of alphaZero might be wrong or the hyperparameters are not good enough")