In [None]:
%cd ../..
%reload_ext autoreload
%autoreload 2

In [None]:
import pyspiel
import math
import numpy as np
import matplotlib.pyplot as plt

from alpha_one.model.model_manager import OpenSpielCheckpointManager, OpenSpielModelManager, PolicyGradientModelManager
from alpha_one.model.agent import PolicyGradientAgent, MCTSAgent, DirectInferenceAgent
from alpha_one.model.config import OpenSpielModelConfig
from alpha_one.utils.mcts import initialize_bot, compute_mcts_policy, mcts_inference, MCTSConfig
from alpha_one.model.evaluation import EvaluationManager
from alpha_one.plots import PlotManager
from env import MODEL_SAVES_DIR

In [None]:
game_name = "connect_four"

In [None]:
# Alpha Zero
model_manager = OpenSpielModelManager(game_name, 'C4')

# Policy Gradient
model_manager = PolicyGradientModelManager(game_name)
print(model_manager.list_runs())

In [None]:
run_name = "PG-6"

In [None]:
checkpoint_manager = model_manager.get_checkpoint_manager(run_name)
print(checkpoint_manager.list_checkpoints())

In [None]:
checkpoint = 0

# 1. Load model

In [None]:
model = checkpoint_manager.load_checkpoint(checkpoint)
model_0 = checkpoint_manager.load_checkpoint(0)

# 2. Setup game and agents

In [None]:
game = pyspiel.load_game(game_name)

In [None]:
UCT_C = math.sqrt(2)
max_simulations = 100
temperature = 0
mcts_config = MCTSConfig(UCT_C, max_simulations, temperature)

agent = MCTSAgent.from_config(game, model, mcts_config)
# agent = DirectInferenceAgent(model)
agent_0 = MCTSAgent.from_config(game, model_0, mcts_config)

In [None]:
# Policy Gradient
agent = PolicyGradientAgent(model)
agent_0 = PolicyGradientAgent(model_0)

# 3. Interactive play

In [None]:
agent = agent_mcts

In [None]:
state = game.new_initial_state()
player_id_model = np.random.choice(2)
player_id_human = 1 - player_id_model
print(f"Welcome to a game of {game_name} against the Computer (iteration {checkpoint}). Enter 'c' to cancel the game")
print(f"Player Human: {player_id_human}, Player model: {player_id_model}")
while not state.is_terminal():
    current_player_str = "Human" if state.current_player() == player_id_human else "Computer"
    print(f"Current player: {current_player_str}")
    print(state.observation_string())
    if state.current_player() == player_id_model:
        action, policy = agent.next_move(state)
        #if model_strategy == 'direct':
        #    _, policy = model.inference([state.observation_tensor()], [state.legal_actions_mask()])
        #    policy = policy[0]
        #elif model_strategy == 'mcts':
        #    policy = mcts_inference(game, model_0, state, uct_c=UCT_C, max_simulations=max_simulations, temperature=temperature)
        print(f"Computer policy: {policy}")
        #action = np.random.choice(len(policy), p=policy)
        print(f"Computer action: {action}")
    else:
        print(f"Possible actions: {np.where(state.legal_actions_mask())[0]}")
        print(f"Your action: ")
        human_input = input()
        if human_input == 'c':
            break
        else:
            action = int(human_input)
    state.apply_action(action)
if not human_input == 'c':
    winner_str = "Human" if state.returns()[player_id_human] == 1 else "Computer"
    print(f"The winner is: {winner_str}")
    print(state.observation_string(0))
else:
    print("Game was cancelled")

# 4. Analysis of specific game scenarios

In [None]:
plot_manager = PlotManager(game_name, run_name)

## 4.1 Sure win

In [None]:
state = game.new_initial_state()
state.apply_action(3)
state.apply_action(2)
state.apply_action(3)
state.apply_action(2)
state.apply_action(3)
state.apply_action(2)
print(state.observation_string())

In [None]:
agent.next_move(state)

In [None]:
%%capture
correct_move_probabilities = dict()
prevent_win_probabilities = dict()
correct_move_probabilities_mcts = dict()
prevent_win_probabilities_mcts = dict()
for iteration in checkpoint_manager.list_checkpoints():
    model_tmp = checkpoint_manager.load_checkpoint(iteration)
    agent_mcts = MCTSAgent.from_config(game, model_tmp, mcts_config)
    agent_direct = DirectInferenceAgent(model_tmp
                                       )
    _, policy = agent_direct.next_move(state)
    _, policy_mcts = agent_mcts.next_move(state)
    
    correct_move_probabilities[iteration] = policy[3]
    prevent_win_probabilities[iteration] = policy[2]
    correct_move_probabilities_mcts[iteration] = policy_mcts[3]
    prevent_win_probabilities_mcts[iteration] = policy_mcts[2]

