In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import numpy as np
import random
import matplotlib.pyplot as plt
import seaborn as sns
from numpy.ma.core import shape
from torch.nn import MSELoss
from torch.utils import data
from torch.utils.data import Dataset
from tqdm import tqdm
from triton.language import dtype
import torch.optim.lr_scheduler as lr_scheduler

In [2]:
# Verify that GPU is connected and available

print(torch.__version__)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

print(torch.cuda.get_device_name(0))

2.9.1+cu128
NVIDIA GeForce RTX 4080


In [3]:
import torch
import random
import numpy as np

class MSGameManager:


    def __init__(self):
        self.player_board = torch.zeros([22, 22], dtype=torch.int8)
        self.mine_board = torch.zeros([22, 22], dtype=torch.int8) # Location of All Mines, 9 for mine, 0-8 for clue
        self.flagged_board = torch.zeros([22, 22], dtype=torch.int8)
        self.opened_board = torch.zeros([22, 22], dtype=torch.int8)
        self.number_of_mines = -1

        # Logic bot variables
        self.cells_remaining = set()
        self.inferred_safe = set()
        self.inferred_mine = set()
        self.clue_number = dict()
        self.size = 22
        self.moves_taken = 0
        self.initial_start_coords = None
        self.is_game_over = False
        self.mines_triggered = 0

    def reset_game(self):
        # This method handles resetting ALL state variables for a new game
        self.player_board = torch.zeros([self.size, self.size], dtype=torch.int8)
        self.mine_board = torch.zeros([self.size, self.size], dtype=torch.int8)
        self.flagged_board = torch.zeros([self.size, self.size], dtype=torch.int8)
        self.opened_board = torch.zeros([self.size, self.size], dtype=torch.int8)
        self.number_of_mines = -1 # Set back to default, to be set later
        self.cells_remaining = set()
        self.inferred_safe = set()
        self.inferred_mine = set()
        self.clue_number = dict()
        self.moves_taken = 0
        self.initial_start_coords = None
        self.is_game_over = False
        self.mines_triggered = 0

    def initialize_board(self, difficulty: str):

        self.reset_game()

        self.number_of_mines = 50 if difficulty == "easy" \
            else 80 if difficulty == "medium" \
                else 100

        # Add all cells into the remaining set and reset boards
        self.cells_remaining = set((r, c) for r in range(self.size) for c in range(self.size))
        self.mine_board = torch.zeros([self.size, self.size], dtype=torch.int8)

        # Generate a random starting location
        initial_start = torch.randint(0, self.size, (2, ), dtype=torch.int8)
        start_r, start_c = initial_start[0].item(), initial_start[1].item()
        self.initial_start_coords = (start_r, start_c)

        # Remove the starting cell from the board
        self.cells_remaining.remove(self.initial_start_coords)

        # All possible coordinates excluding the initial start cell
        possible_mine_locations = list(self.cells_remaining)

        # Grab the squares that will be set to mines
        mine_locations = random.sample(possible_mine_locations, self.number_of_mines)

        # Place mines on the mine board
        MINE = 9
        for r, c in mine_locations:
            self.mine_board[r, c] = MINE

        # Calculate and place clues on the mine_board
        self.calculate_clues()

        # Perform the initial move so we can have a starting board
        self.open_cell(start_r, start_c)

    def start_bot_game(self, difficulty: str, allow_mine: bool = False):

        self.initialize_board(difficulty)
        self.is_game_over = False
        self.moves_taken = 0
        self.mines_triggered = 0

        while not self.is_game_over:

            # 1. Get the bot's next move choice
            move = self.get_logic_bot_move()  # ← Assign to variable first

            # 2. CHECK FOR None BEFORE UNPACKING
            if move is None:
                self.is_game_over = True
                break

            r, c = move  # ← Now safe to unpack

            # 3. Execute the move
            success, game_over = self.make_move(r, c, allow_mine)

            if game_over:
                self.is_game_over = True
                break

        # Return the required metrics
        return {
            "success": self.check_win_condition(),
            "moves_taken": self.moves_taken,
            "mines_triggered": self.mines_triggered
        }

    def get_neighbors(self, r, c):
        """Returns a list of valid neighboring coordinates (up to 8)."""
        neighbors = []
        for dr in [-1, 0, 1]:
            for dc in [-1, 0, 1]:
                if dr == 0 and dc == 0:
                    continue
                nr, nc = r + dr, c + dc

                # Check for boundary conditions
                if 0 <= nr < self.size and 0 <= nc < self.size:
                    neighbors.append((nr, nc))
        return neighbors

    def calculate_clues(self):
        """Iterates over the board and calculates the clue number for non-mine cells."""
        MINE = 9

        for r in range(self.size):
            for c in range(self.size):

                # Skip square if it is a mine
                if self.mine_board[r, c] == MINE:
                    continue

                mine_count = 0
                for nr, nc in self.get_neighbors(r, c):
                    # Check if the neighbor is a mine
                    if self.mine_board[nr, nc] == MINE:
                        mine_count += 1

                # Store the clue number (0-8) on the mine_board
                self.mine_board[r, c] = mine_count

    def open_cell(self, r, c, allow_mine: bool = False):
        """
        Performs the action of opening a cell, updating all state boards,
        and handling the chain reaction if a blank (clue 0) is hit.
        Returns: True if the move was successful (not a mine), False otherwise.
        """
        if self.opened_board[r, c] == 1:
            return True # Already opened, treat as successful non-fatal move

        # 1. Update general state
        self.opened_board[r, c] = 1
        coord = (r, c)
        if coord in self.cells_remaining:
             self.cells_remaining.remove(coord)

        # Also remove from inferred sets if it was picked
        if coord in self.inferred_safe:
            self.inferred_safe.remove(coord)

        # 2. Get the ground truth value
        value = self.mine_board[r, c].item()

        MINE_CODE = 9

        if value == MINE_CODE: # Mine is hit
            self.is_game_over = True if not allow_mine else False
            self.mines_triggered += 1
            self.player_board[r, c] = -1 # A code for 'detonated mine'
            return False # Game over, move was unsuccessful

        elif value == 0: # Blank cell (Clue 0) - start chain reaction
            self.player_board[r, c] = 9 # Use 9 for revealed blank
            self.clue_number[coord] = 0

            # --- Chain Reaction Logic (using BFS for cascade) ---
            # ... (The full chain reaction logic from the previous answer goes here) ...

            queue = [(r, c)]
            visited = set([(r, c)])

            while queue:
                curr_r, curr_c = queue.pop(0)

                for nr, nc in self.get_neighbors(curr_r, curr_c):
                    neighbor_coord = (nr, nc)
                    if neighbor_coord in visited or self.opened_board[nr, nc] == 1:
                        continue

                    visited.add(neighbor_coord)
                    true_value = self.mine_board[nr, nc].item()

                    if true_value == 0:
                        # Continue the cascade (reveal blank, add to queue)
                        self.player_board[nr, nc] = 9
                        self.opened_board[nr, nc] = 1
                        self.clue_number[neighbor_coord] = 0
                        if neighbor_coord in self.cells_remaining:
                            self.cells_remaining.remove(neighbor_coord)
                        queue.append(neighbor_coord)

                    elif 1 <= true_value <= 8:
                        # Stop the cascade (reveal clue)
                        self.player_board[nr, nc] = true_value
                        self.opened_board[nr, nc] = 1
                        self.clue_number[neighbor_coord] = true_value
                        if neighbor_coord in self.cells_remaining:
                            self.cells_remaining.remove(neighbor_coord)

            return True

        elif 1 <= value <= 8: # Clue cell
            self.player_board[r, c] = value
            self.clue_number[coord] = value
            return True

        return True # Should not be reached, but ensures return value

    def get_logic_bot_move(self):
        """
        The main move logic for the simple bot.
        1. Runs inferences to find safe cells.
        2. Picks an inferred safe cell if available, otherwise picks randomly.
        Returns: (r, c) tuple of the next move.
        """
        # Step 1: Run inferences until the board is stable [cite: 24]
        self.run_logic_inferences()

        # Step 2: Select a cell to open [cite: 18]
        if self.inferred_safe:
            # Pick one of the inferred safe cells
            r, c = self.inferred_safe.pop()
        elif self.cells_remaining:
            # No safe inference found, pick a cell from the remaining pool at random [cite: 18]
            r, c = random.choice(list(self.cells_remaining))
            # Note: The assignment implies the cells_remaining only holds cells *not* inferred as mine.
        else:
            # Game is likely won or stuck
            return None

        return r, c

    def make_move(self, r, c, allow_mine: bool = False):
        """
        Executes a move at (r, c) and updates game state and metrics.
        Returns: Tuple (success: bool, is_game_over: bool)
        """
        if self.is_game_over or self.opened_board[r, c] == 1:
            return True, self.is_game_over # Cannot make move or already open

        # If the move is flagged, we assume the bot unflags it first (optional)
        if self.flagged_board[r, c] == 1:
             self.flagged_board[r, c] = 0 # Unflag before opening

        self.moves_taken += 1 # Increment step counter

        move_successful = self.open_cell(r, c, allow_mine)

        # Check for win condition after a successful move
        if move_successful and self.check_win_condition():
            self.is_game_over = True

        return move_successful, self.is_game_over

    def check_win_condition(self):
        """Checks if all safe cells have been opened."""
        total_cells = self.size * self.size
        cells_opened = torch.sum(self.opened_board).item()

        # Win condition: Number of opened cells equals (Total Cells - Number of Mines)
        return cells_opened == (total_cells - self.number_of_mines)

    def run_logic_inferences(self):
        """
        Applies the two main Minesweeper logic rules iteratively until no new inferences are found.
        Returns True if any new inference was made, False otherwise.
        """
        inferences_made = False

        # [cite_start]Loop until a full pass yields no new inferences [cite: 24]
        while True:
            new_inferences_in_pass = False

            # Iterate over all cells where a clue has been revealed
            for r, c in list(self.clue_number.keys()):
                clue = self.clue_number[(r, c)]
                neighbors = self.get_neighbors(r, c)

                # --- Categorize Neighbors ---
                unrevealed_neighbors = set()
                mines_inferred_count = 0
                safe_inferred_count = 0

                for nr, nc in neighbors:
                    coord = (nr, nc)
                    if coord in self.inferred_mine:
                        mines_inferred_count += 1
                    elif coord in self.inferred_safe or self.opened_board[nr, nc] == 1:
                        safe_inferred_count += 1
                    elif self.opened_board[nr, nc] == 0:
                        # This cell is truly unrevealed and un-inferred
                        unrevealed_neighbors.add(coord)

                num_unrevealed = len(unrevealed_neighbors)

                # 1. Mine Inference (Rule 1)
                # [cite_start]If (cell clue) - (# neighbors inferred to be mines) = (# unrevealed neighbors) [cite: 22]
                if clue - mines_inferred_count == num_unrevealed and num_unrevealed > 0:
                    for nr, nc in unrevealed_neighbors:
                        coord = (nr, nc)

                        # Only infer if not already marked as mine
                        if coord not in self.inferred_mine:
                            self.inferred_mine.add(coord)

                            # --- FIX: FLAG THE INFERRED MINE ---
                            self.flagged_board[nr, nc] = 1 # Mark the cell as flagged

                            # [cite_start]Remove them from cells_remaining [cite: 22]
                            if coord in self.cells_remaining:
                                self.cells_remaining.remove(coord)

                            new_inferences_in_pass = True

                # 2. Safe Inference (Rule 2)
                # [cite_start]If ((# neighbors) - (cell clue)) - (# neighbors revealed or inferred to be safe) = (# unrevealed neighbors) [cite: 23]
                num_non_mines_required = len(neighbors) - clue

                # Total cells known to be safe (opened + inferred safe)
                total_known_safe = safe_inferred_count + num_unrevealed

                # If the total known safe cells (including all unrevealed) equals
                # the total number of non-mines possible
                if num_non_mines_required == total_known_safe and num_unrevealed > 0:
                    for nr, nc in unrevealed_neighbors:
                        coord = (nr, nc)

                        # Only infer if not already marked as safe
                        if coord not in self.inferred_safe and self.opened_board[nr, nc] == 0:
                            self.inferred_safe.add(coord)
                            # [cite_start]Remove from cells_remaining [cite: 23]
                            if coord in self.cells_remaining:
                                self.cells_remaining.remove(coord)

                            new_inferences_in_pass = True

            # [cite_start]If no new inferences were made in this full pass, the loop terminates [cite: 24]
            if not new_inferences_in_pass:
                break
            else:
                inferences_made = True

        return inferences_made

    # def get_nn_input_state(self):
    #     """
    #     Creates the 4-channel input tensor (4, 22, 22) for the Neural Network.
    #     """
    #     size = self.size
    #
    #     # Channel 1: Clue Values (0-8)
    #     # Use player_board, but replace the 'blank' code (9) with 0 for clean clue representation
    #     clues_channel = self.player_board.clone().float()
    #     clues_channel[clues_channel == 9] = 0.0
    #     clues_channel[clues_channel == -1] = 0.0 # Ignore detonated mines if applicable
    #
    #     # Channel 2: Opened Mask (1/0)
    #     opened_mask_channel = self.opened_board.clone().float()
    #
    #     # Channel 3: Flag Mask (1/0)
    #     flag_mask_channel = self.flagged_board.clone().float()
    #
    #     # Channel 4: Global Context (e.g., Mine Density)
    #     # Represented as the ratio of mines to total cells
    #     mine_density = self.number_of_mines / (size * size)
    #     global_context_channel = torch.full((size, size), mine_density, dtype=torch.float32)
    #
    #     # Stack the channels
    #     nn_input = torch.stack([
    #         clues_channel,
    #         opened_mask_channel,
    #         flag_mask_channel,
    #         global_context_channel
    #     ], dim=0)
    #
    #     return nn_input

    def get_nn_input_state(self):
        """
        Creates the 10-channel input tensor for the Neural Network using one-hot encoding for clues.
        Channels: 9 Clue One-Hot + 1 Opened Mask.
        """
        size = self.size

        # Start with the player_board, clean up the codes (9 for blank, -1 for mine)
        clue_board = self.player_board.clone().float()

        # Treat detonated mines (-1) and untouched mines (implicit 0) as 0 for this step
        clue_board[clue_board == -1] = 0.0

        # --- 1. Clue Identity Channels (9 channels: 0 through 8) ---
        clue_channels = []

        for clue_value in range(9):
            if clue_value == 0:
                # FIX: Only select Revealed Blanks (9.0).
                # Do NOT include Unopened (0.0).
                mask = (clue_board == 9.0)
            else:
                mask = (clue_board == clue_value)

            clue_channels.append(mask.float())

        # Convert list of tensors to a single tensor of shape (9, 22, 22)
        clue_input = torch.stack(clue_channels, dim=0)


        # --- 2. Board State Channel (1 channel) ---
        # Channel 10: Opened Mask (Already 1/0)
        # Use unsqueeze(0) to give it a channel dimension (1, 22, 22)
        opened_mask_channel = self.opened_board.clone().float().unsqueeze(0)

        # --- Final Stack ---
        # Stack the 9 clue channels and 1 opened mask channel
        nn_input = torch.cat([
            clue_input,
            opened_mask_channel,
        ], dim=0)

        # Resulting shape: (10, 22, 22)
        return nn_input

    def get_safety_label(self):
        """
        Creates ground-truth MINE map.
        - 1.0 for mines (unopened)
        - 0.0 for safe, unopened cells
        - -1.0 for opened cells (masked out)
        """
        MINE_CODE = 9

        # 1. Start with All Zeros
        # (We assume everything is safe/0.0 initially)
        mine_label = torch.zeros((self.size, self.size), dtype=torch.float32)

        # 2. Mark Mines as 1.0
        mine_label[self.mine_board == MINE_CODE] = 1.0

        # 3. Mask out opened cells with -1.0
        mine_label[self.opened_board == 1] = -1.0

        return mine_label


    def generate_random_walk_data(self, difficulty: str, num_games: int, min_moves: int = 20, max_moves: int = 150):
        """
        Generates training data (X, Y) for Task 1 (Mine Prediction).
        The policy takes a limited number of random, safe steps (not using logic bot)
        to ensure data contains varied mid-game states.

        Returns:
            X_data: (N, 4, 22, 22) - board states
            Y_data: (N, 22, 22) - ground truth safety maps
        """
        X_data = [] # List of (4, 22, 22) input tensors
        Y_data = [] # List of (22, 22) output safety label tensors

        for game_idx in range(num_games):
            self.reset_game()
            self.initialize_board(difficulty)

            if game_idx == 0:
                print(f"Generating Task 1 Random Walk data ({min_moves}-{max_moves} steps)...")

            # 1. Determine the maximum number of moves for this run
            max_steps_for_game = random.randint(min_moves, max_moves)

            self.moves_taken = 0
            self.is_game_over = False
            moves_made = 0

            # Get list of all unopened cells after initial cascade
            current_unopened_cells = list(set((r, c) for r in range(self.size) for c in range(self.size) if self.opened_board[r, c] == 0))

            # Loop to take moves up to the max_steps_for_game threshold
            while not self.is_game_over and moves_made < max_steps_for_game:

                if not current_unopened_cells:
                    break

                # --- DATA COLLECTION POINT (BEFORE THE MOVE) ---

                # 1. Capture the current state (X)
                current_input_state = self.get_nn_input_state() # 4-channel input

                # 2. Capture the ground truth label (Y)
                current_safety_label = self.get_safety_label()

                # Store the data pair
                X_data.append(current_input_state)
                Y_data.append(current_safety_label)

                # --- MOVE DECISION: Random Safe Move ---

                # Find only the SAFE, unopened cells for random selection
                safe_unopened_cells = [(r, c) for r, c in current_unopened_cells if self.mine_board[r, c].item() != 9]

                if not safe_unopened_cells:
                    # No safe moves left, forced to click a mine, so we stop data collection here.
                    self.is_game_over = True
                    break

                # 3. Select a random safe move
                r, c = random.choice(safe_unopened_cells)

                # 4. Execute the move (allow_mine=False since we only select safe cells)
                success, game_over = self.make_move(r, c, allow_mine=False)

                if game_over:
                    # Should not happen since we checked for mines, but break if win condition hit
                    self.is_game_over = True
                    break

                moves_made += 1

                # Update the list of unopened cells for the next iteration
                current_unopened_cells = list(set((r, c) for r in range(self.size) for c in range(self.size) if self.opened_board[r, c] == 0))


            if (game_idx + 1) % 100 == 0:
                print(f"Completed {game_idx + 1} games. Total data points: {len(X_data)}")

        # Convert lists to final PyTorch tensors
        X_tensor = torch.stack(X_data)
        Y_tensor = torch.stack(Y_data)

        return X_tensor, Y_tensor

    def generate_training_data(self, difficulty: str, num_games: int, allow_mine: bool = False):

        # Lists to store the collected Input (X) and Output (Y) tensors
        X_data = [] # List of (4, 22, 22) input tensors
        Y_data = [] # List of (22, 22) output safety label tensors

        for game_idx in range(num_games):

            # Start a fresh game
            self.reset_game()
            self.initialize_board(difficulty)

            # Print after initialize_board sets number_of_mines
            if game_idx == 0:
                mode_str = "with mines allowed" if allow_mine else "standard mode"
                print(f"Generating data for {difficulty.upper()} difficulty ({self.number_of_mines} mines) - {mode_str}...")

            self.is_game_over = False
            self.moves_taken = 0
            self.mines_triggered = 0

            # Logic Bot Play Loop
            while not self.is_game_over:

                # --- DATA COLLECTION POINT (BEFORE THE MOVE) ---

                # 1. Capture the current state (X)
                current_input_state = self.get_nn_input_state()

                # 2. Capture the ground truth label (Y)
                current_safety_label = self.get_safety_label()

                # Store the data pair
                X_data.append(current_input_state)
                Y_data.append(current_safety_label)

                # --- BOT DECISION AND EXECUTION ---

                # 3. Get the bot's next move choice (r, c)
                move = self.get_logic_bot_move()

                # CHECK FOR None BEFORE UNPACKING
                if move is None:
                    # Game ended (won or no moves left)
                    self.is_game_over = True
                    break

                r, c = move  # Now safe to unpack

                # 4. Execute the move (updates game state)
                success, game_over = self.make_move(r, c, allow_mine)

            if (game_idx + 1) % 100 == 0:
                print(f"Completed {game_idx + 1} games. Total data points: {len(X_data)}")

        # Convert lists to final PyTorch tensors
        X_tensor = torch.stack(X_data)
        Y_tensor = torch.stack(Y_data)

        return X_tensor, Y_tensor

    def get_frontier_mask(self):
        """
        Uses 2D Convolution to find the 'Frontier': Unopened cells adjacent to Opened cells.
        Returns: (22, 22) tensor where 1.0 = Frontier, 0.0 = Background/Opened
        """
        # Ensure opened_board is float for conv2d
        # Add Batch and Channel dims: (1, 1, 22, 22)
        opened_input = self.opened_board.float().unsqueeze(0).unsqueeze(0)

        # 3x3 Kernel of 1s to aggregate neighbors
        kernel = torch.ones((1, 1, 3, 3), dtype=torch.float32)

        # Convolve: padding=1 ensures output size matches board size
        neighbor_count = F.conv2d(opened_input, kernel, padding=1)

        # Squeeze back to (22, 22)
        neighbor_count = neighbor_count.squeeze()

        # Frontier Definition:
        # 1. Must have at least one opened neighbor (neighbor_count > 0)
        # 2. Must NOT be opened itself (self.opened_board == 0)
        frontier_mask = (neighbor_count > 0) & (self.opened_board == 0)

        return frontier_mask.float()

    def get_constrained_safety_label(self):
        """
        Generates Y label for Safe Prediction (1=Safe, 0=Mine),
        BUT masks out everything except the Frontier.

        Returns:
             Y tensor where:
             1.0 = Safe Frontier
             0.0 = Mine Frontier
            -1.0 = Everything else (Fog, Opened cells) -> Ignored by Loss
        """
        MINE_CODE = 9

        # 1. Get the Frontier Mask
        frontier = self.get_frontier_mask()

        # 2. Start with a mask of -1.0 (Ignore everything)
        label = torch.full((self.size, self.size), -1.0, dtype=torch.float32)

        # 3. Apply labels ONLY where Frontier is active
        # Identify mines and safe cells
        is_mine = (self.mine_board == MINE_CODE)
        is_safe = (self.mine_board != MINE_CODE)

        # Set Frontier Mines to 0.0
        label[frontier.bool() & is_mine] = 0.0

        # Set Frontier Safe cells to 1.0
        label[frontier.bool() & is_safe] = 1.0

        return label

    def generate_frontier_training_data(self, difficulty: str, num_games: int, min_moves: int = 20, max_moves: int = 150):
        """
        Generates data by playing 'realistically':
        Only clicking safe cells that are on the frontier.
        """
        X_data = []
        Y_data = []

        for game_idx in range(num_games):
            self.reset_game()
            self.initialize_board(difficulty)

            if game_idx == 0:
                print(f"Generating Frontier-Constrained Data ({difficulty})...")

            max_steps = random.randint(min_moves, max_moves)
            steps = 0

            while not self.is_game_over and steps < max_steps:

                # --- 1. Identify Valid Moves (Safe Frontier) ---
                # Get the frontier mask
                frontier = self.get_frontier_mask()

                # Filter for SAFE cells on the frontier
                # Note: We cheat slightly here to generate 'Safe' training data,
                # ensuring the bot sees valid gameplay sequences.
                safe_frontier_indices = torch.nonzero(frontier.bool() & (self.mine_board != 9), as_tuple=False)

                if len(safe_frontier_indices) == 0:
                    # If frontier is purely mines (or empty at start), we must pick ANY safe unopened cell to continue
                    # (This happens rarely, usually at start or end game)
                    unopened = (self.opened_board == 0) & (self.mine_board != 9)
                    safe_indices = torch.nonzero(unopened, as_tuple=False)
                    if len(safe_indices) == 0: break
                    move_idx = random.randint(0, len(safe_indices) - 1)
                    r, c = safe_indices[move_idx].tolist()
                else:
                    # Pick a random SAFE FRONTIER cell (Realistic play)
                    move_idx = random.randint(0, len(safe_frontier_indices) - 1)
                    r, c = safe_frontier_indices[move_idx].tolist()

                # --- 2. Capture Data ---
                X_data.append(self.get_nn_input_state()) # Your existing 10-channel input
                Y_data.append(self.get_constrained_safety_label()) # NEW Frontier-only label

                # --- 3. Make Move ---
                self.make_move(r, c, allow_mine=False)
                steps += 1

            if (game_idx + 1) % 100 == 0:
                print(f"Generated {game_idx + 1} games. Total samples: {len(X_data)}")

        return torch.stack(X_data), torch.stack(Y_data)


    def generate_critic_training_data(self, difficulty: str, num_games: int, actor: nn.Module):
        """
        Generates training data for Critic using 5-channel input.

        Input: (5, 22, 22) where channel 4 is the selected move
        Output: number of moves survived
        """
        X_data = []  # List of (5, 22, 22) tensors
        Y_data = []  # List of survival counts

        for game_idx in range(num_games):
            self.reset_game()
            self.initialize_board(difficulty)

            if game_idx == 0:
                print(f"Generating Critic data for {difficulty.upper()} (5-channel approach)...")

            # Store game trajectory
            game_inputs = []

            while not self.is_game_over:
                # Get current 4-channel state
                current_state = self.get_nn_input_state()  # (4, 22, 22)

                # Get move from actor
                if actor is None:
                    move = self.get_logic_bot_move()
                else:
                    pass
                    move = self.get_actor_move(actor)  # Your trained actor

                if move is None:
                    break

                r, c = move

                # Create move channel (one-hot encoding of the move)
                move_channel = torch.zeros((self.size, self.size), dtype=torch.float32)
                move_channel[r, c] = 1.0

                # Combine into 5-channel input
                full_input = torch.cat([
                    current_state,  # (4, 22, 22)
                    move_channel.unsqueeze(0)  # (1, 22, 22)
                ], dim=0)  # Result: (5, 22, 22)

                game_inputs.append(full_input)

                # Execute move (allow continuing after mines)
                success, game_over = self.make_move(r, c, allow_mine=True)

            # Retrospectively label with survival counts
            total_moves = len(game_inputs)
            for i in range(total_moves):
                moves_survived = total_moves - i

                X_data.append(game_inputs[i])
                Y_data.append(moves_survived)

            if (game_idx + 1) % 100 == 0:
                print(f"Completed {game_idx + 1} games. Total data points: {len(X_data)}")

        return torch.stack(X_data), torch.tensor(Y_data, dtype=torch.float32)

    def generate_actor_training_data(self, difficulty: str, num_games: int, critic_model):
        """
        Generate training data for the Actor network.

        Process:
        1. For each game state, evaluate ALL possible moves using the Critic
        2. Create targets based on Critic's predictions
        3. Actor learns to predict these values directly from board state

        Returns:
            X_data: (N, 4, 22, 22) - board states
            Y_data: (N, 22, 22) - value map for each cell
        """
        X_data = []  # Board states
        Y_data = []  # Value maps (predicted survival for each cell)

        for game_idx in range(num_games):
            self.reset_game()
            self.initialize_board(difficulty)

            if game_idx == 0:
                print(f"Generating Actor training data for {difficulty.upper()}...")

            while not self.is_game_over:
                current_state = self.get_nn_input_state()  # (4, 22, 22)

                # Create value map by querying Critic for all possible moves
                value_map = self.create_value_map_from_critic(current_state, critic_model)

                # Store this training pair
                X_data.append(current_state)
                Y_data.append(value_map)

                # Make a move using current Actor (or logic bot initially)
                move = self.get_best_move_from_value_map(value_map)

                if move is None:
                    break

                r, c = move
                success, game_over = self.make_move(r, c, allow_mine=False)

                if not success:
                    break

            if (game_idx + 1) % 100 == 0:
                print(f"Actor data: Completed {game_idx + 1} games. Data points: {len(X_data)}")

        return (
            torch.stack(X_data),   # (N, 4, 22, 22)
            torch.stack(Y_data)    # (N, 22, 22)
        )


    def create_value_map_from_critic(self, current_state, critic_model):
        """
        Query the Critic for all possible moves to create a value map.

        Returns:
            value_map: (22, 22) tensor where each cell contains predicted survival
                       -1 for already opened cells
        """
        value_map = torch.full((self.size, self.size), -1.0, dtype=torch.float32)

        critic_model.eval()
        with torch.no_grad():
            for r in range(self.size):
                for c in range(self.size):
                    if self.opened_board[r, c] == 0:  # Unopened cell
                        # Create 5-channel input for this move
                        move_channel = torch.zeros((self.size, self.size), dtype=torch.float32)
                        move_channel[r, c] = 1.0

                        full_input = torch.cat([
                            current_state,
                            move_channel.unsqueeze(0)
                        ], dim=0).unsqueeze(0)  # (1, 5, 22, 22)

                        full_input = full_input.to(device)

                        # Get Critic's prediction
                        predicted_survival = critic_model(full_input)
                        value_map[r, c] = predicted_survival.item()

        return value_map


    def get_best_move_from_value_map(self, value_map):
        """
        Select the best move from value map.
        """
        # Get valid moves (value >= 0)
        valid_moves = []
        values = []

        for r in range(self.size):
            for c in range(self.size):
                if value_map[r, c] >= 0:  # Valid move
                    valid_moves.append((r, c))
                    values.append(value_map[r, c])

        if not valid_moves:
            return None

        # Select move with highest value
        values_tensor = torch.tensor(values)
        best_idx = torch.argmax(values_tensor).item()

        return valid_moves[best_idx]


    def get_actor_move(self, actor):
        """
        Get the best move from the Actor network.

        Args:
            actor: Trained Actor network

        Returns:
            (r, c) tuple of selected move, or None if no moves available
        """
        actor.eval()
        with torch.no_grad():
            # Get current board state
            current_state = self.get_nn_input_state().to(device)  # (4, 22, 22)

            # Get Actor's value predictions for all cells
            value_map = actor(current_state.unsqueeze(0)).squeeze(0)  # (22, 22)

            value_map = value_map.reshape((self.size, self.size))

            # Get all valid (unopened) moves and their predicted values
            valid_moves = []
            values = []

            for r in range(self.size):
                for c in range(self.size):
                    if self.opened_board[r, c] == 0:  # Cell is unopened
                        valid_moves.append((r, c))
                        values.append(value_map[r, c].item())

            if not valid_moves:
                return None

            # Select move with highest predicted value
            values_tensor = torch.tensor(values)
            best_idx = torch.argmax(values_tensor).item()

            return valid_moves[best_idx]


    def train_critc(self, critic, train_loader, test_loader, hp):
        epoch_over_training = []
        epoch_over_testing = []

        # Hyperparameter setup
        epochs = hp['epochs']
        learning_rate = hp['learning_rate']
        decay_rate = hp['decay_rate']

        c_dropout = hp['cnn_dropout']
        f_dropout = hp['linear_dropout']

        print('######## Beginning training for MS Critic ##########')

        model = critic
        model.to(device)

        loss_function = nn.MSELoss()

        optimizer = optim.AdamW(model.parameters(),
                               lr=learning_rate,
                               weight_decay=decay_rate
                               )

        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer,
            mode='min',
            factor=0.5,
            patience=5,
            min_lr=1e-6
        )

        # Have references to variables outside of the epoch loop

        avg_training_loss = 0
        avg_testing_loss = 0

        # Epoch Loop
        for epoch in range(epochs):
            print(f'----- Epoch: {epoch + 1}/{epochs} -----')

            avg_training_loss = 0
            avg_testing_loss = 0

            model.train()

            for x, Y in tqdm(train_loader, desc='Training', unit=' batch'):
                # Transfer images to GPU
                x = x.to(device)
                Y = Y.to(device)

                # Zero out gradients
                optimizer.zero_grad()

                # Send images to model
                x_pred = model(x)

                Y = Y.unsqueeze(1)

                # Calc loss
                loss = loss_function(x_pred, Y)

                # Calc gradient and update weights
                loss.backward()
                optimizer.step()

                with torch.no_grad():
                    avg_training_loss += loss.item()

            # Switch to eval mode
            model.eval()

            with torch.no_grad():
                for x, Y in tqdm(test_loader, desc='Testing', unit=' batches'):
                    # Move the images to the GPU
                    x = x.to(device)
                    Y = Y.to(device)

                    # Get logits and sum up total loss
                    x_pred = model(x)
                    Y = Y.unsqueeze(1)
                    avg_testing_loss += loss_function(x_pred, Y).item()

            # Get training loss
            avg_training_loss /= len(train_loader)

             # Get testing loss
            avg_testing_loss /= len(test_loader)

            # Monitor learning
            scheduler.step(avg_testing_loss)

            # Switch model back to training mode
            model.train()

            epoch_over_training.append({
                "epoch": epoch,
                "training_loss": avg_training_loss
                })

            epoch_over_testing.append({
                "epoch": epoch,
                "testing_loss": avg_testing_loss
                })


            print("")

            print(f'   -> Training Loss: {avg_training_loss: .4f}\n')
            print(f'   -> Testing Loss: {avg_testing_loss: .4f}\n')


        return epoch_over_training, epoch_over_testing


    def train_actor(self, actor, train_loader, test_loader, hp):
        epoch_over_training = []
        epoch_over_testing = []

        # Hyperparameter setup
        epochs = hp['epochs']
        learning_rate = hp['learning_rate']
        decay_rate = hp['decay_rate']

        c_dropout = hp['cnn_dropout']
        f_dropout = hp['linear_dropout']

        print('######## Beginning training for MS Actor ##########')

        model = actor
        model.to(device)

        pos_weight_tensor = torch.tensor(30.0, device=device)

        loss_function = masked_bce_loss

        optimizer = optim.AdamW(model.parameters(),
                               lr=learning_rate,
                               weight_decay=decay_rate
                               )

        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer,
            mode='min',
            factor=0.5,
            patience=5,
            min_lr=1e-6
        )

        # Have references to variables outside of the epoch loop

        avg_training_loss = 0
        avg_testing_loss = 0


        # Epoch Loop
        for epoch in range(epochs):
            print(f'----- Epoch: {epoch + 1}/{epochs} -----')

            avg_training_loss = 0
            avg_testing_loss = 0

            model.train()

            for x, Y in tqdm(train_loader, desc='Training', unit=' batch'):
                # Transfer images to GPU
                x = x.to(device)
                Y = Y.to(device)

                # Zero out gradients
                optimizer.zero_grad()

                # Send images to model
                x_pred = model(x)

                # Calc loss
                loss = loss_function(x_pred, Y, pos_weight_tensor)

                # Calc gradient and update weights
                loss.backward()
                optimizer.step()

                with torch.no_grad():
                    avg_training_loss += loss.item()

            # Switch to eval mode
            model.eval()

            with torch.no_grad():
                for x, Y in tqdm(test_loader, desc='Testing', unit=' batches'):
                    # Move the images to the GPU
                    x = x.to(device)
                    Y = Y.to(device)

                    # Get logits and sum up total loss
                    x_pred = model(x)
                    avg_testing_loss += loss_function(x_pred, Y, pos_weight_tensor).item()

            # Get training loss
            avg_training_loss /= len(train_loader)

             # Get testing loss
            avg_testing_loss /= len(test_loader)

            scheduler.step(avg_testing_loss)

            # Switch model back to training mode
            model.train()

            epoch_over_training.append({
                "epoch": epoch,
                "training_loss": avg_training_loss
                })

            epoch_over_testing.append({
                "epoch": epoch,
                "testing_loss": avg_testing_loss
                })


            print("")

            print(f'   -> Training Loss: {avg_training_loss: .4f}\n')
            print(f'   -> Testing Loss: {avg_testing_loss: .4f}\n')


        return epoch_over_training, epoch_over_testing


    def critic_and_actor_training_loop(self, hp_c, hp_a, num_of_games = 50, epochs=10, difficulty: str = 'medium'):

        old_actor: nn.Module | None = None
        old_critic: nn.Module | None = None

        for epoch in range(epochs):
            # Get new models
            critic = Critic()
            actor = Actor()

            # Generate Initial Critic data based off of logic bot
            x, y = self.generate_critic_training_data(num_games=num_of_games, difficulty=difficulty, actor=old_actor)

            # Convert to dataloader
            critic_train_loader, critic_test_loader = get_dataloader(x, y, 64)

            # Train Critic
            self.train_critc(critic=critic, train_loader=critic_train_loader, test_loader=critic_test_loader, hp=hp_c)

            # Get actor data
            x, y = self.generate_actor_training_data('medium', num_of_games, critic)

            # Convert to dataloader
            actor_train_loader, actor_test_loader = get_dataloader(x, y, 64)

            # Train the actor
            self.train_actor(actor=actor, train_loader=actor_train_loader, test_loader=actor_test_loader, hp=hp_a)

            old_actor = actor


        return old_actor, old_critic

