In [None]:
%cd ..
%reload_ext autoreload
%autoreload 2

In [None]:
import pyspiel
import numpy as np

from alpha_one.model.model_manager import OpenSpielModelManager, OpenSpielModelConfig
from alpha_one.utils.mcts import initialize_bot, compute_mcts_policy_reward, investigate_node
from alpha_one.game.observer import OmniscientObserver, get_observation_tensor_shape
from alpha_one.alg.imperfect_information import AlphaZeroOmniscientMCTSEvaluator, BasicOmniscientMCTSEvaluator
from alpha_one.game.information_set import InformationSetGenerator
from alpha_one.utils.play import InteractiveGameMachine
from open_spiel.python.observation import make_observation
from open_spiel.python.algorithms import mcts
from open_spiel.python.algorithms.alpha_zero import model as model_lib

# 1. Setup Game

In [None]:
game_name = 'leduc_poker'
game = pyspiel.load_game(game_name)

# 2. Setup Omniscient MCTS

In [None]:
UCT_C = 3
max_simulations = 800

## 2.1. Basic Omniscient MCTS 

In [None]:
evaluator = BasicOmniscientMCTSEvaluator(game)
mcts_bot = mcts.MCTSBot(game, UCT_C, max_simulations, evaluator, solve=False, child_selection_fn=mcts.SearchNode.puct_value)

## 2.2. AlphaZero Omniscient MCTS (Untrained)

In [None]:
config = OpenSpielModelConfig(game, 'mlp', get_observation_tensor_shape(game, omniscient_observer=True), 64, 2, 5e-3, 5e-3, omniscient_observer=True)

In [None]:
model_manager = OpenSpielModelManager(game_name, 'KP')
model = model_manager.get_checkpoint_manager('x').build_model(config)
mcts_bot = initialize_bot(game, model, UCT_C, max_simulations, omniscient_observer=True)

# 3. Play Game

In [None]:
game_machine = InteractiveGameMachine(game)

In [None]:
game_machine.new_game()

In [None]:
_ = game_machine.list_player_actions()

In [None]:
_ = game_machine.get_observations()

In [None]:
root = mcts_bot.mcts_search(game_machine.state)
policy = compute_mcts_policy_reward(game, game_machine.state, root)
print(policy)

investigate_node(root)

In [None]:
game_machine.information_set_generator.current_player()

In [None]:
game_machine.await_action()

In [None]:
information_set_generator.register_observation(game_machine.state)

In [None]:
_ = game_machine.finish_game()