In [None]:
%cd ..
%reload_ext autoreload
%autoreload 2

In [None]:
import pyspiel
from alpha_one.game.information_set import InformationSetGenerator
import numpy as np

from open_spiel.python.algorithms import mcts
from open_spiel.python.algorithms.alpha_zero import evaluator as evaluator_lib
from open_spiel.python.algorithms.alpha_zero import model as model_lib
from open_spiel.python.algorithms.mcts import SearchNode

from alpha_one.alg.imperfect_information import BasicImperfectInformationMCTSEvaluator
from alpha_one.alg.mcts import ImperfectInformationMCTSBot

In [None]:
def build_model(game, model_type, nn_width, nn_depth, learning_rate, weight_decay, model_saves_path):
    return model_lib.Model.build_model(
      model_type, game.observation_tensor_shape(), game.num_distinct_actions(),
      nn_width=nn_width, nn_depth=nn_depth, weight_decay=weight_decay, learning_rate=learning_rate, path=model_saves_path)

In [None]:
def initialize_bot(game, model, uct_c, max_simulations, policy_epsilon, policy_alpha):
    
    if policy_epsilon == None or policy_alpha == None:
        noise = None
    else:
        noise = (policy_epsilon, policy_alpha)
        

    evaluator = mcts.RandomRolloutEvaluator(n_rollouts=100)

    bot = mcts.MCTSBot(
          game,
          uct_c,
          max_simulations,
          evaluator,
          solve=False,
          dirichlet_noise=noise,
          child_selection_fn=mcts.SearchNode.puct_value,
          verbose=False)
    
    return bot

In [None]:
game_name = "kuhn_poker"
model_saves_path = '../model_saves/kuhn_poker'
nn_width = 10
nn_depth = 5
learning_rate = 0.001
weight_decay = 0.0001

model_type = 'mlp'

game = pyspiel.load_game(game_name)
    
model = build_model(game, model_type, nn_width, nn_depth, learning_rate, weight_decay, model_saves_path)

In [None]:
uct_c = 3                                     
max_simulations = 100                            
policy_epsilon = 0.25                             
policy_alpha = 1
temperature = 1                                   
temperature_drop = 10    
verbose = False

In [None]:
def print_game_tree(node, level = 0):
    print(''.join(level * ['  ']), node, node.total_reward)
    for c in node.children:
        print_game_tree(c, level + 1)

In [None]:
def mcts_agent(state, information_set_generator):
    current_player = state.current_player()
    information_set = information_set_generator.calculate_information_set(current_player)
    policy = np.zeros(game.num_distinct_actions())

    # Evaluate each state in the information set by MCTS independently.
    # After the searches are completed, the numbers of visits for each action from the root 
    # are summed across all trees, 
    # and an action is chosen that maximises the total number of visits.
    for s in information_set:
        bot = initialize_bot(game, model, uct_c, max_simulations, policy_epsilon, policy_alpha)
        root = bot.mcts_search(s)
        if verbose:
            print_game_tree(root)
        for c in root.children:
            #policy[c.action] += c.explore_count
            policy[c.action] += c.total_reward / c.explore_count  # Use value of node for selection
    return policy


In [None]:
def ii_mcts_agent(information_set_generator):
    root, _ = ii_mcts_bot.mcts_search(information_set_generator)
    guess_states_values = np.zeros(len(root.children))
    guess_states_values = [c.actual_reward / c.explore_count for c in root.children]
    guess_state = np.argmax(guess_states_values)
    print(f"II guessing state {guess_state}")

    #root.investigate()

    policy = np.zeros(len(root.children[guess_state].children))
    for c in root.children[guess_state].children:
        policy[c.action] = c.actual_reward / c.explore_count
    return policy

In [None]:
game_returns = []
for _ in range(100):
    actions = []
    state = game.new_initial_state()
    information_set_generator = InformationSetGenerator(game)
    
    ii_mcts_bot = ImperfectInformationMCTSBot(game,
                                            uct_c,
                                                  max_simulations,
                                                  BasicImperfectInformationMCTSEvaluator(),
                                                  False,
                                                  child_selection_fn=SearchNode.puct_value)
    
    while not state.is_terminal():

        # environment state
        if state.current_player() < 0:
            action = np.random.choice(state.legal_actions())

            action_str = state.action_to_string(state.current_player(), action)
            actions.append(action_str)

            information_set_generator.register_action(action)
            state.apply_action(action)
            information_set_generator.register_observation(state)

        # player 1 as a MCTS bot
        elif state.current_player() == 0:
            policy = mcts_agent(state, information_set_generator)

            #policy = policy ** (1 / temperature)
            #policy /= policy.sum()
            #if len(actions) >= temperature_drop:

            print(f"Policy 1: {policy}")
            action = np.argmax(policy)

            #else:
            #    action = np.random.choice(len(policy), p=policy)

            action_str = state.action_to_string(state.current_player(), action)
            actions.append(action_str)

            information_set_generator.register_action(action)
            state.apply_action(action)
            information_set_generator.register_observation(state)

        # player 2 as IIG-MCTS bot
        else:
            # action = np.random.choice(state.legal_actions())
            
            policy = ii_mcts_agent(information_set_generator)
            print(f"Policy 2: {policy}")
            action = np.argmax(policy)
#             for c in root.children:
#                 #policy[c.action] += c.explore_count
#                 policy[c.action] += c.total_reward / c.explore_count  # Use value of node for selection

            action_str = state.action_to_string(state.current_player(), action)
            actions.append(action_str)

            information_set_generator.register_action(action)
            state.apply_action(action)
            information_set_generator.register_observation(state)


    print(actions)       
    print(state.returns())
    print()
    game_returns.append(state.returns())



In [None]:
game_returns = np.array(game_returns)
print(f"Average return: {game_returns.mean(axis=0)}")