In [1]:
#importing libraries
import numpy as np
from open_spiel.python.algorithms.alpha_zero import model as model_lib
from open_spiel.python.algorithms.alpha_zero import evaluator as evaluator_lib
import pyspiel
from open_spiel.python.algorithms import mcts
import math

In [2]:
# build the tensorflow model
def build_model(game, model_type):
    return model_lib.Model.build_model(
      model_type, game.observation_tensor_shape(), game.num_distinct_actions(),
      nn_width=32, nn_depth=2, weight_decay=1e-4, learning_rate=0.01, path=None)

In [3]:
def executeEpisode(game, temperature):
    UCT_C = math.sqrt(2)
    rng = np.random.RandomState(42)
    train_inputs = []
    state = game.new_initial_state()
    
    mcts_bot = mcts.MCTSBot(
      game,
      UCT_C,
      max_simulations=5,
      solve=True,
      random_state=rng,
      evaluator=evaluator_lib.AlphaZeroEvaluator(game, model))                                      
        
    while not state.is_terminal():
        root = mcts_bot.mcts_search(state)
        policy = np.zeros(game.num_distinct_actions())
        
        for c in root.children:
            policy[c.action] = c.explore_count
        policy = policy ** (1 / temperature)
        policy /= policy.sum()
        action = np.random.choice(len(policy), p=policy)
        obs = state.observation_tensor()
        act_mask = state.legal_actions_mask()
    
        train_inputs.append(model_lib.TrainInput(obs, act_mask, policy, value=1))              
   
        state.apply_action(action) 
    
    #TODO: Use state.player_reward(0) to get actual player reward after rollout
    
    return train_inputs

In [4]:
game = pyspiel.load_game("go")
model = build_model(game, 'mlp')
print("Num variables:", model.num_trainable_variables)
model.print_trainable_variables()
losses = []
for i in range(200):
    train_inputs = executeEpisode(game, 1)   
    loss = model.update(train_inputs)                                  
    print(i, loss)
    losses.append(loss)

Instructions for updating:
If using Keras pass *_constraint arguments to layers.
Num variables: 61387
torso_0_dense/kernel:0: (1444, 32)
torso_0_dense/bias:0: (32,)
torso_1_dense/kernel:0: (32, 32)
torso_1_dense/bias:0: (32,)
policy_dense/kernel:0: (32, 32)
policy_dense/bias:0: (32,)
policy/kernel:0: (32, 362)
policy/bias:0: (362,)
value_dense/kernel:0: (32, 32)
value_dense/bias:0: (32,)
value/kernel:0: (32, 1)
value/bias:0: (1,)
0 Losses(total: 6.197, policy: 4.468, value: 1.718, l2: 0.011)
1 Losses(total: 4.847, policy: 4.836, value: 0.000, l2: 0.011)
2 Losses(total: 5.328, policy: 5.317, value: 0.000, l2: 0.010)
3 Losses(total: 4.205, policy: 4.195, value: 0.000, l2: 0.010)
4 Losses(total: 4.374, policy: 4.364, value: 0.000, l2: 0.011)
5 Losses(total: 4.488, policy: 4.477, value: 0.000, l2: 0.011)
6 Losses(total: 4.632, policy: 4.621, value: 0.000, l2: 0.011)
7 Losses(total: 4.905, policy: 4.893, value: 0.000, l2: 0.011)
8 Losses(total: 4.425, policy: 4.413, value: 0.000, l2: 0.012)

In [66]:
state = game.new_initial_state()
selected_actions = []
while not state.is_terminal():
    obs = state.observation_tensor()
    value, policy = model.inference([obs], [state.legal_actions_mask()])
    action = policy.argmax(-1)
    state.apply_action(policy.argmax(-1))
    selected_actions.append(action)

In [68]:
print(f"The winner is: Player {1 if state.player_reward(0) == 1 else 2}!")

The winner is: Player 1!


In [73]:
print("Selected actions:")
print(np.array(selected_actions).squeeze().tolist())

Selected actions:
[163, 358, 15, 337, 33, 342, 322, 58, 252, 8, 9, 57, 27, 268, 24, 148, 48, 360, 160, 194, 357, 0, 303, 21, 291, 14, 289, 16, 288, 361, 43, 361, 139, 361, 42, 361, 114, 361, 111, 361, 112, 361, 118, 361, 113, 361, 116, 361, 117, 361, 119, 361, 110, 361, 40, 361, 121, 13, 120, 361, 44, 361, 115, 361, 38, 261, 1, 361, 87, 361, 359, 361, 45, 361, 86, 348, 287, 145, 46, 140, 332, 355, 108, 138, 104, 129, 90, 154, 41, 353, 137, 335, 149, 143, 109, 130, 98, 96, 136, 278, 106, 345, 331, 52, 135, 152, 352, 323, 356, 245, 279, 193, 94, 53, 298, 354, 82, 133, 350, 134, 92, 280, 272, 304, 271, 204, 300, 343, 150, 141, 29, 155, 147, 125, 320, 338, 356, 207, 229, 85, 328, 127, 12, 28, 301, 101, 324, 361, 361]
