In [None]:
%cd ..
%reload_ext autoreload
%autoreload 2

In [None]:
import pyspiel
from alpha_one.game.information_set import InformationSetGenerator
import numpy as np

from open_spiel.python.algorithms import mcts
from open_spiel.python.algorithms.alpha_zero import evaluator as evaluator_lib
from open_spiel.python.algorithms.alpha_zero import model as model_lib
from open_spiel.python.algorithms.mcts import SearchNode

from alpha_one.alg.imperfect_information import BasicImperfectInformationMCTSEvaluator, BasicOmniscientMCTSEvaluator
from alpha_one.alg.mcts import ImperfectInformationMCTSBot
from alpha_one.game.observer import OmniscientObserver
from alpha_one.model.model_manager import OpenSpielCheckpointManager

In [None]:
def build_model(game, model_type, nn_width, nn_depth, learning_rate, weight_decay, model_saves_path):
    return model_lib.Model.build_model(
      model_type, game.observation_tensor_shape(), game.num_distinct_actions(),
      nn_width=nn_width, nn_depth=nn_depth, weight_decay=weight_decay, learning_rate=learning_rate, path=model_saves_path)

In [None]:
def initialize_bot(game, model, uct_c, max_simulations, policy_epsilon, policy_alpha):
    
    if policy_epsilon == None or policy_alpha == None:
        noise = None
    else:
        noise = (policy_epsilon, policy_alpha)
        

    evaluator = mcts.RandomRolloutEvaluator(n_rollouts=100)

    bot = mcts.MCTSBot(
          game,
          uct_c,
          max_simulations,
          evaluator,
          solve=False,
          dirichlet_noise=noise,
          child_selection_fn=mcts.SearchNode.puct_value,
          verbose=False)
    
    return bot

In [None]:
game_name = "leduc_poker"
model_saves_path = '../model_saves/kuhn_poker'
nn_width = 10
nn_depth = 5
learning_rate = 0.001
weight_decay = 0.0001

model_type = 'mlp'

game = pyspiel.load_game(game_name)
    
model = build_model(game, model_type, nn_width, nn_depth, learning_rate, weight_decay, model_saves_path)

In [None]:
uct_c = 5
optimism = 0.1  # Only for IIG-MCTS. Whether guessing states is biased towards good outcomes
max_simulations = 100                            
policy_epsilon = 0.25                             
policy_alpha = 1
temperature = 1                                   
temperature_drop = 10    
verbose = False

In [None]:
model_manager = OpenSpielCheckpointManager(game_name, 'KP-local-11-blind-1')
blind_model = model_manager.load_checkpoint(-1)

In [None]:
def print_game_tree(node, level = 0):
    print(''.join(level * ['  ']), node, node.total_reward)
    for c in node.children:
        print_game_tree(c, level + 1)

In [None]:
def mcts_agent(state, information_set_generator):
    current_player = state.current_player()
    information_set = information_set_generator.calculate_information_set(current_player)
    policy = np.zeros(game.num_distinct_actions())

    # Evaluate each state in the information set by MCTS independently.
    # After the searches are completed, the numbers of visits for each action from the root 
    # are summed across all trees, 
    # and an action is chosen that maximises the total number of visits.
    for s in information_set:
        bot = initialize_bot(game, model, uct_c, max_simulations, policy_epsilon, policy_alpha)
        root = bot.mcts_search(s)
        if verbose:
            print_game_tree(root)
        for c in root.children:
            if c.explore_count > 0:
                if c.outcome is not None or c.explore_count == 1:
                    policy[c.action] += c.total_reward / c.explore_count
                else:
                    # If node is not a leaf, one explore count is used to unfold it. To get a proper average,
                    # we have to subtract that here
                    policy[c.action] += c.total_reward / (c.explore_count - 1) # Use value of node for selection
                    #policy[c.action] += c.explore_count
    return policy


