In [1]:
import gym
import numpy as np
import time
import pachi_py

from gym.envs.board_game.go import _coord_to_action, GoState
SEED = 0
np.random.seed(SEED)

Let's see if we can beat the gym environment using only MCTS

If we seed the gym environment, we can recreate the game state

### In each iteration of MCTS, there are four steps

1. Selection
1. Expansion
1. Simulation
1. Backpropagation

In [2]:
env_id = 'Go9x9-v0'
def get_env():
    env = gym.make(env_id)
    env.seed(SEED)
    env.reset()
    return env

In [3]:
def get_legal_actions(state):
    actions = [_coord_to_action(state.board, c) for c in state.board.get_legal_coords(state.color)]
    return actions

In [4]:
def fast_sim(state):
    
    state = GoState(state.board.clone(), state.color)
    og_player = state.color
    
    actions = dict()
    actions[pachi_py.BLACK] = []
    actions[pachi_py.WHITE] = []
    
    while not state.board.is_terminal:
        player = state.color   
        a = np.random.choice(state.board.get_legal_coords(player))
        a = _coord_to_action(state.board, a)
        state = state.act(a)
        actions[player].append(a)
        # resign is the end of the game
        if a==pachi_py.RESIGN_COORD:
            break
    
    # Score with respect to player
    score = np.sign(state.board.official_score)
    if og_player == pachi_py.BLACK:
        score = np.negative(score)
    
    # Resignation is always loosing
    if a == pachi_py.RESIGN_COORD:
        if player == og_player:
            score = -1
        else:
            score = 1

    return score, actions

In [5]:
class Node():
    def __init__(self, state, parent=None):
        
        self.state = GoState(state.board.clone(), state.color)
        self.unexpanded = get_legal_actions(self.state)
        self.parent = parent
        self.child = dict() # key: action
        self.r = 0.
        self.n = 0.
        
    def select(self):
        actions = []
        ucbs = []
        for a, n in self.child.items():
            actions.append(a)
            ucbs.append(n.get_ucb())
        i = np.argmax(ucbs)
        a = actions[i]
        return self.child[a]
    
    def get_ucb(self):
        assert self.n > 0
        return self.r / self.n + 2 * np.sqrt(np.log(self.parent.n)/self.n)
        
    def expand(self):
        assert self.expandable
        action = self.unexpanded[0]
        self.child[action] = Node(self.state, self)
        self.unexpanded.remove(action)
        return self.child[action]
    
    def update(self, r):
        self.n += 1
        self.r +=r
        
    @property
    def n_child(self):
        return len(self.child)
    
    @property
    def expandable(self):
        if len(self.unexpanded) > 0:
            return True
        else:
            return False

In [6]:
def mcts(root_node):
    current_node = root_node
    
    # Selection
    while not current_node.expandable:
        current_node = current_node.select()
        
    # Expansion
    current_node = current_node.expand()
    
    # Simulate
    r, _ = fast_sim(current_node.state)
    current_node.update(r)
    
    # backprop
    while current_node.parent is not None:
        r = -r
        current_node = current_node.parent
        current_node.update(r)

# Test

In [7]:
env = get_env()
root_node = Node(env.state, None)
for i in range(10000):
    mcts(root_node)

In [8]:
# for k, v in root_node.child.items():
#     print('ACTION: {:2} VISITS: {:3.0f} REWARD: {:3.0f}'.format(k, v.n, v.r))