In [4]:
def start_nn_game(manager, model, difficulty: str, device: torch.device, allow_mines: bool = False):
    """
    Runs a complete game simulation using the trained Neural Network model.
    """
    manager.initialize_board(difficulty)
    manager.is_game_over = False

    # Put the model in evaluation mode
    model.eval()

    with torch.no_grad():
        while not manager.is_game_over:
            # 1. Get the current state (Input: X)
            X_current = manager.get_nn_input_state().to(device)
            # Add batch dimension: (4, 22, 22) -> (1, 4, 22, 22)
            X_batch = X_current.unsqueeze(0)

            # 2. Get the model's prediction (Output: Y_pred)
            logits = model(X_batch)  # Shape: (1, 484)

            # Convert logits to probabilities using sigmoid
            probs = torch.sigmoid(logits)  # Range [0, 1]

            # Reshape to 22x22 safety map
            safety_map = probs.view(manager.size, manager.size)  # Shape: (22, 22)

            # 3. Choose the move (The NN's decision logic)
            r, c = get_nn_move_choice(safety_map, manager.opened_board.to(device))

            if r is None:
                # No safe moves found, game likely won or stuck
                manager.is_game_over = True
                break

            # 4. Execute the move
            success, game_over = manager.make_move(r, c, allow_mines)

    # Return the results for metric calculation
    return {
        "success": manager.check_win_condition(),
        "moves_taken": manager.moves_taken,
        "mines_triggered": manager.mines_triggered
    }


