In [1]:
# environment:
# pip3 install torch

In [2]:
# Implementation of simple game: Tic-Tac-Toe
# You can change this to another two-player game.

import numpy as np

BLACK, WHITE =1, -1 # first turn or second turn player

class State:
    '''Board implementation of Tic-Tac-Toe'''
    X, Y = 'ABC',  '123'
    C = {0: '_', BLACK: 'O', WHITE: 'X'}

    def __init__(self):
        self.board = np.zeros((3, 3)) # (x, y)
        self.color = 1
        self.win_color = 0
        self.record = []

    def action2str(self, a):
        return self.X[a // 3] + self.Y[a % 3]

    def str2action(self, s):
        return self.X.find(s[0]) * 3 + self.Y.find(s[1])

    def record_string(self):
        return ' '.join([self.action2str(a) for a in self.record])

    def __str__(self):
        # output board.
        s = '   ' + ' '.join(self.Y) + '\n'
        for i in range(3):
            s += self.X[i] + ' ' + ' '.join([self.C[self.board[i, j]] for j in range(3)]) + '\n'
        s += 'record = ' + self.record_string()
        return s

    def play(self, action):
        # state transition function
        # action is position inerger (0~8) or string representation of action sequence
        if isinstance(action, str):
            for astr in action.split():
                self.play(self.str2action(astr))
            return self

        x, y = action // 3, action % 3
        self.board[x, y] = self.color

        # check whether 3 stones are on the line
        if self.board[x, :].sum() == 3 * self.color \
          or self.board[:, y].sum() == 3 * self.color \
          or (x == y and np.diag(self.board, k=0).sum() == 3 * self.color) \
          or (x == 2 - y and np.diag(self.board[::-1,:], k=0).sum() == 3 * self.color):
            self.win_color = self.color

        self.color = -self.color
        self.record.append(action)
        return self

    def terminal(self):
        # terminal state check
        return self.win_color != 0 or len(self.record) == 3 * 3

    def terminal_reward(self):
        # terminal reward 
        return self.win_color if self.color == BLACK else -self.win_color

    def legal_actions(self):
        # list of legal actions on each state
        return [a for a in range(3 * 3) if self.board[a // 3, a % 3] == 0]

    def feature(self):
        # input tensor for neural nets (state)
        return np.stack([self.board == self.color, self.board == -self.color]).astype(np.float32)

    def action_feature(self, action):
        # input tensor for neural nets (action)
        a = np.zeros((1, 3, 3), dtype=np.float32)
        a[0, action // 3, action % 3] = 1
        return a

state = State().play('B1')
print(state)
print('input feature')
print(state.feature())
state = State().play('B2 A1 C2')
print('input feature')
print(state.feature())

   1 2 3
A _ _ _
B O _ _
C _ _ _
record = B1
input feature
[[[0. 0. 0.]
  [0. 0. 0.]
  [0. 0. 0.]]

 [[0. 0. 0.]
  [1. 0. 0.]
  [0. 0. 0.]]]
input feature
[[[1. 0. 0.]
  [0. 0. 0.]
  [0. 0. 0.]]

 [[0. 0. 0.]
  [0. 1. 0.]
  [0. 1. 0.]]]


In [3]:
# Neural nets with PyTorch
# small version of nets used in MuZero paper

import torch
import torch.nn as nn
import torch.nn.functional as F

class Conv(nn.Module):
    def __init__(self, filters0, filters1, kernel_size, bn=False):
        super().__init__()
        self.conv = nn.Conv2d(filters0, filters1, kernel_size, stride=1, padding=kernel_size//2, bias=False)
        self.bn = None
        if bn:
            self.bn = nn.BatchNorm2d(filters1)

    def forward(self, x):
        h = self.conv(x)
        if self.bn is not None:
            h = self.bn(h)
        return h

class ResidualBlock(nn.Module):
    def __init__(self, filters):
        super().__init__()
        self.conv = Conv(filters, filters, 3, True)

    def forward(self, x):
        return F.relu(x + (self.conv(x)))

In [4]:
num_filters = 8
num_blocks = 2

class Representation(nn.Module):
    ''' Conversion from observation to inner abstract state '''
    def __init__(self, input_shape):
        super().__init__()
        self.input_shape = input_shape
        self.board_size = self.input_shape[1] * self.input_shape[2]

        self.layer0 = Conv(self.input_shape[0], num_filters, 3, bn=True)
        self.blocks = nn.ModuleList([ResidualBlock(num_filters) for _ in range(num_blocks)])

    def forward(self, x):
        h = F.relu(self.layer0(x))
        for block in self.blocks:
            h = block(h)
        return h

    def inference(self, x):
        self.eval()
        with torch.no_grad():
            rp = self(torch.from_numpy(x).unsqueeze(0))
        return rp.cpu().numpy()[0]

class Prediction(nn.Module):
    ''' Policy and value prediction from inner abstract state '''
    def __init__(self, action_shape):
        super().__init__()
        self.board_size = np.prod(action_shape[1:])
        self.action_size = action_shape[0] * self.board_size

        self.conv_p1 = Conv(num_filters, 4, 1, bn=True)
        self.conv_p2 = Conv(4, 1, 1)

        self.conv_v = Conv(num_filters, 4, 1, bn=True)
        self.fc_v = nn.Linear(self.board_size * 4, 1, bias=False)

    def forward(self, rp):
        h_p = F.relu(self.conv_p1(rp))
        h_p = self.conv_p2(h_p).view(-1, self.action_size)

        h_v = F.relu(self.conv_v(rp))
        h_v = self.fc_v(h_v.view(-1, self.board_size * 4))

        # range of value is -1 ~ 1
        return F.softmax(h_p, dim=-1), torch.tanh(h_v)

    def inference(self, rp):
        self.eval()
        with torch.no_grad():
            p, v = self(torch.from_numpy(rp).unsqueeze(0))
        return p.cpu().numpy()[0], v.cpu().numpy()[0][0]

class Dynamics(nn.Module):
    '''Abstruct state transition'''
    def __init__(self, rp_shape, act_shape):
        super().__init__()
        self.rp_shape = rp_shape
        self.layer0 = Conv(rp_shape[0] + act_shape[0], num_filters, 3, bn=True)
        self.blocks = nn.ModuleList([ResidualBlock(num_filters) for _ in range(num_blocks)])

    def forward(self, rp, a):
        h = torch.cat([rp, a], dim=1)
        h = self.layer0(h)
        for block in self.blocks:
            h = block(h)
        return h

    def inference(self, rp, a):
        self.eval()
        with torch.no_grad():
            rp = self(torch.from_numpy(rp).unsqueeze(0), torch.from_numpy(a).unsqueeze(0))
        return rp.cpu().numpy()[0]

class Nets(nn.Module):
    '''Whole nets'''
    def __init__(self):
        super().__init__()
        state = State()
        input_shape = state.feature().shape
        action_shape = state.action_feature(0).shape
        rp_shape = (num_filters, *input_shape[1:])

        self.representation = Representation(input_shape)
        self.prediction = Prediction(action_shape)
        self.dynamics = Dynamics(rp_shape, action_shape)

    def predict_all(self, state0, path):
        '''Predict p and v from original state and path'''
        outputs = []
        self.eval()
        x = torch.from_numpy(state0.feature()).unsqueeze(0)
        with torch.no_grad():
            rp = self.representation(x)
            outputs.append(self.prediction(rp))
            for action in path:
                a = state0.action_feature(action).unsqueeze(0)
                rp = self.dynamics(rp, a)
                outputs.append(self.prediction(rp))
        #  return as numpy arrays
        return [(p.cpu().numpy()[0], v.cpu().numpy()[0][0]) for p, v in outputs]

In [5]:
def show_net(nets, state):
    '''Display policy (p) and value (v)'''
    print(state)
    p, v = nets.predict_all(state, [])[-1]
    print('p = ')
    print((p *1000).astype(int).reshape((-1, *nets.representation.input_shape[1:3])))
    print('v = ', v)
    print()

#  Outputs before training
show_net(Nets(), State())

   1 2 3
A _ _ _
B _ _ _
C _ _ _
record = 
p = 
[[[111 111 111]
  [111 111 111]
  [111 111 111]]]
v =  0.0



In [6]:
# Implementation of Monte Carlo Tree Search

class Node:
    '''Search result of one abstruct (or root) state'''
    def __init__(self, p, v):
        self.p, self.v = p, v
        self.n, self.q_sum = np.zeros_like(p), np.zeros_like(p)
        self.n_all, self.q_sum_all = 1, v / 2 # prior

    def update(self, action, q_new):
        # Update
        self.n[action] += 1
        self.q_sum[action] += q_new

        # Update overall stats
        self.n_all += 1
        self.q_sum_all += q_new

In [7]:
import time, copy

class Tree:
    '''Monte Carlo Tree'''
    def __init__(self, nets):
        self.nets = nets
        self.nodes = {}

    def search(self, state, path, rp, depth):
        # Return predicted value from new state
        key = state.record_string() #どこにノードを保存するかキーを作成。
        if len(path) > 0:
            key += '|' + ' '.join(map(state.action2str, path))
        if key not in self.nodes:
            p, v = self.nets.prediction.inference(rp)
            self.nodes[key] = Node(p, v)
            return v

        # State transition by an action selected from bandit
        node = self.nodes[key]
        p = node.p
        mask = np.zeros_like(p)
        if depth == 0:
            # Add noise to policy on the root node
            p = 0.75 * p + 0.25 * np.random.dirichlet([0.15] * len(p))
            # On the root node, we choose action only from legal actions
            mask[state.legal_actions()] = 1
            p *= mask
            p /= p.sum() + 1e-16

        n, q_sum = 1 + node.n, node.q_sum_all / node.n_all + node.q_sum
        ucb = q_sum / n + 2.0 * np.sqrt(node.n_all) * p / n + mask * 4 # PUCB formula
        best_action = np.argmax(ucb)

        # Search next state by recursively calling this function
        #representation = self.nets.dynamics.inference(rp, state.action_feature(best_action))
        rp = self.nets.dynamics.inference(rp, state.action_feature(best_action))
        
        '''
        rp not representation????
        '''
        path.append(best_action)
        q_new = -self.search(state, path, rp, depth + 1) # With the assumption of changing player by turn
        # 自分の関数を呼ぶことでループする。
        # best move tree section
        node.update(best_action, q_new)

        return q_new

    def think(self, state, num_simulations, temperature = 0, show=False):
        # End point of MCTS
        if show:
            print(state)
        start, prev_time = time.time(), 0
        for _ in range(num_simulations):
            self.search(state, [], self.nets.representation.inference(state.feature()), depth=0)
            # ここで木を伸ばす
            # Display search result on every second
            if show:
                tmp_time = time.time() - start
                if int(tmp_time) > int(prev_time):
                    prev_time = tmp_time
                    root, pv = self.nodes[state.record_string()], self.pv(state)
                    print('%.2f sec. best %s. q = %.4f. n = %d / %d. pv = %s'
                          % (tmp_time, state.action2str(pv[0]), root.q_sum[pv[0]] / root.n[pv[0]],
                             root.n[pv[0]], root.n_all, ' '.join([state.action2str(a) for a in pv])))

        #  Return probability distribution weighted by the number of simulations
        n = root = self.nodes[state.record_string()].n + 1
        n = (n / np.max(n)) ** (1 / (temperature + 1e-8))
        return n / n.sum()  #next move probability representing good moves

    def pv(self, state):
        # Return principal variation (action sequence which is considered as the best)
        s, pv_seq = copy.deepcopy(state), []
        while True:
            key = s.record_string()
            if key not in self.nodes or self.nodes[key].n.sum() == 0:
                break
            best_action = sorted([(a, self.nodes[key].n[a]) for a in s.legal_actions()], key=lambda x: -x[1])[0][0]
            pv_seq.append(best_action)
            s.play(best_action)
        return pv_seq

In [None]:
# Search with initialized nets

tree = Tree(Nets())
tree.think(State(), 100, show=True)

tree = Tree(Nets())
tree.think(State().play('A1 C1 A2 C2'), 200, show=True)

tree = Tree(Nets())
tree.think(State().play('B2 A2 A3 C1 B3'), 200, show=True)

tree = Tree(Nets())
tree.think(State().play('B2 A2 A3 C1'), 200, show=True)

In [9]:
# Training of neural nets

import torch.optim as optim

batch_size = 32
num_epochs = 30

def gen_target(ep, k):
    '''Generate inputs and targets for training'''
    # path, reward, observation, action, policy
    turn_idx = np.random.randint(len(ep[0]))
    ps, vs, ax = [], [], []
    for t in range(turn_idx, turn_idx + k + 1):
        if t < len(ep[0]):
            p = ep[4][t]
            a = ep[3][t]
        else: # state after finishing game
            # p is 0 (loss is 0)
            p = np.zeros_like(ep[4][-1])
            # random action selection
            a = np.zeros(np.prod(ep[3][-1].shape), dtype=np.float32)
            a[np.random.randint(len(a))] = 1
            a = a.reshape(ep[3][-1].shape)
        vs.append([ep[1] if t % 2 == 0 else -ep[1]])
        ps.append(p)
        ax.append(a)
        
    return ep[2][turn_idx], ax, ps, vs

def train(episodes, nets=Nets()):
    '''Train neural nets'''
    optimizer = optim.SGD(nets.parameters(), lr=1e-3, weight_decay=1e-4, momentum=0.75)
    for epoch in range(num_epochs):
        p_loss_sum, v_loss_sum = 0, 0
        nets.train()
        for i in range(0, len(episodes), batch_size):
            k = 4#np.random.randint(4)
            x, ax, p_target, v_target = zip(*[gen_target(episodes[np.random.randint(len(episodes))], k) for j in range(batch_size)])
            x = torch.from_numpy(np.array(x))
            ax = torch.from_numpy(np.array(ax))
            p_target = torch.from_numpy(np.array(p_target))
            v_target = torch.FloatTensor(np.array(v_target))
            
            # Change the order of axis as [time step, batch, ...]
            ax = torch.transpose(ax, 0, 1)
            p_target = torch.transpose(p_target, 0, 1)
            v_target = torch.transpose(v_target, 0, 1)

            p_loss, v_loss = 0, 0

            # Compute losses for k (+ current) steps
            for t in range(k + 1):
                rp = nets.representation(x) if t == 0 else nets.dynamics(rp, ax[t - 1])
                p, v = nets.prediction(rp)
                p_loss += torch.sum(-p_target[t] * torch.log(p))
                v_loss += torch.sum((v_target[t] - v) ** 2)

            p_loss_sum += p_loss.item()
            v_loss_sum += v_loss.item()

            optimizer.zero_grad()
            (p_loss + v_loss).backward()
            optimizer.step()

        for param_group in optimizer.param_groups:
            param_group['lr'] *= 0.85
    print('p_loss %f v_loss %f' % (p_loss_sum / len(episodes), v_loss_sum / len(episodes)))
    return nets

In [10]:
#  Battle against random agents

def vs_random(nets, n=100):
    results = {}
    for i in range(n):
        first_turn = i % 2 == 0
        turn = first_turn
        state = State()
        while not state.terminal():
            if turn:
                p, _ = nets.predict_all(state, [])[-1]
                action = sorted([(a, p[a]) for a in state.legal_actions()], key=lambda x:-x[1])[0][0]
            else:
                action = np.random.choice(state.legal_actions())
            state.play(action)
            turn = not turn
        r = state.terminal_reward() if turn else -state.terminal_reward()
        results[r] = results.get(r, 0) + 1
    return results

In [11]:
test_prefix = 'fix01_1500iter'

In [12]:
import time
def humanize_time(secs):
    mins, secs = divmod(secs, 60)
    hours, mins = divmod(mins, 60)
    return '%02d:%02d:%02d' % (hours, mins, secs)

In [13]:
# Main algorithm of MuZero

#num_games = 50000
num_games = 1500

num_train_steps = 10
num_simulations = 30

nets = Nets()

# Display battle results as {-1: lose 0: draw 1: win} (for episode generated for training, 1 means that the first player won)
vs_random_sum = vs_random(nets)
print('vs_random = ', sorted(vs_random_sum.items()))

episodes = []
result_distribution = {1:0, 0:0, -1:0}
start_time = time.time()
last_time = None

for g in range(num_games):
    # Generate one 1 episode
    record, p_targets, features, action_features = [], [], [], []
    state = State()
    temperature = 0.7 # temperature using to make policy targets from search results
    while not state.terminal():
        tree = Tree(nets)
        p_target = tree.think(state, num_simulations, temperature)
        # num_simulation = total number of trials 
        p_targets.append(p_target)
        features.append(state.feature())
        # Select action with generated distribution, and then make a transition by that action
        action = np.random.choice(np.arange(len(p_target)), p=p_target)
        action_features.append(state.action_feature(action))
        state.play(action) # update physical state
        record.append(action)
        temperature *= 0.8 # 探索の多様性を下げるために温度を下げる。,exploitation = greedy
    # reward seen from the first turn player
    reward = state.terminal_reward() * (1 if len(record) % 2 == 0 else -1)
    result_distribution[reward] += 1
    episodes.append((record, reward, features, action_features, p_targets))
    if g % num_train_steps == 0:
        
        print('game ', end='')
        with open(test_prefix + '_training_log.txt', 'a') as f:
            print('game ', end='',file=f)

    print(g, ' ', end='')
    with open(test_prefix + '_training_log.txt', 'a') as f:
        print(g, ' ', end='',file=f)
    
    # Training of neural nets
    if (g + 1) % num_train_steps == 0:
        runtime = humanize_time(time.time()-start_time) if (last_time == None) else humanize_time(time.time()-last_time)
        last_time = time.time()
        # Show the result distributiuon of generated episodes
        with open(test_prefix + '_training_log.txt', 'a') as f:
            print('Time took from last training = '+str(runtime),file=f)
        print('generated = ', sorted(result_distribution.items()))
        
        with open(test_prefix + '_training_log.txt', 'a') as f:
            print('generated = ', sorted(result_distribution.items()),file=f)

        nets = train(episodes, nets)
        vs_random_once = vs_random(nets)
        print('vs_random = ', sorted(vs_random_once.items()), end='')
        with open(test_prefix + '_training_log.txt', 'a') as f:
            print('vs_random = ', sorted(vs_random_once.items()), end='',file=f)
        
        for r, n in vs_random_once.items():
            vs_random_sum[r] += n
        print(' sum = ', sorted(vs_random_sum.items()))      
        with open(test_prefix + '_training_log.txt', 'a') as f:
            print(' sum = ', sorted(vs_random_sum.items()),file=f)
            
        #show_net(nets, State())
        #show_net(nets, State().play('A1 C1 A2 C2'))
        #show_net(nets, State().play('A1 B2 C3 B3 C1'))
        #show_net(nets, State().play('B2 A2 A3 C1 B3'))
        #show_net(nets, State().play('B2 A2 A3 C1'))
print('finished')

vs_random =  [(-1, 42), (0, 10), (1, 48)]
game 0  1  2  3  4  5  6  7  8  9  generated =  [(-1, 3), (0, 1), (1, 6)]
p_loss 25.736346 v_loss 4.982507
vs_random =  [(-1, 46), (0, 10), (1, 44)] sum =  [(-1, 88), (0, 20), (1, 92)]
game 10  11  12  13  14  15  16  17  18  19  generated =  [(-1, 6), (0, 3), (1, 11)]
p_loss 12.180507 v_loss 2.736623
vs_random =  [(-1, 53), (0, 11), (1, 36)] sum =  [(-1, 141), (0, 31), (1, 128)]
game 20  21  22  23  24  25  26  27  28  29  generated =  [(-1, 10), (0, 3), (1, 17)]
p_loss 8.955230 v_loss 2.049779
vs_random =  [(-1, 32), (0, 22), (1, 46)] sum =  [(-1, 173), (0, 53), (1, 174)]
game 30  31  32  33  34  35  36  37  38  39  generated =  [(-1, 15), (0, 5), (1, 20)]
p_loss 12.066294 v_loss 5.209023
vs_random =  [(-1, 32), (0, 18), (1, 50)] sum =  [(-1, 205), (0, 71), (1, 224)]
game 40  41  42  43  44  45  46  47  48  49  generated =  [(-1, 21), (0, 5), (1, 24)]
p_loss 9.023237 v_loss 3.054748
vs_random =  [(-1, 47), (0, 12), (1, 41)] sum =  [(-1, 252),

In [None]:
# Show outputs from trained nets

print('initial state')
show_net(nets, State())

print('WIN by put')
show_net(nets, State().play('A1 C1 A2 C2'))

print('LOSE by opponent\'s double')
show_net(nets, State().play('B2 A2 A3 C1 B3'))

print('WIN through double')
show_net(nets, State().play('B2 A2 A3 C1'))

# hard case: putting on A1 will cause double
print('strategic WIN by following double')
show_net(nets, State().play('B1 A3'))


In [None]:
# Search with trained nets

tree = Tree(net)
tree.think(State(), 100000, show=True)