ACTION:  0 VISITS: 130 REWARD:  -2
ACTION:  1 VISITS: 145 REWARD:   2
ACTION:  2 VISITS:  76 REWARD: -14
ACTION:  3 VISITS: 138 REWARD:   0
ACTION:  4 VISITS:  94 REWARD: -11
ACTION:  5 VISITS:  92 REWARD: -11
ACTION:  6 VISITS: 121 REWARD:  -5
ACTION:  7 VISITS: 185 REWARD:  13
ACTION:  8 VISITS:  85 REWARD: -12
ACTION:  9 VISITS: 109 REWARD:  -7
ACTION: 10 VISITS: 103 REWARD:  -9
ACTION: 11 VISITS: 100 REWARD:  -9
ACTION: 12 VISITS: 168 REWARD:   8
ACTION: 13 VISITS:  66 REWARD: -16
ACTION: 14 VISITS:  67 REWARD: -16
ACTION: 15 VISITS:  83 REWARD: -13
ACTION: 16 VISITS:  95 REWARD: -11
ACTION: 17 VISITS: 126 REWARD:  -3
ACTION: 18 VISITS: 145 REWARD:   1
ACTION: 19 VISITS: 239 REWARD:  30
ACTION: 20 VISITS: 134 REWARD:  -1
ACTION: 21 VISITS: 109 REWARD:  -7
ACTION: 22 VISITS:  64 REWARD: -16
ACTION: 23 VISITS: 104 REWARD:  -8
ACTION: 24 VISITS: 104 REWARD:  -8
ACTION: 25 VISITS: 135 REWARD:  -1
ACTION: 26 VISITS:  53 REWARD: -17
ACTION: 27 VISITS: 118 REWARD:  -5
ACTION: 28 VISITS: 1

In [None]:
a = root_node.child[0].child[33]
for k, v in a.child.items():
    print('ACTION: {:2} VISITS: {:3.0f} REWARD: {:3.0f}'.format(k, v.n, v.r))

ACTION: 81 VISITS:   1 REWARD:   1


# Play a Game

In [None]:
env = get_env()
done = False
while not done:
    root_node = Node(env.state, None)
    tic = time.time()
    for i in range(8000):
        mcts(root_node)
    toc = time.time() - tic
    print('DECISION TIME: {:.1f}s'.format(toc))
        
    actions = []
    actions_visits = []
    actions_value = []

    for k, v in root_node.child.items():
        actions.append(k)
        actions_visits.append(v.n)
        actions_value.append(v.r)
    max_n = np.argmax(actions_visits)
    max_r = np.argmax(actions_value)
    print('Most Visited:\t A:{} F:{} V:{}'.format(actions[max_n], actions_visits[max_n], actions_value[max_n]))
    print('Highest Value:\t A:{} F:{} V:{}'.format(actions[max_r], actions_visits[max_r], actions_value[max_r]))
    a = actions[max_n]
    ob, r, done, info = env.step(a)
    env.render()

DECISION TIME: 24.1s
Most Visited:	 A:60 F:195.0 V:28.0
Highest Value:	 A:60 F:195.0 V:28.0
To play: black
Move:   2  Komi: 0.0  Handicap: 0  Captures B: 0 W: 0
      A B C D E F G H J  
    +-------------------+
  9 | . . . . . . . . . |
  8 | . . . . . . . . . |
  7 | . . . . . . . . . |
  6 | . . . O). . . . . |
  5 | . . . . . . . . . |
  4 | . . . . . . . . . |
  3 | . . . . . . X . . |
  2 | . . . . . . . . . |
  1 | . . . . . . . . . |
    +-------------------+
DECISION TIME: 23.9s
Most Visited:	 A:77 F:208.0 V:35.0
Highest Value:	 A:77 F:208.0 V:35.0
To play: black
Move:   4  Komi: 0.0  Handicap: 0  Captures B: 0 W: 0
      A B C D E F G H J  
    +-------------------+
  9 | . . . . . . . . . |
  8 | . . . . . . . . . |
  7 | . . . . . . . . . |
  6 | . . . O . . . . . |
  5 | . . . . . . . . . |
  4 | . . . . . . . . . |
  3 | . . . . . . X . . |
  2 | . . . . . O). . . |
  1 | . . . . . X . . . |
    +-------------------+
DECISION TIME: 24.1s
Most Visited:	 A:73 F:215.0 V:38.

In [None]:
if r > 0:
    print('WIN')
elif r < 0:
    print('LOST')
else:
    print('DRAW')

In [None]:
env.state.board.official_score

# Closing Thoughts

* When expanding a new node, parallelisation can be used to increase speed. E.g. expand all nodes at the same time.
* Tree policy enhancements such as AMAF and RAVE would be more suitable to Go.
* The simulation/rollout phase would be more accurate if we understand the policy of the enemy. E.g. a neural network that predicts enemy moves.
* MCTS along is not enought as the game of Go has a very large search space