In [None]:
def ii_mcts_agent(information_set_generator):
    omniscient_observer = OmniscientObserver(game)
    root, _ = ii_mcts_bot.mcts_search(information_set_generator)
    #guess_states_values = np.zeros(len(root.children))
    #guess_states_values = [c.actual_reward / c.explore_count for c in root.children]
    #guess_state = np.argmax(guess_states_values)
    guess_policy = np.zeros(len(root.children))
    for c in root.children:
        guess_policy[c.action] = c.explore_count
    guess_policy /= np.sum(guess_policy)
    guess_state = np.random.choice(range(len(root.children)), p=guess_policy)
    
    if game_name == 'kuhn_poker':
        print(f"II guessing state {guess_state} (opponent has card {information_set_generator.calculate_information_set()[guess_state].observation_tensor(1 - information_set_generator.current_player())[2:5].index(1)})")
    else:
        
        print(f"II guessing state {guess_state}: {omniscient_observer.get_observation_string(root.children[guess_state].state)}")

    #root.investigate()

    policy = np.zeros(len(root.children[guess_state].state.legal_actions_mask()))
    for c in root.children[guess_state].children:
        policy[c.action] = c.actual_reward / c.explore_count
    return policy

In [None]:
def omniscient_agent(state):
    root = omniscient_bot.mcts_search(state)
    policy = np.zeros(game.num_distinct_actions())
    for c in root.children:
        if c.outcome is not None:
            policy[c.action] = c.total_reward / c.explore_count
        else:
            policy[c.action] = c.total_reward / (c.explore_count - 1)  # If node is not a leaf, one explore count is used to unfold it. To get a proper average, we have to subtract that here
    #policy /= policy.sum()
    return policy

In [None]:
player_setup = {
    'd-mcts': 0,
    'iig-mcts': 1,
    'omniscient': -1,
    'blind': -1
}

assert len({player_id for player_id in player_setup.values() if player_id >= 0}) == 2, f"Player Setup misconfigured"
player_setup_reverse = {player_id:player_type for player_type, player_id in player_setup.items()}

In [None]:
game_returns = []
for _ in range(100):
    actions = []
    state = game.new_initial_state()
    information_set_generator = InformationSetGenerator(game)
    
    ii_mcts_bot = ImperfectInformationMCTSBot(game,
                                              uct_c,
                                              max_simulations,
                                              BasicImperfectInformationMCTSEvaluator(),
                                              optimism=optimism,
                                              solve=False,
                                              child_selection_fn=SearchNode.puct_value)
    
    omniscient_bot = mcts.MCTSBot(game, uct_c, max_simulations, BasicOmniscientMCTSEvaluator(game), solve=False)
    
    while not state.is_terminal():

        # environment state
        if state.current_player() < 0:
            action = np.random.choice(state.legal_actions())

            action_str = state.action_to_string(state.current_player(), action)
            actions.append(action_str)

            information_set_generator.register_action(action)
            state.apply_action(action)
            information_set_generator.register_observation(state)
        else:

            current_player_type = player_setup_reverse[state.current_player()]

            if current_player_type == 'd-mcts':     
               
                policy = mcts_agent(state, information_set_generator)
                policy[~np.array(state.legal_actions_mask(), dtype=np.bool)] = float('-inf')

                print(f"Policy D-MCTS: {policy}")
                action = np.argmax(policy)

            elif current_player_type == 'iig-mcts':

                try:
                    policy = ii_mcts_agent(information_set_generator)
                    policy[~np.array(state.legal_actions_mask(), dtype=np.bool)] = float('-inf')
                except Exception as e:
                    print(state.current_player())
                    raise e
                print(f"Policy IIG-MCTS: {policy}")
                action = np.argmax(policy)
    #             for c in root.children:
    #                 #policy[c.action] += c.explore_count
    #                 policy[c.action] += c.total_reward / c.explore_count  # Use value of node for selection

                
            elif current_player_type == 'omniscient':
                policy = omniscient_agent(state)
                policy[~np.array(state.legal_actions_mask(), dtype=np.bool)] = float('-inf')
                print(f"Policy Omniscient: {policy}")
                action = np.argmax(policy)
                
            elif current_player_type == 'blind':
                _, policy = blind_model.inference([state.observation_tensor(state.current_player())], [state.legal_actions_mask()])
                policy = policy[0]
                policy[~np.array(state.legal_actions_mask(), dtype=np.bool)] = float('-inf')
                print(f"Blind: {policy}")
                action = np.argmax(policy)
            action_str = state.action_to_string(state.current_player(), action)
            actions.append(action_str)

            information_set_generator.register_action(action)
            state.apply_action(action)
            information_set_generator.register_observation(state)


    print(actions)       
    print(state.returns())
    print()
    game_returns.append(state.returns())



In [None]:
game_returns = np.array(game_returns)
average_return = game_returns.mean(axis=0)
print(f"Average return:")
print(f"---------------")
print(f"  {player_setup_reverse[0]}: {average_return[0]}")
print(f"  {player_setup_reverse[1]}: {average_return[1]}")