In [1]:
import numpy as np
import tqdm

from attention.connect4_mcts.policy import Model as AttentionModel
from attention.connect4_mcts.mcts import MCTS as AttentionMCTS
from attention.connect4_mcts.players import MctsPlayer as AttentionPlayer

from no_attention.connect4_mcts.policy import Model as ConvolutionModel
from no_attention.connect4_mcts.players import MctsPlayer as ConvolutionPlayer
from no_attention.connect4_mcts.mcts import MCTS as ConvolutionMCTS

from attention.connect4_mcts.game import Game
from attention.connect4_mcts.game import GameResult

In [2]:
attention_model = AttentionModel(256, 1, 1, 'cpu')
attention_model.load('attention/model.pt')
convolution_model = ConvolutionModel(128, 10, 1, 'cpu')
convolution_model.load('no_attention/model.pt')

In [3]:
def play_game_attention_first(attention_model, convolution_model) -> float:
    game = Game()
    states = np.zeros((0, 4, 6, 7), np.float32)
    attention_mcts = AttentionMCTS(1)
    convolution_mcts = ConvolutionMCTS(1)
    i = 0
    while not game.is_terminal():
        if i % 2 == 0:
            probs, _wdl = attention_mcts.run(game, attention_model.policy_function, 180, states)
            move = np.argmax(probs)
        else:
            probs, _wdl = convolution_mcts.run(game, convolution_model.policy_function, 180)
            move = np.argmax(probs)
        states = np.append(states, game.get_state()[None, ...], axis=0)
        attention_mcts.make_move(move)
        convolution_mcts.make_move(move)
        game.make_move(move)
        i += 1
    if game.get_winner() == GameResult.RED_WINS:
        return 1
    if game.get_winner() == GameResult.DRAW:
        return .5
    if game.get_winner() == GameResult.RED_LOSES:
        return 0
    

In [4]:
def play_game_convolution_first(attention_model, convolution_model) -> float:
    game = Game()
    states = np.zeros((0, 4, 6, 7), np.float32)
    attention_mcts = AttentionMCTS(1)
    convolution_mcts = ConvolutionMCTS(1)
    i = 0
    while not game.is_terminal():
        if i % 2 == 1:
            probs, _wdl = attention_mcts.run(game, attention_model.policy_function, 10, states)
            move = np.argmax(probs)
        else:
            probs, _wdl = convolution_mcts.run(game, convolution_model.policy_function, 10)
            move = np.argmax(probs)
        states = np.append(states, game.get_state()[None, ...], axis=0)
        attention_mcts.make_move(move)
        convolution_mcts.make_move(move)
        game.make_move(move)
        i += 1
    if game.get_winner() == GameResult.RED_WINS:
        return 0
    if game.get_winner() == GameResult.DRAW:
        return .5
    if game.get_winner() == GameResult.RED_LOSES:
        return 1

In [5]:
score = 0
for _ in tqdm.trange(10):
    score += play_game_attention_first(attention_model, convolution_model)
    score += play_game_convolution_first(attention_model, convolution_model)

100%|██████████| 10/10 [09:06<00:00, 54.65s/it]


In [8]:
score

14