In [1]:
import numpy as np

import gym
import gym_tictactoe
from tqdm import tqdm
from copy import deepcopy

# 3d tic-tac-toe in openai gym
# https://pypi.org/project/gym-tictactoe/
# pip install gym-tictactoe

In [2]:
env = gym.make('tictactoe-v0')

[2019-04-04 15:43:21,808] Making new env: tictactoe-v0


In [3]:
s = env.reset()

In [4]:
action = '1021'
'''
1: player 1
0: first axis
2: second axis
1: third axis
'''
s_, r, done, info = env.step(action)
# env.render()

In [5]:
print(s_)
print(r)
print(done)
print(info)

[[[0, 0, 0], [0, 0, 0], [0, 1, 0]], [[0, 0, 0], [0, 0, 0], [0, 0, 0]], [[0, 0, 0], [0, 0, 0], [0, 0, 0]]]
0
False
{'round': 1, 'next_player': 2}


In [6]:
# def state2string(state):
#     return ''.join(str(j) for rows in state for row in rows for j in row)

# def state2pos(state):
#     # state to possible actions
#     return [''.join(str(i) for i in list(ind)) \
#             for ind in np.argwhere(np.array(state)==0)]

def state2string(state):
    # care abount only first axis
    return ''.join(str(j) for ind in state[0] for j in ind)

def state2pos(state):
    # state to possible actions
    return ['0'+''.join(str(i) for i in list(ind)) \
            for ind in np.argwhere(np.array(state[0])==0)]

def state2traj(state):
    # state to episode trajectory
    traj = [0 for i in range(len(np.argwhere(np.array(state)!=0)))]
    for i, ind in enumerate(np.argwhere(np.array(state)==1)):
        traj[2*i] = '1'+''.join(str(i) for i in list(ind))
    for i, ind in enumerate(np.argwhere(np.array(state)==2)):
        traj[2*i+1] = '2'+''.join(str(i) for i in list(ind))
    return traj

def give_env(state):
    e = gym.make('tictactoe-v0')
    s = e.reset()
    traj = state2traj(state)
    for t in traj:
        e.step(t)
    return e

def one_step(state, action):
    # give one step simulation
    e = give_env(state)
    s_, r, done, info = e.step(action)
    s_ = deepcopy(s_)
    return s_, r, done

In [7]:
print(state2string(s_))
print(state2pos(s_))
print(state2traj(s_))
print(one_step(s_,'2'+state2pos(s_)[0]))

[2019-04-04 15:43:22,094] Making new env: tictactoe-v0


000000010
['000', '001', '002', '010', '011', '012', '020', '022']
['1021']
o - -    - - -    - - -    
- - x    - - -    - - -    
- - -    - - -    - - -    
([[[2, 0, 0], [0, 0, 0], [0, 1, 0]], [[0, 0, 0], [0, 0, 0], [0, 0, 0]], [[0, 0, 0], [0, 0, 0], [0, 0, 0]]], 0, False)


In [8]:
class Node:
    # node class for each state
    def __init__(self, player, state, parent=None, done=False):
        self.player = player # 1 or 2
        self.parent = parent # node instance 
        self.state = state[:] # state to play. shape : (3, 3, 3)
        self.done = done
        self.state_str = state2string(self.state)
        self.pos = state2pos(self.state) # possible actions
        self.counts = np.array([0 for i in range(len(self.pos))]) # values for possible actions
        self.values = np.array([0 for i in range(len(self.pos))]) # counts for possible actions
        
    def best_action(self, c=5.):
        ucb = self.values/self.counts + c * np.sqrt(2*np.log(np.sum(self.counts))/self.counts)
        ind = np.random.choice(np.argwhere(ucb == np.max(ucb)).flatten().tolist())
        action = self.pos[ind]
        return ind, str(self.player) + action, self.values/self.counts, self.counts

In [9]:
n = Node(2, s_)

In [10]:
print(n.state_str)
print(n.pos)
print(n.counts)
print(n.values)

000000010
['000', '001', '002', '010', '011', '012', '020', '022']
[0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0]


