In [1]:
import torch
import numpy as np
from matplotlib import pyplot as plt
from matplotlib import colors

%matplotlib widget 

Environment closely follows OpenAI gym API. Currently can not be invoked with ```gym.make("env_id")```, though it should be easy to do.

In [4]:
import gym

env = gym.make('gym_go:go-v0', size=7, komi=0, reward_method='real')

In [None]:
import monte_carlo_tree

class GoNode(monte_carlo_tree.Node):
    """Go Game Tree Node"""

    def prepared_game_state(self, player=None):
        """
        Prepare game state X from perspective of current player
        [
            [ 1 -1 -1 ]
            [ 1  0  0 ]
            [ 0  0 -1 ]
        ]

        
        Where  
            1:  current player
            -1: opposing player
            
        """

        if player == None:
            player = self.current_player()

        # take advantage of game symmetry        
        state = self.blacks() - self.whites() if player == 1 else self.whites() - self.blacks()

        return state

    def whites(self):
        """White pieces on board"""
        return state[0]
    
    def blacks(self):
        """Black pieces on board"""
        return state[1]

    def current_player(self):
        return state[2]
    
    def possible_actions(self, player=None, raw=False):
        """List of possible next actions"""
        actions = (self.blacks() + self.whites()) + self.state[3]
        actions[actions!=0] = np.nan
        action[action!=np.nan] = 1

        coords = []
        for piece_id in range(0, 18):
            piece = self.get_piece_by_id(piece_id)
            
            for action in self.state['moveset'][piece_id]:
                coords.append((piece['x'], piece['y'], action[0], action[1]))

        actions = np.zeros((8, 8, 8, 8))
        for c in coords:
            actions[c] = 1
        
        mask = self.possible_actions_mask(player)
        actions = actions * mask

        if raw:
            return actions

        return np.argwhere(actions).tolist()

    def possible_actions_mask(self, player=None):
        """Return list of possible next actions as i8 mask"""
        if player == None:
            player = self.current_player()

        coords = []
        for piece_id in range(0, 18):
            piece = self.get_piece_by_id(piece_id)

            if (piece['color'] == "Black" and player == 1) or (piece['color'] == "White" and player == 0):
                coords.append((piece['x'], piece['y']))

        actions = np.zeros((8, 8, 8, 8))
        for c in coords:
            actions[c] = 1

        return actions


## Random Play Tree

This is very basic algorithm that plays the game by making random moves. Sometimes it reaches the end goal, but overall it supper inneficient.

In [6]:
import monte_carlo_tree


def random_play():
    '''
    Play a game using random tree strategy
    '''
        
    tree = monte_carlo_tree.RandomPlayTree(env, 7)
    
    root_node = tree.root_node
    return tree.simulate(root_node)
    

def build_stats(playfunc, n_games=100):
    '''
    Play a number of random games and display result
    '''

    black_wins = 0
    white_wins = 0
    draws = 0
    moves = []
    
    for _ in range(n_games):
        m = playfunc()
        if m.reward > 0:
            black_wins += 1
        elif m.reward < 0:
            white_wins += 1
        elif m.reward == 0:
            draws += 1
       
        moves.append(m.depth())
    
    print("Blacks: ", black_wins, "Whites: ", white_wins, "Draws: ", draws, "Moves mean:", np.mean(moves))

In [7]:
build_stats(random_play, 1)

IndexError: only integers, slices (`:`), ellipsis (`...`), numpy.newaxis (`None`) and integer or boolean arrays are valid indices

## Monte Carlo Search Tree

In [5]:
from monte_carlo_tree import MonteCarloPlayTree


def mtsc_play():
    '''
    Play a game using MonteCarloSearchTree
    '''

    mcst = MonteCarloPlayTree(checkers_env, 8)
    
    root_node = mcst.root_node
    terminal_node = mcst.simulate(root_node)
    
    return terminal_node


In [6]:
build_stats(mtsc_play, 1)

Blacks:  0 Whites:  0 Draws:  1 Moves mean: 103.0


## Guided Tree Search

In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F

BOARD_SIZE = 8

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