# def get_nn_move_choice(mine_map, opened_board):
#     # Mask opened cells with a HIGH value so they aren't picked as 'minimum'
#     masked_map = mine_map.clone()
#     masked_map[opened_board == 1] = 100.0
#
#     # Find the MINIMUM score (Lowest probability of being a mine)
#     min_score = masked_map.min()
#
#     # Safety Check: If the 'best' move has a high chance of being a mine, don't move.
#     # (e.g., if the lowest probability on the board is 0.5, you are guessing)
#     # if min_score > 0.5: # Only move if < 5% chance of mine
#     #     return None, None
#
#     # Get coordinates
#     r_idx, c_idx = torch.where(masked_map == min_score)
#
#     # If there are ties, pick one randomly
#     if len(r_idx) > 1:
#         choice = torch.randint(0, len(r_idx), (1,)).item()
#         return r_idx[choice].item(), c_idx[choice].item()
#     else:
#         return r_idx[0].item(), c_idx[0].item()

def get_nn_move_choice(safety_map, opened_board):
    """
    Finds the unopened cell with the highest predicted safety score.
    """
    # Mask out opened cells by setting their scores to a very low value
    masked_safety_map = safety_map.clone()
    masked_safety_map[opened_board == 1] = -1e9

    # Find the maximum score
    max_score = masked_safety_map.max()

    # Check if all cells are opened/invalid
    if max_score < -1e8:
        return None, None

    # Find coordinates of the max score
    r_idx, c_idx = torch.where(masked_safety_map == max_score)

    # If there are ties, pick one randomly
    if len(r_idx) > 1:
        # Note: Added .cpu() here for robustness, assuming map is on GPU
        r_idx = r_idx.cpu()
        c_idx = c_idx.cpu()
        choice = torch.randint(0, len(r_idx), (1,)).item()
        return r_idx[choice].item(), c_idx[choice].item()
    else:
        # Note: Added .cpu() here for robustness, assuming map is on GPU
        return r_idx[0].cpu().item(), c_idx[0].cpu().item()


