In [8]:
%reload_ext autoreload
%autoreload 2

import sys
from pathlib import Path

src_path = Path("../..").resolve()
if str(src_path) not in sys.path:
    sys.path.append(str(src_path))


In [None]:
import torch

from alphazero_simple.connect4_game import Connect4Game
from alphazero_simple.monte_carlo_tree_search import MCTS
from alphazero_simple.resnet import ResNet

# Initialize game and get dimensions
game = Connect4Game()
board_size = game.get_board_size()
action_size = game.get_action_size()

# Set device
device = torch.device(
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)

# Create model instance
model = ResNet(board_size, action_size, 9, 128)


mcts = MCTS(game, model, 600)

# Load saved weights
checkpoint_path = "/Users/pveron/Code/alphazero-implementation/lightning_logs/alphazero_less_simple/run_278_ResNet_iter200_episodes100_sims100/checkpoints/epoch=1919-step=3941250.ckpt"
checkpoint = torch.load(checkpoint_path)
state_dict = checkpoint["state_dict"]
new_state_dict = {k.replace("model.", ""): v for k, v in state_dict.items()}
model.load_state_dict(new_state_dict)
model.eval()

In [10]:
def print_board(board):
    for row in board:
        print("|", end=" ")
        for cell in row:
            if cell == 0:
                print(".", end=" ")
            elif cell == 1:
                print("X", end=" ")
            else:
                print("O", end=" ")
        print("|")
    print("-" * (board.shape[1] * 2 + 3))
    print("|", end=" ")
    for i in range(board.shape[1]):
        print(i, end=" ")
    print("|")

In [11]:
test_cases = [
    {
        "grid": [
            [0, 0, 0, 0, 0, 0, 0],
            [0, 0, 0, 0, 0, 0, 0],
            [0, 0, 0, 0, 0, 0, 0],
            [0, 0, 0, 0, 0, 0, 0],
            [0, 0, 0, 0, 0, 0, 0],
            [0, 0, 0, 0, 0, 0, 0],
        ],
        "player": 1,
        "expected_move": 3,
    },
    {
        "grid": [
            [0, 0, 0, 0, 0, 0, 0],
            [0, 0, 0, 0, 0, 0, 0],
            [0, 0, 0, 0, 0, 0, 0],
            [0, 0, 0, 0, 0, 1, 0],
            [0, 0, 0, 0, 0, 1, 0],
            [-1, -1, -1, 0, 0, 1, 0],
        ],
        "player": 1,
        "expected_move": 5,
    },
    {
        "grid": [
            [0, 0, 0, 0, 0, 0, 0],
            [0, 0, 0, 0, 0, 0, 0],
            [0, 0, 0, 0, 0, 0, 0],
            [0, 1, 0, 0, 0, -1, 0],
            [0, 1, 0, 0, 0, -1, 0],
            [0, 1, 0, 0, 0, -1, 0],
        ],
        "player": 1,
        "expected_move": 1,
    },
    {
        "grid": [
            [0, 0, 0, 0, 0, 0, 0],
            [0, 0, 0, 0, 0, 0, 0],
            [0, 0, 0, 0, 0, 0, 0],
            [0, 0, 0, 0, 0, -1, 0],
            [0, 0, 0, 0, 0, -1, 0],
            [1, 1, 1, 0, 0, -1, 0],
        ],
        "player": 1,
        "expected_move": 3,
    },
    {
        "grid": [
            [0, 0, 0, 0, 0, 0, 0],
            [0, 0, 0, 0, 0, 0, 0],
            [0, 0, 0, 0, 0, 0, 0],
            [0, 0, 0, 0, 0, 0, 0],
            [0, 0, -1, -1, -1, 0, 0],
            [0, 0, 1, 1, 1, 0, 0],
        ],
        "player": 1,
        "expected_move": 1,
    },
    {
        "grid": [
            [0, 0, 0, 0, 0, 0, 0],
            [0, 0, 0, 0, 0, 0, 0],
            [0, 0, 0, 0, 0, 0, 0],
            [0, 0, 0, 0, 0, 0, 0],
            [0, 0, 1, 1, 1, 0, 0],
            [0, 1, -1, -1, -1, 0, 0],
        ],
        "player": -1,
        "expected_move": 5,
    },
    {
        "grid": [
            [0, 0, 0, 0, 0, 0, 0],
            [0, 0, 0, 0, 0, 0, 0],
            [0, 0, 0, 0, 0, 0, 0],
            [0, 0, 0, 0, 0, -1, 0],
            [0, 0, 0, 0, 0, -1, 1],
            [0, 1, 0, 1, 0, -1, 1],
        ],
        "player": -1,
        "expected_move": 5,
    },
    {
        "grid": [
            [0, 0, 0, 0, 0, 0, 0],
            [0, 0, 0, 0, 0, 0, 0],
            [0, 0, 0, 0, 0, 0, 0],
            [0, 0, 0, 0, 0, -1, 0],
            [0, 0, 0, 0, 0, -1, 0],
            [1, 1, 0, 1, 0, -1, 1],
        ],
        "player": -1,
        "expected_move": 5,
    },
    {
        "grid": [
            [-1, 0, 0, 0, 0, 0, 0],
            [-1, 0, 0, 0, 1, 0, 0],
            [1, 0, -1, 0, -1, 0, 0],
            [1, 0, 1, 0, 1, 1, 0],
            [1, -1, 1, 1, -1, -1, -1],
            [-1, -1, 1, -1, -1, 1, 1],
        ],
        "player": 1,
        "expected_move": 5,
    },
    {
        "grid": [
            [0, 0, 0, 0, 0, 0, 0],
            [0, 0, 0, 0, 0, 0, 0],
            [0, 0, 0, 0, 0, 0, 0],
            [-1, 1, 0, -1, 0, 0, 0],
            [1, 1, 0, -1, 0, 0, 0],
            [1, 1, 0, -1, 0, 0, -1],
        ],
        "player": 1,
        "expected_move": 1,
    },
]


