In [None]:
%load_ext autoreload
%autoreload 2

In [15]:
from simulator.game.connect import Action, Config, State

from alphazero_implementation.models.games.connect4 import CNNModel

config = Config(6, 7, 4)


# path = "/Users/pveron/Code/alphazero-implementation/checkpoints/run_156/model-epoch=1999.ckpt"
path = "/Users/pveron/Code/alphazero-implementation/lightning_logs/alphazero/run_168_iter200_episodes100_sims100/checkpoints/epoch=1999-step=1431360.ckpt"
path = "/Users/pveron/Code/alphazero-implementation/lightning_logs/alphazero/run_169_CNNModel_iter200_episodes100_sims100/checkpoints/epoch=399-step=218810.ckpt"
model = CNNModel.load_from_checkpoint(  # type: ignore[arg-type]
    path,
    height=config.height,
    width=config.width,
    max_actions=config.width,
    num_players=config.num_players,
).eval()


In [16]:
def raw_policy(state: State) -> tuple[dict[Action, float], list[float]]:
    [policy], [value] = model.predict([state])
    return policy, value


In [17]:
from alphazero_implementation.core.search.mcts import AlphaZeroSearch, Node


def mcts_improved_policy(
    state: State, num_simulations: int = 100
) -> tuple[dict[Action, float], float]:
    mcts = AlphaZeroSearch(model=model, num_simulations=num_simulations)
    return mcts.run(Node(state))


In [18]:
def print_policy(policy: dict[Action, float]):
    for action, prob in policy.items():
        print(f"{action.column}: {prob:.4f}")


In [19]:
def compare_policies(state: State, print_policy: bool = False):
    # Compare the three different policy functions
    state_policy, state_value = raw_policy(state)

    mcts_policy, mcts_value = mcts_improved_policy(state)

    if print_policy:
        print("Raw policy:")
        print(state_policy)
        print()

        print("MCTS policy:")
        print(mcts_policy)
        print()

        print("Raw value:")
        print(state_value)
        print()

    return state_policy, mcts_policy, state_value, mcts_value


In [20]:
final_situations = [
    {
        "grid": [
            [0, 0, 0, -1, -1, 1, -1],
            [-1, -1, -1, -1, -1, 1, -1],
            [-1, -1, -1, -1, -1, 1, -1],
            [-1, -1, -1, -1, -1, -1, -1],
            [-1, -1, -1, -1, -1, -1, -1],
            [-1, -1, -1, -1, -1, -1, -1],
        ],
        "player": 0,
        "expected_move": 3,
    },
    {
        "grid": [
            [1, 1, 1, -1, -1, 0, -1],
            [-1, -1, -1, -1, -1, 0, -1],
            [-1, -1, -1, -1, -1, 0, -1],
            [-1, -1, -1, -1, -1, -1, -1],
            [-1, -1, -1, -1, -1, -1, -1],
            [-1, -1, -1, -1, -1, -1, -1],
        ],
        "player": 0,
        "expected_move": 5,
    },
    {
        "grid": [
            [-1, 0, 0, 0, -1, -1, -1],
            [-1, 1, 1, 1, -1, -1, -1],
            [-1, -1, -1, -1, -1, -1, -1],
            [-1, -1, -1, -1, -1, -1, -1],
            [-1, -1, -1, -1, -1, -1, -1],
            [-1, -1, -1, -1, -1, -1, -1],
        ],
        "player": 0,
        "expected_move": 0,
    },
    {
        "grid": [
            [0, 0, 0, -1, -1, 1, -1],
            [-1, 0, -1, -1, -1, 1, -1],
            [-1, -1, -1, -1, -1, 1, -1],
            [-1, -1, -1, -1, -1, -1, -1],
            [-1, -1, -1, -1, -1, -1, -1],
            [-1, -1, -1, -1, -1, -1, -1],
        ],
        "player": 1,
        "expected_move": 5,
    },
    {
        "grid": [
            [0, 1, 0, 1, -1, -1, -1],  # Row 0 (bottom)
            [1, 0, 1, 0, -1, -1, -1],  # Row 1
            [0, -1, -1, -1, -1, -1, -1],  # Row 2
            [1, -1, -1, -1, -1, -1, -1],  # Row 3
            [-1, -1, -1, -1, -1, -1, -1],  # Row 4
            [-1, -1, -1, -1, -1, -1, -1],  # Row 5 (top)
        ],
        "player": 1,
        "expected_move": 1,
    },
    {
        "grid": [
            [0, 1, 0, 1, -1, 0, 1],  # Row 0 (bottom)
            [0, 0, 0, -1, -1, 1, 0],  # Row 1
            [1, 1, -1, -1, -1, 1, -1],  # Row 2
            [-1, -1, -1, -1, -1, -1, -1],  # Row 3
            [-1, -1, -1, -1, -1, -1, -1],  # Row 4
            [-1, -1, -1, -1, -1, -1, -1],  # Row 5 (top)
        ],
        "player": 0,
        "expected_move": 3,
    },
    {
        "grid": [
            [0, 1, 0, 1, 1, -1, -1],  # Row 0 (bottom)
            [1, 0, 0, 1, -1, -1, -1],  # Row 1
            [0, 0, 1, 0, -1, -1, -1],  # Row 2
            [-1, -1, -1, -1, -1, -1, -1],  # Row 3
            [-1, -1, -1, -1, -1, -1, -1],  # Row 4
            [-1, -1, -1, -1, -1, -1, -1],  # Row 5 (top)
        ],
        "player": 1,
        "expected_move": 1,
    },
]

