In [1]:
from bokeh.io import push_notebook, output_notebook
from bokeh.plotting import figure, show, output_file
from bokeh.models.tickers import FixedTicker
from pickle import Pickler, Unpickler
from IPython.display import clear_output
from IPython.core.debugger import Pdb
import numpy as np
import time

output_notebook()

# Game Environment

## Game Base Class

In [2]:
def pos_check(bottom_left, top_right, input_pos):
    """
    check input position is within the game board.
    
    bottom_left : board's bottom left coordinates.
    top_right   : board's top right coordinates.
    input_pos   : input position.
    """
    pos = np.array(input_pos)
    bl = np.array(bottom_left)
    tr = np.array(top_right)
    check = (pos.size == 2) * (pos.dtype == np.int) * \
            (bl <= pos).all() * (pos <= tr).all()
    return check

class BoardGame():
    """
    General board game base class.
    
    Default implementation in this base class is assuming a game like
    Tic-tac-toe which:
    1. only have 2 player and 2 kind of stone.
    2. placed stone will never change during game.
    3. player can move stone to any position where is empty.
    4. game have mirror, rotational symmetries
    you can overload functions in child class for different game.
    
    NOTICE: Here are two different format in representing board or stone 
    locations, one is called 'pos(positions)' that (0,0) is left bottom and 
    first axis is x, second axis is y. this format is used for drawing and 
    user input. another one is called 'indices' that (0,0) is left top 
    and first axis is y, second axis is x. this format is used for array slice
    and mcts. 
    """
    def __init__(self, bsize, dsize=(600,600), use_gui=False):
        """
        Inputs:
        bsize: board size as (column, row)
        dsize: draw size as (width px, height px)
        use_gui: whether use BoardGUI to draw graphical game board.
        """
        self.bsize = bsize
        self.mode = None  #'PvP', 'PvE', 'EvE'
        
        # curPlayer: current player. 1 is player 1 and -1 is player 2.
        # NOTICE: palyer 1 is default first move.
        self.curPlayer = 1
        
        self.playHistory = []
        
        if use_gui == True:
            self.gui = BoardGUI(bsize, dsize)
        else:
            self.gui = None
        self.board = np.zeros(shape=bsize, dtype=np.int)
        
        # fill follow attribute in your own game class.
        # player settings 
        self.p1 = {'name':'1', 'char':'b', 'color':'#000000'} 
        self.p2 = {'name':'2', 'char':'w', 'color':'#000000'}
        
        # initialize variable for MCTS
        self._mcts_attr_init()
        
    def draw_board(self):
        if self.gui != None:
            self.gui.draw()
            
    def pvp(self):
        """
        Human player vs Human player.
        """
        self.mode = "PvP"
        while True:
            judge = self.judge_winner()
            if judge != None:
                break
            updated = False
            while updated == False:
                time.sleep(0.25)  # wait for clear_output
                if self.curPlayer == 1:
                    pos = input("Player "+self.p1['name']+": ")
                elif self.curPlayer == -1:
                    pos = input("Player "+self.p2['name']+": ")
                clear_output()
                try:
                    pos = tuple(int(x) for x in pos.split(','))
                    if len(pos) != 2:
                        raise ValueError
                except ValueError:
                    print("Invalid input. Please input position like 1,1\n")
                    if self.gui == None:
                        print(self.board, '\n')
                    continue
                updated = self.update_move(self.curPlayer, pos)
                if updated == False:
                    print("You can not move a piece to "+ str(pos) + "\n")
                if self.gui == None:
                    print(self.board, '\n')
            self.curPlayer *= -1
                
        if judge == 1:
            print("Player "+self.p1['name']+" win!")
        if judge == -1:
            print("Player "+self.p2['name']+" win!")
        if judge == 0:
            print("Draw!")

    def pve(self, pc1):
        """
        Player vs Environment.
        
        NOTICE: pc player must have a 'play' fuction that accept (board, player)
         and return move indices to use this pve fuction.
        """
        self.mode = "PvE"
        
        # decide pc1 be either player 1 or player 2
        self._pc1 = np.random.choice([1, -1])
        
        while True:
            judge = self.judge_winner()
            if judge != None:
                break
            updated = False
            while updated == False:
                time.sleep(0.25)  # wait for clear_output
                if self.curPlayer == self._pc1:
                    indices = pc1.play(self.board, self._pc1)
                    pos = self._indices2pos(indices)
                else:
                    if self._pc1 == 1: # whether human player is player 2
                        pos = input("Player "+self.p2['name']+": ")
                    else:
                        pos = input("Player "+self.p1['name']+": ")
                clear_output()
                
                if self.curPlayer == self._pc1:
                    updated = self.update_move(self.curPlayer, pos)
                    if updated == False:
                        raise ValueError("Invaild PC input.")
                else:
                    try:
                        pos = tuple(int(x) for x in pos.split(','))
                        if len(pos) != 2:
                            raise ValueError
                    except ValueError:
                        print("Invalid input. Please input position like 1,1\n")
                        if self.gui == None:
                            print(self.board, '\n')
                        continue
                        
                    updated = self.update_move(self.curPlayer, pos)
                    if updated == False:
                        print("You can not move a piece to "+ str(pos) + "\n")
                    if self.gui == None:
                        print(self.board, '\n')
            self.curPlayer *= -1
                
        if judge == 1:
            print("Player "+self.p1['name']+" win!")
        if judge == -1:
            print("Player "+self.p2['name']+" win!")
        if judge == 0:
            print("Draw!")
    
    def eve(self, pc1, pc2):
        """
        Environment vs Environment.
        
        NOTICE: pc player must have a 'play' fuction that accept (board, player)
         and return move indices to use this eve fuction.
        """
        self.mode = "EvE"
        
        # decide pc1 be either player 1 or player 2
        self._pc1 = np.random.choice([1, -1])
        self._pc2 = self._pc1 * -1
        while True:
            judge = self.judge_winner()
            if judge != None:
                break
            updated = False
            while updated == False:
                if self.curPlayer == self._pc1:
                    indices = pc1.play(self.board, self._pc1)
                    pos = self._indices2pos(indices)
                else:
                    indices = pc2.play(self.board, self._pc2)
                    pos = self._indices2pos(indices)
                    
                updated = self.update_move(self.curPlayer, pos)
                if updated == False:
                    raise ValueError("Invaild PC input.")
            self.curPlayer *= -1
                
        return self._pc1, judge
    
    def reset_game(self):
        self.curPlayer = 1
        self.playHistory = []
        self.board = np.zeros(shape=self.bsize, dtype=np.int)
        if self.gui != None:
            self.gui.reset()
    
    def update_move(self, player, pos):
        """
        Default implementation for updating game player's move.
        
        Inputs:
         player   : -1 for Player1, and 1 for Player2
         pos      : move position as (column, row)
         
        Returns:
         True : Update succeed.
         False: Update failed (invalid move).
        """
        # NOTICE: bsize is in (y,x) format. bsize[::-1] convert it to (x,y)
        if pos_check((1, 1), self.bsize[::-1], pos):
            # here we convert pos to indices. 
            if self._update_board(self.board, player, self._pos2indices(pos)):
                self.playHistory.append((len(self.playHistory)+1, player,
                                         (pos[0],pos[1])))
                if self.gui != None:
                    if player == 1:
                        self.gui.update((pos[0], pos[1]), self.p1['char'], self.p1['color'])
                    if player == -1:
                        self.gui.update((pos[0], pos[1]), self.p2['char'], self.p2['color'])
                return True
        return False

    def judge_winner(self, external=None):
        """
        Inputs:
         external: external input for MCTS.
        
        Returns:
         -1   : Player 1 win.
          1   : Player 2 win.
          0   : Draw.
          None: Game not end.
        """
        raise NotImplementedError("judge_winner")
    
    def _update_board(self, board, player, indices):
        """
        Default implementation for updating game board array from given
        player and board array indices. 
        """
        if board[indices[0], indices[1]] == 0:
            board[indices[0], indices[1]] = player
            return True
        else:
            return False

    def _pos2indices(self, pos):
        return (-pos[1]+self.bsize[0], pos[0]-1)
    
    def _indices2pos(self, indices):
        return (indices[1]+1, -(indices[0]-self.bsize[0]))

    """
    NOTICE: following fuctions are made for MCTS, which do not use self's
    board and must input one that have same size with self's board.
    """
    def get_legal_moves(self, board):
        """
        Returns all the legal moves in indices format.
        """
        if board.shape != self.bsize:
            raise ValueError("invaild input")
            
        ys, xs = np.where(board == 0)
        moves = [(y,x) for y,x in zip(ys, xs)]
        return moves
    
    def get_legal_actions(self, board):
        """
        Returns all the legal moves in indices format.
        """
        legal_action = np.zeros(shape=self.mctsActionSize, dtype=np.int)
        legal_moves = self.get_legal_moves(board)
        for i in [self.mcts_indices2action(m) for m in legal_moves]:
            legal_action[i]=1
        return legal_action
    
    def get_symmetries(self, board, pi):
        if board.shape != self.bsize or pi.size != self.board.size:
            raise ValueError("invaild input")
            
        pi_board = np.reshape(pi, (self.bsize[0], self.bsize[1]))
        syms = []

        for i in range(1, 5):
            for j in [True, False]:
                new_board = np.rot90(board, i)
                new_pi = np.rot90(pi_board, i)
                if j:
                    new_board = np.fliplr(new_board)
                    new_pi = np.fliplr(new_pi)
                syms += [(new_board, new_pi.ravel())]
        return syms
        
    def get_next_state(self, board, player, action):
        next_board = np.copy(board)
        # convert action number to board indices.
        move_indices = self.mcts_action2indices(action)
        if self._update_board(next_board, player, move_indices):
            return (next_board, -player)
        else:
            raise ValueError("invaild input")
            
    def get_canonical_form(self, board, player):
        return board * player
    
    def get_game_ended(self, board, player):
        """
        judge the input board and give reward.
        """
        judge = self.judge_winner(board)
        if judge == None:      # game not end
            return None
        elif judge == 0:       # if draw
            return 0           # draw have 0 reward
        elif judge == player:  # if player win
            return 1           # win have 1 reward
        elif judge == -player: # if player lose
            return -1          # lose have -1 reward
        
    """
    NOTICE: following fuctions are used in MCTS.
    """
    def mcts_board2string(self, board):
        return board.tostring()
    
    def mcts_action2pos(self, action):
        indices = self.mcts_action2indices(action)
        return self._indices2pos(indices)
        
    def mcts_action2indices(self, action):
        return (action // self.bsize[0], action % self.bsize[0])
    
    def mcts_indices2action(self, indices):
        return indices[1]+self.bsize[0]*indices[0]
    
    def _mcts_attr_init(self):
        self.mctsActionSize = self.board.size

class BoardGUI():
    """
    General board game gui by bokeh.
    """
    def __init__(self, bsize, dsize):
        """
        bsize: board size. (column, row)
        dsize: draw size. (width px, height px)
        """
        self.bsize = np.array(bsize)
        self.dsize = np.array(dsize)
        self.fig = figure()
        self.bg_color = None
        
    def draw(self):
        """
        Draw board figure.
        """
        self.fig = figure(plot_width=self.dsize[0],
                          plot_height=self.dsize[1],
                          x_range=(0.5, self.bsize[0] + 0.51),
                          y_range=(0.5, self.bsize[1] + 0.51))
        if self.bg_color:
            self.fig.background_fill_color = color
        self.fig.xaxis.ticker = np.arange(1, self.bsize[0]+1)
        self.fig.yaxis.ticker = np.arange(1, self.bsize[1]+1)
        self.fig.xgrid.ticker = FixedTicker(ticks=np.arange(0.5,self.bsize[0]+1.5))
        self.fig.ygrid.ticker = FixedTicker(ticks=np.arange(0.5, self.bsize[1]+1.5))
        self.fig.xgrid.grid_line_color = "#000000"
        self.fig.ygrid.grid_line_color = "#000000"
        
        self.fig_handle = show(self.fig, notebook_handle=True)
    
    def background(self, color):
        self.bg_color = color
    
    def reset(self):
        self.fig_handle.doc.clear()
        push_notebook(self.fig_handle)
    
    def update(self, pos, char, color):
        """
        pos: position to update.
        char: 'x', 'o', 'w', 'b'
        color: color code, like #000000
        """
        if pos_check([1,1], self.bsize[::-1], pos):
            x = pos[0]
            y = pos[1]
            s = np.min(self.dsize/self.bsize) * 0.8
                
            if char == 'x':
                self.fig.cross(x, y, size=s,
                               angle=np.pi/4, line_width=s/8,
                               color=color)
                push_notebook(self.fig_handle)
                return True
            elif char == 'o':
                self.fig.circle(x, y, size=s*0.8,
                                line_width=s/8,
                                color=color,
                                fill_color="white")
                push_notebook(self.fig_handle)
                return True
            elif char == 'w':
                self.fig.circle(x, y, size=s*0.8,
                                line_width=s/32,
                                color=color,
                                fill_color="white")
                push_notebook(self.fig_handle)
                return True
            elif char == 'b':
                self.fig.circle(x, y, size=s*0.8,
                                line_width=s/32,
                                color=color)
                push_notebook(self.fig_handle)
                return True
                
        raise ValueError("BoardGUI: Wrong position/char input.")

## Tic-tac-toe Game

In [3]:
import random

class TicTacToeGame(BoardGame):
    """
    Tic-tac-toe game environment class.
    
    Board data: 1=X, -1=O, 0=Empty. 
    """
    def __init__(self, use_gui=False):
        super().__init__((3, 3), use_gui=use_gui)
        self.p1 = {'name':'X', 'char':'x', 'color':'#6A5ACD'} 
        self.p2 = {'name':'O', 'char':'o', 'color':'#FFDEAD'}
        
    def judge_winner(self, external=None):
        """
        Inputs:
          external: external board input.
        """
        if type(external) == type(None):
            board = self.board
        else:
            board = external
            if board.shape != self.bsize:
                raise ValueError("invaild external input")
                
        # check x-strips
        sx = np.sum(board, axis=1, dtype=np.int)
        # check y-strips
        sy = np.sum(board, axis=0, dtype=np.int)
        # check two diagonal strips (down-right and down-left direction)
        sr = np.sum(np.diag(board), dtype=np.int)
        sl = np.sum(np.diag(np.flip(board, 1)),dtype=np.int)
        
        s = np.append(sx, sy)
        s = np.append(s, [sr, sl])
        if -3 in s:
            return -1
        if 3 in s:
            return 1
        if 0 not in board:
            return 0
        return None

class TicTacToeSolver():
    """
    A tic-tac-toe slover.
    """
    def play(self, board, player):
        if board[1][1] == 0:
            return (1, 1)
        
        if (abs(board[0][0] + board[1][1] + board[2][2]) == 2):
            for i in range(3):
                if board[i][i] == 0:
                    return (i,i)
                
        if (abs(board[0][2] + board[1][1] + board[2][0]) == 2):
            for i in range(3):
                if board[i][2-i] == 0:
                    return (i,2-i)
        
        for i in range(3):
            t1score = 0
            t2score = 0
            for j in range(3):
                t1score += board[i][j]
                t2score += board[j][i]
                if t1score == 2*player:
                    for t in range(3):
                        if board[i][t] == 0:
                            return (i,t)
                elif t2score == 2*player:
                    for t in range(3):
                        if board[t][i] == 0:
                            return (t, i)
                        
        for i in range(3):
            t1score = 0
            t2score = 0
            for j in range(3):
                t1score += board[i][j]
                t2score += board[j][i]
                if abs(t1score) == 2:
                    for t in range(3):
                        if board[i][t] == 0:
                            return (i,t)
                elif abs(t2score) == 2:
                    for t in range(3):
                        if board[t][i] == 0:
                            return (t, i)
        
        for i in range(2):
            for j in range(2):
                square_area = [board[i][j], board[i+1][j], 
                               board[i][j+1], board[i+1][j+1]]
                square_area_dict = {
                    'left_up':((i, j), board[i][j]), 
                    'right_up':((i+1, j), board[i+1][j]), 
                    'left_down':((i, j+1), board[i][j+1]), 
                    'right_down':((i+1, j+1), board[i+1][j+1])
                }
                if square_area.count(0) == 1 and abs(sum(square_area)) == 1:
                    for k in square_area_dict:
                        if square_area_dict[k][1] == 0:
                            return square_area_dict[k][0]
        
        coners_dict = {
            'left_up':((0, 0), board[0][0]), 
            'right_up':((2, 0), board[2][0]), 
            'left_down':((0, 2), board[0][2]), 
            'right_down':((2, 2), board[2][2])
        }
        coners = [coners_dict[i][1] for i in coners_dict]
        
        if sum(coners) == 0 and coners.count(0) == 2:
            return random.choice([coners_dict[i][0] 
                                  for i in coners_dict if coners_dict[i][1]==0])
        
        if board[0][0] == board[0][2] == board[2][0] == board[2][2] == 0:
            return random.choice([(0, 0), (0, 2), (2, 0), (2, 2)])
                        
        if board[0][1] == board[1][0] == board[1][2] == board[2][1] == 0:
            return random.choice([(0, 1), (1, 0), (1, 2), (2, 1)])
        
        time.sleep(1.0)
        return self.random_pos(board)
    
    def random_pos(self, board):
        tmpList = []
        for i in range(board.shape[0]):
            for j in range(board.shape[1]):
                if board[i][j] == 0:
                    tmpList.append((i,j))
        tmpPos = random.choice(tmpList)
        return tmpPos

# AlphaZero NNet

In [4]:
import sys
import os

import keras.backend as K
from keras.layers import (Add, Input, Reshape, Conv2D,
                          BatchNormalization, Activation, 
                          Flatten, Dense)
from keras.models import Model
from keras.optimizers import Adam, Adagrad
from keras.utils import plot_model

Using TensorFlow backend.


##  NNet Base Class

In [5]:
class BoardNNet():
    """
    General neural network base class.
    """
    def __init__(self, config, save_summary=False):
        self.config = config
        self.label = config['nnetLabel']
        
        self._build_nnet()
        if not os.path.exists('nnet_save'):
            os.mkdir('nnet_save')
        if save_summary:
            self._save_summary()
    
    def _build_nnet(self):
        raise NotImplementedError("build_nnet")
    
    def _save_summary(self, filename='model_summary'):
        fpath = 'nnet_save/' + self.label + '_summary'
        f = open(fpath+'.txt', 'w')
        tmp = sys.stdout
        sys.stdout = f
        self.model.summary(line_length=140)
        f.close()
        sys.stdout = tmp
        plot_model(self.model, show_shapes=True, 
                   to_file=fpath+'.png')
        
    def predict(self, boards):
        if boards.ndim == 2:
            boards = boards[np.newaxis, :, :]
        pis, vs = self.model.predict(boards)
        return pis, vs
    
    def save_checkpoint(self, filename='checkpoint.h5'):
        filepath = os.path.join('nnet_save', filename)
        self.model.save_weights(filepath)
        
    def load_checkpoint(self, filename='checkpoint.h5'):
        filepath = os.path.join('nnet_save', filename)
        self.model.load_weights(filepath)
        
    def train(self, examples, iters):
        """
        Inputs:
        examples: training data in the form of [(boards, pis, vs), ...]
        """
        #Pdb().set_trace()
        config = self.config
        input_boards, target_pis, target_vs = [], [], []
        for i in examples:
            input_boards.append(i[0])
            target_pis.append(i[1])
            target_vs.append(i[2])
        input_boards = np.asarray(input_boards)
        target_pis = np.asarray(target_pis)
        target_vs = np.asarray(target_vs)
        fit = self.model.fit(x = input_boards,
                             y = [target_pis, target_vs],
                             batch_size=config["batchSize"],
                             epochs=config["epochs"],
                             verbose=False)
        
        print("Fitting result at iters[%04d]." % iters)
        for i,j in zip(fit.history['pi_loss'], fit.history['v_loss']):
            print('pi_loss: %1.4f' % i, '  ' , 'v_loss: %1.4f' % j)
        print('')
        
        if not os.path.exists('nnet_save/%s/' % self.label):
            os.mkdir('nnet_save/%s' % self.label)
        
        with open('nnet_save/%s/train_history_iter%04d.pickle' % \
                  (self.label, iters), "wb") as f:
            Pickler(f).dump(fit.history)
        f.closed
        
        
    
class BoardAI():
    """
    A class that use input nnet to play input game.
    """
    def __init__(self, game, nnet, config):
        self.game = game
        self.nnet = nnet
        self.mcts = MCTS(game, nnet, config)
    
    def play(self, board, player):
        canonicalBoard = self.game.get_canonical_form(board, player)
        action = np.argmax(self.mcts.calc_action_prob(canonicalBoard, temp=0))
        indices = self.game.mcts_action2indices(action)
        return indices

## Tic-tac-toe NNet

In [6]:
class TicTacToeNNetCNN(BoardNNet):
    def _build_nnet(self):
        config = self.config
        
        with K.name_scope(self.label):
            # Input layer
            self.input_boards = Input(shape=(3,3))
            x = Reshape((3, 3, 1))(self.input_boards)
            
            # Convolutional layer
            conv = [None] * config['numConvLayers']
            for i in range(config['numConvLayers']):
                if i == 0:
                    conv[i] = Conv2D(config['numConvChannels'], 3,
                                     padding='same')(x)
                elif i == config['numConvLayers']-1:
                    conv[i] = Conv2D(config['numConvChannels'], 3,
                                     padding='valid')(conv[i-1])
                else:
                    conv[i] = Conv2D(config['numConvChannels'], 3,
                                     padding='same')(conv[i-1])
                conv[i] = BatchNormalization(axis=3)(conv[i])
                conv[i] = Activation('relu')(conv[i])
            conv_out = Flatten()(conv[-1])
            
            # Fully-connected layer
            fc1 = Dense(256)(conv_out)
            fc1 = BatchNormalization(axis=1)(fc1)
            fc1 = Activation('relu')(fc1)
        
            fc2 = Dense(128)(fc1)
            fc2 = BatchNormalization(axis=1)(fc2)
            fc2 = Activation('relu')(fc2)
        
            # Output layer
            self.pi = Dense(9, activation='softmax', name='pi')(fc2)
            self.v = Dense(1, activation='tanh', name='v')(fc2)
        
            # Model Compile
            self.model = Model(inputs=self.input_boards,
                               outputs=[self.pi, self.v])
            self.model.compile(loss=['categorical_crossentropy',
                                     'mean_squared_error'],
                               optimizer=Adam(config['lr']))
            
            # NOTICE: This fuction reset the counter that keras use 
            # it indexing layers and making unique-name for layers. 
            # since we use different name scope for each network, 
            # we do not want keras using continued layer index between
            # different networks, so we do a reset.
            K.reset_uids()
        
class TicTacToeNNetFC(BoardNNet):
    """
    TicTacToe neural network, fully-connected network only version.
    """
    def _build_nnet(self):
        config = self.config
        
        with K.name_scope(self.label):
            # Input layer
            self.input_boards = Input(shape=(3,3))
            x = Flatten()(self.input_boards)
        
            # Fully-connected layer
            fc1 = Dense(256)(x)
            fc1 = BatchNormalization(axis=1)(fc1)
            fc1 = Activation('relu')(fc1)
            
            fc2 = Dense(128)(fc1)
            fc2 = BatchNormalization(axis=1)(fc2)
            fc2 = Activation('relu')(fc2)
            
            # Output layer
            self.pi = Dense(9, activation='softmax', name='pi')(fc2)
            self.v = Dense(1, activation='tanh', name='v')(fc2)
        
            # Model Compile
            self.model = Model(inputs=self.input_boards,
                               outputs=[self.pi, self.v])
            self.model.compile(loss=['categorical_crossentropy',
                                     'mean_squared_error'],
                               optimizer=Adam(config['lr']))
            K.reset_uids()

# MCTS

In [7]:
class MCTS():
    """
    Class of the MCTS(Monte Carlo tree search).
    """
    def __init__(self, game, nnet, config, debug=False):
        self.game = game
        self.nnet = nnet
        self.config = config
        self.debug = debug
        
        self.Qsa = {}     # action value in edeg (s,a)
        self.Nsa = {}     # visit count in edge (s,a)
        
        self.Ps = {}      # prior probability in board s
        self.Ns = {}      # times board s was visited
        self.Es = {}      # judged end for board s
        self.Vs = {}      # legal moves for board s
        
        self.EPS = 1e-8   # a small vaule (to avoid UCB=0)

    def calc_action_prob(self, canonicalBoard, temp=1):
        """
        This function performs numMCTSSims simulations of MCTS starting from
        canonicalBoard.

        Returns:
          probs: a policy vector where the probability of the ith action is
                 proportional to Nsa(s,a)**(1/temp)
        """
        if self.debug:
            Pdb().set_trace()
            
        for i in range(self.config['numMCTSSims']):
            self.search(canonicalBoard)

        s = self.game.mcts_board2string(canonicalBoard)
        counts = [self.Nsa[(s,a)] if (s,a) in self.Nsa else 0
                  for a in range(self.game.mctsActionSize)]

        if temp == 0: # if temp parameter is 0, return a one-hot like π.
            bestAction = np.argmax(counts)
            probs = [0] * len(counts)
            probs[bestAction] = 1
            return np.asarray(probs)
        
        else:   # else, return a π that action_prob is proportional to Nsa(s,a)
            counts = [x ** (1./temp) for x in counts]
            probs = [x / float(sum(counts)) for x in counts]
            return np.asarray(probs)
    
    def search(self, canonicalBoard):
        """
        This function performs one iteration of MCTS. It is recursively called
        till a leaf node is found. The action chosen at each node is one that
        has the maximum upper confidence bound as in the paper.

        Once a leaf node is found, the neural network is called to return an
        initial policy P and a value v for the state. This value is propogated
        up the search path. In case the leaf node is a terminal state, the
        outcome is propogated up the search path. The values of Ns, Nsa, Qsa are
        updated.

        NOTICE: the return values are the negative of the value of the current
        state. This is done since v is in [-1,1] and if v is the value of a
        state for the current player, then its value is -v for the other player.
        since we use the canonical board to do tree search, we need to inverting
        v in every back-up steps.
        
        Returns:
            v: the negative of the value of the current canonicalBoard
        """
        if self.debug:
            Pdb().set_trace()
            
        s = self.game.mcts_board2string(canonicalBoard)
        
        if s not in self.Es:
            self.Es[s] = self.game.get_game_ended(canonicalBoard, 1)
        if self.Es[s] != None:   # terminal node
            return -self.Es[s]
        
        if s not in self.Ps:     # leaf node
            result = self.nnet.predict(canonicalBoard)
            self.Ps[s], v = result[0][0], result[1][0]
            valids = self.game.get_legal_actions(canonicalBoard)
            self.Ps[s] = self.Ps[s] * valids   # masking illegal action probability
            
            sum_Ps = np.sum(self.Ps[s])
            if sum_Ps > 0:
                self.Ps[s] /= sum_Ps    # renormalize action probability
            else:
                # NOTICE: All valid moves may be masked if either your NNet 
                # architecture is insufficient or you've get overfitting or
                # something else. If you have got dozens of these messages 
                # you should pay attention to your NNet and/or training process.
                # if all valid moves were masked, we will make all valid moves 
                # equally probable as a workground. 
                print("All valid moves were masked, do workaround.")
                self.Ps[s] = self.Ps[s] + valids
                self.Ps[s] /= np.sum(self.Ps[s])
                
            self.Vs[s] = valids
            self.Ns[s] = 0
            return -v

        valids = self.Vs[s]
        cur_best = -float('inf')
        best_act = -1

        # pick the action with the highest upper confidence bound.
        for a in range(self.game.mctsActionSize):
            if valids[a]:
                if (s,a) in self.Qsa:
                    u = self.Qsa[(s,a)] + self.config['cpuct'] \
                        * self.Ps[s][a] * np.sqrt(self.Ns[s])  \
                        / (1 + self.Nsa[(s,a)])
                else:
                    # Q(s,a) = 0
                    u = self.config['cpuct'] * self.Ps[s][a] * \
                        np.sqrt(self.Ns[s] + self.EPS)     

                if u > cur_best:
                    cur_best = u
                    best_act = a

        a = best_act

        next_s, next_player = self.game.get_next_state(canonicalBoard, 1, a)
        next_s = self.game.get_canonical_form(next_s, next_player)

        v = self.search(next_s)

        if (s,a) in self.Qsa:
            self.Qsa[(s,a)] = (self.Nsa[(s,a)] * self.Qsa[(s,a)] + v) \
                              / (self.Nsa[(s,a)]+1)
            self.Nsa[(s,a)] += 1

        else:
            self.Qsa[(s,a)] = v
            self.Nsa[(s,a)] = 1

        self.Ns[s] += 1
        return -v

# Coach

In [8]:
from collections import deque
from tqdm import tnrange, tqdm_notebook

In [9]:
class Coach():
    """
    This class executes the self-play + learning. It uses the functions defined
    in BoardGame and BoardNNet.
    """
    def __init__(self, game, nnet, pnet, config,
                 coach_debug=False, mcts_debug = False):
        self.game = game
        self.nnet = nnet
        self.pnet = pnet      # the competitor network
        self.config = config
        self.debug = coach_debug
        self.mcts = MCTS(game, nnet, config, mcts_debug)
        self.trainExamplesHistory = []
        
    def execute_episode(self):
        """
        This function executes one episode of self-play, starting with player 1.
        As the game is played, each turn is added as a training example to
        trainExamples. The game is played utill the game ends. After the game
        ends, the outcome of the game is used to assign values to each example
        in trainExamples.

        It uses a temp=1 if episodeStep < tempThreshold, and thereafter
        uses temp=0.

        Returns:
          trainExamples: a list of examples of the form (canonicalBoard,pi,v)
                         pi is the MCTS informed policy vector,
                         v is +1 if the player eventually won the game,
                         draw is +0.001 and lose is -1.
        """
        if self.debug:
            Pdb().set_trace()
        
        trainExamples = []
        episodeStep = 0
        
        self.game.reset_game()
        board = self.game.board
        self.curPlayer = self.game.curPlayer
        
        while True:
            episodeStep += 1
            
            canonicalBoard = self.game.get_canonical_form(board, self.curPlayer)
            temp = int(episodeStep < config['tempThreshold'])
            
            pi = self.mcts.calc_action_prob(canonicalBoard, temp=temp)
            syms = self.game.get_symmetries(canonicalBoard, pi)
            
            for b,p in syms:
                trainExamples.append((b, self.curPlayer, p, None))
            action = np.random.choice(len(pi), p=pi)
            board, self.curPlayer = self.game.get_next_state(board,
                                                             self.curPlayer,
                                                             action)
            
            r = self.game.get_game_ended(board, self.curPlayer)
            
            if r != None:
                return [(x[0], x[2], r*((-1)**(x[1]!=self.curPlayer))) 
                        for x in trainExamples]
            
    def learn(self):
        # save print out to log file.
        stdout = sys.stdout
        f = open('./nnet_save/'+self.config['nnetLabel']+'_log.txt', 'w')
        sys.stdout = f
        
        if self.debug:
            Pdb().set_trace()
            
        config = self.config
        for i in tnrange(config['numIters'], desc='Iters'):
            
            iterationTrainExamples = deque(maxlen=self.config['maxlenTrainExamples'])
            for eps in tnrange(config['numEps'], desc='Eps', leave=False):
                # reset search tree
                self.mcts = MCTS(self.game, self.nnet, self.config)    
                iterationTrainExamples += self.execute_episode()
            
            # save the iteration examples to the history 
            self.trainExamplesHistory.append(iterationTrainExamples)
            if len(self.trainExamplesHistory) > self.config['maxlenTrainExamplesHistory']:
                print("TrainExamplesHistory reach maxlen,"
                      "remove the oldest trainExamples")
                self.trainExamplesHistory.pop(0)
            
            # shuffle examlpes before training
            trainExamples = []
            for e in self.trainExamplesHistory:
                trainExamples.extend(e)
            np.random.shuffle(trainExamples)
            
            # training new network, and let pnet become the old one of the nnet.
            self.nnet.save_checkpoint(filename=self.config['nnetLabel']+'_nnet_tmp.h5')
            self.pnet.load_checkpoint(filename=self.config['nnetLabel']+'_nnet_tmp.h5')
            self.nnet.train(trainExamples, i+1)
            
            # pitting against
            nwins, pwins, draws = self.compare()
            print('New vs Prev. Wins : %d vs %d ; Draws : %d' % (nwins, pwins, draws))
            
            if pwins+nwins > 0 and float(nwins)/(pwins+nwins) < self.config['updateThreshold']:
                print('Rejecting new model...')
                self.nnet.load_checkpoint(filename=self.config['nnetLabel']+'_nnet_tmp.h5')
            else:
                print('Accpeting new model...')
                self.nnet.save_checkpoint(filename=self.config['nnetLabel']+'_nnet_curbest.h5')
            print('')
            
        sys.stdout = stdout
        f.close()
        
    def compare(self, external=None):
        if external == None:
            nnet_ai = BoardAI(self.game, self.nnet, self.config)
            pnet_ai = BoardAI(self.game, self.pnet, self.config)
            numCompares = self.config['numCompares']
        else:
            nnet_ai = external[0]
            pnet_ai = external[1]
            numCompares = external[2]
        
        nnet_wins = 0
        pnet_wins = 0
        draws = 0
        for i in range(numCompares):
            """
            NOTICE: nnet_ai will be pc1 in eve, eve return the player number of 
             pc1 and the winner player. so if returned player = returned winner,
             will means nnet win.
            """
            self.game.reset_game()
            nnet_player, winner = self.game.eve(nnet_ai, pnet_ai)
            if winner == 0:  # draw
                draws += 1
            elif nnet_player == winner:
                nnet_wins += 1
            else:
                pnet_wins += 1
        return nnet_wins, pnet_wins, draws
            

# All-in-one Configure

In [10]:
from copy import copy

config = {
    # NNet Parameter
    'nnetLabel': '',                  # label for identify network
    'numConvChannels': 32,
    'numConvLayers': 4,
    'batchSize': 100,
    'epochs': 10,
    'lr': 0.001,
    
    # Coach Settings
    'numIters': 1000,                 # number of iterations
    'numEps': 100,                    # number of episodes
    'maxlenTrainExamples': 200000,    # max length of train examples history
    'maxlenTrainExamplesHistory': 20,
    'updateThreshold': 0.6,
    'numCompares': 40,
    
    # MCTS Settings
    'tempThreshold': 15,
    'cpuct': 1,
    'numMCTSSims': 25,
}

# Tic-tac-toe Test

In [11]:
config_ticfc = copy(config)
config_ticfc['numIters'] = 15
config_ticfc['nnetLabel'] = 'ticfc'
nnet_ticfc = TicTacToeNNetFC(config_ticfc, save_summary=True)
pnet_ticfc = TicTacToeNNetFC(config_ticfc)  # the competitor network

In [12]:
game_tictac = TicTacToeGame(use_gui=True)

# make coach instance
coach_tic = Coach(game_tictac, nnet_ticfc, pnet_ticfc, config_ticfc)

In [13]:
nnet_ticfc.load_checkpoint('ticfc_nnet_curbest.h5')

In [14]:
ai=BoardAI(game_tictac,nnet_ticfc,config_ticfc)

In [15]:
game_tictac.draw_board()

In [16]:
game_tictac.pve(ai)

Draw!
