In [2]:
import numpy as np
import tensorflow as tf
import os
from src.game import GoState, roll_axis
from src.net import DualRes
import pachi_py

In [3]:
N_ACTIION = 3
C_PUCT = 1.38

In [6]:
class Node():
    def __init__(self, parent, action, prob=None, state = None):
        self.parent = parent
        self.a = action
        self.illegal = False
        
        # State creaction
        if state is None:
            try:
                state = parent.state.act(self.a)
            except:
                state = None
                self.illegal = True
        
        self.state = state
        
        
        # New nodes have no children
        self.children = dict()
        
        self.n = 0
        self.w = 0
        self.q = 0
        self.p = prob
        
        
        # terminal state, unable to expand
        self.is_terminal = False
        
        
    def select(self):
        assert len(self.children) != 0
        u = [self.children[a].get_u() for a in range(N_ACTIION)]
        sa = np.argmax(u)
        return self.children[sa]
        
    def get_u(self):
        return C_PUCT * self.p * np.sqrt(1+self.parent.n) / (1 + self.n)
    
    def expand(self):
        assert len(self.children) == 0
        pi = get_prob()
        self.children = {a: Node(self, a, pi[a]) for a in range(N_ACTIION)}
        
        
    def backup(self, v):
        if self.illegal:
            v = -1
            
        self.w += v
        self.n += 1
        
        if self.parent is not None:
            self.parent.backup(v)
        
    def __repr__(self):
        #return repr(self.state)
        return 'a: {:2d} p: {:2.2f} \n'.format(self.a, self.p) + repr(self.state)

In [7]:
def mcts(root_node, n=2):
    for i in range(n):
        print('INFO: MCTS {}'.format(i))
        cn = root_node
        # Select
        print(len(cn.children))
        while len(cn.children) != 0:
            print('INFO: SELECTION')
            cn = cn.select()

        # Expand, eval, backup
        print('INFO: EXPAND')
        cn.expand()
        print('INFO: EVALUATION & BACKUP')
        for a, n in cn.children.items():
            v = eval()
            n.backup(v)

# Tensorflow Model

In [22]:
# Load tensorflow graph
version = 'v1'
tf.reset_default_graph()
sess = tf.Session()
meta = os.path.join('models', version, version + '.meta')
saver = tf.train.import_meta_graph(meta)
saver.restore(sess, tf.train.latest_checkpoint(os.path.join('models', version)))
graph = tf.get_default_graph()
tf_board = graph.get_tensor_by_name('Placeholder:0')
tf_player = graph.get_tensor_by_name('Placeholder_1:0')
tf_pi = graph.get_tensor_by_name('policy_head/dense/BiasAdd:0')
tf_v = graph.get_tensor_by_name('value_head/Tanh:0')


def sample(board, player):
    pi, v = sess.run([tf_pi, tf_v], {tf_board:[board], tf_player:[player-1]})
    return pi, v

INFO:tensorflow:Restoring parameters from models/v1/v1


In [23]:
s0 = GoState(pachi_py.CreateBoard(9), pachi_py.BLACK)
root_node = Node(None, None, None, s0)

In [24]:
pi, v = sample(roll_axis(s0.board.encode()), pachi_py.BLACK)

In [26]:
def _action_to_coord(board, a):
    '''Converts actions to Pachi coordinates'''
    if a == board.size**2: return pachi_py.PASS_COORD
    if a == board.size**2 + 1: return pachi_py.RESIGN_COORD
    return board.ij_to_coord(a // board.size, a % board.size)

In [27]:
s0.board.coord_to_str(_action_to_coord(s0.board, np.argmax(pi)))

b'G5'