In [None]:
%cd ../..
%reload_ext autoreload
%autoreload 2

In [None]:
import pyspiel
import math
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from alpha_one.model.model_manager import OpenSpielCheckpointManager, AlphaOneCheckpointManager
from alpha_one.model.agent import IIGMCTSAgent, DMCTSAgent, OmniscientAgent, DirectInferenceAgent, RandomAgent
from alpha_one.train import MCTSConfig
from alpha_one.utils.state_to_value import state_to_value
from alpha_one.utils.mcts_II import initialize_bot_alphaone, ii_mcts_agent, IIGMCTSConfig
from alpha_one.utils.play import GameMachine
from alpha_one.utils.determinized_mcts import initialize_bot, compute_mcts_policy
from alpha_one.game.information_set import InformationSetGenerator

# 1. Game Setup

In [None]:
game_name = "leduc_poker"
game = pyspiel.load_game(game_name)


# 2. Setup agents

In [None]:
UCT_C = math.sqrt(2)
max_mcts_simulations = 100

## 2.1. AlphaOne

In [None]:
run_name_alpha_one = "LP-local-28"

In [None]:
n_previous_observations = 3
optimism = 0.1

In [None]:
model_manager_alpha_one = AlphaOneCheckpointManager(game_name, run_name_alpha_one)

observation_model, game_model = model_manager_alpha_one.load_checkpoint(-1)
observation_model_untrained, game_model_untrained = model_manager_alpha_one.load_checkpoint(0)

In [None]:
alphaone_mcts_config = IIGMCTSConfig(UCT_C, max_mcts_simulations, 0, None, None, None, 
                                  alpha_one=True, 
                                  state_to_value=state_to_value(game_name), use_reward_polic=True, n_previous_observations=n_previous_observations, optimism=optimism)

alpha_one_agent = IIGMCTSAgent.from_config(game, observation_model, game_model, alphaone_mcts_config)
untrained_alpha_one_agent = IIGMCTSAgent.from_config(game, observation_model_untrained, game_model_untrained, alphaone_mcts_config)

## 2.2. D-MCTS

In [None]:
run_name_d_mcts = "LP-local-6"

In [None]:
model_manager_dmcts = OpenSpielCheckpointManager(game_name, run_name_d_mcts)

dmcts_model = model_manager_dmcts.load_checkpoint(-1)
dmcts_model_untrained = model_manager_dmcts.load_checkpoint(0)

In [None]:
dmcts_mcts_config = MCTSConfig(UCT_C, max_mcts_simulations, 0, None, None, None, 
                               determinized_MCTS=True, 
                               omniscient_observer=True)

d_mcts_agent = DMCTSAgent(dmcts_model, dmcts_mcts_config)
untrained_d_mcts_agent = DMCTSAgent(dmcts_model_untrained, dmcts_mcts_config)

## 2.3. Omniscient Agent

In [None]:
run_name_omniscient = "LP-local-6"

In [None]:
model_manager_omniscient = OpenSpielCheckpointManager(game_name, run_name_omniscient)

omniscient_model = model_manager_omniscient.load_checkpoint(-1)

In [None]:
omniscient_mcts_config = MCTSConfig(UCT_C, max_mcts_simulations, 0, use_reward_policy=True, omniscient_observer=True)

omniscient_agent_untrained = OmniscientAgent(game, omniscient_mcts_config)
omniscient_agent = OmniscientAgent(game, omniscient_mcts_config, model=omniscient_model)

## 2.4. Blind Agent

In [None]:
run_name_blind = "LP-local-6-blind-1"

In [None]:
model_manager_blind = OpenSpielCheckpointManager(game_name, run_name_blind)

blind_model = model_manager_blind.load_checkpoint(-1)

In [None]:
blind_agent = DirectInferenceAgent(blind_model)

## 2.5. Random Agent

In [None]:
random_agent = RandomAgent(game)

# 3. Player Setup

Available Agents:
 - `alpha_one_agent`
 - `alpha_one_agent_untrained`
 - `d_mcts_agent`
 - `d_mcts_agent_untrained`
 - `omniscient_agent`
 - `omniscient_agent_untrained`
 - `blind_agent`
 - `random_agent`

In [None]:
player_setup = {
    0: omniscient_agent,
    1: omniscient_agent_untrained
}

# 4.  Competition with 2 players

In [None]:
game_machine = GameMachine(game)

In [None]:
correct_guess = 0
incorrect_guess = 0
game_returns = []
for _ in tqdm(range(100)):
    game_machine.new_game()

    while not game_machine.is_finished():
        player = game_machine.current_player()
        agent = player_setup[player]
        
        if agent.is_information_set_agent():
            action, policy = agent.next_move(game_machine.get_information_set_generator())
        else:
            action, policy = agent.next_move(game_machine.get_state())
            
        action = np.argmax(policy)
        if isinstance(agent, IIGMCTSAgent):
            guessed_state = agent.get_last_guessed_state()
            if guessed_state.__str__() == game_machine.state.__str__():
                correct_guess += 1
            else:
                incorrect_guess += 1
    
        game_machine.play_action(action)
            
    game_returns.append(game_machine.get_rewards())
            
            

In [None]:
game_returns = np.array(game_returns)
average_return = game_returns.mean(axis=0)
print(f"Average return:")
print(f"---------------")
print(f"  {type(player_setup[0]).__name__}: {average_return[0]}")
print(f"  {type(player_setup[1]).__name__}: {average_return[1]}")
if correct_guess + incorrect_guess > 0:
    print(f" correct guess probability: {correct_guess/(correct_guess+incorrect_guess):0.2%}")