In [1]:
import torch
from torch import nn
import torch.nn.functional as F
import math
import random
import numpy as np
from scipy.signal import convolve2d
import matplotlib.pyplot as plt
plt.style.use('fivethirtyeight')
from IPython.display import HTML
from base64 import b64encode
import tkinter as tk
from tkinter import messagebox


config_map = {
    'device': torch.device('cuda') if torch.cuda.is_available() else 'cpu',
    'num_filters': 128,              # Number of convolutional filters used in residual blocks
    'num_residual_blocks': 8,             # Number of residual blocks used in network
    'exploration_constant': 2,     # Exploration constant used in PUCT calculation
    'selection_temperature': 1.25,           # Selection temperature. A greater temperature is a more uniform distribution
    'dirichlet_alpha': 1.,         # Alpha parameter for Dirichlet noise. Larger values mean more uniform noise
    'dirichlet_epsilon': 0.25,         # Weight of dirichlet noise
    'learning_rate': 0.001,        # Adam learning rate
    'training_epochs': 1,         # How many full training epochs
    'games_per_epoch': 1,        # How many self-played games per epoch
    'minibatch_size': 128,         # Size of each minibatch used in learning update 
    'num_minibatches': 4,            # How many minibatches to accumulate per learning step 
    'mcts_initial_iterations': 50,  # Number of Monte Carlo tree search iterations initially
    'mcts_max_iterations': 150,   # Maximum number of MCTS iterations
    'mcts_search_increment': 1,    # After each epoch, how much should search iterations be increased by
}

# Convert to a struct-esque object
class Config:
    def __init__(self, mapping):
        for key, value in mapping.items():
            setattr(self, key, value)

config = Config(config_map)