In [21]:
def print_grid(grid: list[list[int]], use_colors: bool = True):
    """Print the Connect 4 grid with nice formatting.

    Args:
        grid: 2D list representing the game grid where:
             -1 = empty
              0 = player 1 (typically red)
              1 = player 2 (typically yellow)
        use_colors: Whether to use ANSI color codes (default: True)
    """
    # Constants for grid display
    EMPTY = "âšª"
    P1 = "ðŸŸ¡"
    P2 = "ðŸ”´"
    VERTICAL = "â”‚"
    HORIZONTAL = "â”€"
    BOTTOM_LEFT = "â””"
    BOTTOM_RIGHT = "â”˜"

    # Print column numbers
    print("  ", end="")
    for col in range(len(grid[0])):
        print(f" {col} ", end="")
    print()

    # Print the grid rows
    for row in grid:
        print(f" {VERTICAL}", end="")
        for cell in row:
            piece = EMPTY if cell == -1 else (P1 if cell == 0 else P2)
            print(f" {piece}", end="")
        print(f" {VERTICAL}")

    # Print bottom border
    width = len(grid[0]) * 3 - 1
    print(f" {BOTTOM_LEFT}{HORIZONTAL * width}{BOTTOM_RIGHT}")

In [28]:
from typing import Any


def print_comparisons(situations: list[dict[str, Any]]):
    for situation in situations:
        json = {"config": {"count": 4, "height": 6, "width": 7}, **situation}
        state = State.from_json(json)
        state_policy, mcts_policy, state_value, mcts_value = compare_policies(state)

        print(f"Grid played by {state.player}:")
        print_grid(situation["grid"])
        print()

        # Check if the move with highest probability matches expected move
        raw_best_move = max(state_policy.items(), key=lambda x: x[1])[0]
        mcts_best_move = max(mcts_policy.items(), key=lambda x: x[1])[0]
        expected_move = situation.get("expected_move")

        print(f"Raw value: {state_value}")
        print(
            f"Raw best move: {raw_best_move.column} (prob: {state_policy[raw_best_move]:.2f})"
        )
        print(f"MCTS value: {mcts_value}")
        print(
            f"MCTS best move: {mcts_best_move.column} (prob: {mcts_policy[mcts_best_move]:.2f})"
        )
        print(f"Expected move: {expected_move}")

        if (
            raw_best_move.column != expected_move
            or mcts_best_move.column != expected_move
        ):
            if raw_best_move.column != expected_move:
                print(
                    f"Raw policy chose column {raw_best_move.column} (prob: {state_policy[raw_best_move]:.2f}) but expected {expected_move}"
                )

            if mcts_best_move.column != expected_move:
                print(
                    f"MCTS policy chose column {mcts_best_move.column} (prob: {mcts_policy[mcts_best_move]:.2f}) but expected {expected_move}"
                )

In [29]:
print_comparisons(final_situations)


