In [None]:
import torch
import numpy as np

In [None]:
BOARD_SIZE = 8

Environment closely follows OpenAI gym API. Currently can not be invoked with ```gym.make("env_id")```, though it should be easy to do.

In [None]:
import gym
import gym_go

env = gym.make('gym_go:go-v0', size=BOARD_SIZE, komi=0, reward_method='real')

In [None]:
import monte_carlo_tree


class GoNode(monte_carlo_tree.Node):
    """Go Game Tree Node"""

    def prepare_state(self, player=None):
        """
        Prepare game state X from perspective of current player
        [
            [ 1 -1 -1 ]
            [ 1  0  0 ]
            [ 0  0 -1 ]
        ]

        
        Where  
            1:  current player
            -1: opposing player
            
        """

        if player == None:
            player = self.current_player()

        # take advantage of game symmetry        
        state = self.blacks() - self.whites() if player == 1 else self.whites() - self.blacks()

        return state

    def whites(self):
        """White pieces on board"""
        return self.state[0]
    
    def blacks(self):
        """Black pieces on board"""
        return self.state[1]

    def current_player(self):
        return 1 if self.state[2, 0, 0] == 1 else -1
    
    def possible_actions(self):
        """List of possible next actions"""
        actions = (self.blacks() + self.whites()) + self.state[3]

        actions_ = np.zeros_like(actions)
        actions_[actions==0] = 1
        
        return actions_

    def possible_actions_list(self):
        actions = self.possible_actions()
        return np.argwhere(actions==1)

    def prepare_action(self, action):
        return action

    def evaluate(self, env):
        return env.winning()

## Random Play Tree

This is very basic algorithm that plays the game by making random moves. Sometimes it reaches the end goal, but overall it supper inneficient.

In [None]:
rt = monte_carlo_tree.RandomPlayTree(env, GoNode, BOARD_SIZE)


def random_play():
    '''
    Play a game using random tree strategy
    '''
    return rt.simulate(rt.root_node)
    

def build_stats(playfunc, n_games=100):
    '''
    Play a number of random games and display result
    '''

    black_wins = 0
    white_wins = 0
    draws = 0
    moves = []
    
    for _ in range(n_games):
        m = playfunc()

        reward = env.winning()

        if reward > 0:
            black_wins += 1
        elif reward < 0:
            white_wins += 1
        elif reward == 0.0:
            draws += 1
       
        moves.append(m.depth())
    
    print("Blacks: ", black_wins, "Whites: ", white_wins, "Draws: ", draws, "Moves mean:", np.mean(moves))

In [None]:
# build_stats(random_play, 100)

## Monte Carlo Search Tree

In [None]:
mcst = monte_carlo_tree.MonteCarloPlayTree(env, GoNode, BOARD_SIZE)

def mtsc_play():
    '''
    Play a game using MonteCarloSearchTree
    '''
    
    return mcst.simulate(mcst.root_node)

In [None]:
# build_stats(mtsc_play, 100)

## Alpha Zero

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F


device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

class ActorCritic(nn.Module):

    def __init__(self, board_size=BOARD_SIZE):
        super(ActorCritic, self).__init__()
        
        self.board_size = board_size
        self.conv1 = nn.Conv2d(1, 64, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
        self.conv4 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
        self.layer1 = nn.Linear(128, 64)
        self.layer2 = nn.Linear(128, 1)
        
    def forward(self, x):

        x = x.unsqueeze(1)
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2)
        x = F.relu(self.conv3(x))
        x = F.max_pool2d(x, 2)
        x = F.relu(self.conv4(x))
        x = F.max_pool2d(x, 2)
        x = F.dropout(x, p=0.2, training=self.training)
        x = x.view(-1, 128)
        
        prob = torch.sigmoid(self.layer1(x))
        value = torch.tanh(self.layer2(x))

        return prob.view(-1, 8, 8), value.view(-1, 1)

actor_critic_network = ActorCritic().to(device)

In [None]:
azt = monte_carlo_tree.GuidedMonteCarloPlayTree(env, GoNode, BOARD_SIZE, actor_critic_network, device)

In [None]:
losses = []

In [None]:
for i in range(50):
    for loss in azt.train(10):
        losses.append(loss)
    torch.save(actor_critic_network.state_dict(), "./actor_critic.pt")