In [20]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from collections import namedtuple
from random import random
import sys
sys.path.append('../')

In [2]:
from MCTS import MCTS
from Node import Node
from Game import TicTacToe_Board
from ReplayMemory import ReplayMemory

In [3]:
# Declare NN of policy:
class DNN(nn.Module):
    def __init__(self):
        super(DNN, self).__init__()
        
        self.l1 = nn.Linear(9, 20)
        self.l2 = nn.Linear(20, 9)
        self.l3 = nn.Linear(20,1)
        
    def forward(self, x):
        x = F.relu(self.l1(x))
        x1 = self.l2(x) # Policy head
        x2 = torch.tanh(self.l3(x)) # Value head.
        return x1, x2

In [104]:
#MCTS game:
T = 0.3
state = torch.tensor([0,0,0,0,0,0,0,0,0],dtype=torch.float)
end = False
player = -1

dnn = DNN()
game = TicTacToe_Board()
root = Node(state, 1, player)
mcts = MCTS(game, root, dnn)

node = root
while not end:
    mcts.explore(node)
    a = mcts.play(node, T)
    node = node.children[a]
    print(node.player)
    game.plot(node.state)
    
    end, winner = game.check_end(node.state)

1
-------------
| x |   |   | 
-------------
|   |   |   | 
-------------
|   |   |   | 
-------------
-1
-------------
| x | o |   | 
-------------
|   |   |   | 
-------------
|   |   |   | 
-------------
1
-------------
| x | o | x | 
-------------
|   |   |   | 
-------------
|   |   |   | 
-------------
-1
-------------
| x | o | x | 
-------------
| o |   |   | 
-------------
|   |   |   | 
-------------
1
-------------
| x | o | x | 
-------------
| o | x |   | 
-------------
|   |   |   | 
-------------
-1
-------------
| x | o | x | 
-------------
| o | x | o | 
-------------
|   |   |   | 
-------------
1
-------------
| x | o | x | 
-------------
| o | x | o | 
-------------
|   | x |   | 
-------------
-1
-------------
| x | o | x | 
-------------
| o | x | o | 
-------------
|   | x | o | 
-------------
1
-------------
| x | o | x | 
-------------
| o | x | o | 
-------------
| x | x | o | 
-------------


# Train loop

In [97]:
CAPACITY = 1000
REPLAY_START_SIZE = 100
BATCH = 100

EPISODES = 100

NGAMES=10

In [21]:
Transition = namedtuple('Transition',('state', 'policy', 'reward'))

In [98]:
dnn = DNN()
replay_memory = ReplayMemory(CAPACITY)

In [119]:
def loss_fn(z, v, policy, net_pol):
    a = torch.mean(torch.pow(z-v,2))
    b = (policy * net_pol).sum()
    print(a,b)
    return a - b

In [57]:
optimizer = torch.optim.SGD(dnn.parameters(), lr=1e-2, weight_decay=1e-4)

In [86]:
def optimize(model, optimizer, loss_fn, memory):
    if memory.__len__() < REPLAY_START_SIZE:
        return
    
    transitions = memory.sample(BATCH)
    batch = Transition(*zip(*transitions))
    
    states = torch.cat(batch.state).reshape([BATCH,-1])
    policy = torch.cat(batch.policy).reshape([BATCH,-1])
    z = torch.cat(batch.reward).reshape([BATCH,-1])
    
    net_pol, v = dnn(states)
    
    
    loss = loss_fn(z,v,policy,net_pol)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    return loss.data.item()

In [94]:
# Train loop:
def play_game():
    T = 0.3
    state = torch.tensor([0,0,0,0,0,0,0,0,0],dtype=torch.float)
    end = False
    player = -1
    
    
    game = TicTacToe_Board()
    root = Node(state, 1, player)
    mcts = MCTS(game, root, dnn,ngames=NGAMES)
    
    node = root
    
    history = []
    while not end:
        history.append(node)
        mcts.explore(node)
        a = mcts.play(node, T)
        node = node.children[a]
        
        end, winner = game.check_end(node.state)
    
    # Save play in replaymemory
    for node in history:
        policy = torch.tensor([0,0,0,0,0,0,0,0,0],dtype=torch.float)
        p = mcts.eval_policy(node, 1)
        for i, a in enumerate(game.avail_actions(node.state)):
            policy[a] = p[i]
            
        replay_memory.add(node.state, policy, torch.tensor([winner*node.player], dtype=torch.float))