def calculate_bot_stats(model: nn.Module, num_of_games: int):

    bot_type = 'NN' if model else 'Logic Bot'


    for mode in [True, False]:

        print('\n--------------------Real Game Rules----------------------' if not mode else '\n--------------Mines can be detonated----------------')

        for difficulty in ['easy', 'medium', 'hard']:

            moves_per_game = []
            games_won = 0
            mines_set_off = []

            for game_idx in range(num_of_games):

                stats = start_nn_game(MSGameManager(), model, difficulty=difficulty, device=device, allow_mines=mode) if model \
                    else manager.start_bot_game(difficulty=difficulty, allow_mine=mode)

                mines_set_off.append(stats['mines_triggered'])

                if stats['success']:
                    games_won += 1
                    moves_per_game.append(stats["moves_taken"])

            win_rate = games_won / num_of_games * 100
            avg_moves_per_won_game = np.average(moves_per_game) if win_rate > 0 else 0
            sample_variance = np.var(moves_per_game) if win_rate > 0 else 0
            sample_std = np.std(moves_per_game) if win_rate > 0 else 0

            avg_mines_set_off = np.average(mines_set_off)

            print(f'On {difficulty}: the {bot_type} won {win_rate:.2f}% of the games, averaged {avg_moves_per_won_game:.2f} moves per won game which had a variance of {sample_variance:.2f} and std deviation {sample_std:.2f}' if not mode else '')
            print(f'On {difficulty} mode, the average number of mines that were set off was {avg_mines_set_off}' if mode else '')