In [2]:
class Connect4Engine:
    "Engine for Connect 4 game with methods for game-related tasks."
    
    def __init__(self):
        self.num_rows = 6
        self.num_cols = 7

    def get_next_state(self, current_state, selected_action, to_play=1):
        "Update the current state based on the selected action and return the resulting board."
        # Preconditions
        assert self.evaluate(current_state) == 0
        assert np.sum(abs(current_state)) != self.num_rows * self.num_cols
        assert selected_action in self.get_valid_actions(current_state)
        
        # Find the next empty row in the selected column
        row_index = np.where(current_state[:, selected_action] == 0)[0][-1]
        
        # Apply the action
        new_state = current_state.copy()
        new_state[row_index, selected_action] = to_play
        return new_state

    def get_valid_actions(self, current_state):
        "Return an array containing the indices of valid actions."
        # If the game is over, there are no valid moves
        if self.evaluate(current_state) != 0:
            return np.array([])
        
        # Identify the columns where pieces can be placed
        col_sums = np.sum(np.abs(current_state), axis=0)
        return np.where((col_sums // self.num_rows) == 0)[0]

    def evaluate(self, current_state):
        "Evaluate the current position. Returns 1 for player 1 win, -1 for player 2, and 0 otherwise."
        # Kernels for checking win conditions
        kernel = np.ones((1, 4), dtype=int)
        
        # Horizontal and vertical checks
        horizontal_check = convolve2d(current_state, kernel, mode='valid')
        vertical_check = convolve2d(current_state, kernel.T, mode='valid')

        # Diagonal checks
        diagonal_kernel = np.eye(4, dtype=int)
        main_diagonal_check = convolve2d(current_state, diagonal_kernel, mode='valid')
        anti_diagonal_check = convolve2d(current_state, np.fliplr(diagonal_kernel), mode='valid')
        
        # Check for a winner
        if any(condition.any() for condition in [horizontal_check == 4, vertical_check == 4, main_diagonal_check == 4, anti_diagonal_check == 4]):
            return 1
        elif any(condition.any() for condition in [horizontal_check == -4, vertical_check == -4, main_diagonal_check == -4, anti_diagonal_check == -4]):
            return -1

        # No winner
        return 0  

    def play(self, current_state, selected_action, to_play=1):
        "Execute an action in the current state. Return the next state, reward, and termination flag."
        # Obtain the new state and reward
        next_state = self.get_next_state(current_state, selected_action, to_play)
        reward = self.evaluate(next_state)
        
        # Check if the game has ended
        game_over = True if reward != 0 or np.sum(abs(next_state)) >= (self.num_rows * self.num_cols - 1) else False
        return next_state, reward, game_over

    def encode_state(self, current_state):
        "Convert the state to a tensor with 3 channels."
        encoded_state = np.stack((current_state == 1, current_state == 0, current_state == -1)).astype(np.float32)
        if len(current_state.shape) == 3:
            encoded_state = np.swapaxes(encoded_state, 0, 1)
        return encoded_state

    def reset(self):
        "Reset the game board."
        return np.zeros([self.num_rows, self.num_cols], dtype=np.int8)

In [3]:
class ResidualNeuralNetwork(nn.Module):
    "Complete residual neural network model."
    
    def __init__(self, game_engine, model_config):
        super().__init__()

        # Board dimensions
        self.board_size = (game_engine.num_rows, game_engine.num_cols)
        num_actions = game_engine.num_cols  # Number of columns represent possible actions
        num_filters = model_config.num_filters
        
        self.base = ConvolutionBase(model_config)  # Base layers

        # Policy head for action selection
        self.policy_head = nn.Sequential(
            nn.Conv2d(num_filters, num_filters//4, kernel_size=3, padding=1),
            nn.BatchNorm2d(num_filters//4),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(num_filters//4 * self.board_size[0] * self.board_size[1], num_actions)
        )

        # Value head for state evaluation
        self.value_head = nn.Sequential(
            nn.Conv2d(num_filters, num_filters//32, kernel_size=3, padding=1),
            nn.BatchNorm2d(num_filters//32),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(num_filters//32 * self.board_size[0] * self.board_size[1], 1),
            nn.Tanh()
        )

    def forward(self, x):
        x = self.base(x) 
        x_value = self.value_head(x)
        x_policy = self.policy_head(x)
        return x_value, x_policy

class ConvolutionBase(nn.Module):
    "Convolutional base for the network."
    
    def __init__(self, model_config):
        super().__init__()
        
        num_filters = model_config.num_filters
        num_residual_blocks = model_config.num_residual_blocks

        # Initial convolutional layer
        self.conv = nn.Sequential(
            nn.Conv2d(3, num_filters, kernel_size=3, padding=1),
            nn.BatchNorm2d(num_filters),
            nn.ReLU()
        )

        # List of residual blocks
        self.res_blocks = nn.ModuleList(
            [ResidualBlock(num_filters) for _ in range(num_residual_blocks)]
        )

    def forward(self, x):
        x = self.conv(x)
        for block in self.res_blocks:
            x = block(x)
        return x

class ResidualBlock(nn.Module):
    "Residual block, the backbone of a ResNet."
    
    def __init__(self, num_filters):
        super().__init__()

        # Two convolutional layers, both with batch normalization
        self.conv_1 = nn.Conv2d(num_filters, num_filters, kernel_size=3, padding=1)
        self.batch_norm_1 = nn.BatchNorm2d(num_filters)
        self.conv_2 = nn.Conv2d(num_filters, num_filters, kernel_size=3, padding=1)
        self.batch_norm_2 = nn.BatchNorm2d(num_filters)
        self.relu = nn.ReLU()

    def forward(self, x):
        # Pass x through layers and add skip connection
        output = self.relu(self.batch_norm_1(self.conv_1(x)))
        output = self.batch_norm_2(self.conv_2(output))
        return self.relu(output + x)

In [4]:
class MonteCarloTreeSearch:
    def __init__(self, neural_net, game_engine, model_config):
        "Initialize Monte Carlo Tree Search with a given neural network, game instance, and configuration."
        self.neural_net = neural_net
        self.game_engine = game_engine
        self.model_config = model_config

    def search(self, initial_state, total_iterations, selection_temperature=None):
        "Performs a search for the desired number of iterations, returns an action and the tree root."
        # Create the root
        root = TreeNode(None, initial_state, 1, self.game_engine, self.model_config)

        # Expand the root, adding noise to each action
        valid_actions = self.game_engine.get_valid_actions(initial_state)
        state_tensor = torch.tensor(self.game_engine.encode_state(initial_state), dtype=torch.float).unsqueeze(0).to(self.model_config.device)
        with torch.no_grad():
            self.neural_net.eval()
            value, logits = self.neural_net(state_tensor)

        # Get action probabilities
        action_probs = F.softmax(logits.view(self.game_engine.num_cols), dim=0).cpu().numpy()

        # Calculate and add Dirichlet noise
        noise = np.random.dirichlet([self.model_config.dirichlet_alpha] * self.game_engine.num_cols)
        action_probs = ((1 - self.model_config.dirichlet_epsilon) * action_probs) + self.model_config.dirichlet_epsilon * noise

        # Mask unavailable actions
        mask = np.full(self.game_engine.num_cols, False)
        mask[valid_actions] = True
        action_probs = action_probs[mask]

        # Softmax
        action_probs /= np.sum(action_probs)

        # Create a child for each possible action
        for action, prob in zip(valid_actions, action_probs):
            child_state = -self.game_engine.get_next_state(initial_state, action)
            root.children[action] = TreeNode(root, child_state, -1, self.game_engine, self.model_config)
            root.children[action].prob = prob

        # Since we're not backpropagating, manually increase visits
        root.n_visits = 1
        # Set value as neural network prediction also as it will slightly improve the accuracy of the value target later
        root.total_score = value.item()

        # Begin search
        for _ in range(total_iterations):
            current_node = root

            # Phase 1: Selection
            # While not currently on a leaf node, select a new node using PUCT score
            while not current_node.is_leaf():
                current_node = current_node.select_child()

            # Phase 2: Expansion
            # When a leaf node is reached and it's not terminal; expand it
            if not current_node.is_terminal():
                current_node.expand()
                # Convert node state to tensor and pass through network
                state_tensor = torch.tensor(self.game_engine.encode_state(current_node.state), dtype=torch.float).unsqueeze(0).to(self.model_config.device)
                with torch.no_grad():
                    self.neural_net.eval()
                    value, logits = self.neural_net(state_tensor)
                    value = value.item()

                # Mask invalid actions, then calculate masked action probs
                mask = np.full(self.game_engine.num_cols, False)
                mask[valid_actions] = True
                action_probs = F.softmax(logits.view(self.game_engine.num_cols)[mask], dim=0).cpu().numpy()
                for child, prob in zip(current_node.children.values(), action_probs):
                    child.prob = prob
            # If node is terminal, get the value of it from game instance
            else:
                value = self.game_engine.evaluate(current_node.state)

            # Phase 3: Backpropagation
            # Backpropagate the value of the leaf to the root
            current_node.backpropagate(value)
        
        # Select action with specified selection_temperature
        if selection_temperature == None:
            selection_temperature = self.model_config.selection_temperature
        return self.select_action(root, selection_temperature), root

    def select_action(self, root, selection_temperature=None):
        "Select an action from the root based on visit counts, adjusted by selection_temperature, 0 temp for greedy."
        if selection_temperature == None:
            selection_temperature = self.model_config.selection_temperature
        action_counts = {key: val.n_visits for key, val in root.children.items()}
        if selection_temperature == 0:
            return max(action_counts, key=action_counts.get)
        elif selection_temperature == np.inf:
            return np.random.choice(list(action_counts.keys()))
        else:
            distribution = np.array([*action_counts.values()]) ** (1 / selection_temperature)
            return np.random.choice([*action_counts.keys()], p=distribution/sum(distribution))

class TreeNode:
    def __init__(self, parent, state, to_play, game_engine, model_config):
        "Represents a node in the MCTS, holding the game state and statistics for MCTS to operate."
        self.parent = parent
        self.state = state
        self.to_play = to_play
        self.model_config = model_config
        self.game_engine = game_engine

        self.prob = 0
        self.children = {}
        self.n_visits = 0
        self.total_score = 0

    def expand(self):
        "Create child nodes for all valid actions. If state is terminal, evaluate and set the node's value."
        # Get valid actions
        valid_actions = self.game_engine.get_valid_actions(self.state)

        # If there are no valid actions, state is terminal, so get value using game instance
        if len(valid_actions) == 0:
            self.total_score = self.game_engine.evaluate(self.state)
            return

        # Create a child for each possible action
        for action in valid_actions:
            # Make move, then flip board to perspective of next player
            child_state = -self.game_engine.get_next_state(self.state, action)
            self.children[action] = TreeNode(self, child_state, -self.to_play, self.game_engine, self.model_config)

    def select_child(self):
        "Select the child node with the highest PUCT score."
        best_puct = -np.inf
        best_child = None
        for child in self.children.values():
            puct = self.calculate_puct(child)
            if puct > best_puct:
                best_puct = puct
                best_child = child
        return best_child

    def calculate_puct(self, child):
        "Calculate the PUCT score for a given child node."
        # Scale Q(s,a) so it's between 0 and 1 so it's comparable to a probability
        # Using 1 - Q(s,a) because it's from the perspective of the child – the opposite of the parent
        exploitation_term = 1 - (child.get_value() + 1) / 2
        exploration_term = child.prob * math.sqrt(self.n_visits) / (child.n_visits + 1)
        return exploitation_term + self.model_config.exploration_constant * exploration_term

    def backpropagate(self, value):
        "Update the current node and its ancestors with the given value."
        self.total_score += value
        self.n_visits += 1
        if self.parent is not None:
            # Backpropagate the negative value so it switches each level
            self.parent.backpropagate(-value)

    def is_leaf(self):
        "Check if the node is a leaf (no children)."
        return len(self.children) == 0

    def is_terminal(self):
        "Check if the node represents a terminal state."
        return (self.n_visits != 0) and (len(self.children) == 0)

    def get_value(self):
        "Calculate the average value of this node."
        if self.n_visits == 0:
            return 0
        return self.total_score / self.n_visits
    
    def __str__(self):
        "Return a string containing the node's relevant information for debugging purposes."
        return (f"State:\n{self.state}\nProb: {self.prob}\nTo play: {self.to_play}" +
                f"\nNumber of children: {len(self.children)}\nNumber of visits: {self.n_visits}" +
                f"\nTotal score: {self.total_score}")

In [5]:
class AlphaMCTS:
    def __init__(self, game_engine, model_config, verbose=True):
        self.network = ResidualNeuralNetwork(game_engine, model_config).to(model_config.device)
        self.mcts = MonteCarloTreeSearch(self.network, game_engine, model_config)
        self.game_engine = game_engine
        self.model_config = model_config

        # Losses and optimizer
        self.loss_cross_entropy = nn.CrossEntropyLoss()
        self.loss_mse = nn.MSELoss()
        self.optimizer = torch.optim.Adam(self.network.parameters(), lr=model_config.learning_rate, weight_decay=0.0001)

        # Pre-allocate memory on GPU
        state_shape = game_engine.encode_state(game_engine.reset()).shape
        self.max_memory = model_config.minibatch_size * model_config.num_minibatches
        self.state_memory = torch.zeros(self.max_memory, *state_shape).to(model_config.device)
        self.value_memory = torch.zeros(self.max_memory, 1).to(model_config.device)
        self.policy_memory = torch.zeros(self.max_memory, game_engine.num_cols).to(model_config.device)
        self.current_memory_index = 0
        self.memory_full = False

        # MCTS search iterations
        self.search_iterations = model_config.mcts_initial_iterations
        
        # Logging
        self.verbose = verbose
        self.total_games = 0

    def train(self, training_epochs):
        "Train the AlphaMCTS agent for a specified number of training epochs."
        # For each training epoch
        for _ in range(training_epochs):

            # Play specified number of games
            for _ in range(self.model_config.games_per_epoch):
                self.self_play()
            
            # At the end of each epoch, increase the number of MCTS search iterations
            self.search_iterations = min(self.model_config.mcts_max_iterations, self.search_iterations + self.model_config.mcts_search_increment)

    def self_play(self):
        "Perform one episode of self-play."
        state = self.game_engine.reset()
        done = False
        while not done:
            # Search for a move
            action, root = self.mcts.search(state, self.search_iterations)

            # Value target is the value of the MCTS root node
            value = root.get_value()

            # Visit counts used to compute policy target
            visits = np.zeros(self.game_engine.num_cols)
            for child_action, child in root.children.items():
                visits[child_action] = child.n_visits
            # Softmax so distribution sums to 1
            visits /= np.sum(visits)

            # Append state + value & policy targets to memory
            self.append_to_memory(state, value, visits)

            # If memory is full, perform a learning step
            if self.memory_full:
                self.learn()

            # Perform action in game
            state, _, done = self.game_engine.play(state, action)

            # Flip the board
            state = -state

        # Increment total games played
        self.total_games += 1

        # Logging if verbose
        if self.verbose:
            print("\rTotal Games:", self.total_games, "Items in Memory:", self.current_memory_index, "Search Iterations:", self.search_iterations, end="")

    def append_to_memory(self, state, value, visits):
        """
        Append state and MCTS results to memory buffers.
        Args:
            state (array-like): Current game state.
            value (float): MCTS value for the game state.
            visits (array-like): MCTS visit counts for available moves.
        """
        # Calculate the encoded states
        encoded_state = np.array(self.game_engine.encode_state(state))
        encoded_state_augmented = np.array(self.game_engine.encode_state(state[:, ::-1]))

        # Stack states and visits
        states_stack = np.stack((encoded_state, encoded_state_augmented), axis=0)
        visits_stack = np.stack((visits, visits[::-1]), axis=0)

        # Convert the stacks to tensors
        state_tensor = torch.tensor(states_stack, dtype=torch.float).to(self.model_config.device)
        visits_tensor = torch.tensor(visits_stack, dtype=torch.float).to(self.model_config.device)
        value_tensor = torch.tensor(np.array([value, value]), dtype=torch.float).to(self.model_config.device).unsqueeze(1)

        # Store in pre-allocated GPU memory
        self.state_memory[self.current_memory_index:self.current_memory_index + 2] = state_tensor
        self.value_memory[self.current_memory_index:self.current_memory_index + 2] = value_tensor
        self.policy_memory[self.current_memory_index:self.current_memory_index + 2] = visits_tensor

        # Increment index, handle overflow
        self.current_memory_index = (self.current_memory_index + 2) % self.max_memory

        # Set memory filled flag to True if memory is full
        if (self.current_memory_index == 0) or (self.current_memory_index == 1):
            self.memory_full = True


    def learn(self):
        "Update the neural network by extracting minibatches from memory and performing one step of optimization for each one."
        self.network.train()

        # Create a randomly shuffled list of batch indices
        batch_indices = np.arange(self.max_memory)
        np.random.shuffle(batch_indices)

        for batch_index in range(self.model_config.num_minibatches):
            # Get minibatch indices
            start = batch_index * self.model_config.minibatch_size
            end = start + self.model_config.minibatch_size
            mb_indices = batch_indices[start:end]

            # Slice memory tensors
            mb_states = self.state_memory[mb_indices]
            mb_value_targets = self.value_memory[mb_indices]
            mb_policy_targets = self.policy_memory[mb_indices]

            # Network predictions
            value_preds, policy_logits = self.network(mb_states)

            # Loss calculation
            policy_loss = self.loss_cross_entropy(policy_logits, mb_policy_targets)
            value_loss = self.loss_mse(value_preds.view(-1), mb_value_targets.view(-1))
            loss = policy_loss + value_loss

            # Backward pass
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.play()

        self.memory_full = False
        self.network.eval()

In [6]:
class Evaluator:
    "Class to evaluate the policy network's performance on simple moves."
    def __init__(self, alpha_mcts, num_examples=500, verbose=True):
        self.network = alpha_mcts.network
        self.game = alpha_mcts.game_engine
        self.model_config = alpha_mcts.model_config
        self.accuracies = []
        self.num_examples = num_examples
        self.verbose = verbose

        # Generate and prepare example states and actions for evaluation
        self.generate_examples()

    def select_action(self, state):
        "Select an action based on the given state, will choose winning or blocking moves."
        valid_actions = self.game.get_valid_actions(state)
        
        # Check for a winning move
        for action in valid_actions:
            next_state, reward, _ = self.game.play(state, action)
            if reward == 1:
                return action

        # Check for a blocking move
        flipped_state = -state
        for action in valid_actions:
            next_state, reward, _ = self.game.play(flipped_state, action)
            if reward == 1:
                return action

        # Default to random action if no winning or blocking move
        return random.choice(valid_actions)

    def generate_examples(self):
        "Generate and prepare example states and actions for evaluation."
        winning_examples = self.generate_examples_for_condition('win')
        blocking_examples = self.generate_examples_for_condition('block')

        # Prepare states and actions for evaluation
        winning_example_states, winning_example_actions = zip(*winning_examples)
        blocking_example_states, blocking_example_actions = zip(*blocking_examples)

        target_states = np.concatenate([winning_example_states, blocking_example_states], axis=0)
        target_actions = np.concatenate([winning_example_actions, blocking_example_actions], axis=0)

        encoded_states = [self.game.encode_state(state) for state in target_states]
        self.X_target = torch.tensor(np.stack(encoded_states, axis=0), dtype=torch.float).to(self.model_config.device)
        self.y_target = torch.tensor(target_actions, dtype=torch.long).to(self.model_config.device)

    def generate_examples_for_condition(self, condition):
        "Generate examples based on either 'win' or 'block' conditions."
        examples = []
        while len(examples) < self.num_examples:
            state = self.game.reset()
            while True:
                action = self.select_action(state)
                next_state, reward, done = self.game.play(state, action, to_play=1)
                
                if condition == 'win' and reward == 1:
                    examples.append((state, action))
                    break
                
                if done:
                    break
                
                state = next_state

                # Flipping the board for opponent's perspective
                action = self.select_action(-state)
                next_state, reward, done = self.game.play(state, action, to_play=-1)
                
                if condition == 'block' and reward == -1:
                    examples.append((-state, action))
                    break
                
                if done:
                    break
                
                state = next_state
        return examples

    def evaluate(self):
        "Evaluate the policy network's accuracy and append it to self.accuracies."
        with torch.no_grad():
            self.network.eval()
            _, logits = self.network(self.X_target)
            pred_actions = logits.argmax(dim=1)
            accuracy = (pred_actions == self.y_target).float().mean().item()
        
        self.accuracies.append(accuracy)
        if self.verbose:
            print(f"Initial Evaluation Accuracy: {100 * accuracy:.1f}%")

In [7]:
# game_engine = Connect4Engine()
# alpha_mcts = AlphaMCTS(game_engine, config)
# evaluator = Evaluator(alpha_mcts)

# # Evaluate pre-training
# evaluator.evaluate()

# # Main training/evaluation loop
# for _ in range(config.training_epochs):
#     alpha_mcts.train(1)
#     evaluator.evaluate()

# # Save trained weights
# torch.save(alpha_mcts.network.state_dict(), 'alpha_mcts-network-weights-new.pth')


In [8]:
# # Plot data
# x_values = np.linspace(0, 101 * len(evaluator.accuracies), len(evaluator.accuracies))
# y_values = [acc * 100 for acc in evaluator.accuracies]

# # Create plot
# plt.figure(figsize=(10, 6))
# plt.plot(x_values, y_values, linewidth=2, marker='o', markersize=4, linestyle='-', color='#636EFA')

# # Formatting
# plt.xlabel('\nNumber of Games', fontsize=16)
# plt.ylabel('Policy Evaluation Accuracy (%)', fontsize=16)
# plt.title('Policy Evaluation\n', fontsize=24)
# plt.grid(True, linestyle='--', linewidth=0.5, color='gray')

# plt.show()

In [9]:
# Define the game, AlphaMCTS, and evaluator
game_engine = Connect4Engine()
alpha_mcts = AlphaMCTS(game_engine, config)
evaluator = Evaluator(alpha_mcts)

# Load the pre-trained weights
file_path = "C:/Users/Rohan Arya/OneDrive/Desktop/FinalYearProject/alphamcts-network-weights-new.pth"
pre_trained_weights = torch.load(file_path, map_location=config.device)
alpha_mcts.network.load_state_dict(pre_trained_weights)

# Evaluate the pre-trained model
evaluator.evaluate()

Initial Evaluation Accuracy: 85.1%


In [10]:
class AlphaMCTSAgent:
    def __init__(self, alpha_mcts):
        self.alpha_mcts = alpha_mcts
        self.alpha_mcts.network.eval()
        
        # Remove noise from move calculations
        self.alpha_mcts.model_config.dirichlet_epsilon = 0

    def select_action(self, state, search_iterations=1000):
        state_tensor = torch.tensor(self.alpha_mcts.game_engine.encode_state(state), dtype=torch.float).to(self.alpha_mcts.model_config.device)
        
        # Get action without using search
        if search_iterations == 0:
            with torch.no_grad():
                _, logits = self.alpha_mcts.network(state_tensor.unsqueeze(0))

            # Get action probs and mask for valid actions
            action_probs = F.softmax(logits.view(-1), dim=0).cpu().numpy()
            valid_actions = self.alpha_mcts.game_engine.get_valid_actions(state)
            valid_action_probs = action_probs[valid_actions]
            best_action = valid_actions[np.argmax(valid_action_probs)]
            return best_action
        # Else use MCTS 
        else:
            action, _ = self.alpha_mcts.mcts.search(state, search_iterations)
            return action

In [11]:
# GUI to play against agent using input box, choosing values from 0 to 6 for moves

agent = AlphaMCTSAgent(alpha_mcts)
# Constants for board dimensions
BOARD_WIDTH = 7
BOARD_HEIGHT = 6
CELL_SIZE = 60
DISK_RADIUS = CELL_SIZE // 2 - 5
WINDOW_PADDING = 20

# Function to handle human move
def human_move():
    global state, turn, done

    # Get the action entered by the player
    action = int(entry.get())

    # Check if the action is valid
    if action not in game_engine.get_valid_actions(state):
        messagebox.showerror("Invalid Move", "Please enter a valid move.")
        return

    # Perform the human move
    next_state, reward, done = game_engine.play(state, action)
    update_board(next_state)

    # Check for game end
    if done:
        end_game(reward)
    else:
        state = -next_state
        turn = 1 - turn
        agent_move()

# Function to handle agent move
def agent_move():
    global state, turn, done

    # Let the agent select its action
    action = agent.select_action(state, 1000)

    # Perform the agent move
    next_state, reward, done = game_engine.play(state, action)
    update_board(next_state)

    # Check for game end
    if done:
        end_game(reward)
    else:
        state = -next_state
        turn = 1 - turn

# Function to draw the board
def draw_board():
    for col in range(BOARD_WIDTH):
        for row in range(BOARD_HEIGHT):
            x0 = col * CELL_SIZE + WINDOW_PADDING
            y0 = row * CELL_SIZE + WINDOW_PADDING
            x1 = x0 + CELL_SIZE
            y1 = y0 + CELL_SIZE
            canvas.create_rectangle(x0, y0, x1, y1, fill="blue", outline="black")
            canvas.create_oval(x0 + 5, y0 + 5, x1 - 5, y1 - 5, fill="white", outline="black")

# Function to draw the disks
def update_board(board):
    canvas.delete("disk")

    for col in range(BOARD_WIDTH):
        for row in range(BOARD_HEIGHT):
            if board[row][col] == 1:
                x_center = col * CELL_SIZE + CELL_SIZE // 2 + WINDOW_PADDING
                y_center = row * CELL_SIZE + CELL_SIZE // 2 + WINDOW_PADDING
                canvas.create_oval(x_center - DISK_RADIUS, y_center - DISK_RADIUS,
                                   x_center + DISK_RADIUS, y_center + DISK_RADIUS,
                                   fill="red", outline="black", tags="disk")
            elif board[row][col] == -1:
                x_center = col * CELL_SIZE + CELL_SIZE // 2 + WINDOW_PADDING
                y_center = row * CELL_SIZE + CELL_SIZE // 2 + WINDOW_PADDING
                canvas.create_oval(x_center - DISK_RADIUS, y_center - DISK_RADIUS,
                                   x_center + DISK_RADIUS, y_center + DISK_RADIUS,
                                   fill="yellow", outline="black", tags="disk")

# Function to handle end of the game
def end_game(reward):
    if reward == -1:
        messagebox.showinfo("Game Over", "You win!")
    elif reward == 1:
        messagebox.showinfo("Game Over", "You lose!")
    else:
        messagebox.showinfo("Game Over", "It's a draw!")
    root.quit()

# Create the Tkinter window
root = tk.Tk()
root.title("Connect4 Game")

# Create a canvas to draw the board
canvas = tk.Canvas(root, width=BOARD_WIDTH * CELL_SIZE + 2 * WINDOW_PADDING,
                   height=BOARD_HEIGHT * CELL_SIZE + 2 * WINDOW_PADDING)
canvas.pack()

# Draw the board
draw_board()

# Create a label and entry widget for the player's move
move_label = tk.Label(root, text="Enter your move (0-6):")
move_label.pack()
entry = tk.Entry(root)
entry.pack()

# Create a button to submit the move
submit_button = tk.Button(root, text="Submit", command=human_move)
submit_button.pack()

# Initial setup
state = game_engine.reset()
turn = 0
done = False
update_board(state)

# Start the main loop
root.mainloop()


In [12]:
# GUI where player can play against agent using cursor functionality
agent = AlphaMCTSAgent(alpha_mcts)
# Constants for board dimensions
BOARD_WIDTH = 7
BOARD_HEIGHT = 6
CELL_SIZE = 60
DISK_RADIUS = CELL_SIZE // 2 - 5
WINDOW_PADDING = 20

# Function to handle human move
def human_move(event):
    global state, turn, done

    # Determine the column clicked
    col = event.x // CELL_SIZE

    # Check if the column is valid
    if col < 0 or col >= BOARD_WIDTH:
        return

    # Check if the column is full
    if all(state[row][col] != 0 for row in range(BOARD_HEIGHT)):
        return

    # Perform the human move
    action = col
    next_state, reward, done = game_engine.play(state, action)
    update_board(next_state)

    # Check for game end
    if done:
        end_game(reward)
    else:
        state = -next_state
        turn = 1 - turn
        agent_move()

# Function to handle agent move
def agent_move():
    global state, turn, done

    # Let the agent select its action
    action = agent.select_action(state, 1000)

    # Perform the agent move
    next_state, reward, done = game_engine.play(state, action)
    update_board(next_state)

    # Check for game end
    if done:
        end_game(reward)
    else:
        state = -next_state
        turn = 1 - turn

# Function to draw the board
def draw_board():
    for col in range(BOARD_WIDTH):
        for row in range(BOARD_HEIGHT):
            x0 = col * CELL_SIZE + WINDOW_PADDING
            y0 = row * CELL_SIZE + WINDOW_PADDING
            x1 = x0 + CELL_SIZE
            y1 = y0 + CELL_SIZE
            canvas.create_rectangle(x0, y0, x1, y1, fill="blue", outline="black")
            canvas.create_oval(x0 + 5, y0 + 5, x1 - 5, y1 - 5, fill="white", outline="black")

# Function to draw the disks
def update_board(board):
    canvas.delete("disk")

    for col in range(BOARD_WIDTH):
        for row in range(BOARD_HEIGHT):
            if board[row][col] == 1:
                x_center = col * CELL_SIZE + CELL_SIZE // 2 + WINDOW_PADDING
                y_center = row * CELL_SIZE + CELL_SIZE // 2 + WINDOW_PADDING
                canvas.create_oval(x_center - DISK_RADIUS, y_center - DISK_RADIUS,
                                   x_center + DISK_RADIUS, y_center + DISK_RADIUS,
                                   fill="red", outline="black", tags="disk")
            elif board[row][col] == -1:
                x_center = col * CELL_SIZE + CELL_SIZE // 2 + WINDOW_PADDING
                y_center = row * CELL_SIZE + CELL_SIZE // 2 + WINDOW_PADDING
                canvas.create_oval(x_center - DISK_RADIUS, y_center - DISK_RADIUS,
                                   x_center + DISK_RADIUS, y_center + DISK_RADIUS,
                                   fill="yellow", outline="black", tags="disk")

# Function to handle end of the game
def end_game(reward):
    if reward == -1:
        messagebox.showinfo("Game Over", "You win!")
    elif reward == 1:
        messagebox.showinfo("Game Over", "You lose!")
    else:
        messagebox.showinfo("Game Over", "It's a draw!")
    root.quit()

# Create the Tkinter window
root = tk.Tk()
root.title("Connect4 Game")

# Create a canvas to draw the board
canvas = tk.Canvas(root, width=BOARD_WIDTH * CELL_SIZE + 2 * WINDOW_PADDING,
                   height=BOARD_HEIGHT * CELL_SIZE + 2 * WINDOW_PADDING)
canvas.pack()

# Draw the board
draw_board()

# Bind mouse click event to human move function
canvas.bind("<Button-1>", human_move)

# Initial setup
state = game_engine.reset()
turn = 0
done = False
update_board(state)

# Start the main loop
root.mainloop()

In [13]:
# GUI to watch game between AlphaMCTS agent and Minimax algorithm

import tkinter as tk
from tkinter import messagebox
import numpy as np

agent = AlphaMCTSAgent(alpha_mcts)
# Constants for board dimensions
BOARD_WIDTH = 7
BOARD_HEIGHT = 6
CELL_SIZE = 60
DISK_RADIUS = CELL_SIZE // 2 - 5
WINDOW_PADDING = 20

# Function to handle AlphaZero agent move
def alpha_zero_move():
    global state, turn, done

    # Let the AlphaZero agent select its action
    action = agent.select_action(state, 1000)

    # Perform the agent move
    next_state, reward, done = game_engine.play(state, action)
    update_board(next_state)

    # Check for game end
    if done:
        end_game(reward)
    else:
        state = -next_state
        turn = 1 - turn

# Function to handle Minimax agent move
def minimax_move():
    global state, turn, done

    # Let the Minimax algorithm select its action
    action = minimax(state, 3, float('-inf'), float('inf'), True)[1]

    # Perform the agent move
    next_state, reward, done = game_engine.play(state, action)
    update_board(next_state)

    # Check for game end
    if done:
        end_game(reward)
    else:
        state = -next_state
        turn = 1 - turn
        alpha_zero_move()

# Minimax algorithm implementation
def minimax(state, depth, alpha, beta, maximizing_player):
    if depth == 0 or len(game_engine.get_valid_actions(state)) == 0:
        return game_engine.evaluate(state), None

    valid_actions = game_engine.get_valid_actions(state)

    if maximizing_player:
        value = float('-inf')
        best_action = None
        for action in valid_actions:
            next_state, _, _ = game_engine.play(state, action)
            eval, _ = minimax(next_state, depth - 1, alpha, beta, False)
            if eval > value:
                value = eval
                best_action = action
            alpha = max(alpha, value)
            if alpha >= beta:
                break  # Beta cut-off
        return value, best_action
    else:
        value = float('inf')
        best_action = None
        for action in valid_actions:
            next_state, _, _ = game_engine.play(state, action)
            eval, _ = minimax(next_state, depth - 1, alpha, beta, True)
            if eval < value:
                value = eval
                best_action = action
            beta = min(beta, value)
            if beta <= alpha:
                break  # Alpha cut-off
        return value, best_action

# Function to draw the board
def draw_board():
    for col in range(BOARD_WIDTH):
        for row in range(BOARD_HEIGHT):
            x0 = col * CELL_SIZE + WINDOW_PADDING
            y0 = row * CELL_SIZE + WINDOW_PADDING
            x1 = x0 + CELL_SIZE
            y1 = y0 + CELL_SIZE
            canvas.create_rectangle(x0, y0, x1, y1, fill="blue", outline="black")
            canvas.create_oval(x0 + 5, y0 + 5, x1 - 5, y1 - 5, fill="white", outline="black")

# Function to draw the disks
def update_board(board):
    canvas.delete("disk")

    for col in range(BOARD_WIDTH):
        for row in range(BOARD_HEIGHT):
            if board[row][col] == 1:
                x_center = col * CELL_SIZE + CELL_SIZE // 2 + WINDOW_PADDING
                y_center = row * CELL_SIZE + CELL_SIZE // 2 + WINDOW_PADDING
                canvas.create_oval(x_center - DISK_RADIUS, y_center - DISK_RADIUS,
                                   x_center + DISK_RADIUS, y_center + DISK_RADIUS,
                                   fill="red", outline="black", tags="disk")
            elif board[row][col] == -1:
                x_center = col * CELL_SIZE + CELL_SIZE // 2 + WINDOW_PADDING
                y_center = row * CELL_SIZE + CELL_SIZE // 2 + WINDOW_PADDING
                canvas.create_oval(x_center - DISK_RADIUS, y_center - DISK_RADIUS,
                                   x_center + DISK_RADIUS, y_center + DISK_RADIUS,
                                   fill="yellow", outline="black", tags="disk")

# Function to handle end of the game
def end_game(reward):
    red_wins = reward == 1
    yellow_wins = reward == -1

    if red_wins:
        messagebox.showinfo("Game Over", "AlphaZero wins!")
    elif yellow_wins:
        messagebox.showinfo("Game Over", "Minimax loses!")
    else:
        messagebox.showinfo("Game Over", "It's a draw!")
    root.quit()

# Create the Tkinter window
root = tk.Tk()
root.title("Connect4 Game")

# Create a canvas to draw the board
canvas = tk.Canvas(root, width=BOARD_WIDTH * CELL_SIZE + 2 * WINDOW_PADDING,
                   height=BOARD_HEIGHT * CELL_SIZE + 2 * WINDOW_PADDING)
canvas.pack()

# Draw the board
draw_board()

# Initial setup
state = game_engine.reset()
turn = 0
done = False
update_board(state)     

# Start the game loop: AlphaZero vs Minimax
def play_game():
    global state, turn, done

    if done:
        return

    if turn == 0:
        alpha_zero_move()
    else:
        minimax_move()

    # Check for game end
    if done:
        if turn == 0:
            end_game(-1)  # Minimax wins
        elif turn == 1:
            end_game(1)  # AlphaZero wins
        else:
            end_game(0)  # Draw
    else:
        root.after(1000, play_game)  # Schedule the next move after 1 second

# Start the game loop
root.after(1000, play_game)  # Start the game loop after 1 second

# Start the main loop
root.mainloop()

: 