In [1]:
%load_ext autoreload
%autoreload 2

In [1]:
from simulator.game.connect import Action, Config, State

from alphazero_implementation.models.games.connect4 import CNNModel

config = Config(6, 7, 4)


path = "/Users/pveron/Code/alphazero-implementation/checkpoints/run_156/model-epoch=1999.ckpt"
model = CNNModel.load_from_checkpoint(  # type: ignore[arg-type]
    path,
    height=config.height,
    width=config.width,
    max_actions=config.width,
    num_players=config.num_players,
).eval()


In [2]:
state = config.sample_initial_state()
json = {
    "config": {"count": 4, "height": 6, "width": 7},
    "grid": [
        [0, 0, 0, -1, -1, 1, -1],
        [-1, -1, -1, -1, -1, 1, -1],
        [-1, -1, -1, -1, -1, 1, -1],
        [-1, -1, -1, -1, -1, -1, -1],
        [-1, -1, -1, -1, -1, -1, -1],
        [-1, -1, -1, -1, -1, -1, -1],
    ],
    "player": 0,
}

state = State.from_json(json)

In [3]:
def raw_policy(state: State) -> dict[Action, float]:
    [policy], _ = model.predict([state])
    return policy


In [4]:
from alphazero_implementation.mcts.agent import MCTSAgent


def improved_policy_v1(state: State) -> dict[Action, float]:
    mcts = MCTSAgent(model, 1, 100, state)

    episodes = mcts.run()
    policy = episodes[0].samples[-1].policy
    return policy


In [5]:
from alphazero_implementation.mcts_v2.mcts import AlphaZeroMCTS
from alphazero_implementation.mcts_v2.node import Node


def improved_policy_v2(state: State) -> dict[Action, float]:
    mcts = AlphaZeroMCTS(model)

    root = Node(state)
    policy = mcts.run(root, 100)
    return policy

In [6]:
print("State:")
print(state.grid)
print()


State:
[[ 0  0  0 -1 -1  1 -1]
 [-1 -1 -1 -1 -1  1 -1]
 [-1 -1 -1 -1 -1  1 -1]
 [-1 -1 -1 -1 -1 -1 -1]
 [-1 -1 -1 -1 -1 -1 -1]
 [-1 -1 -1 -1 -1 -1 -1]]



In [7]:
def print_policy(policy: dict[Action, float]):
    for action, prob in policy.items():
        print(f"{action.column}: {prob:.4f}")


# Compare the three different policy functions
state_policy = raw_policy(state)

mcts_policy_v1 = improved_policy_v1(state)

mcts_policy_v2 = improved_policy_v2(state)


print("Raw policy:")
print_policy(state_policy)
print()

print("MCTS v1 policy:")
print_policy(mcts_policy_v1)
print()

print("MCTS v2 policy:")
print_policy(mcts_policy_v2)


run took 0.11 seconds to run.
Raw policy:
0: 0.0423
1: 0.1359
2: 0.0286
3: 0.6309
4: 0.0643
5: 0.0581
6: 0.0399

MCTS v1 policy:
0: 0.0057
3: 0.0402
5: 0.9540

MCTS v2 policy:
0: 0.0000
1: 0.0500
2: 0.0000
3: 0.9400
4: 0.0000
5: 0.0000
6: 0.0000


In [11]:
mcts_policy_v2 = improved_policy_v2(state)
mcts_policy_v2

{<simulator.game.connect.Action at 0x144377d70>: 0.0,
 <simulator.game.connect.Action at 0x1442ebc90>: 0.05,
 <simulator.game.connect.Action at 0x1442e80d0>: 0.0,
 <simulator.game.connect.Action at 0x1442e8e10>: 0.94,
 <simulator.game.connect.Action at 0x1442fdb30>: 0.0,
 <simulator.game.connect.Action at 0x1442fc4f0>: 0.0,
 <simulator.game.connect.Action at 0x1442fcf70>: 0.0}