In [10]:
import torch
import numpy as np
from matplotlib import pyplot as plt
from matplotlib import colors

%matplotlib widget 

Environment closely follows OpenAI gym API. Currently can not be invoked with ```gym.make("env_id")```, though it should be easy to do.

In [11]:
import env
checkers_env = env.Env()

A state is ```(5+NUMBER_OF_PIECES, BOARD_SIZE, BOARD_SIZE) ndarray```

Where

* ```state[0]``` — black pieces
* ```state[1]``` — white pieces
* ```state[2]``` — pieces ids
* ```state[3]``` — current turn (blacks=1, whites=0)
* ```state[4]``` — whether game in terminal state
* ```state[5+i]``` — allowed moves for piece number i

I should probably move allowed moves elsewhere but I have a lot of RAM to keep things unoptimized.


In [12]:
""" Various plotting helper methods """

def show_state(state):
    plt.imshow(state)
    plt.show()

def show_board(board):
   cmap = colors.ListedColormap(['white', 'red'])
   bounds=[0,0.5,18]
   norm = colors.BoundaryNorm(bounds, cmap.N)
   # plt.figure(figsize=(4,4))
   plt.imshow(board, cmap=cmap, norm=norm, interpolation='none')
   plt.xticks(np.arange(0.5,8.5), [])
   plt.yticks(np.arange(0.5,8.5), [])

   plt.grid()

def do_the_flip(arr):
    return np.flip(np.flip(np.flip(arr, 0), 0), 1)

def show_trajectory_item(trajectory, index):
    board_i = monte_carlo_tree.state_to_board(trajectory[index].original_state)
    show_board(do_the_flip(board_i[0] + board_i[1]))

## Random Play Tree

This is very basic algorithm that plays the game by making random moves. Sometimes it reaches the end goal, but overall it supper inneficient.

In [13]:
import monte_carlo_tree


def random_play():
    '''
    Play a game using random tree strategy
    '''
        
    tree = monte_carlo_tree.RandomPlayTree(checkers_env, 8)
    
    root_node = tree.root_node
    return tree.simulate(root_node)
    

def build_stats(playfunc, n_games=100):
    '''
    Play a number of random games and display result
    '''

    black_wins = 0
    white_wins = 0
    draws = 0
    moves = []
    
    for _ in range(n_games):
        m = playfunc()
        if m.value > 0:
            black_wins += 1
        elif m.value < 0:
            white_wins += 1
        elif m.value == 0:
            draws += 1
       
        moves.append(m.depth())
    
    print("Blacks: ", black_wins, "Whites: ", white_wins, "Draws: ", draws, "Moves mean:", np.mean(moves))

In [14]:
# build_stats(random_play, 100)

## Monte Carlo Search Tree

In [15]:
from monte_carlo_tree import MonteCarloPlayTree


def mtsc_play():
    '''
    Play a game using MonteCarloSearchTree
    '''

    mcst = MonteCarloPlayTree(checkers_env, 8)
    
    root_node = mcst.root_node
    terminal_node = mcst.simulate(root_node)
    
    return terminal_node


In [16]:
# build_stats(mtsc_play, 100)

## Guided Tree Search

In [17]:
import torch
import torch.nn as nn
import torch.nn.functional as F

BOARD_SIZE = 8

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

class ActorCritic(nn.Module):

    def __init__(self, board_size=BOARD_SIZE):
        super(ActorCritic, self).__init__()
        
        self.board_size = BOARD_SIZE

        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        
        # Policy head
        self.act_conv1 = nn.Conv2d(128, 4, kernel_size=1)
        self.act_fc1 = nn.Linear(4*(self.board_size**2), self.board_size**4)
        
        # Critic head
        self.val_conv1 = nn.Conv2d(128, 2, kernel_size=1)
        self.val_fc1 = nn.Linear(2*self.board_size**2, 64)
        self.val_fc2 = nn.Linear(64, 1)
        
    def forward(self, x, moves):

        x = x.unsqueeze(1)
        x = torch.relu(self.conv1(x))
        x = torch.relu(self.conv2(x))
        x = torch.relu(self.conv3(x))
        
        # Actor head
        x_act = torch.relu(self.act_conv1(x))
        x_act = x_act.view(-1, 4*self.board_size**2)
        x_act = torch.log_softmax(self.act_fc1(x_act), dim=1)
        
        # availability of moves
        avail = moves.view(-1, self.board_size**4) 

        # locations where actions are not possible, we set the prob to zero
        maxa = torch.max(x_act)
        # subtract off max for numerical stability (avoids blowing up at infinity)
        exp = avail*torch.exp(x_act-maxa)
        prob = (exp/torch.sum(exp))
        
        prob = prob.view(-1, BOARD_SIZE,BOARD_SIZE,BOARD_SIZE,BOARD_SIZE)
        
        # Critic head
        value = torch.relu(self.val_conv1(x))
        value = value.view(-1, 2*(self.board_size**2))
        value = F.relu(self.val_fc1(value))
        value = torch.tanh(self.val_fc2(value))

        return prob, value

actor_critic_network = ActorCritic().to(device)

In [18]:
tree = monte_carlo_tree.GuidedMonteCarloPlayTree(checkers_env, BOARD_SIZE, actor_critic_network, device)

In [23]:
tree.train(1000)

Iteration # 0
number of moves:  56
Loss: tensor(2372.7778, device='cuda:0', grad_fn=<SumBackward0>)
Iteration # 1
number of moves:  56
Loss: tensor(2372.7749, device='cuda:0', grad_fn=<SumBackward0>)
Iteration # 2
number of moves:  56
Loss: tensor(2372.7722, device='cuda:0', grad_fn=<SumBackward0>)
Iteration # 3
number of moves:  56
Loss: tensor(2372.7698, device='cuda:0', grad_fn=<SumBackward0>)
Iteration # 4
number of moves:  56
Loss: tensor(2372.7671, device='cuda:0', grad_fn=<SumBackward0>)
Iteration # 5
number of moves:  56
Loss: tensor(2372.7644, device='cuda:0', grad_fn=<SumBackward0>)
Iteration # 6
number of moves:  56
Loss: tensor(2372.7617, device='cuda:0', grad_fn=<SumBackward0>)
Iteration # 7
number of moves:  56
Loss: tensor(2372.7590, device='cuda:0', grad_fn=<SumBackward0>)
Iteration # 8
number of moves:  56
Loss: tensor(2372.7563, device='cuda:0', grad_fn=<SumBackward0>)
Iteration # 9
number of moves:  56
Loss: tensor(2372.7537, device='cuda:0', grad_fn=<SumBackward0>)


In [20]:
torch.save(actor_critic_network.state_dict(), "./../assets/actor_critic.pt")

In [21]:
def gmcts_play():
    '''
    Play a game using random tree strategy
    '''
        
    tree = monte_carlo_tree.GuidedMonteCarloPlayTree(checkers_env, BOARD_SIZE, actor_critic_network, device)
    
    root_node = tree.root_node
    return tree.simulate(root_node)

In [22]:
build_stats(gmcts_play, 100)

Blacks:  6 Whites:  5 Draws:  89 Moves mean: 55.18