In [11]:
class UCT:
    def __init__(self, player, state, iter_num=200):
        self.player = player
        self.root = Node(self.player, state) # root node of tree
        self.iter_num = iter_num # computational budget
        self.node_list = [self.root] # tree nodes
        
    def uct_search(self, c):
        # collect trajectory you can fix this with conditional loop
        for i in range(self.iter_num):
            new_node, traj, r = self.tree_policy(c)
            if new_node is not None:
                if not(new_node.done) and len(state2pos(new_node.state)) != 0:
                    r = self.default_policy(new_node)
            self.backup(traj, r)
        return self.root.best_action(c=0.0)
            
    def tree_policy(self, c):
        # index of node_list for episode trajectory
        node = self.root
        traj = []
        r = 0
        while not(node.done) and len(node.pos)>0:
            # if there exists unchosen action do it.
            # note that every tree_policy call makes expand call until ariive to leaf node
            if len(np.argwhere(node.counts==0))>0:
                new_node, ind, r = self.expand(node)
                traj.append((node,ind))
                if new_node is not None:
                    self.node_list.append(new_node)
                break
            # else choose best action
            else:
                ind, a, v, counts = node.best_action(c)
                state, r, done = one_step(node.state, a)
                if len(state2pos(deepcopy(state)))==0:
                    done = True
                if not(done):
                    # if not done make state for new node
                    e = give_env(state)
                    action = np.random.choice(state2pos(state))
                    s_, r, done, info = e.step(str(3-self.player)+action)
                    r = -r
                    state = deepcopy(s_)
                    e.close()
                    for n in self.node_list:
                        # if seen state stop 
                        if n.state == state:
                            new_node = n
                            traj.append((node, ind))
                            break
                    else:
                        # else add node to node_list
                        new_node = Node(self.player, state, parent=node, done=done)
                        traj.append((node,ind))
                        self.node_list.append(new_node)
                else:
                    # if done 
                    new_node = None
                    traj.append((node, ind))
                    break
            node = new_node
        return new_node, traj, r
            
    def expand(self, node):
        # only call when there exists zero-chosen action in the node
        assert len(np.argwhere(node.counts==0)) != 0
        # choose zero-chosen action
        ind = np.argwhere(node.counts==0)[0][0]
        action = node.pos[ind]
        e = give_env(node.state)
        s_, r, done, info = e.step(str(node.player)+action)
        if len(state2pos(s_))==0:
            done = True
        if not(done):
            # give next node whose player is itself
            # note that reward is negative of player's reward
            s_, r, done, info = e.step(str(3-node.player)+np.random.choice(state2pos(deepcopy(s_))))
            new_node = Node(self.player, deepcopy(s_), parent=node, done=done)
            r = -r
            e.close()
            return node, ind, r
        else:
            return None, ind, r
    
    def default_policy(self, node):
        done = node.done
        player = node.player
        s = node.state
        e = give_env(s)
        r = 0
        if len(state2pos(node.state))==0:
            done = True
        while not(done):
            action = np.random.choice(state2pos(node.state))
            s_, r, done, info = e.step(str(player)+action)
            player = 3 - player # 1 -> 2, 2 -> 1
            s = deepcopy(s_)
        if player == node.player:
            # lose
            r *= -1
        e.close()
        return r
    
    def backup(self, traj, r):
        for (node,ind) in traj:
            node.counts[ind] += 1
            node.values[ind] += r

In [None]:
env = gym.make('tictactoe-v0')
s = env.reset()
player = 1
c = 5.0
done = False
traj = []
while not(done) and len(state2pos(s)) != 0:
    u = UCT(player, s)
    ind, a, value, counts = u.uct_search(c)
    s_, r, done, info = env.step(a)
    s = deepcopy(s_)
    player = 3 - player # 1 -> 2, 2 -> 1
    # anneal the exploration rate
    c *= 0.9
    traj.append((s, value, counts, r, done, info))

In [13]:
traj

[([[[0, 0, 0], [0, 0, 0], [1, 0, 0]],
   [[0, 0, 0], [0, 0, 0], [0, 0, 0]],
   [[0, 0, 0], [0, 0, 0], [0, 0, 0]]],
  array([ 0.33333333,  0.09090909, -0.05263158,  0.33333333,  0.33333333,
          0.04761905,  0.46153846, -0.05263158,  0.04761905]),
  array([24, 22, 19, 24, 24, 21, 26, 19, 21]),
  0,
  False,
  {'next_player': 2, 'round': 1}),
 ([[[0, 0, 0], [0, 2, 0], [1, 0, 0]],
   [[0, 0, 0], [0, 0, 0], [0, 0, 0]],
   [[0, 0, 0], [0, 0, 0], [0, 0, 0]]],
  array([ 0.        , -0.14285714, -0.04347826,  0.        ,  0.47058824,
         -0.14285714,  0.12      ,  0.21428571]),
  array([24, 21, 23, 24, 34, 21, 25, 28]),
  0,
  False,
  {'next_player': 1, 'round': 2}),
 ([[[0, 0, 0], [0, 2, 0], [1, 1, 0]],
   [[0, 0, 0], [0, 0, 0], [0, 0, 0]],
   [[0, 0, 0], [0, 0, 0], [0, 0, 0]]],
  array([ 0.26666667,  0.35483871,  0.        ,  0.35483871, -0.09090909,
          0.48571429,  0.18518519]),
  array([30, 31, 24, 31, 22, 35, 27]),
  0,
  False,
  {'next_player': 2, 'round': 3}),
 ([[[0,