This file contains the methods to test the different trained networks against each other

In [1]:
from Connect4Game import C4Game

g = C4Game(height=6, width=7, win_length=4)

Set up randomized attack agent which returns each action with uniform probability

In [26]:
class RandomAgent():
    def __init__(self, g):
        self.g = g
        
    def forward(self, board, temp = 1):
        return torch.FloatTensor([1] * self.g.width), 0
    
    def getActionProb(self, board, temp = 1):
        return self.forward(board)[0].detach().numpy()


load all trained networks

In [27]:
import os
from C4_net import C4_net
import torch
from C4_es import *
folder = 'checkpoint'


names = ['a2c', 'az_es', 'az_grad', 'cem', 'oneone', 'random']
filenames = {
    'a2c' : 'a2c_C4_check.pth.tar', 
    'az_es' : 'AZ_ES_C4_check.pth.tar', 
    'az_grad' : 'AZ_C4_check.pth.tar', 
    'cem' : 'cem_C4_check.pth.tar', 
    'oneone' : 'oneone_C4_check.pth.tar'
}


players = {}

for name in filenames.keys():
    fname = filenames[name]
    loc = os.path.join(folder, fname)
    
    players[name] = C4_net(g)
    checkpoint = torch.load(loc)
    if 'AZ' in fname:
        players[name].load_state_dict(checkpoint['state_dict'])
        
    else:
        players[name].load_state_dict(checkpoint)

players['random'] = RandomAgent(g)

In [28]:
# total number of games for each pairing to play, original was 100
TOTAL_GAMES = 1

Play matches using only neural networks

In [30]:
wins = []
count = 0
for p1_name in players.keys():
    player1 = players[p1_name]
    print(p1_name)
    count += 1
    player1_wins = []
    for p2_name in players.keys():
        player2 = players[p2_name]
        win_count = play_games(g, total = TOTAL_GAMES, player1 = player1, player2 = player2, temp = 1)
        player1_wins.append(win_count)
        print(win_count, end = ', ')
    print()
    wins.append(player1_wins)


a2c
0, 1, 1, 0, 1, 1, 
az_es
1, 0, 1, 0, 1, 1, 
az_grad
0, 0, 1, 1, 0, 0, 
cem
0, 0, 0, 0, 1, 0, 
oneone
1, 0, 1, 0, 1, 0, 
random
1, 1, 0, 1, 1, 1, 


Store win counts as csv

In [31]:
import pandas as pd

nnet_wins = pd.DataFrame(wins, columns = names, index = names)

nnet_wins = nnet_wins.transpose()

nnet_wins.to_csv(r'nnet_wins.csv')

nnet_wins

Unnamed: 0,a2c,az_es,az_grad,cem,oneone,random
a2c,0,1,0,0,1,1
az_es,1,0,0,0,0,1
az_grad,1,1,1,0,1,0
cem,0,0,1,0,0,1
oneone,1,1,0,1,1,1
random,1,1,0,0,0,1


Play matches using Monte Carlo tree search

this can take a while

In [37]:
from mcts_c4 import MCTS
from C4_net import NNetWrapper as wrapper

args = dotdict({
    'numEps': 5,        # Number of complete self-play games to simulate during a new iteration.
    'numMCTSSims': 20,  # Number of games moves for MCTS to simulate.
    'cpuct': 1,         # hyperparameter for MCTS
    'batch_size' : 8,  # number of samples to take for AZ-ES, N in paper
    'elite_size' : 4,  # elite size for AZ-ES, K in paper
})

wins = []
count = 0

for p1_name in players.keys():
    player1 = players[p1_name]
    print(p1_name)
    count += 1
    player1_wins = []
    for p2_name in players.keys():
        player2 = players[p2_name]
        nnet_wrapper1 = wrapper(g)
        nnet_wrapper2 = wrapper(g)
        
        if type(player1) != type(RandomAgent(g)):
            nnet_wrapper1.nnet = player1
            
        if type(player2) != type(RandomAgent(g)):
            nnet_wrapper2.nnet = player2
        
        mcts1 = MCTS(g, nnet_wrapper1, args)
        mcts2 = MCTS(g, nnet_wrapper2, args)
        
        win_count = play_games(g, total = TOTAL_GAMES, player1 = mcts1, player2 = mcts2, temp = 1)
                
        print(win_count, end = ', ')
        player1_wins.append(win_count)
        
    print()
    wins.append(player1_wins)

a2c
0, 0, 0, 1, 1, 1, 
az_es
1, 1, 1, 0, 1, 1, 
az_grad
0, 1, 0, 1, 1, 1, 
cem
0, 0, 1, 0, 0, 0, 
oneone
0, 0, 1, 0, 0, 0, 
random
1, 0, 1, 1, 0, 1, 


In [38]:
mcts_wins = pd.DataFrame(wins, columns = names, index = names)

mcts_wins = mcts_wins.transpose()

mcts_wins.to_csv(r'mcts_wins.csv')

mcts_wins

Unnamed: 0,a2c,az_es,az_grad,cem,oneone,random
a2c,0,1,0,0,0,1
az_es,0,1,1,0,0,0
az_grad,0,1,0,1,1,1
cem,1,0,1,0,0,1
oneone,1,1,1,0,0,0
random,1,1,1,0,0,1
