In [None]:
%cd ../..

In [None]:
import pyspiel
import math
import numpy as np

from alpha_one.model.model_manager import OpenSpielModelManager
from alpha_one.model.config import OpenSpielModelConfig
from alpha_one.utils.mcts import initialize_bot, compute_mcts_policy, mcts_inference
from env import MODEL_SAVES_DIR

In [None]:
game_name = "connect_four"
run_name = "C4-3"
iteration = 99

# 1. Load model

In [None]:
model_manager = OpenSpielModelManager(f"{game_name}/{run_name}")

In [None]:
# Legacy way to load model, use this for now
game = pyspiel.load_game("connect_four")

config = OpenSpielModelConfig(game, "mlp", 64, 4, 1e-5, 5e-4)
model = model_manager.build_model(config)
model_0 = model_manager.build_model(config)
model.load_checkpoint(f"{MODEL_SAVES_DIR}/{game_name}/{run_name}/checkpoint-{iteration}")
model_0.load_checkpoint(f"{MODEL_SAVES_DIR}/{game_name}/{run_name}/checkpoint-0")

In [None]:
# Future way to load model

In [None]:
model = model_manager.load_model(iteration)
model_0 = model_manager.load_model(0)

# 2. Setup Game

In [None]:
game = pyspiel.load_game("connect_four")

In [None]:
state = game.new_initial_state()
state.apply_action(3)
state.apply_action(3)
state.apply_action(2)
print(state.observation_string())

# 2.1 Direct inference

In [None]:
model.inference([state.observation_tensor()], [state.legal_actions_mask()])

In [None]:
model_0.inference([state.observation_tensor()], [state.legal_actions_mask()])

# 2.2 Using MCTS search

In [None]:
UCT_C = math.sqrt(2)
max_simulations = 100
temperature = 1

In [None]:
mcts_inference(game, model, state, uct_c=UCT_C, max_simulations=max_simulations, temperature=temperature)

In [None]:
mcts_inference(game, model_0, state, uct_c=UCT_C, max_simulations=max_simulations, temperature=temperature)

# 3. Interactive play

In [None]:
model_strategy = "mcts"  # direct or mcts

In [None]:
state = game.new_initial_state()
player_id_model = np.random.choice(2)
player_id_human = 1 - player_id_model
print(f"Welcome to a game of {game_name} against the Computer (iteration {iteration}). Enter 'c' to cancel the game")
print(f"Player Human: {player_id_human}, Player model: {player_id_model}")
while not state.is_terminal():
    current_player_str = "Human" if state.current_player() == player_id_human else "Computer"
    print(f"Current player: {current_player_str}")
    print(state.observation_string())
    if state.current_player() == player_id_model:
        if model_strategy == 'direct':
            _, policy = model.inference([state.observation_tensor()], [state.legal_actions_mask()])
            policy = policy[0]
        elif model_strategy == 'mcts':
            policy = mcts_inference(game, model_0, state, uct_c=UCT_C, max_simulations=max_simulations, temperature=temperature)
        print(f"Computer policy: {policy}")
        action = np.random.choice(len(policy), p=policy)
        print(f"Computer action: {action}")
    else:
        print(f"Possible actions: {np.where(state.legal_actions_mask())[0]}")
        print(f"Your action: ")
        human_input = input()
        if human_input == 'c':
            break
        else:
            action = int(human_input)
    state.apply_action(action)
if not human_input == 'c':
    winner_str = "Human" if state.returns()[player_id_human] == 1 else "Computer"
    print(f"The winner is: {winner_str}")
    print(state.observation_string(0))
else:
    print("Game was cancelled")