In [None]:
%cd ../..
%reload_ext autoreload
%autoreload 2

In [None]:
import pyspiel
import math
import numpy as np
import matplotlib.pyplot as plt
from alpha_one.model.model_manager import OpenSpielCheckpointManager
from alpha_one.train import MCTSConfig
from alpha_one.utils.state_to_value import state_to_value
from alpha_one.utils.mcts_II import initialize_bot_alphaone, ii_mcts_agent
from alpha_one.utils.determinized_mcts import initialize_bot, compute_mcts_policy
from alpha_one.game.information_set import InformationSetGenerator

In [None]:
game_name = "leduc_poker"
game = pyspiel.load_game(game_name)
print(game.max_utility())

In [None]:
model_manager = OpenSpielCheckpointManager('leduc_poker', 'LP-local-12')
dmcts_model = model_manager.load_checkpoint(-1)

In [None]:
UCT_C = math.sqrt(2)
max_mcts_simulations = 100

In [None]:
dmcts_mcts_config = MCTSConfig(UCT_C, max_mcts_simulations, 0, None, None, None, 
                               determinized_MCTS=True, 
                               omniscient_observer=True,
                               use_reward_policy=True)

In [None]:
player_setup = {
    'd-mcts_NN': 1,
    'd-mcts_random': 0,
    'random': -11
}
player_setup_reverse = {player_id:player_type for player_type, player_id in player_setup.items()}

In [None]:
game_returns = []
for _ in range(100):
    state = game.new_initial_state()
    information_set_generator = InformationSetGenerator(game)
    while not state.is_terminal():
        if state.current_player() < 0:
            action = np.random.choice(state.legal_actions())
            information_set_generator.register_action(action)
            state.apply_action(action)
            information_set_generator.register_observation(state)
        else:
            current_player_type = player_setup_reverse[state.current_player()]
        
            if current_player_type == 'd-mcts_NN':
                policy = compute_mcts_policy(game, dmcts_model, state, 
                                             information_set_generator, 
                                             dmcts_mcts_config)
                action = np.argmax(policy)
                information_set_generator.register_action(action)
                state.apply_action(action)
                information_set_generator.register_observation(state)
            elif current_player_type == 'd-mcts_random':
                policy = compute_mcts_policy(game, dmcts_model, state, 
                                             information_set_generator, 
                                             dmcts_mcts_config,
                                             use_NN=False, 
                                             n_rollouts=100)
                action = np.argmax(policy)
                information_set_generator.register_action(action)
                state.apply_action(action)
                information_set_generator.register_observation(state)
                
            elif current_player_type == 'random':
                action = np.random.choice(state.legal_actions())
                information_set_generator.register_action(action)
                state.apply_action(action)
                information_set_generator.register_observation(state)
                
            
    game_returns.append(state.returns())
            
            

In [None]:
game_returns = np.array(game_returns)
average_return = game_returns.mean(axis=0)
print(f"Average return:")
print(f"---------------")
print(f"  {player_setup_reverse[0]}: {average_return[0]}")
print(f"  {player_setup_reverse[1]}: {average_return[1]}")