Grid played by 0:
   0  1  2  3  4  5  6 
 â”‚ ðŸŸ¡ ðŸŸ¡ ðŸŸ¡ âšª âšª ðŸ”´ âšª â”‚
 â”‚ âšª âšª âšª âšª âšª ðŸ”´ âšª â”‚
 â”‚ âšª âšª âšª âšª âšª ðŸ”´ âšª â”‚
 â”‚ âšª âšª âšª âšª âšª âšª âšª â”‚
 â”‚ âšª âšª âšª âšª âšª âšª âšª â”‚
 â”‚ âšª âšª âšª âšª âšª âšª âšª â”‚
 â””â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”˜

Raw value: [0.15005777776241302, -0.15005643665790558]
Raw best move: 3 (prob: 1.00)
MCTS value: 0.9915005777776241
MCTS best move: 3 (prob: 1.00)
Expected move: 3
Grid played by 0:
   0  1  2  3  4  5  6 
 â”‚ ðŸ”´ ðŸ”´ ðŸ”´ âšª âšª ðŸŸ¡ âšª â”‚
 â”‚ âšª âšª âšª âšª âšª ðŸŸ¡ âšª â”‚
 â”‚ âšª âšª âšª âšª âšª ðŸŸ¡ âšª â”‚
 â”‚ âšª âšª âšª âšª âšª âšª âšª â”‚
 â”‚ âšª âšª âšª âšª âšª âšª âšª â”‚
 â”‚ âšª âšª âšª âšª âšª âšª âšª â”‚
 â””â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”˜

Raw value: [-0.07849454879760742, 0.07849317044019699]
Raw best move: 5 (prob: 1.00)
MCTS value: 0.989215054512024
MCTS best move: 5 (prob: 1.00)
Expected mov

In [None]:
for situation in final_situations[:2]:
    json = {"config": {"count": 4, "height": 6, "width": 7}, **situation}
    state = State.from_json(json)
    print(model._states_to_tensor([state]))