class ActorCritic(nn.Module):

    def __init__(self, board_size):
        super(ActorCritic, self).__init__()
        
        self.board_size = BOARD_SIZE

        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        
        # Policy head
        self.act_conv1 = nn.Conv2d(128, 4, kernel_size=1)
        self.act_fc1 = nn.Linear(4*(self.board_size**2), self.board_size**4)
        
        # Critic head
        self.val_conv1 = nn.Conv2d(128, 2, kernel_size=1)
        self.val_fc1 = nn.Linear(2*self.board_size**2, 64)
        self.val_fc2 = nn.Linear(64, 1)

        # Activation Functions
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()
        self.tanh = nn.Tanh()

        
    def forward(self, x):

        x = x.unsqueeze(1)
        x = self.relu(self.conv1(x))
        x = self.relu(self.conv2(x))
        x = self.relu(self.conv3(x))
        
        # Actor head
        y_actions = self.relu(self.act_conv1(x)).view(-1, 4*self.board_size**2)
        y_actions = self.act_fc1(y_actions)
        y_actions = self.sigmoid(y_actions)
        
        # Critic head
        y_value = self.relu(self.val_conv1(x))
        y_value = y_value.view(-1, 2*(self.board_size**2))
        y_value = self.relu(self.val_fc1(y_value))
        y_value = self.tanh(self.val_fc2(y_value))

        return y_actions, y_value

actor_critic_network = ActorCritic(BOARD_SIZE).to(device)

In [8]:
tree = monte_carlo_tree.GuidedMonteCarloPlayTree(checkers_env, BOARD_SIZE, actor_critic_network, device)

In [9]:
losses = []

In [10]:
# tree eats too much memory. free up every 10 cycles
actor_critic_network.load_state_dict(torch.load("./actor_critic.pt"))

for i in range(50):
    tree = monte_carlo_tree.GuidedMonteCarloPlayTree(checkers_env, BOARD_SIZE, actor_critic_network, device)
    for loss in tree.train(10):
        losses.append(loss)
    torch.save(actor_critic_network.state_dict(), "./actor_critic.pt")
    del tree.root_node


Iteration # 0
number of actions:  116
Loss: tensor(0.7676, device='cuda:0', grad_fn=<SumBackward0>)
Iteration # 1
number of actions:  29
Loss: tensor(2.7165, device='cuda:0', grad_fn=<SumBackward0>)
Iteration # 2
number of actions:  90
Loss: tensor(360.3415, device='cuda:0', grad_fn=<SumBackward0>)
Iteration # 3
number of actions:  52
Loss: tensor(205.2152, device='cuda:0', grad_fn=<SumBackward0>)
Iteration # 4
number of actions:  115
Loss: tensor(2.3692, device='cuda:0', grad_fn=<SumBackward0>)
Iteration # 5
number of actions:  116
Loss: tensor(1.2483, device='cuda:0', grad_fn=<SumBackward0>)
Iteration # 6
number of actions:  111
Loss: tensor(0.2214, device='cuda:0', grad_fn=<SumBackward0>)
Iteration # 7
number of actions:  128
Loss: tensor(23.4483, device='cuda:0', grad_fn=<SumBackward0>)
Iteration # 8
number of actions:  65
Loss: tensor(249.7696, device='cuda:0', grad_fn=<SumBackward0>)
Iteration # 9
number of actions:  60
Loss: tensor(224.4137, device='cuda:0', grad_fn=<SumBackward

In [None]:
def gmcts_play():
    '''
    Play a game using random tree strategy
    '''
        
    tree = monte_carlo_tree.GuidedMonteCarloPlayTree(checkers_env, BOARD_SIZE, actor_critic_network, device)
    
    root_node = tree.root_node
    return tree.simulate(root_node)

In [None]:
# build_stats(gmcts_play, 100)

In [None]:
example = torch.rand(1, 8, 8).to(device)
traced_script_module = torch.jit.trace(actor_critic_network, example)
traced_script_module.save("model.pt")

In [None]:
class Actor(nn.Module):

    def __init__(self, actor_critic_network):
        super(Actor, self).__init__()
        self.actor_critic = actor_critic_network

    def forward(self, x):
        y, _ = self.actor_critic(x)
        return y

class Critic(nn.Module):

    def __init__(self, actor_critic_network):
        super(Actor, self).__init__()
        self.actor_critic = actor_critic_network

    def forward(self, x):
        _, y = self.actor_critic(x)
        return y

actor_network = Actor(actor_critic_network).to(torch.device("cpu"))
critic_network = Actor(actor_critic_network).to(torch.device("cpu"))

In [None]:
from torch import jit

example = torch.rand(1, 8, 8).to(torch.device("cpu"))
print(example.size())
traced_script_module = torch.jit.trace(actor_network, example)
traced_script_module.save("actor_network.pt")

In [None]:
from matplotlib import pyplot as plt 
plt.plot([l.detach().cpu().numpy() for l in losses])
plt.show()