In [120]:
for episode in range(EPISODES):
    # Play game to fill memory
    play_game()
    
    # Optimize
    loss = optimize(dnn, optimizer, loss_fn, replay_memory)
    
    print(loss)

tensor(0.9897, grad_fn=<MeanBackward0>) tensor(-2.8327, grad_fn=<SumBackward0>)
3.8223655223846436
tensor(0.9820, grad_fn=<MeanBackward0>) tensor(1.0510, grad_fn=<SumBackward0>)
-0.06899690628051758
tensor(0.9900, grad_fn=<MeanBackward0>) tensor(-2.8139, grad_fn=<SumBackward0>)
3.8038864135742188
tensor(1.0156, grad_fn=<MeanBackward0>) tensor(1.6030, grad_fn=<SumBackward0>)
-0.5873874425888062
tensor(0.9960, grad_fn=<MeanBackward0>) tensor(-0.7285, grad_fn=<SumBackward0>)
1.7245323657989502
tensor(1.0031, grad_fn=<MeanBackward0>) tensor(-0.3560, grad_fn=<SumBackward0>)
1.3590631484985352
tensor(0.9887, grad_fn=<MeanBackward0>) tensor(-1.7879, grad_fn=<SumBackward0>)
2.776566505432129
tensor(0.9953, grad_fn=<MeanBackward0>) tensor(-4.8000, grad_fn=<SumBackward0>)
5.795280933380127
tensor(0.9802, grad_fn=<MeanBackward0>) tensor(-0.9452, grad_fn=<SumBackward0>)
1.9253618717193604
tensor(1.0063, grad_fn=<MeanBackward0>) tensor(-2.5988, grad_fn=<SumBackward0>)
3.6051154136657715
tensor(0.98

tensor(1.0083, grad_fn=<MeanBackward0>) tensor(-6.1604, grad_fn=<SumBackward0>)
7.168708801269531
tensor(0.9903, grad_fn=<MeanBackward0>) tensor(1.6728, grad_fn=<SumBackward0>)
-0.6825123429298401
tensor(0.9971, grad_fn=<MeanBackward0>) tensor(-2.1351, grad_fn=<SumBackward0>)
3.1322453022003174
tensor(1.0144, grad_fn=<MeanBackward0>) tensor(-3.1548, grad_fn=<SumBackward0>)
4.169234752655029
tensor(1.0057, grad_fn=<MeanBackward0>) tensor(-3.9700, grad_fn=<SumBackward0>)
4.975736141204834
tensor(1.0125, grad_fn=<MeanBackward0>) tensor(-4.2592, grad_fn=<SumBackward0>)
5.271704196929932
tensor(1.0029, grad_fn=<MeanBackward0>) tensor(-7.0969, grad_fn=<SumBackward0>)
8.099796295166016
tensor(1.0106, grad_fn=<MeanBackward0>) tensor(-0.3581, grad_fn=<SumBackward0>)
1.3686509132385254
tensor(1.0101, grad_fn=<MeanBackward0>) tensor(-0.9748, grad_fn=<SumBackward0>)
1.984839916229248
tensor(1.0215, grad_fn=<MeanBackward0>) tensor(-4.8555, grad_fn=<SumBackward0>)
5.8769450187683105
tensor(1.0236, g

In [122]:
#Test:
T = 0.3
state = torch.tensor([0,0,0,0,0,0,0,0,0],dtype=torch.float)
end = False
player = -1

game = TicTacToe_Board()
root = Node(state, 1, player)
mcts = MCTS(game, root, dnn)

node = root
while not end:
    # Expand state:
    mcts.player *= -1
    actions = game.avail_actions(node.state)
    mcts.expand(node, actions)
    a = torch.argmax(torch.tensor([child.P for child in node.children]))
    node = node.children[a]
    print(node.Q)
    game.plot(node.state)
    
    end, winner = game.check_end(node.state)

0.0
-------------
|   |   | x | 
-------------
|   |   |   | 
-------------
|   |   |   | 
-------------
0.0
-------------
|   |   | x | 
-------------
|   |   |   | 
-------------
|   | o |   | 
-------------
0.0
-------------
| x |   | x | 
-------------
|   |   |   | 
-------------
|   | o |   | 
-------------
0.0
-------------
| x | o | x | 
-------------
|   |   |   | 
-------------
|   | o |   | 
-------------
0.0
-------------
| x | o | x | 
-------------
|   |   |   | 
-------------
| x | o |   | 
-------------
0.0
-------------
| x | o | x | 
-------------
|   | o |   | 
-------------
| x | o |   | 
-------------
