In [None]:
# install packages
try:
    __import__('numpy')
    print('numpy is already installed')
except ImportError:
    print("package not found installing")
    %pip install numpy

try:
    __import__('pygame')
    print('numpy is already installed')
except ImportError:
    print("package not found installing")
    %pip install pygame

In [5]:
import numpy as np
import pygame
import sys
import math
import random
import threading

# ---------------------- Initialization ----------------------
# Board dimensions and game constants
ROW_COUNT = 6
COLUMN_COUNT = 7
PLAYER = 'X'
AI = 'O'
EMPTY = ' '

# Colors (RGB)
BLUE = (0, 0, 255)
BLACK = (0, 0, 0)
RED = (255, 0, 0)
YELLOW = (255, 255, 0)
WHITE = (255, 255, 255)

# Graphics settings
SQUARESIZE = 100
RADIUS = int(SQUARESIZE / 2 - 5)
width = COLUMN_COUNT * SQUARESIZE
height = (ROW_COUNT + 1) * SQUARESIZE  # extra row on top for move preview

In [9]:
# ---------------------- Board Functions ----------------------
def create_board():
    # Use a board filled with EMPTY strings so that our drop & win checks work properly
    return np.full((ROW_COUNT, COLUMN_COUNT), EMPTY, dtype=str)

def is_valid_location(board, col):
    return board[ROW_COUNT - 1][col] == EMPTY

def get_next_open_row(board, col):
    for r in range(ROW_COUNT):
        if board[r][col] == EMPTY:
            return r
    return None

def drop_piece(board, row, col, piece):
    board[row][col] = piece

def winning_move(board, piece):
    # Check horizontal locations for win
    for r in range(ROW_COUNT):
        for c in range(COLUMN_COUNT - 3):
            if board[r][c] == piece and board[r][c+1] == piece and board[r][c+2] == piece and board[r][c+3] == piece:
                return True
    # Check vertical locations for win
    for c in range(COLUMN_COUNT):
        for r in range(ROW_COUNT - 3):
            if board[r][c] == piece and board[r+1][c] == piece and board[r+2][c] == piece and board[r+3][c] == piece:
                return True
    # Check positively sloped diagonals
    for r in range(ROW_COUNT - 3):
        for c in range(COLUMN_COUNT - 3):
            if board[r][c] == piece and board[r+1][c+1] == piece and board[r+2][c+2] == piece and board[r+3][c+3] == piece:
                return True
    # Check negatively sloped diagonals
    for r in range(3, ROW_COUNT):
        for c in range(COLUMN_COUNT - 3):
            if board[r][c] == piece and board[r-1][c+1] == piece and board[r-2][c+2] == piece and board[r-3][c+3] == piece:
                return True
    return False

# ---------------------- MCTS Components ----------------------
class Node:
    def __init__(self, board, parent=None, move=None):
        self.board = board.copy()
        self.parent = parent
        self.move = move
        self.children = []
        self.wins = 0
        self.visits = 0
        self.untried_moves = [col for col in range(COLUMN_COUNT) if is_valid_location(board, col)]
        
    def select_child(self):
        # Use the UCT (Upper Confidence Bound for Trees) formula
        exploration_weight = 1.41  # roughly sqrt(2)
        best_score = -float('inf')
        best_child = None
        
        for child in self.children:
            if child.visits == 0:
                score = float('inf')
            else:
                exploit = child.wins / child.visits
                explore = exploration_weight * math.sqrt(math.log(self.visits) / child.visits)
                score = exploit + explore
            if score > best_score:
                best_score = score
                best_child = child
                
        return best_child

    def add_child(self, move):
        new_board = self.board.copy()
        row = get_next_open_row(new_board, move)
        drop_piece(new_board, row, move, AI)  # AI makes the move in the expansion
        child = Node(new_board, parent=self, move=move)
        self.children.append(child)
        if move in self.untried_moves:
            self.untried_moves.remove(move)
        return child

def mcts(root, simulations=1000):
    for _ in range(simulations):
        node = root
        
        # --- Selection ---
        while node.untried_moves == [] and node.children:
            node = node.select_child()
            
        # --- Expansion ---
        if node.untried_moves:
            move = random.choice(node.untried_moves)
            node = node.add_child(move)
            
        # --- Simulation ---
        temp_board = node.board.copy()
        current_player = AI  # simulation starts with AI's move (since we've just expanded an AI move)
        
        while True:
            valid_moves = [col for col in range(COLUMN_COUNT) if is_valid_location(temp_board, col)]
            if not valid_moves:
                break  # board is full, tie game
            move = random.choice(valid_moves)
            row = get_next_open_row(temp_board, move)
            drop_piece(temp_board, row, move, current_player)
            
            if winning_move(temp_board, current_player):
                break
                
            current_player = PLAYER if current_player == AI else AI
        
        # --- Backpropagation ---
        winner = None
        if winning_move(temp_board, AI):
            winner = AI
        elif winning_move(temp_board, PLAYER):
            winner = PLAYER
            
        while node is not None:
            node.visits += 1
            if winner == AI:
                node.wins += 1
            elif winner == PLAYER:
                node.wins -= 1
            node = node.parent
            
    # Choose the move from the root with the highest visit count
    best_child = max(root.children, key=lambda c: c.visits)
    return best_child.move

