In [1]:
import numpy as np
import torch

from alphazero_simple.connect4_game import Connect4Game
from alphazero_simple.resnet import ResNet

# Initialize game and get dimensions
game = Connect4Game()
board_size = game.get_board_size()
action_size = game.get_action_size()

# Create model instance
model = ResNet(board_size, action_size, 9, 128)

# Load saved weights
checkpoint_path = "/Users/pveron/Code/alphazero-implementation/lightning_logs/alphazero_less_simple/run_278_ResNet_iter200_episodes100_sims100/checkpoints/epoch=1919-step=3941250.ckpt"
checkpoint = torch.load(checkpoint_path)
model.load_state_dict(checkpoint["state_dict"])
model.eval()


  checkpoint = torch.load(checkpoint_path)


ResNet(
  (input_conv): Sequential(
    (0): Conv2d(3, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
  )
  (residual_blocks): ModuleList(
    (0-8): 9 x ResBlock(
      (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (policy_head): Sequential(
    (0): Conv2d(128, 32, kernel_size=(1, 1), stride=(1, 1))
    (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Flatten(start_dim=1, end_dim=-1)
    (4): Linear(in_features=1344, out_features=7, bias=True)
  )
  (value_head): Sequential(
    (0): Conv2d(128, 3, kern

In [11]:
import sys

from alphazero_simple.monte_carlo_tree_search import MCTS


def print_board(board: np.ndarray):
    for row in board:
        print("|", end=" ")
        for cell in row:
            if cell == 0:
                print(".", end=" ")
            elif cell == 1:
                print("X", end=" ")
            else:
                print("O", end=" ")
        print("|")
    print("-" * (board.shape[1] * 2 + 3))
    print("|", end=" ")
    for i in range(board.shape[1]):
        print(i, end=" ")
    print("|")


def get_human_move(valid_moves):
    while True:
        try:
            move = input("Enter your move (0-6):")
            move = int(move)
            if 0 <= move <= 6 and valid_moves[move]:
                return move
            else:
                print("Invalid move. Try again.")
        except ValueError as e:
            if e.args[0] == "invalid literal for int() with base 10: ''":
                print("\nExiting game...")
                sys.exit()
            print("Please enter a number between 0 and 6.")


# Initialize game state
mcts = MCTS(game, model, 100)
state = game.get_init_board()
current_player = -1  # 1 for human (X), -1 for AI (O)

print("\nGame starts! You are X, AI is O")
print_board(state)

while True:
    valid_moves = game.get_valid_moves(state)

    if current_player == 1:  # Human's turn
        action = get_human_move(valid_moves)
    else:  # AI's turn
        print("AI is thinking...")
        canonical_board = game.get_canonical_board(state, current_player)
        root = mcts.run(canonical_board, 1)
        _, [predicted_value] = model.predict([canonical_board])
        print("root.value():", root.value(), "predicted_value:", predicted_value)
        action = root.select_action(temperature=0)

    # Make move
    state, current_player = game.get_next_state(state, current_player, action)
    print_board(state)

    # Check for game end
    reward = game.get_reward_for_player(state, current_player)
    if reward is not None:
        if reward == 1:
            print(
                "Game Over! You win!" if current_player == 1 else "Game Over! AI wins!"
            )
        elif reward == -1:
            print(
                "Game Over! AI wins!" if current_player == 1 else "Game Over! You win!"
            )
        else:
            print("Game Over! It's a draw!")
        break



Game starts! You are X, AI is O
| . . . . . . . |
| . . . . . . . |
| . . . . . . . |
| . . . . . . . |
| . . . . . . . |
| . . . . . . . |
-----------------
| 0 1 2 3 4 5 6 |
AI is thinking...
root.value(): 0.341232116445899 predicted_value: 0.28778875
| . . . . . . . |
| . . . . . . . |
| . . . . . . . |
| . . . . . . . |
| . . . . . . . |
| . . . O . . . |
-----------------
| 0 1 2 3 4 5 6 |
invalid literal for int() with base 10: ''

Exiting game...


SystemExit: 

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)