In [41]:
start_situations = [
    {
        "grid": [
            [-1, -1, -1, -1, -1, -1, -1],
            [-1, -1, -1, -1, -1, -1, -1],
            [-1, -1, -1, -1, -1, -1, -1],
            [-1, -1, -1, -1, -1, -1, -1],
            [-1, -1, -1, -1, -1, -1, -1],
            [-1, -1, -1, -1, -1, -1, -1],
        ],
        "player": 0,
    },
    {
        "grid": [
            [0, -1, -1, -1, -1, -1, -1],
            [-1, -1, -1, -1, -1, -1, -1],
            [-1, -1, -1, -1, -1, -1, -1],
            [-1, -1, -1, -1, -1, -1, -1],
            [-1, -1, -1, -1, -1, -1, -1],
            [-1, -1, -1, -1, -1, -1, -1],
        ],
        "player": 1,
    },
    {
        "grid": [
            [0, -1, -1, -1, 1, -1, -1],
            [-1, -1, -1, -1, -1, -1, -1],
            [-1, -1, -1, -1, -1, -1, -1],
            [-1, -1, -1, -1, -1, -1, -1],
            [-1, -1, -1, -1, -1, -1, -1],
            [-1, -1, -1, -1, -1, -1, -1],
        ],
        "player": 0,
    },
    {
        "grid": [
            [0, -1, -1, -1, 1, -1, -1],
            [-1, -1, -1, -1, 0, -1, -1],
            [-1, -1, -1, -1, -1, -1, -1],
            [-1, -1, -1, -1, -1, -1, -1],
            [-1, -1, -1, -1, -1, -1, -1],
            [-1, -1, -1, -1, -1, -1, -1],
        ],
        "player": 1,
    },
    {
        "grid": [
            [0, -1, -1, -1, 1, -1, 1],
            [-1, -1, -1, -1, 0, -1, -1],
            [-1, -1, -1, -1, -1, -1, -1],
            [-1, -1, -1, -1, -1, -1, -1],
            [-1, -1, -1, -1, -1, -1, -1],
            [-1, -1, -1, -1, -1, -1, -1],
        ],
        "player": 0,
    },
    {
        "grid": [
            [0, -1, -1, -1, 1, -1, 1],
            [-1, -1, -1, -1, 0, -1, -1],
            [-1, -1, -1, -1, 0, -1, -1],
            [-1, -1, -1, -1, -1, -1, -1],
            [-1, -1, -1, -1, -1, -1, -1],
            [-1, -1, -1, -1, -1, -1, -1],
        ],
        "player": 1,
    },
    {
        "grid": [
            [0, -1, -1, -1, 1, 1, 1],
            [-1, -1, -1, -1, 0, -1, -1],
            [-1, -1, -1, -1, 0, -1, -1],
            [-1, -1, -1, -1, -1, -1, -1],
            [-1, -1, -1, -1, -1, -1, -1],
            [-1, -1, -1, -1, -1, -1, -1],
        ],
        "player": 0,
        "expected_move": 3,
    },
    {
        "grid": [
            [0, 0, -1, -1, 1, 1, 1],
            [-1, -1, -1, -1, 0, -1, -1],
            [-1, -1, -1, -1, 0, -1, -1],
            [-1, -1, -1, -1, -1, -1, -1],
            [-1, -1, -1, -1, -1, -1, -1],
            [-1, -1, -1, -1, -1, -1, -1],
        ],
        "player": 1,
        "expected_move": 3,
    },
    {
        "grid": [
            [0, -1, -1, 0, 1, 1, 1],
            [-1, -1, -1, -1, 0, -1, -1],
            [-1, -1, -1, -1, 0, -1, -1],
            [-1, -1, -1, -1, -1, -1, -1],
            [-1, -1, -1, -1, -1, -1, -1],
            [-1, -1, -1, -1, -1, -1, -1],
        ],
        "player": 1,
    },
    {
        "grid": [
            [0, -1, -1, 0, 1, 1, 1],
            [-1, -1, -1, -1, 0, 1, -1],
            [-1, -1, -1, -1, 0, -1, -1],
            [-1, -1, -1, -1, -1, -1, -1],
            [-1, -1, -1, -1, -1, -1, -1],
            [-1, -1, -1, -1, -1, -1, -1],
        ],
        "player": 0,
    },
    {
        "grid": [
            [0, -1, 0, 0, 1, 1, 1],
            [-1, -1, -1, -1, 0, 1, -1],
            [-1, -1, -1, -1, 0, -1, -1],
            [-1, -1, -1, -1, -1, -1, -1],
            [-1, -1, -1, -1, -1, -1, -1],
            [-1, -1, -1, -1, -1, -1, -1],
        ],
        "player": 1,
    },
    {
        "grid": [
            [0, -1, 0, 0, 1, 1, 1],
            [-1, -1, -1, -1, 0, 1, -1],
            [-1, -1, -1, -1, 0, 1, -1],
            [-1, -1, -1, -1, -1, -1, -1],
            [-1, -1, -1, -1, -1, -1, -1],
            [-1, -1, -1, -1, -1, -1, -1],
        ],
        "player": 0,
    },
]

print_comparisons(start_situations)


Grid played by 0:
   0  1  2  3  4  5  6 
 â”‚ âšª âšª âšª âšª âšª âšª âšª â”‚
 â”‚ âšª âšª âšª âšª âšª âšª âšª â”‚
 â”‚ âšª âšª âšª âšª âšª âšª âšª â”‚
 â”‚ âšª âšª âšª âšª âšª âšª âšª â”‚
 â”‚ âšª âšª âšª âšª âšª âšª âšª â”‚
 â”‚ âšª âšª âšª âšª âšª âšª âšª â”‚
 â””â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”˜

Raw value: [0.31763574481010437, -0.317638635635376]
Raw best move: 0 (prob: 0.39)
MCTS value: 0.2959209460951388
MCTS best move: 1 (prob: 0.58)
Expected move: None
Raw policy chose column 0 (prob: 0.39) but expected None
MCTS policy chose column 1 (prob: 0.58) but expected None
Grid played by 1:
   0  1  2  3  4  5  6 
 â”‚ ðŸŸ¡ âšª âšª âšª âšª âšª âšª â”‚
 â”‚ âšª âšª âšª âšª âšª âšª âšª â”‚
 â”‚ âšª âšª âšª âšª âšª âšª âšª â”‚
 â”‚ âšª âšª âšª âšª âšª âšª âšª â”‚
 â”‚ âšª âšª âšª âšª âšª âšª âšª â”‚
 â”‚ âšª âšª âšª âšª âšª âšª âšª â”‚
 â””â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”˜

Raw value: [0.2978994846343994, -0.2979014217853546]