In [None]:
plt.figure(figsize=(12,4))
plt.title('Learned Policies for scenario 1 (Sure win)')
plt.plot(list(correct_move_probabilities.keys()), list(correct_move_probabilities.values()), label='winning move')
plt.plot(list(correct_move_probabilities_mcts.keys()), list(correct_move_probabilities_mcts.values()), label='winning move (MCTS)', linestyle=':', c='blue')
plt.plot(list(prevent_win_probabilities.keys()), list(prevent_win_probabilities.values()), label='prevent enemy win')
plt.plot(list(prevent_win_probabilities_mcts.keys()), list(prevent_win_probabilities_mcts.values()), label='prevent enemy win (MCTS)', linestyle=':', c='orange')
plt.legend()

plot_manager.save_current_plot("policies_scenario_1.pdf")
plt.show()

## 4.2 Prevent Sure win

In [None]:
state = game.new_initial_state()
state.apply_action(1)
state.apply_action(2)
state.apply_action(3)
state.apply_action(2)
state.apply_action(3)
state.apply_action(2)
print(state.observation_string())

In [None]:
agent.next_move(state)

In [None]:
model.inference([state.observation_tensor()], [state.legal_actions_mask()])

In [None]:
mcts_inference(game, model_0, state, uct_c=UCT_C, max_simulations=100, temperature=temperature)

In [None]:
%%capture
prevent_win_probabilities = dict()
prevent_win_probabilities_mcts = dict()
for iteration in checkpoint_manager.list_checkpoints():
    model_tmp = checkpoint_manager.load_checkpoint(iteration)
    
    agent_mcts = MCTSAgent.from_config(game, model_tmp, mcts_config)
    agent_direct = DirectInferenceAgent(model_tmp
                                       )
    _, policy = agent_direct.next_move(state)
    _, policy_mcts = agent_mcts.next_move(state)
    
    prevent_win_probabilities[iteration] = policy[2]
    prevent_win_probabilities_mcts[iteration] = policy_mcts[2]

In [None]:
plt.figure(figsize=(12, 4))
plt.title('Learned Policies for scenario 2 (Prevent Sure win)')
plt.plot(list(prevent_win_probabilities.keys()), list(prevent_win_probabilities.values()), label='Prevent enemy win')
plt.plot(list(prevent_win_probabilities_mcts.keys()), list(prevent_win_probabilities_mcts.values()), label='Prevent enemy win (MCTS)', linestyle=':', c='blue')
plt.legend()

plot_manager.save_current_plot("policies_scenario_2.pdf")
plt.show()

## 4.3 Prevent Sure win next turn 

In [None]:
state = game.new_initial_state()
state.apply_action(3)
state.apply_action(3)
state.apply_action(2)
print(state.observation_string())

In [None]:
agent.next_move(state)

In [None]:
%%capture
correct_move_left_probabilities = dict()
correct_move_right_probabilities = dict()
correct_move_left_probabilities_mcts = dict()
correct_move_right_probabilities_mcts = dict()
for iteration in checkpoint_manager.list_checkpoints():
    model_tmp = checkpoint_manager.load_checkpoint(iteration)
    
    agent_mcts = MCTSAgent.from_config(game, model_tmp, mcts_config)
    agent_direct = DirectInferenceAgent(model_tmp
                                       )
    _, policy = agent_direct.next_move(state)
    _, policy_mcts = agent_mcts.next_move(state)
    
    correct_move_left_probabilities[iteration] = policy[1]
    correct_move_left_probabilities_mcts[iteration] = policy_mcts[1]
    correct_move_right_probabilities[iteration] = policy[4]
    correct_move_right_probabilities_mcts[iteration] = policy_mcts[4]

In [None]:
plt.figure(figsize=(12, 4))
plt.title('Learned Policies for scenario 3 (Prevent sure win next turn)')
plt.plot(list(correct_move_left_probabilities.keys()), list(correct_move_left_probabilities.values()), label='Correct move left')
plt.plot(list(correct_move_left_probabilities_mcts.keys()), list(correct_move_left_probabilities_mcts.values()), label='Correct move left (MCTS)', linestyle=':', c='blue')
plt.plot(list(correct_move_right_probabilities.keys()), list(correct_move_right_probabilities.values()), label='Correct move right')
plt.plot(list(correct_move_right_probabilities_mcts.keys()), list(correct_move_right_probabilities_mcts.values()), label='Correct move right (MCTS)', linestyle=':', c='orange')
plt.legend()

plot_manager.save_current_plot("policies_scenario_3.pdf")
plt.show()

### 4.3.1 Direct Inference

In [None]:
model.inference([state.observation_tensor()], [state.legal_actions_mask()])

In [None]:
model_0.inference([state.observation_tensor()], [state.legal_actions_mask()])

### 4.3.2 Using MCTS

In [None]:
mcts_inference(game, model, state, uct_c=UCT_C, max_simulations=max_simulations, temperature=temperature)

In [None]:
mcts_inference(game, model_0, state, uct_c=UCT_C, max_simulations=100, temperature=temperature)

# 5. Play against previous generations

In [None]:
evaluation_manager = EvaluationManager(game, 100, mcts_config)

In [None]:
model_0 = model_manager.load_checkpoint(339)
model_1 = model_manager.load_checkpoint(0)
evaluation_results = evaluation_manager.compare_models(model_0, model_1)

In [None]:
print(f"Trained model won {1 - evaluation_results[0]:0.2%} of the games")