In [None]:
%load_ext autoreload
%autoreload 2

In [7]:
from simulator.game.connect import Action, Config, State

from alphazero_implementation.models.games.connect4 import CNNModel

config = Config(6, 7, 4)


path = "/Users/pveron/Code/alphazero-implementation/checkpoints/run_156/model-epoch=1999.ckpt"
model = CNNModel.load_from_checkpoint(  # type: ignore[arg-type]
    path,
    height=config.height,
    width=config.width,
    max_actions=config.width,
    num_players=config.num_players,
).eval()


In [8]:
state = config.sample_initial_state()
json = {
    "config": {"count": 4, "height": 6, "width": 7},
    "grid": [
        [0, 0, 0, -1, -1, 1, -1],
        [-1, -1, -1, -1, -1, 1, -1],
        [-1, -1, -1, -1, -1, 1, -1],
        [-1, -1, -1, -1, -1, -1, -1],
        [-1, -1, -1, -1, -1, -1, -1],
        [-1, -1, -1, -1, -1, -1, -1],
    ],
    "player": 0,
}

state = State.from_json(json)

In [9]:
def raw_policy(state: State) -> dict[Action, float]:
    [policy], _ = model.predict([state])
    return policy


In [10]:
from alphazero_implementation.search.mcts import AlphaZeroSearch, Node


def mcts_improved_policy(
    state: State, num_simulations: int = 100
) -> dict[Action, float]:
    mcts = AlphaZeroSearch(model=model, num_simulations=num_simulations)
    return mcts.run(Node(state))


In [None]:
def print_policy(policy: dict[Action, float]):
    for action, prob in policy.items():
        print(f"{action.column}: {prob:.4f}")


# Compare the three different policy functions
state_policy = raw_policy(state)

mcts_policy = mcts_improved_policy(state)


print("Raw policy:")
print_policy(state_policy)
print()

print("MCTS policy:")
print_policy(mcts_policy)
print()
