In [None]:
#importing libraries
import numpy as np
import pyspiel
import math
import matplotlib.pyplot as plt
import torch
from statistics import mean

from open_spiel.python.algorithms.alpha_zero import model as model_lib
from open_spiel.python.algorithms.alpha_zero import evaluator as evaluator_lib
from open_spiel.python.algorithms import mcts

In [None]:
# build the tensorflow model
def build_model(game, model_type):
    return model_lib.Model.build_model(
      model_type, game.observation_tensor_shape(), game.num_distinct_actions(),
      nn_width=64, nn_depth=2, weight_decay=1e-5, learning_rate=5e-4, path=None)

In [None]:
def executeEpisode(game, temperature):
    UCT_C = math.sqrt(2)
    rng = np.random.RandomState(42)
    train_inputs = []
    state = game.new_initial_state()
    
    mcts_bot = mcts.MCTSBot(
      game,
      UCT_C,
      max_simulations=100,
      solve=True,
      random_state=rng,
      evaluator=evaluator_lib.AlphaZeroEvaluator(game, model))
    
    observations = []
    action_masks = []
    policies = []
        
    while not state.is_terminal():
        root = mcts_bot.mcts_search(state)
        policy = np.zeros(game.num_distinct_actions())
        
        for c in root.children:
            policy[c.action] = c.explore_count
        policy = policy ** (1 / temperature)
        policy /= policy.sum()
        action = np.random.choice(len(policy), p=policy)
        obs = state.observation_tensor()
        act_mask = state.legal_actions_mask()
        
        observations.append(obs)
        action_masks.append(act_mask)
        policies.append(policy)
    
        # train_inputs.append(model_lib.TrainInput(obs, act_mask, policy, value=1))              
   
        state.apply_action(action) 
    
    final_game_reward = state.player_reward(0)
    train_inputs = [model_lib.TrainInput(obs, act_mask, policy, value=final_game_reward) for obs, act_mask, policy in zip(observations, action_masks, policies)]
    
    return train_inputs

# 1 Train agent

In [None]:
print_every = 10
n_playthroughs = 200

In [None]:
game = pyspiel.load_game("connect_four")
model = build_model(game, 'mlp')
print("Num variables:", model.num_trainable_variables)
model.print_trainable_variables()
losses = []
for i in range(n_playthroughs):
    train_inputs = executeEpisode(game, 1)   
    loss = model.update(train_inputs)
    if (i+1) % print_every == 0:
        print(i + 1, loss)
    losses.append(loss)

# 2 Visualization

In [None]:
baseline_losses = [(train_input.value - 0.5) ** 2 - torch.tensor(train_input.policy) @ (torch.ones(game.num_distinct_actions(), dtype=torch.float64) / sum(train_input.legals_mask)).log() 
                   for train_input in train_inputs]
baseline_losses = torch.stack(baseline_losses)

In [None]:
plt.title('Training performance')
plt.plot([l.total for l in losses], label='Train loss')
plt.axhline(baseline_losses.mean().item(), label='Uniform distribution baseline')
plt.legend()
plt.show()

In [None]:
# 3 Let learned agent play

In [None]:
state

In [None]:
print(f"The winner is: Player {1 if state.player_reward(0) == 1 else 2}!")

In [None]:
print("Selected actions:")
print(np.array(selected_actions).squeeze().tolist())