## MCTSを用いた五目並べの実践

In [11]:
from __future__ import division

from copy import deepcopy
from mcts import mcts
from gomoku import GomokuState
from gravity_gomoku import GravityGomokuState
from functools import reduce
import operator
import importlib
from graphviz import Graph

import numpy as np
import matplotlib.pyplot as plt

%matplotlib inline

### 五目並べ状態を設定
- 状態にはインターフェースに沿っていれば任意のゲームを指定できる。

In [16]:
N = 6
K = 5
initialState = GomokuState(N, K)

### 探索アルゴリズムを設定
- cleaverSearcher: MCTSによる賢いアルゴリズム  
- foolSearhcer: UCTの探索項を0にしたアルゴリズム

In [17]:
searcher = mcts(timeLimit=10000)
cleaverSearcher = mcts(timeLimit=3000)
foolSearcher = mcts(timeLimit=3000, explorationConstant=0)

In [18]:
state = initialState

In [19]:
print(state.draw_board())
print("------")

while True:
    print("Player:{}".format(state.currentPlayer))
    
    # 最適な手を探索
    # ■ cleaver vs foolの場合
    currentPlayer = state.getCurrentPlayer()
    if currentPlayer == -1:
        searcher = cleaverSearcher
    elif currentPlayer == 1:
        searcher = foolSearcher
    # ■ cleaver vs cleaver の場合
    #     searcher = cleaverSearcher
    action = searcher.search(initialState=state)
    print(action)
    
    # 最適手をプレイ
    state = state.takeAction(action)
    
    # 盤面を表示
    print(state.draw_board())
    
    # 子ノードの試行回数を表示
    for action, node in searcher.root.children.items():
        print("{}: {}".format(action, node.numVisits), end = ' ')
    print("")
    
    # 勝敗がついていたら終了
    if state.isTerminal():
        break
        
    print("------")

XXXXXX
XXXXXX
XXXXXX
XXXXXX
XXXXXX
XXXXXX

------
Player:1
(2, 3)
XXXXXX
XXXXXX
XXXAXX
XXXXXX
XXXXXX
XXXXXX

(0, 0): 1 (0, 1): 1 (0, 2): 1 (0, 3): 5 (0, 4): 1 (0, 5): 2 (1, 0): 5 (1, 1): 46 (1, 2): 2 (1, 3): 22 (1, 4): 1 (1, 5): 3 (2, 0): 1 (2, 1): 6 (2, 2): 12 (2, 3): 322 (2, 4): 3 (2, 5): 1 (3, 0): 2 (3, 1): 30 (3, 2): 2 (3, 3): 1 (3, 4): 1 (3, 5): 22 (4, 0): 2 (4, 1): 1 (4, 2): 2 (4, 3): 23 (4, 4): 1 (4, 5): 2 (5, 0): 3 (5, 1): 3 (5, 2): 3 (5, 3): 2 (5, 4): 1 (5, 5): 1 
------
Player:-1
(3, 3)
XXXXXX
XXXXXX
XXXAXX
XXXBXX
XXXXXX
XXXXXX

(0, 0): 5 (0, 1): 11 (0, 2): 8 (0, 3): 8 (0, 4): 10 (0, 5): 20 (1, 0): 28 (1, 1): 18 (1, 2): 30 (1, 3): 11 (1, 4): 13 (1, 5): 11 (2, 0): 30 (2, 1): 18 (2, 2): 13 (2, 4): 13 (2, 5): 5 (3, 0): 20 (3, 1): 15 (3, 2): 13 (3, 3): 54 (3, 4): 24 (3, 5): 28 (4, 0): 26 (4, 1): 22 (4, 2): 22 (4, 3): 15 (4, 4): 24 (4, 5): 24 (5, 0): 8 (5, 1): 28 (5, 2): 8 (5, 3): 13 (5, 4): 18 (5, 5): 11 
------
Player:1
(3, 4)
XXXXXX
XXXXXX
XXXAXX
XXXBAX
XXXXXX
XXXXXX

(0, 0): 2

In [6]:
state.board = [
    [0,0,0,0,0],
    [0,0,0,0,0],
    [0,0,0,0,0],
    [0,0,0,0,0],
    [-1,1,-1,-1,1],
]

In [7]:
print(state.draw_board())

XXXXX
XXXXX
XXXXX
XXXXX
BABBA



In [8]:
state.currentPlayer

1

In [9]:
searcher.search(state)

(3, 1)

In [10]:
searcher.makeGraph()

In [9]:
initialState.board

[[0, -1, 0, 1, 0, 0],
 [0, -1, 1, -1, -1, 0],
 [1, 0, 1, 1, 0, -1],
 [-1, 1, 1, 1, 0, 0],
 [0, 1, -1, 0, 1, 0],
 [0, -1, 1, -1, 0, 0]]

In [10]:
action = searcher.search(initialState=initialState)

In [5]:
searcher.makeGraph()

In [10]:
for v in searcher.root.children.values():
    print(v.numVisits)

29
55
46
54
20
21
33
40
16
95
159
62
157
100
293
30
85
39
143
35
55
25
28
78
19
40
57
46
30
62
45


In [None]:
def make_graph(G, nd):
    state = nd.state
    uct = nd.CalcUCT()
    G.node(str(state), str(nd.rmsd_max) + '\n' + "{:.4}".format(float(nd.rmsd))  + '\n' + str(nd.ix) + ' ' + str(nd.iy) + '\n' + str(nd.visits) + '\n' + str(uct))
    parent_node = nd.parentNode
    if parent_node != None:
        parent_state = parent_node.state
        G.edge(str(parent_state), str(state))
    for child_node in nd.childNodes:
        make_graph(G,child_node)

In [7]:
G = Graph(format='png')
G.attr('node', shape='circle')
G.graph_attr.update(size="320")
G.node('a', '100')
G.node('b', '200')
G.edge('a', 'b')
G.render('test')

'test.png'

In [48]:
board =  [
                    [0, 0, 0, 0, 1, 0],
                    [1, 0, 0, 1, 0, 0],
                    [0, 1, 1, 0, 0, 0],
                    [0, 1, 1, 0, 0, 0],
                    [1, 0, 0, 1, 0, 0],
                    [0, 0, 0, 0, 0, 0]
                ]

In [49]:
N = 6
K = 5
def isTerminal(board):
    # 上下
    for x in range(N):
        for y in range(N-K+1):
            tmp = sum([board[x][y + i] for i in range(K)])
            if abs(tmp) == K:
                return tmp / K

    # 左右
    for y in range(N):
        for x in range(N-K+1):
            tmp = sum([board[x +i][y] for i in range(K)])
            if abs(tmp) == K:
                return tmp / K

    # 左上 右下
    for x in range(N-K+1):
        for y in range(N-K+1):
            tmp = sum([board[x+i][y+i] for i in range(K)])
            if  abs(tmp) == K:
                return tmp / K


    # 右上 左下
    for x in range(N-K+1):
        for y in range(K-1, N):
            tmp = sum([board[x+i][y-i] for i in range(K)])
            if  abs(tmp) == K:
                return tmp / K
            
    return 0

In [50]:
isTerminal(board)

1.0

In [31]:
board[0:5][4]

[0, 0, 0, 0, 1]

In [52]:
not 0

True