In [None]:
from generate import *
from tree_search import *
from sklearn.externals import joblib

In [None]:
class Network:
    def __init__(self, model, sk_learn=False):
        self.model = model
        self.board = np.zeros((15, 15))
        self.sk_learn = sk_learn
    def make_move(self, *args):
        if self.sk_learn:
            action = legal_move(self.board, self.model, False)
        else:
            action = legal_nn_move(self.board, self.model)
        self.board[action] = -1.0 - 2*self.board.sum()
        return action
    def pass_move(self, action):
        self.board[action] = -1.0 - 2*self.board.sum()

In [None]:
class Meatbag:
    def __init__(self):
        self.board = np.zeros((15, 15))
    def make_move(self, *args):
        while True:
            # many more possible problems
            raw = input()
            width = ord(raw[0].lower())
            if width > ord('i'):
                width -= 1
            width -= ord('a')
            height = int(raw[1:]) - 1
            if round(self.board[height, width]) == 0:
                break
            print('Wrong action')
        self.board[height, width] = -1.0 - 2*self.board.sum()
        return (height, width)
    
    def pass_move(self, action):
        self.board[action] = -1.0 - 2*self.board.sum()

In [None]:
import MCTS_cpp

def cpp_predict_value(board, is_black_to_go):
    global value_policy
    res = value_policy.predict(prepare_board(board, is_black_to_go)[np.newaxis])[0, 0]
    return res


def cpp_predict_probas(board, is_black_to_go):
    global sl_policy
    res = sl_policy.predict(prepare_board(board, not is_black_to_go)[np.newaxis])[0]
    return res


class MCTS_Cpp_Wrapper:
    def __init__(self):
        self.search = MCTS_cpp.MCTS_cpp(cpp_predict_probas, rollout_policy, cpp_predict_value, 0.5)
    def make_move(self, *args):
        return MCTS_cpp.make_move(self.search, args[0])
    def pass_move(self, action):
        MCTS_cpp.pass_move(self.search, action)

In [None]:
sl_large_policy = load_model('large_policy_model')
rl_policy = None # load_model('rl_policy_72%.keras')
value_policy = None # load_model('rl_value_model')
filtered_rollout_policy = joblib.load('filtered_sklearn_policy.pkl')

In [None]:
board = np.zeros((15, 15), dtype='float32')
last_action = (0, 0)
turn = 0

# players = [MonteCarloTreeSearch(sl_policy, filtered_rollout_policy, value_policy, 1.0, verbose=False),
#            Network(sl_policy)]
players = [Meatbag(),
           MonteCarloTreeSearch(sl_large_policy, None, None, time=5, mixing_param=1.0, verbose=1)]

while get_rollout_result(board, last_action) is None:
    show_board(board)
    last_action = players[turn % 2].make_move()
    players[(turn + 1) % 2].pass_move(last_action)
    board[last_action] = -1.0 + 2 * (turn % 2)
    turn += 1
show_board(board)
res = get_rollout_result(board, last_action)
if res is None:
    print("Not finished")
elif round(res) == 1:
    print('White won!')
elif round(res) == -1:
    print('Black won!')
else:
    print('Draw!')

In [None]:
counter_5 = tqdm_notebook(total=50, desc='5', leave=True)
counter_3 = tqdm_notebook(total=50, desc='3', leave=True)
counter_draw = tqdm_notebook(total=50, desc='draw', leave=True)

for i in range(25):
    players = [MonteCarloTreeSearch(sl_large_policy, None, None, time=1, mixing_param=1.0, verbose=0),
               search2.MonteCarloTreeSearch(sl_large_policy, None, None, time=1, mixing_param=1.0, verbose=0)]
    board = np.zeros((15, 15), dtype='float32')
    last_action = (0, 0)
    turn = 0
    while get_rollout_result(board, last_action) is None:
        last_action = players[turn % 2].make_move()
        players[(turn + 1) % 2].pass_move(last_action)
        board[last_action] = -1.0 + 2 * (turn % 2)
        turn += 1
    res = get_rollout_result(board, last_action)
    if round(res) == 1:
        counter_3.update()
    elif round(res) == -1:
        counter_5.update()
    else:
        counter_draw.update()
        
    players = [search2.MonteCarloTreeSearch(sl_large_policy, None, None, time=1, mixing_param=1.0, verbose=0),
               MonteCarloTreeSearch(sl_large_policy, None, None, time=1, mixing_param=1.0, verbose=0)]
    board = np.zeros((15, 15), dtype='float32')
    last_action = (0, 0)
    turn = 0
    while get_rollout_result(board, last_action) is None:
        last_action = players[turn % 2].make_move()
        players[(turn + 1) % 2].pass_move(last_action)
        board[last_action] = -1.0 + 2 * (turn % 2)
        turn += 1
    res = get_rollout_result(board, last_action)
    if round(res) == 1:
        counter_5.update()
    elif round(res) == -1:
        counter_3.update()
    else:
        counter_draw.update()