In [7]:
from torch.utils.data import Dataset

class MSDataset(Dataset):

    def __init__(self, X: torch.Tensor, y: torch.Tensor):

        self.X = X.float()

        # CRITICAL FIX: Only reshape Y if it is not scalar data (i.e., not Critic data).
        # We need to preserve the shape of scalar Critic data (N,)
        if y.ndim > 1:
            self.y = y.float().view(y.size(0), -1) # Reshape the label to match model output (N x 484)
        else:
            # If y is already a 1D tensor of scalar values (Critic data), keep it as is.
            self.y = y.float()

        self.size = X.size(-1)

    def __len__(self):
        return self.X.size(0)

    def __getitem__(self, idx):
        # 1. Retrieve the original sample
        # X is (C, 22, 22). Y is either a scalar (Critic) or (484) vector (Actor).

        # Use .clone().contiguous() to ensure data is independently owned
        X = self.X[idx].clone().contiguous()
        Y = self.y[idx].clone().contiguous()

        return X, Y

In [5]:
class MSSafeSquares(nn.Module):

    def __init__(self, lin_dropout=0, cnn_dropout=0):

        super().__init__()

        self.BOARD_SIZE = 22

        cnn_layer_1_size = 32
        cnn_layer_2_size = 64
        cnn_layer_3_size = 128
        cnn_layer_4_size = 128

        lin_layer_1_size = 2000
        lin_layer_2_size = 1500
        lin_layer_3_size = 1000
        lin_layer_4_size = 500

        self.CNN = nn.Sequential(
            nn.Conv2d(in_channels=10, out_channels=cnn_layer_1_size, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(cnn_layer_1_size),
            nn.ReLU(),
            nn.Dropout2d(p=cnn_dropout),


            nn.Conv2d(in_channels=cnn_layer_1_size, out_channels=cnn_layer_2_size, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(cnn_layer_2_size),
            nn.ReLU(),
            nn.Dropout2d(p=cnn_dropout),

            nn.Conv2d(in_channels=cnn_layer_2_size, out_channels=cnn_layer_3_size, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(cnn_layer_3_size),
            nn.ReLU(),
            nn.Dropout2d(p=cnn_dropout),

            nn.Conv2d(in_channels=cnn_layer_3_size, out_channels=cnn_layer_4_size, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(cnn_layer_4_size),
            nn.ReLU(),
            nn.Dropout2d(p=cnn_dropout),

            nn.Conv2d(in_channels=cnn_layer_4_size, out_channels=1, kernel_size=1),
        )

    def forward(self, x):

        logits = self.CNN(x)

        return logits.view(x.size(0), -1)

class CNN_Block(nn.Module):

    def __init__(self, in_planes, planes, stride=1):
        super().__init__()

        # self.cnn_dropout = cnn_dropout

        self.skip = nn.Sequential()

        self.conv_block = nn.Sequential(
            nn.Conv2d(in_channels=in_planes, out_channels=planes, kernel_size=3, stride=stride, padding=1, bias=False),
            nn.BatchNorm2d(planes),

            nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(planes),
        )

        if stride != 1 or in_planes != planes:
            # Use a 1x1 convolution to match dimensions
            self.skip = nn.Sequential(
                nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes)
            )

    def forward(self, x):

        # Pass the input through the block
        logits = self.conv_block(x)

        # Skip the original data
        logits += self.skip(x)

        # Activation Function
        logits = F.relu(logits)

        return logits


class MS_ResNet(nn.Module):
    def __init__(self, num_blocks:list, linear_dropout=0.25, cnn_dropout=0.25):
        super().__init__()

        self.BOARD_SIZE = 22

        # Initial size of the CNN layer that accepts the image
        # Also used when creating stages of blocks self.stage_layer
        self.in_planes = 64

        self.image_input_layer = nn.Sequential(
            nn.Conv2d(in_channels=10, out_channels=self.in_planes, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(self.in_planes),
            nn.ReLU(inplace=True)
        )

        # range_beginning = 48
        # range_end = 144
        # range_step = range_beginning
        #
        # self.cnn_plane_range = range(range_beginning, range_end + 1, range_step)

        # self.stages = nn.Sequential(
        #     *[self.make_stage(planes, num_block, 1 if idx == 0 else 2)
        #         for idx, (planes, num_block) in enumerate(zip(self.cnn_plane_range, num_blocks))]
        # )

        self.stages = nn.Sequential(
            self.make_stage(64, num_blocks[0], 1),
            self.make_stage(128, num_blocks[1], 2),
            self.make_stage(192, num_blocks[2], 2),
        )

        self.lin_layer_1_size = 100
        self.lin_layer_2_size = 500
        self.lin_layer_3_size = 1000
        self.lin_layer_4_size = 500


        self.classifier = nn.Sequential(
            nn.Linear(192, self.lin_layer_1_size),
            nn.LayerNorm(self.lin_layer_1_size),
            nn.ReLU(inplace=True),
            nn.Dropout(linear_dropout),

            nn.Linear(self.lin_layer_1_size, self.lin_layer_2_size),
            nn.LayerNorm(self.lin_layer_2_size),
            nn.ReLU(inplace=True),
            nn.Dropout(linear_dropout),

            nn.Linear(self.lin_layer_2_size, self.lin_layer_3_size),
            nn.LayerNorm(self.lin_layer_3_size),
            nn.ReLU(inplace=True),
            nn.Dropout(linear_dropout),

            nn.Linear(self.lin_layer_3_size, self.lin_layer_4_size),
            nn.LayerNorm(self.lin_layer_4_size),
            nn.ReLU(inplace=True),
            nn.Dropout(linear_dropout),

            nn.Linear(self.lin_layer_4_size, self.BOARD_SIZE * self.BOARD_SIZE)
        )


    def make_stage(self, planes, num_blocks, stride):

        strides = [stride] + [1] * (num_blocks - 1)

        layers = []

        for stride in strides:
            # Add ResBlock to list
            layers.append(CNN_Block(self.in_planes, planes, stride))
            # Reset the in planes to preserve in_channels of the next blocks
            self.in_planes = planes

        return nn.Sequential(*layers)

    def forward(self, x):

        logits = self.image_input_layer(x)

        logits = self.stages(logits)

        logits = F.avg_pool2d(logits, 6)

        logits = logits.view(logits.size(0), -1)

        logits = self.classifier(logits)

        return logits


def masked_bce_loss(predictions, targets: torch.Tensor, pos_weight):
    # Create mask FIRST
    mask = (targets != -1.0).float()

    # Replace -1 with 0 for valid BCE computation
    targets_clean = targets.clone()
    targets_clean[targets == -1.0] = 0.0

    # Compute BCE loss
    bce = nn.BCEWithLogitsLoss(pos_weight=pos_weight, reduction='none')
    loss = bce(predictions, targets_clean)

    # Apply mask and average only over valid cells
    masked_loss = (loss * mask).sum() / (mask.sum() + 1e-8)  # Avoid division by zero

    return masked_loss

In [8]:
from torch.utils.data import DataLoader

def get_dataloader(X_dataset, Y_dataset, batch_size):
    total_size = len(X_dataset)

    train_size = int(0.8 * total_size)

    label_size = total_size - train_size

    X_train, X_test = torch.split(X_dataset, [train_size, label_size])
    Y_train, Y_test = torch.split(Y_dataset, [train_size, label_size])

    train_dataset = MSDataset(X_train, Y_train)
    test_dataset = MSDataset(X_test, Y_test)

    train_loader = DataLoader(
        dataset=train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=5,
        pin_memory=True
    )

    test_loader = DataLoader(
        dataset=test_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=5,
        pin_memory=True
    )

    return train_loader, test_loader

manager = MSGameManager()

# X, Y = manager.generate_frontier_training_data('easy', 1500, 50, 400)

X, Y = torch.load('./data/X_train_easy_front.pt'), torch.load('./data/Y_train_easy_front.pt')

MS_train_loader_easy, MS_test_loader_easy = get_dataloader(X, Y, 64)

In [9]:
def training_loop(train_loader, test_loader, model: nn.Module, hp: dict):

    epoch_over_training = []
    epoch_over_testing = []

    # Hyperparameter setup
    epochs = hp['epochs']
    learning_rate = hp['learning_rate']
    decay_rate = hp['decay_rate']

    c_dropout = hp['cnn_dropout']
    f_dropout = hp['linear_dropout']

    print('######## Beginning training for MS Safe Square Predictor ##########')

    pos_weight_tensor = torch.tensor(hp['pos_weight'], device=device)

    loss_function = masked_bce_loss

    optimizer = optim.AdamW(model.parameters(),
                           lr=learning_rate,
                           weight_decay=decay_rate
                           )

    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        mode='min',
        factor=0.5,
        patience=5,
        min_lr=1e-6
    )

    # Have references to variables outside of the epoch loop

    avg_training_loss = 0
    avg_testing_loss = 0


    # Epoch Loop
    for epoch in range(epochs):
        print(f'----- Epoch: {epoch + 1}/{epochs} -----')

        avg_training_loss = 0
        avg_testing_loss = 0

        model.train()

        for x, Y in tqdm(train_loader, desc='Training', unit=' batch'):
            # Transfer images to GPU
            x = x.to(device)
            Y = Y.to(device)

            # Zero out gradients
            optimizer.zero_grad()

            # Send images to model
            x_pred = model(x)

            # Calc loss
            loss = loss_function(x_pred, Y, pos_weight_tensor)

            # Calc gradient and update weights
            loss.backward()
            optimizer.step()

            with torch.no_grad():
                avg_training_loss += loss.item()

        # Switch to eval mode
        model.eval()

        with torch.no_grad():
            for x, Y in tqdm(test_loader, desc='Testing', unit=' batches'):
                # Move the images to the GPU
                x = x.to(device)
                Y = Y.to(device)

                # Get logits and sum up total loss
                x_pred = model(x)
                avg_testing_loss += loss_function(x_pred, Y, pos_weight_tensor).item()

        # Get training loss
        avg_training_loss /= len(train_loader)

         # Get testing loss
        avg_testing_loss /= len(test_loader)

        scheduler.step(avg_testing_loss)

        # Switch model back to training mode
        model.train()

        epoch_over_training.append({
            "epoch": epoch,
            "training_loss": avg_training_loss
            })

        epoch_over_testing.append({
            "epoch": epoch,
            "testing_loss": avg_testing_loss
            })


        print("")

        print(f'   -> Training Loss: {avg_training_loss: .4f}\n')
        print(f'   -> Testing Loss: {avg_testing_loss: .4f}\n')

    return model, epoch_over_training, epoch_over_testing

In [10]:
hyperparameters = {
    'epochs': 20,
    'learning_rate': 1e-4,
    'decay_rate': 1e-3,
    'cnn_dropout': 0.3,
    'linear_dropout': 0.6,
    'pos_weight': 15
}

model_easy = MSSafeSquares(#[3,3,3],
                       hyperparameters['linear_dropout'],
                       hyperparameters['cnn_dropout']).to(device)

model_easy, train, test= training_loop(train_loader=MS_train_loader_easy, test_loader=MS_test_loader_easy, hp=hyperparameters, model=model_easy)

######## Beginning training for MS Safe Square Predictor ##########
----- Epoch: 1/20 -----


Training: 100%|██████████| 1386/1386 [00:04<00:00, 288.72 batch/s]
Testing: 100%|██████████| 347/347 [00:00<00:00, 639.60 batches/s]



   -> Training Loss:  0.7181

   -> Testing Loss:  0.3277

----- Epoch: 2/20 -----


Training: 100%|██████████| 1386/1386 [00:04<00:00, 308.47 batch/s]
Testing: 100%|██████████| 347/347 [00:00<00:00, 628.31 batches/s]



   -> Training Loss:  0.3411

   -> Testing Loss:  0.2383

----- Epoch: 3/20 -----


Training: 100%|██████████| 1386/1386 [00:04<00:00, 313.39 batch/s]
Testing: 100%|██████████| 347/347 [00:00<00:00, 618.40 batches/s]



   -> Training Loss:  0.2801

   -> Testing Loss:  0.2117

----- Epoch: 4/20 -----


Training: 100%|██████████| 1386/1386 [00:04<00:00, 309.44 batch/s]
Testing: 100%|██████████| 347/347 [00:00<00:00, 626.72 batches/s]



   -> Training Loss:  0.2512

   -> Testing Loss:  0.1961

----- Epoch: 5/20 -----


Training: 100%|██████████| 1386/1386 [00:04<00:00, 303.80 batch/s]
Testing: 100%|██████████| 347/347 [00:00<00:00, 616.41 batches/s]



   -> Training Loss:  0.2339

   -> Testing Loss:  0.1887

----- Epoch: 6/20 -----


Training: 100%|██████████| 1386/1386 [00:04<00:00, 308.65 batch/s]
Testing: 100%|██████████| 347/347 [00:00<00:00, 633.56 batches/s]



   -> Training Loss:  0.2221

   -> Testing Loss:  0.1821

----- Epoch: 7/20 -----


Training: 100%|██████████| 1386/1386 [00:04<00:00, 307.71 batch/s]
Testing: 100%|██████████| 347/347 [00:00<00:00, 637.79 batches/s]



   -> Training Loss:  0.2140

   -> Testing Loss:  0.1774

----- Epoch: 8/20 -----


Training: 100%|██████████| 1386/1386 [00:04<00:00, 307.05 batch/s]
Testing: 100%|██████████| 347/347 [00:00<00:00, 621.58 batches/s]



   -> Training Loss:  0.2068

   -> Testing Loss:  0.1729

----- Epoch: 9/20 -----


Training: 100%|██████████| 1386/1386 [00:04<00:00, 308.43 batch/s]
Testing: 100%|██████████| 347/347 [00:00<00:00, 617.50 batches/s]



   -> Training Loss:  0.2017

   -> Testing Loss:  0.1713

----- Epoch: 10/20 -----


Training: 100%|██████████| 1386/1386 [00:04<00:00, 308.12 batch/s]
Testing: 100%|██████████| 347/347 [00:00<00:00, 635.44 batches/s]



   -> Training Loss:  0.1969

   -> Testing Loss:  0.1679

----- Epoch: 11/20 -----


Training: 100%|██████████| 1386/1386 [00:04<00:00, 310.50 batch/s]
Testing: 100%|██████████| 347/347 [00:00<00:00, 651.20 batches/s]



   -> Training Loss:  0.1934

   -> Testing Loss:  0.1674

----- Epoch: 12/20 -----


Training: 100%|██████████| 1386/1386 [00:04<00:00, 309.09 batch/s]
Testing: 100%|██████████| 347/347 [00:00<00:00, 610.85 batches/s]



   -> Training Loss:  0.1902

   -> Testing Loss:  0.1659

----- Epoch: 13/20 -----


Training: 100%|██████████| 1386/1386 [00:04<00:00, 311.84 batch/s]
Testing: 100%|██████████| 347/347 [00:00<00:00, 619.89 batches/s]



   -> Training Loss:  0.1868

   -> Testing Loss:  0.1647

----- Epoch: 14/20 -----


Training: 100%|██████████| 1386/1386 [00:04<00:00, 308.51 batch/s]
Testing: 100%|██████████| 347/347 [00:00<00:00, 634.96 batches/s]



   -> Training Loss:  0.1841

   -> Testing Loss:  0.1641

----- Epoch: 15/20 -----


Training: 100%|██████████| 1386/1386 [00:04<00:00, 306.04 batch/s]
Testing: 100%|██████████| 347/347 [00:00<00:00, 629.44 batches/s]



   -> Training Loss:  0.1817

   -> Testing Loss:  0.1631

----- Epoch: 16/20 -----


Training: 100%|██████████| 1386/1386 [00:04<00:00, 305.34 batch/s]
Testing: 100%|██████████| 347/347 [00:00<00:00, 615.60 batches/s]



   -> Training Loss:  0.1797

   -> Testing Loss:  0.1626

----- Epoch: 17/20 -----


Training: 100%|██████████| 1386/1386 [00:04<00:00, 312.10 batch/s]
Testing: 100%|██████████| 347/347 [00:00<00:00, 626.51 batches/s]



   -> Training Loss:  0.1776

   -> Testing Loss:  0.1624

----- Epoch: 18/20 -----


Training: 100%|██████████| 1386/1386 [00:04<00:00, 307.35 batch/s]
Testing: 100%|██████████| 347/347 [00:00<00:00, 608.23 batches/s]



   -> Training Loss:  0.1762

   -> Testing Loss:  0.1607

----- Epoch: 19/20 -----


Training: 100%|██████████| 1386/1386 [00:04<00:00, 307.96 batch/s]
Testing: 100%|██████████| 347/347 [00:00<00:00, 623.54 batches/s]



   -> Training Loss:  0.1743

   -> Testing Loss:  0.1612

----- Epoch: 20/20 -----


Training: 100%|██████████| 1386/1386 [00:04<00:00, 306.16 batch/s]
Testing: 100%|██████████| 347/347 [00:00<00:00, 622.43 batches/s]


   -> Training Loss:  0.1729

   -> Testing Loss:  0.1599






In [11]:
calculate_bot_stats(model_easy, 100)
calculate_bot_stats(None, 100)


--------------Mines can be detonated----------------

On easy mode, the average number of mines that were set off was 4.73

On medium mode, the average number of mines that were set off was 14.66

On hard mode, the average number of mines that were set off was 28.98

--------------------Real Game Rules----------------------
On easy: the NN won 84.00% of the games, averaged 80.04 moves per won game which had a variance of 208.49 and std deviation 14.44

On medium: the NN won 25.00% of the games, averaged 160.04 moves per won game which had a variance of 349.40 and std deviation 18.69

On hard: the NN won 0.00% of the games, averaged 0.00 moves per won game which had a variance of 0.00 and std deviation 0.00


--------------Mines can be detonated----------------

On easy mode, the average number of mines that were set off was 0.51

On medium mode, the average number of mines that were set off was 4.6

On hard mode, the average number of mines that were set off was 18.54

---------------

In [None]:
class Critic(nn.Module):

    def __init__(self, linear_dropout=0):
        super().__init__()

        self.BOARD_SIZE = 22

        cnn_layer_1_size = 32
        cnn_layer_2_size = 32
        cnn_layer_3_size = 64
        cnn_layer_4_size = 64

        self.CNN = nn.Sequential(
            nn.Conv2d(in_channels=5, out_channels=cnn_layer_1_size, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(num_features=cnn_layer_1_size),
            nn.ReLU(),

            nn.Conv2d(in_channels=cnn_layer_1_size, out_channels=cnn_layer_2_size, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(num_features=cnn_layer_2_size),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),

            nn.Conv2d(in_channels=cnn_layer_2_size, out_channels=cnn_layer_3_size, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(num_features=cnn_layer_3_size),
            nn.ReLU(),

            nn.Conv2d(in_channels=cnn_layer_3_size, out_channels=cnn_layer_4_size, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(num_features=cnn_layer_4_size),
            nn.ReLU(),
        )

        total_count = 0

        with torch.no_grad():
            test_input = torch.zeros(1, 5, self.BOARD_SIZE, self.BOARD_SIZE)

            test_input.to(device)

            features = self.CNN(test_input)

            total_count = features.view(1, -1).size(1)

        lin_layer_1_size = 2000
        lin_layer_2_size = 2000
        lin_layer_3_size = 1500
        lin_layer_4_size = 500


        self.Classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(in_features=total_count, out_features=lin_layer_1_size),
            nn.BatchNorm1d(num_features=lin_layer_1_size),
            nn.ReLU(inplace=True),
            nn.Dropout(p=linear_dropout),

            nn.Linear(in_features=lin_layer_1_size, out_features=lin_layer_2_size),
            nn.BatchNorm1d(num_features=lin_layer_2_size),
            nn.ReLU(inplace=True),
            nn.Dropout(p=linear_dropout),

            nn.Linear(in_features=lin_layer_2_size, out_features=lin_layer_3_size),
            nn.BatchNorm1d(num_features=lin_layer_3_size),
            nn.ReLU(inplace=True),
            nn.Dropout(p=linear_dropout),

            nn.Linear(in_features=lin_layer_3_size, out_features=lin_layer_4_size),
            nn.BatchNorm1d(num_features=lin_layer_4_size),
            nn.ReLU(inplace=True),
            nn.Dropout(p=linear_dropout),
        )

        self.output_layer = nn.Linear(in_features=lin_layer_4_size, out_features=1)

    def forward(self, x):

        logits = self.CNN(x)
        logits = self.Classifier(logits)

        logits = self.output_layer(logits)

        return logits

class Actor(nn.Module):

    def __init__(self, linear_dropout=0):
        super().__init__()

        self.BOARD_SIZE = 22

        cnn_layer_1_size = 32
        cnn_layer_2_size = 32
        cnn_layer_3_size = 64
        cnn_layer_4_size = 64

        self.CNN = nn.Sequential(
            nn.Conv2d(in_channels=4, out_channels=cnn_layer_1_size, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(num_features=cnn_layer_1_size),
            nn.ReLU(),

            nn.Conv2d(in_channels=cnn_layer_1_size, out_channels=cnn_layer_2_size, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(num_features=cnn_layer_2_size),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),

            nn.Conv2d(in_channels=cnn_layer_2_size, out_channels=cnn_layer_3_size, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(num_features=cnn_layer_3_size),
            nn.ReLU(),

            nn.Conv2d(in_channels=cnn_layer_3_size, out_channels=cnn_layer_4_size, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(num_features=cnn_layer_4_size),
            nn.ReLU(),
        )

        total_count = 0

        with torch.no_grad():
            test_input = torch.zeros(1, 4, self.BOARD_SIZE, self.BOARD_SIZE)

            test_input.to(device)

            features = self.CNN(test_input)

            total_count = features.view(1, -1).size(1)

        lin_layer_1_size = 2000
        lin_layer_2_size = 2000
        lin_layer_3_size = 1500
        lin_layer_4_size = 500


        self.Classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(in_features=total_count, out_features=lin_layer_1_size),
            nn.ReLU(inplace=True),
            nn.Dropout(p=linear_dropout),

            nn.Linear(in_features=lin_layer_1_size, out_features=lin_layer_2_size),
            nn.ReLU(inplace=True),
            nn.Dropout(p=linear_dropout),

            nn.Linear(in_features=lin_layer_2_size, out_features=lin_layer_3_size),
            nn.ReLU(inplace=True),
            nn.Dropout(p=linear_dropout),

            nn.Linear(in_features=lin_layer_3_size, out_features=lin_layer_4_size),
            nn.ReLU(inplace=True),
            nn.Dropout(p=linear_dropout),
        )

        self.output_layer = nn.Linear(in_features=lin_layer_4_size, out_features=self.BOARD_SIZE * self.BOARD_SIZE)

    def forward(self, x):

        logits = self.CNN(x)
        logits = self.Classifier(logits)

        logits = self.output_layer(logits)

        return logits


In [None]:
hp_a = {
    'epochs': 5,
    'learning_rate': 1e-2,
    'decay_rate': 1e-3,
    'cnn_dropout': 0.0,
    'linear_dropout': 0.25
}

hp_c = {
    'epochs': 5,
    'learning_rate': 1e-4,
    'decay_rate': 1e-3,
    'cnn_dropout': 0.0,
    'linear_dropout': 0.25
}

manager = MSGameManager()

critic, actor = manager.critic_and_actor_training_loop(hp_a=hp_a, hp_c=hp_c, num_of_games=500)