In [None]:
import numpy as np

# Convert each test case to a tensor and get model predictions
for i, test_case in enumerate(test_cases[:]):
    # Convert grid to tensor format
    board = np.array(test_case["grid"])
    player = test_case["player"]
    canonical_board = game.get_canonical_board(board, player)

    # Get model predictions
    [policy], [value] = model.predict([canonical_board])

    # Get predicted move (argmax of policy)
    predicted_move = np.argmax(policy)

    node = mcts.run(canonical_board, player)

    action_probs = [0 for _ in range(game.get_action_size())]
    for k, v in node.children.items():
        action_probs[k] = v.visit_count
    action_probs = action_probs / np.sum(action_probs)

    print(f"\nTest Case {i+1}:")
    print(f"Player: {'X' if player == 1 else 'O'}")
    print_board(board)
    print(f"Expected move: {test_case['expected_move']}")
    print(f"Predicted move: {predicted_move}")
    print(f"MCTS move: {node.select_action(temperature=0)}")
    print(f"Move probabilities: {policy.round(3)}")
    print(f"MCTS move probabilities: {action_probs.round(3)}")
    print(f"Predicted value: {value:.3f}")


In [None]:
def play_against_model():
    game = Connect4Game()
    board = game.get_init_board()
    human_player = 1  # You'll play as O
    model_player = -1  # Model plays as X
    current_player = 1  # X goes first

    while True:
        print("\nCurrent board:")
        print_board(board)

        if current_player == model_player:
            # Model's turn
            policy, value = model.predict(board)
            print(f"policy: {policy.round(3)}, value: {value:.3f}")
            move = np.argmax(policy)
            print(f"Model plays column {move}")
        else:
            # Human's turn
            policy, value = model.predict(board)
            print(f"policy: {policy.round(3)}, value: {value:.3f}")
            move = np.argmax(policy)
            print(f"User should play column {move}")
            while True:
                try:
                    move = int(input("Your turn! Enter column (0-6): "))
                    if 0 <= move <= 6 and board[0][move] == 0:  # Check if move is valid
                        break
                    print("Invalid move. Try again.")
                except ValueError as e:
                    if e.args[0] == "invalid literal for int() with base 10: ''":
                        return
                    print("Please enter a number between 0 and 6.")

        # Make the move
        board, current_player = game.get_next_state(board, current_player, move)

        reward = game.get_reward_for_player(board, current_player)

        if reward == -1:
            print("\nFinal board:")
            print_board(board)
            print("You win!")
            break
        elif reward == 1:
            print("\nFinal board:")
            print_board(board)
            print("You lose!")
            break

        # # Check for game end
        # if game.is_win(board, current_player):
        #     print("\nFinal board:")
        #     print_board(board)
        #     winner = "Model" if current_player == model_player else "Human"
        #     print(f"{winner} wins!")
        #     break
        # elif game.is_wincheck_draw(board):
        #     print("\nFinal board:")
        #     print_board(board)
        #     print("It's a draw!")
        #     break


# Start the game
play_against_model()