In [3]:
import pygame
import sys
import random
import numpy as np
import torch
import copy
from backgammon import BackgammonBoard, Game
from BackModel import ResidualBlock, BackModel
from MCTS import MCTS_Searcher, MCTSNode

pygame 2.6.1 (SDL 2.28.4, Python 3.10.12)
Hello from the pygame community. https://www.pygame.org/contribute.html


In [None]:
# ---------------- CONFIG ----------------
WIDTH, HEIGHT = 1000, 600
BOARD_COLOR = (181, 136, 99)
POINT_COLOR = (240, 217, 181)
TEXT_COLOR = (0, 0, 0)
FPS = 30
MCTS_DEPTH = 1000  # Adjust AI search depth


model = BackModel(num_resnets=4, num_skips=4)
state = torch.load("model.pth", map_location="cpu")
model.load_state_dict(state)
model.eval()
MODEL = model
# ----------------------------------------

pygame.init()
FONT = pygame.font.SysFont("Arial", 22)
BIG_FONT = pygame.font.SysFont("Arial", 40)

screen = pygame.display.set_mode((WIDTH, HEIGHT))
pygame.display.set_caption("Backgammon Game - You vs AI")

clock = pygame.time.Clock()


def draw_board(game, eval_text="Bot's evaluation of position: N/A"):
    screen.fill(BOARD_COLOR)
    pygame.draw.rect(screen, POINT_COLOR, (50, 50, WIDTH - 100, HEIGHT - 100))

    # Draw top (points 13–24) and bottom (points 12–1)
    point_width = (WIDTH - 100) // 12
    for i in range(12):
        # Top row
        pygame.draw.rect(screen, (200, 200, 200), (50 + i * point_width, 50, point_width, (HEIGHT - 100) // 2), 1)
        # Bottom row
        pygame.draw.rect(screen, (200, 200, 200), (50 + i * point_width, HEIGHT // 2, point_width, (HEIGHT - 100) // 2), 1)

    # Draw pieces
    for i in range(24):
        count = game.board.board[i]

        if i < 12:  # bottom row
            col = 11 - i  # flip left-right
            x = 50 + col * point_width + point_width // 2
            y_start = HEIGHT - 60
            step = -25
        else:  # top row
            col = i - 12
            x = 50 + col * point_width + point_width // 2
            y_start = 60
            step = 25

        color = (255, 255, 255) if count > 0 else (0, 0, 0)
        for j in range(abs(count)):
            pygame.draw.circle(screen, color, (x, y_start + step * j), 10)


    # Draw text info
    info_text = f"Player: {'You' if game.current_player == 1 else 'AI'} | Dice: {game.dice}"
    text_surface = FONT.render(info_text, True, TEXT_COLOR)
    screen.blit(text_surface, (50, HEIGHT - 40))

    broken_text = f"Broken (Player): {game.broken_pieces[1]} | Broken (AI): {game.broken_pieces[-1]}"
    screen.blit(FONT.render(broken_text, True, TEXT_COLOR), (400, HEIGHT - 40))

    borne_text = f"Borne Off (Player): {game.collected_pieces[1]} | Borne Off (AI): {game.collected_pieces[-1]}"
    screen.blit(FONT.render(borne_text, True, TEXT_COLOR), (700, HEIGHT - 40))

    eval_surface = FONT.render(eval_text, True, TEXT_COLOR)
    screen.blit(eval_surface, (50, 20))


def evaluate_position(game):
    with torch.no_grad():
        x = game.get_input_matrix().unsqueeze(0)  # add batch dim
        policy_logits, value = MODEL(x)
        return float(value.item())


def main():
    game = Game()
    game.roll_dice()
    game.get_legal_moves()
    eval_text = f"Bot's evaluation of position: {evaluate_position(game):.2f}"

    searcher = MCTS_Searcher(MODEL, n_simulations=MCTS_DEPTH)

    dragging_piece = False
    drag_start_point = None
    drag_pos = (0, 0)

    highlight_targets = []  # points to highlight when dragging

    point_width = (WIDTH - 100) // 12

    while not game.game_over:
        draw_board(game, eval_text)

        # Draw legal target highlights
        if dragging_piece:
            for target_point in highlight_targets:
                col = target_point % 12
                if target_point < 12:  # bottom row
                    col = 11 - target_point  # flip to match display
                    x = 50 + col * point_width
                    y = HEIGHT // 2
                else:  # top row
                    col = target_point - 12
                    x = 50 + col * point_width
                    y = 50

                pygame.draw.rect(screen, (0, 255, 0), (x, y, point_width, (HEIGHT - 100) // 2), 4)

            # Draw dragged piece
            pygame.draw.circle(screen, (255, 0, 0), drag_pos, 15)

        pygame.display.flip()
        clock.tick(FPS)

        if game.current_player == 1:

            for event in pygame.event.get():
                if event.type == pygame.QUIT:
                    pygame.quit()
                    sys.exit()

                elif event.type == pygame.MOUSEBUTTONDOWN:
                    x, y = event.pos
                    if 50 < x < WIDTH - 50 and 50 < y < HEIGHT - 50:
                        col = (x - 50) // point_width

                        if game.broken_pieces[1] > 0:
                            # Reenter from bar
                            if y > HEIGHT // 2:  # bottom row
                                point = 12 - col  # mirror
                            else:  # top row
                                point = 12 + col

                            # Check if move is legal
                            if any(s == -1 and e == point and d == point for (s, e, d) in game.legal_moves):
                                game.play_one_move(-1, point, point)  # play broken piece
                                eval_text = f"Bot's evaluation of position: {evaluate_position(game):.2f}"
                                if len(game.dice) == 0 or game.get_legal_moves() == []:
                                    game.switch_player()
                                    eval_text = f"Bot's evaluation of position: {evaluate_position(game):.2f}"
                                break

                        else:
                            # Normal drag logic
                            if y > HEIGHT // 2:  # bottom row
                                point = 11 - col
                            else:
                                point = 12 + col

                            if game.board.board[point] > 0:
                                dragging_piece = True
                                drag_start_point = point
                                drag_pos = event.pos


                            # Compute highlight targets
                            # Highlight legal moves for broken pieces
                            highlight_targets = []
                            if game.broken_pieces[1] > 0:
                                # Only moves that reenter from bar
                                highlight_targets = [e for (s, e, d) in game.legal_moves if s == -1]  # -1 indicates broken piece
                            else:
                                # Normal dragging highlights
                                if dragging_piece:
                                    highlight_targets = [e for (s, e, d) in game.legal_moves if s == drag_start_point]


                elif event.type == pygame.MOUSEMOTION and dragging_piece:
                    drag_pos = event.pos

                elif event.type == pygame.MOUSEBUTTONUP and dragging_piece:
                    x, y = event.pos
                    dragging_piece = False
                    highlight_targets = []  # clear highlights

                    if 50 < x < WIDTH - 50 and 50 < y < HEIGHT - 50:
                        col = (x - 50) // point_width
                        if y > HEIGHT // 2:  # bottom row
                            point = 11 - col  # mirror
                        else:  # top row
                            point = 12 + col  # already matches draw logic




                        # Validate move
                        move_found = False
                        for (s, e, d) in game.legal_moves:
                            if s == drag_start_point:
                                if e >= 24 or e < 0:  # legal bearing off
                                    game.play_one_move(s, e, d)
                                    eval_text = f"Bot's evaluation of position: {evaluate_position(game):.2f}"
                                    move_found = True
                                    if len(game.dice) == 0 or game.get_legal_moves() == []:
                                        game.switch_player()
                                        eval_text = f"Bot's evaluation of position: {evaluate_position(game):.2f}"
                                    break
                                elif e == point:  # normal move
                                    game.play_one_move(s, e, d)
                                    eval_text = f"Bot's evaluation of position: {evaluate_position(game):.2f}"
                                    move_found = True
                                    if len(game.dice) == 0 or game.get_legal_moves() == []:
                                        game.switch_player()
                                        eval_text = f"Bot's evaluation of position: {evaluate_position(game):.2f}"
                                    break

                        if not move_found:
                            pygame.display.set_caption("Illegal move!")


        else:  # AI turn
            pygame.time.delay(500)

            root, best_child = searcher.search(copy.deepcopy(game))
            if root != best_child:
                s, e, d = best_child.last_move
                game.play_one_move(s, e, d)
                eval_text = f"Bot's evaluation of position: {evaluate_position(game):.2f}"

            if len(game.dice) == 0 or game.get_legal_moves() == []:
                game.switch_player()
                eval_text = f"Bot's evaluation of position: {evaluate_position(game):.2f}"

        game.check_game_over()

    # Game Over
    screen.fill(BOARD_COLOR)
    winner_text = f"Winner: {'You' if game.check_game_over() == 1 else 'AI'}"
    text_surface = BIG_FONT.render(winner_text, True, TEXT_COLOR)
    screen.blit(text_surface, (WIDTH // 2 - 100, HEIGHT // 2 - 50))

    # Draw restart button
    button_rect = pygame.Rect(WIDTH // 2 - 75, HEIGHT // 2 + 50, 150, 50)
    pygame.draw.rect(screen, (0, 200, 0), button_rect)
    button_text = FONT.render("Restart", True, (255, 255, 255))
    screen.blit(button_text, (WIDTH // 2 - button_text.get_width() // 2, HEIGHT // 2 + 65))

    pygame.display.flip()

    waiting_restart = True
    while waiting_restart:
        for event in pygame.event.get():
            if event.type == pygame.QUIT:
                pygame.quit()
                sys.exit()
            elif event.type == pygame.MOUSEBUTTONDOWN:
                if button_rect.collidepoint(event.pos):
                    waiting_restart = False
                    main()  # restart the game





if __name__ == "__main__":
    main()