# ---------------------- Drawing Functions ----------------------
def draw_board(board, message=None):
    screen.fill(WHITE)
    
    # Draw message (if any)
    if message:
        label = FONT.render(message, 1, BLACK)
        screen.blit(label, (40, 10))
    
    # Draw the board grid (blue rectangles and black circles for empty slots)
    for c in range(COLUMN_COUNT):
        for r in range(ROW_COUNT):
            pygame.draw.rect(screen, BLUE, (c * SQUARESIZE, r * SQUARESIZE + SQUARESIZE, SQUARESIZE, SQUARESIZE))
            pygame.draw.circle(screen, BLACK, (int(c * SQUARESIZE + SQUARESIZE / 2),
                                                 int(r * SQUARESIZE + SQUARESIZE + SQUARESIZE / 2)), RADIUS)
    
    # Draw the pieces on the board (draw from the bottom up)
    for c in range(COLUMN_COUNT):
        for r in range(ROW_COUNT):
            if board[r][c] == PLAYER:
                color = RED
            elif board[r][c] == AI:
                color = YELLOW
            else:
                continue
            # Draw pieces so that row 0 appears at the bottom
            pygame.draw.circle(screen, color, (int(c * SQUARESIZE + SQUARESIZE / 2),
                                                 height - int(r * SQUARESIZE + SQUARESIZE / 2)), RADIUS)
    
    pygame.display.update()

# ---------------------- Main Game Loop ----------------------
def main():
    board = create_board()
    game_over = False
    turn = random.choice([PLAYER, AI])
    clock = pygame.time.Clock()
    global ai_col
    ai_col = None  # Initialize global AI move variable

    # Draw the initial board so the player sees it immediately
    draw_board(board)

    while not game_over:
        for event in pygame.event.get():
            if event.type == pygame.QUIT:
                pygame.quit()
                return  # Return instead of sys.exit()

            # Human player's turn
            if turn == PLAYER and not game_over:
                if event.type == pygame.MOUSEMOTION:
                    # Draw the moving piece on top of the board
                    pygame.draw.rect(screen, BLACK, (0, 0, width, SQUARESIZE))
                    posx = event.pos[0]
                    pygame.draw.circle(screen, RED, (posx, int(SQUARESIZE / 2)), RADIUS)
                    pygame.display.update()
                    
                if event.type == pygame.MOUSEBUTTONDOWN:
                    pygame.draw.rect(screen, BLACK, (0, 0, width, SQUARESIZE))
                    posx = event.pos[0]
                    col = int(math.floor(posx / SQUARESIZE))
                    
                    if is_valid_location(board, col):
                        row = get_next_open_row(board, col)
                        drop_piece(board, row, col, PLAYER)
                        if winning_move(board, PLAYER):
                            draw_board(board, "You win!")
                            game_over = True
                        turn = AI
                        draw_board(board)

        # AI's turn using MCTS
        if turn == AI and not game_over:
            def run_mcts():
                global ai_col
                root = Node(board)
                ai_col = mcts(root, simulations=2000)
            
            # Run MCTS in a separate thread to keep the UI responsive
            ai_thread = threading.Thread(target=run_mcts)
            ai_thread.start()
            
            while ai_thread.is_alive():
                for event in pygame.event.get():
                    if event.type == pygame.QUIT:
                        pygame.quit()
                        return  # Return instead of sys.exit()
                pygame.time.wait(100)
            
            if ai_col is not None and is_valid_location(board, ai_col):
                row = get_next_open_row(board, ai_col)
                drop_piece(board, row, ai_col, AI)
                if winning_move(board, AI):
                    draw_board(board, "AI wins!")
                    game_over = True
                turn = PLAYER
                draw_board(board)
        
        # Check for a tie (no valid moves)
        if all(not is_valid_location(board, col) for col in range(COLUMN_COUNT)):
            draw_board(board, "It's a tie!")
            game_over = True
        
        if game_over:
            pygame.time.wait(3000)
            pygame.quit()
            return  # Return instead of sys.exit()

        clock.tick(30)

# Initialize Pygame display settings for Jupyter
pygame.init()
screen = pygame.display.set_mode((width, height))
pygame.display.set_caption("Connect 4 - MCTS vs Human")
FONT = pygame.font.SysFont("monospace", 50)

# Run the game loop (in a Jupyter Notebook cell)
main()