In [248]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import numpy as np
import random
import matplotlib.pyplot as plt
import seaborn as sns
from numpy.ma.core import shape
from torch.nn import MSELoss
from torch.utils import data
from torch.utils.data import Dataset
from tqdm import tqdm
from triton.language import dtype
import torch.optim.lr_scheduler as lr_scheduler

In [None]:
# Verify that GPU is connected and available

print(torch.__version__)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

print(torch.cuda.get_device_name(0))

In [278]:
import torch
import random
import numpy as np

class MSGameManager:


    def __init__(self):
        self.player_board = torch.zeros([22, 22], dtype=torch.int8)
        self.mine_board = torch.zeros([22, 22], dtype=torch.int8) # Location of All Mines, 9 for mine, 0-8 for clue
        self.flagged_board = torch.zeros([22, 22], dtype=torch.int8)
        self.opened_board = torch.zeros([22, 22], dtype=torch.int8)
        self.number_of_mines = -1

        # Logic bot variables
        self.cells_remaining = set()
        self.inferred_safe = set()
        self.inferred_mine = set()
        self.clue_number = dict()
        self.size = 22
        self.moves_taken = 0
        self.initial_start_coords = None
        self.is_game_over = False
        self.mines_triggered = 0

    def reset_game(self):
        # This method handles resetting ALL state variables for a new game
        self.player_board = torch.zeros([self.size, self.size], dtype=torch.int8)
        self.mine_board = torch.zeros([self.size, self.size], dtype=torch.int8)
        self.flagged_board = torch.zeros([self.size, self.size], dtype=torch.int8)
        self.opened_board = torch.zeros([self.size, self.size], dtype=torch.int8)
        self.number_of_mines = -1 # Set back to default, to be set later
        self.cells_remaining = set()
        self.inferred_safe = set()
        self.inferred_mine = set()
        self.clue_number = dict()
        self.moves_taken = 0
        self.initial_start_coords = None
        self.is_game_over = False
        self.mines_triggered = 0

    def initialize_board(self, difficulty: str):

        self.reset_game()

        self.number_of_mines = 50 if difficulty == "easy" \
            else 80 if difficulty == "medium" \
                else 100

        # Add all cells into the remaining set and reset boards
        self.cells_remaining = set((r, c) for r in range(self.size) for c in range(self.size))
        self.mine_board = torch.zeros([self.size, self.size], dtype=torch.int8)

        # Generate a random starting location
        initial_start = torch.randint(0, self.size, (2, ), dtype=torch.int8)
        start_r, start_c = initial_start[0].item(), initial_start[1].item()
        self.initial_start_coords = (start_r, start_c)

        # Remove the starting cell from the board
        self.cells_remaining.remove(self.initial_start_coords)

        # All possible coordinates excluding the initial start cell
        possible_mine_locations = list(self.cells_remaining)

        # Grab the squares that will be set to mines
        mine_locations = random.sample(possible_mine_locations, self.number_of_mines)

        # Place mines on the mine board
        MINE = 9
        for r, c in mine_locations:
            self.mine_board[r, c] = MINE

        # Calculate and place clues on the mine_board
        self.calculate_clues()

        # Perform the initial move so we can have a starting board
        self.open_cell(start_r, start_c)


    def get_neighbors(self, r, c):
        """Returns a list of valid neighboring coordinates (up to 8)."""
        neighbors = []
        for dr in [-1, 0, 1]:
            for dc in [-1, 0, 1]:
                if dr == 0 and dc == 0:
                    continue
                nr, nc = r + dr, c + dc

                # Check for boundary conditions
                if 0 <= nr < self.size and 0 <= nc < self.size:
                    neighbors.append((nr, nc))
        return neighbors

    def calculate_clues(self):
        """Iterates over the board and calculates the clue number for non-mine cells."""
        MINE = 9

        for r in range(self.size):
            for c in range(self.size):

                # Skip square if it is a mine
                if self.mine_board[r, c] == MINE:
                    continue

                mine_count = 0
                for nr, nc in self.get_neighbors(r, c):
                    # Check if the neighbor is a mine
                    if self.mine_board[nr, nc] == MINE:
                        mine_count += 1

                # Store the clue number (0-8) on the mine_board
                self.mine_board[r, c] = mine_count

    def open_cell(self, r, c, allow_mine: bool = False):
        """
        Performs the action of opening a cell, updating all state boards,
        and handling the chain reaction if a blank (clue 0) is hit.
        Returns: True if the move was successful (not a mine), False otherwise.
        """
        if self.opened_board[r, c] == 1:
            return True # Already opened, treat as successful non-fatal move

        # 1. Update general state
        self.opened_board[r, c] = 1
        coord = (r, c)
        if coord in self.cells_remaining:
             self.cells_remaining.remove(coord)

        # Also remove from inferred sets if it was picked
        if coord in self.inferred_safe:
            self.inferred_safe.remove(coord)

        # 2. Get the ground truth value
        value = self.mine_board[r, c].item()

        MINE_CODE = 9

        if value == MINE_CODE: # Mine is hit
            self.is_game_over = True if not allow_mine else False
            self.mines_triggered += 1
            self.player_board[r, c] = -1 # A code for 'detonated mine'
            return False # Game over, move was unsuccessful

        elif value == 0: # Blank cell (Clue 0) - start chain reaction
            self.player_board[r, c] = 9 # Use 9 for revealed blank
            self.clue_number[coord] = 0

            # --- Chain Reaction Logic (using BFS for cascade) ---
            # ... (The full chain reaction logic from the previous answer goes here) ...

            queue = [(r, c)]
            visited = set([(r, c)])

            while queue:
                curr_r, curr_c = queue.pop(0)

                for nr, nc in self.get_neighbors(curr_r, curr_c):
                    neighbor_coord = (nr, nc)
                    if neighbor_coord in visited or self.opened_board[nr, nc] == 1:
                        continue

                    visited.add(neighbor_coord)
                    true_value = self.mine_board[nr, nc].item()

                    if true_value == 0:
                        # Continue the cascade (reveal blank, add to queue)
                        self.player_board[nr, nc] = 9
                        self.opened_board[nr, nc] = 1
                        self.clue_number[neighbor_coord] = 0
                        if neighbor_coord in self.cells_remaining:
                            self.cells_remaining.remove(neighbor_coord)
                        queue.append(neighbor_coord)

                    elif 1 <= true_value <= 8:
                        # Stop the cascade (reveal clue)
                        self.player_board[nr, nc] = true_value
                        self.opened_board[nr, nc] = 1
                        self.clue_number[neighbor_coord] = true_value
                        if neighbor_coord in self.cells_remaining:
                            self.cells_remaining.remove(neighbor_coord)

            return True

        elif 1 <= value <= 8: # Clue cell
            self.player_board[r, c] = value
            self.clue_number[coord] = value
            return True

        return True # Should not be reached, but ensures return value

    def _run_logic_inferences(self):
        """
        Applies the two main Minesweeper logic rules iteratively until no new inferences are found.
        Returns True if any new inference was made, False otherwise.
        """
        inferences_made = False

        # Loop until a full pass yields no new inferences [cite: 24]
        while True:
            new_inferences_in_pass = False

            # Iterate over all cells where a clue has been revealed
            for r, c in list(self.clue_number.keys()):
                clue = self.clue_number[(r, c)]
                neighbors = self.get_neighbors(r, c)

                # --- Categorize Neighbors ---
                unrevealed_neighbors = set()
                mines_inferred_count = 0
                safe_inferred_count = 0

                for nr, nc in neighbors:
                    coord = (nr, nc)
                    if coord in self.inferred_mine:
                        mines_inferred_count += 1
                    elif coord in self.inferred_safe or self.opened_board[nr, nc] == 1:
                        safe_inferred_count += 1
                    elif self.opened_board[nr, nc] == 0:
                        # This cell is truly unrevealed and un-inferred
                        unrevealed_neighbors.add(coord)

                num_unrevealed = len(unrevealed_neighbors)

                # 1. Mine Inference (Rule 1)
                # If clue - (known mines) == (unrevealed cells)
                if clue - mines_inferred_count == num_unrevealed and num_unrevealed > 0:
                    for nr, nc in unrevealed_neighbors:
                        coord = (nr, nc)

                        # Only infer if not already marked as mine
                        if coord not in self.inferred_mine:
                            self.inferred_mine.add(coord)
                            # Remove from cells_remaining as it is now determined
                            if coord in self.cells_remaining:
                                self.cells_remaining.remove(coord)

                            new_inferences_in_pass = True

                # 2. Safe Inference (Rule 2)
                # If (Total Neighbors - Clue) - (known safe cells) == (unrevealed cells)
                # A simpler check: if all mines are accounted for by inferred mines/flags
                num_non_mines_required = len(neighbors) - clue

                # Total cells known to be safe (opened + inferred safe)
                total_known_safe = safe_inferred_count + num_unrevealed

                # If the total known safe cells (including all unrevealed) equals
                # the total number of non-mines possible
                if num_non_mines_required == total_known_safe and num_unrevealed > 0:
                    for nr, nc in unrevealed_neighbors:
                        coord = (nr, nc)

                        # Only infer if not already marked as safe
                        if coord not in self.inferred_safe and self.opened_board[nr, nc] == 0:
                            self.inferred_safe.add(coord)
                            # Remove from cells_remaining as it is now determined
                            if coord in self.cells_remaining:
                                self.cells_remaining.remove(coord)

                            new_inferences_in_pass = True

            # If no new inferences were made in this full pass, the loop terminates [cite: 24]
            if not new_inferences_in_pass:
                break
            else:
                inferences_made = True

        return inferences_made


    def get_logic_bot_move(self):
        """
        The main move logic for the simple bot.
        1. Runs inferences to find safe cells.
        2. Picks an inferred safe cell if available, otherwise picks randomly.
        Returns: (r, c) tuple of the next move.
        """
        # Step 1: Run inferences until the board is stable [cite: 24]
        self._run_logic_inferences()

        # Step 2: Select a cell to open [cite: 18]
        if self.inferred_safe:
            # Pick one of the inferred safe cells
            r, c = self.inferred_safe.pop()
        elif self.cells_remaining:
            # No safe inference found, pick a cell from the remaining pool at random [cite: 18]
            r, c = random.choice(list(self.cells_remaining))
            # Note: The assignment implies the cells_remaining only holds cells *not* inferred as mine.
        else:
            # Game is likely won or stuck
            return None

        return r, c

    def make_move(self, r, c, allow_mine: bool = False):
        """
        Executes a move at (r, c) and updates game state and metrics.
        Returns: Tuple (success: bool, is_game_over: bool)
        """
        if self.is_game_over or self.opened_board[r, c] == 1:
            return True, self.is_game_over # Cannot make move or already open

        # If the move is flagged, we assume the bot unflags it first (optional)
        if self.flagged_board[r, c] == 1:
             self.flagged_board[r, c] = 0 # Unflag before opening

        self.moves_taken += 1 # Increment step counter

        move_successful = self.open_cell(r, c, allow_mine)

        # Check for win condition after a successful move
        if move_successful and self.check_win_condition():
            self.is_game_over = True

        return move_successful, self.is_game_over

    def check_win_condition(self):
        """Checks if all safe cells have been opened."""
        total_cells = self.size * self.size
        cells_opened = torch.sum(self.opened_board).item()

        # Win condition: Number of opened cells equals (Total Cells - Number of Mines)
        return cells_opened == (total_cells - self.number_of_mines)

    def start_bot_game(self, difficulty: str, allow_mine: bool = False):

        self.initialize_board(difficulty) # Initializes board, places mines, and performs initial click
        self.is_game_over = False
        self.moves_taken = 0
        self.mines_triggered = 0

        # Ensure the initial move is tracked as a move taken (if it wasn't a blank that cascaded)
        # Note: initialize_board already opened the first cell. We start tracking from the second move.

        # The main simulation loop
        while not self.is_game_over:

            # 1. Get the bot's next move choice
            r, c = self.get_logic_bot_move()

            if r is None:
                # Bot could not find any remaining cell to click (likely won)
                self.is_game_over = True
                break

            # 2. Execute the move
            success, game_over = self.make_move(r, c, allow_mine)

            # If the game is over (due to a mine hit or win), the loop terminates

        # Return the required metrics for comparison
        return {
            "success": self._check_win_condition(),
            "moves_survived": self.moves_taken,
            "mines_triggered": self.mines_triggered
        }

    def run_logic_inferences(self):
        """
        Applies the two main Minesweeper logic rules iteratively until no new inferences are found.
        Returns True if any new inference was made, False otherwise.
        """
        inferences_made = False

        # [cite_start]Loop until a full pass yields no new inferences [cite: 24]
        while True:
            new_inferences_in_pass = False

            # Iterate over all cells where a clue has been revealed
            for r, c in list(self.clue_number.keys()):
                clue = self.clue_number[(r, c)]
                neighbors = self.get_neighbors(r, c)

                # --- Categorize Neighbors ---
                unrevealed_neighbors = set()
                mines_inferred_count = 0
                safe_inferred_count = 0

                for nr, nc in neighbors:
                    coord = (nr, nc)
                    if coord in self.inferred_mine:
                        mines_inferred_count += 1
                    elif coord in self.inferred_safe or self.opened_board[nr, nc] == 1:
                        safe_inferred_count += 1
                    elif self.opened_board[nr, nc] == 0:
                        # This cell is truly unrevealed and un-inferred
                        unrevealed_neighbors.add(coord)

                num_unrevealed = len(unrevealed_neighbors)

                # 1. Mine Inference (Rule 1)
                # [cite_start]If (cell clue) - (# neighbors inferred to be mines) = (# unrevealed neighbors) [cite: 22]
                if clue - mines_inferred_count == num_unrevealed and num_unrevealed > 0:
                    for nr, nc in unrevealed_neighbors:
                        coord = (nr, nc)

                        # Only infer if not already marked as mine
                        if coord not in self.inferred_mine:
                            self.inferred_mine.add(coord)

                            # --- FIX: FLAG THE INFERRED MINE ---
                            self.flagged_board[nr, nc] = 1 # Mark the cell as flagged

                            # [cite_start]Remove them from cells_remaining [cite: 22]
                            if coord in self.cells_remaining:
                                self.cells_remaining.remove(coord)

                            new_inferences_in_pass = True

                # 2. Safe Inference (Rule 2)
                # [cite_start]If ((# neighbors) - (cell clue)) - (# neighbors revealed or inferred to be safe) = (# unrevealed neighbors) [cite: 23]
                num_non_mines_required = len(neighbors) - clue

                # Total cells known to be safe (opened + inferred safe)
                total_known_safe = safe_inferred_count + num_unrevealed

                # If the total known safe cells (including all unrevealed) equals
                # the total number of non-mines possible
                if num_non_mines_required == total_known_safe and num_unrevealed > 0:
                    for nr, nc in unrevealed_neighbors:
                        coord = (nr, nc)

                        # Only infer if not already marked as safe
                        if coord not in self.inferred_safe and self.opened_board[nr, nc] == 0:
                            self.inferred_safe.add(coord)
                            # [cite_start]Remove from cells_remaining [cite: 23]
                            if coord in self.cells_remaining:
                                self.cells_remaining.remove(coord)

                            new_inferences_in_pass = True

            # [cite_start]If no new inferences were made in this full pass, the loop terminates [cite: 24]
            if not new_inferences_in_pass:
                break
            else:
                inferences_made = True

        return inferences_made

    def get_nn_input_state(self):
        """
        Creates the 4-channel input tensor (4, 22, 22) for the Neural Network.
        """
        size = self.size

        # Channel 1: Clue Values (0-8)
        # Use player_board, but replace the 'blank' code (9) with 0 for clean clue representation
        clues_channel = self.player_board.clone().float()
        clues_channel[clues_channel == 9] = 0.0
        clues_channel[clues_channel == -1] = 0.0 # Ignore detonated mines if applicable

        # Channel 2: Opened Mask (1/0)
        opened_mask_channel = self.opened_board.clone().float()

        # Channel 3: Flag Mask (1/0)
        flag_mask_channel = self.flagged_board.clone().float()

        # Channel 4: Global Context (e.g., Mine Density)
        # Represented as the ratio of mines to total cells
        mine_density = self.number_of_mines / (size * size)
        global_context_channel = torch.full((size, size), mine_density, dtype=torch.float32)

        # Stack the channels
        nn_input = torch.stack([
            clues_channel,
            opened_mask_channel,
            flag_mask_channel,
            global_context_channel
        ], dim=0)

        return nn_input

    def get_safety_label(self):
        """
        Creates ground-truth safety map.
        - 1.0 for safe, unopened cells
        - 0.0 for mines (unopened)
        - -1.0 for opened cells (to be masked in loss)
        """
        MINE_CODE = 9

        # Start with mine detection: 1.0 for safe, 0.0 for mine
        safety_label = (self.mine_board != MINE_CODE).float()

        # Mask out opened cells with -1 so they're ignored in loss
        safety_label[self.opened_board == 1] = -1.0

        return safety_label

    def generate_training_data(self, difficulty: str, num_games: int, allow_mine: bool = False):

        # Lists to store the collected Input (X) and Output (Y) tensors
        X_data = [] # List of (4, 22, 22) input tensors
        Y_data = [] # List of (22, 22) output safety label tensors

        for game_idx in range(num_games):

            # Start a fresh game
            self.reset_game()
            self.initialize_board(difficulty)

            # Print after initialize_board sets number_of_mines
            if game_idx == 0:
                mode_str = "with mines allowed" if allow_mine else "standard mode"
                print(f"Generating data for {difficulty.upper()} difficulty ({self.number_of_mines} mines) - {mode_str}...")

            self.is_game_over = False
            self.moves_taken = 0
            self.mines_triggered = 0

            # Logic Bot Play Loop
            while not self.is_game_over:

                # --- DATA COLLECTION POINT (BEFORE THE MOVE) ---

                # 1. Capture the current state (X)
                current_input_state = self.get_nn_input_state()

                # 2. Capture the ground truth label (Y)
                current_safety_label = self.get_safety_label()

                # Store the data pair
                X_data.append(current_input_state)
                Y_data.append(current_safety_label)

                # --- BOT DECISION AND EXECUTION ---

                # 3. Get the bot's next move choice (r, c)
                move = self.get_logic_bot_move()

                # CHECK FOR None BEFORE UNPACKING
                if move is None:
                    # Game ended (won or no moves left)
                    self.is_game_over = True
                    break

                r, c = move  # Now safe to unpack

                # 4. Execute the move (updates game state)
                success, game_over = self.make_move(r, c, allow_mine)

            if (game_idx + 1) % 100 == 0:
                print(f"Completed {game_idx + 1} games. Total data points: {len(X_data)}")

        # Convert lists to final PyTorch tensors
        X_tensor = torch.stack(X_data)
        Y_tensor = torch.stack(Y_data)

        return X_tensor, Y_tensor

    def generate_critic_training_data(self, difficulty: str, num_games: int, actor: nn.Module):
        """
        Generates training data for Critic using 5-channel input.

        Input: (5, 22, 22) where channel 4 is the selected move
        Output: number of moves survived
        """
        X_data = []  # List of (5, 22, 22) tensors
        Y_data = []  # List of survival counts

        for game_idx in range(num_games):
            self.reset_game()
            self.initialize_board(difficulty)

            if game_idx == 0:
                print(f"Generating Critic data for {difficulty.upper()} (5-channel approach)...")

            # Store game trajectory
            game_inputs = []

            while not self.is_game_over:
                # Get current 4-channel state
                current_state = self.get_nn_input_state()  # (4, 22, 22)

                # Get move from actor
                if actor is None:
                    move = self.get_logic_bot_move()
                else:
                    pass
                    move = self.get_actor_move(actor)  # Your trained actor

                if move is None:
                    break

                r, c = move

                # Create move channel (one-hot encoding of the move)
                move_channel = torch.zeros((self.size, self.size), dtype=torch.float32)
                move_channel[r, c] = 1.0

                # Combine into 5-channel input
                full_input = torch.cat([
                    current_state,  # (4, 22, 22)
                    move_channel.unsqueeze(0)  # (1, 22, 22)
                ], dim=0)  # Result: (5, 22, 22)

                game_inputs.append(full_input)

                # Execute move (allow continuing after mines)
                success, game_over = self.make_move(r, c, allow_mine=True)

            # Retrospectively label with survival counts
            total_moves = len(game_inputs)
            for i in range(total_moves):
                moves_survived = total_moves - i

                X_data.append(game_inputs[i])
                Y_data.append(moves_survived)

            if (game_idx + 1) % 100 == 0:
                print(f"Completed {game_idx + 1} games. Total data points: {len(X_data)}")

        return torch.stack(X_data), torch.tensor(Y_data, dtype=torch.float32)

    def generate_actor_training_data(self, difficulty: str, num_games: int, critic_model):
        """
        Generate training data for the Actor network.

        Process:
        1. For each game state, evaluate ALL possible moves using the Critic
        2. Create targets based on Critic's predictions
        3. Actor learns to predict these values directly from board state

        Returns:
            X_data: (N, 4, 22, 22) - board states
            Y_data: (N, 22, 22) - value map for each cell
        """
        X_data = []  # Board states
        Y_data = []  # Value maps (predicted survival for each cell)

        for game_idx in range(num_games):
            self.reset_game()
            self.initialize_board(difficulty)

            if game_idx == 0:
                print(f"Generating Actor training data for {difficulty.upper()}...")

            while not self.is_game_over:
                current_state = self.get_nn_input_state()  # (4, 22, 22)

                # Create value map by querying Critic for all possible moves
                value_map = self.create_value_map_from_critic(current_state, critic_model)

                # Store this training pair
                X_data.append(current_state)
                Y_data.append(value_map)

                # Make a move using current Actor (or logic bot initially)
                move = self.get_best_move_from_value_map(value_map)

                if move is None:
                    break

                r, c = move
                success, game_over = self.make_move(r, c, allow_mine=False)

                if not success:
                    break

            if (game_idx + 1) % 100 == 0:
                print(f"Actor data: Completed {game_idx + 1} games. Data points: {len(X_data)}")

        return (
            torch.stack(X_data),   # (N, 4, 22, 22)
            torch.stack(Y_data)    # (N, 22, 22)
        )


    def create_value_map_from_critic(self, current_state, critic_model):
        """
        Query the Critic for all possible moves to create a value map.

        Returns:
            value_map: (22, 22) tensor where each cell contains predicted survival
                       -1 for already opened cells
        """
        value_map = torch.full((self.size, self.size), -1.0, dtype=torch.float32)

        critic_model.eval()
        with torch.no_grad():
            for r in range(self.size):
                for c in range(self.size):
                    if self.opened_board[r, c] == 0:  # Unopened cell
                        # Create 5-channel input for this move
                        move_channel = torch.zeros((self.size, self.size), dtype=torch.float32)
                        move_channel[r, c] = 1.0

                        full_input = torch.cat([
                            current_state,
                            move_channel.unsqueeze(0)
                        ], dim=0).unsqueeze(0)  # (1, 5, 22, 22)

                        full_input = full_input.to(device)

                        # Get Critic's prediction
                        predicted_survival = critic_model(full_input)
                        value_map[r, c] = predicted_survival.item()

        return value_map


    def get_best_move_from_value_map(self, value_map):
        """
        Select the best move from value map.
        """
        # Get valid moves (value >= 0)
        valid_moves = []
        values = []

        for r in range(self.size):
            for c in range(self.size):
                if value_map[r, c] >= 0:  # Valid move
                    valid_moves.append((r, c))
                    values.append(value_map[r, c])

        if not valid_moves:
            return None

        # Select move with highest value
        values_tensor = torch.tensor(values)
        best_idx = torch.argmax(values_tensor).item()

        return valid_moves[best_idx]


    def get_actor_move(self, actor):
        """
        Get the best move from the Actor network.

        Args:
            actor: Trained Actor network

        Returns:
            (r, c) tuple of selected move, or None if no moves available
        """
        actor.eval()
        with torch.no_grad():
            # Get current board state
            current_state = self.get_nn_input_state().to(device)  # (4, 22, 22)

            # Get Actor's value predictions for all cells
            value_map = actor(current_state.unsqueeze(0)).squeeze(0)  # (22, 22)

            value_map = value_map.reshape((self.size, self.size))

            # Get all valid (unopened) moves and their predicted values
            valid_moves = []
            values = []

            for r in range(self.size):
                for c in range(self.size):
                    if self.opened_board[r, c] == 0:  # Cell is unopened
                        valid_moves.append((r, c))
                        values.append(value_map[r, c].item())

            if not valid_moves:
                return None

            # Select move with highest predicted value
            values_tensor = torch.tensor(values)
            best_idx = torch.argmax(values_tensor).item()

            return valid_moves[best_idx]


    def train_critc(self, critic, train_loader, test_loader, hp):
        epoch_over_training = []
        epoch_over_testing = []

        # Hyperparameter setup
        epochs = hp['epochs']
        learning_rate = hp['learning_rate']
        decay_rate = hp['decay_rate']

        c_dropout = hp['cnn_dropout']
        f_dropout = hp['linear_dropout']

        print('######## Beginning training for MS Critic ##########')

        model = critic
        model.to(device)

        loss_function = nn.MSELoss()

        optimizer = optim.AdamW(model.parameters(),
                               lr=learning_rate,
                               weight_decay=decay_rate
                               )

        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer,
            mode='min',
            factor=0.5,
            patience=5,
            min_lr=1e-6
        )

        # Have references to variables outside of the epoch loop

        avg_training_loss = 0
        avg_testing_loss = 0

        # Epoch Loop
        for epoch in range(epochs):
            print(f'----- Epoch: {epoch + 1}/{epochs} -----')

            avg_training_loss = 0
            avg_testing_loss = 0

            model.train()

            for x, Y in tqdm(train_loader, desc='Training', unit=' batch'):
                # Transfer images to GPU
                x = x.to(device)
                Y = Y.to(device)

                # Zero out gradients
                optimizer.zero_grad()

                # Send images to model
                x_pred = model(x)

                Y = Y.unsqueeze(1)

                # Calc loss
                loss = loss_function(x_pred, Y)

                # Calc gradient and update weights
                loss.backward()
                optimizer.step()

                with torch.no_grad():
                    avg_training_loss += loss.item()

            # Switch to eval mode
            model.eval()

            with torch.no_grad():
                for x, Y in tqdm(test_loader, desc='Testing', unit=' batches'):
                    # Move the images to the GPU
                    x = x.to(device)
                    Y = Y.to(device)

                    # Get logits and sum up total loss
                    x_pred = model(x)
                    Y = Y.unsqueeze(1)
                    avg_testing_loss += loss_function(x_pred, Y).item()

            # Get training loss
            avg_training_loss /= len(train_loader)

             # Get testing loss
            avg_testing_loss /= len(test_loader)

            # Monitor learning
            scheduler.step(avg_testing_loss)

            # Switch model back to training mode
            model.train()

            epoch_over_training.append({
                "epoch": epoch,
                "training_loss": avg_training_loss
                })

            epoch_over_testing.append({
                "epoch": epoch,
                "testing_loss": avg_testing_loss
                })


            print("")

            print(f'   -> Training Loss: {avg_training_loss: .4f}\n')
            print(f'   -> Testing Loss: {avg_testing_loss: .4f}\n')


        return epoch_over_training, epoch_over_testing


    def train_actor(self, actor, train_loader, test_loader, hp):
        epoch_over_training = []
        epoch_over_testing = []

        # Hyperparameter setup
        epochs = hp['epochs']
        learning_rate = hp['learning_rate']
        decay_rate = hp['decay_rate']

        c_dropout = hp['cnn_dropout']
        f_dropout = hp['linear_dropout']

        print('######## Beginning training for MS Actor ##########')

        model = actor
        model.to(device)

        pos_weight_tensor = torch.tensor(30.0, device=device)

        loss_function = masked_bce_loss

        optimizer = optim.AdamW(model.parameters(),
                               lr=learning_rate,
                               weight_decay=decay_rate
                               )

        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer,
            mode='min',
            factor=0.5,
            patience=5,
            min_lr=1e-6
        )

        # Have references to variables outside of the epoch loop

        avg_training_loss = 0
        avg_testing_loss = 0


        # Epoch Loop
        for epoch in range(epochs):
            print(f'----- Epoch: {epoch + 1}/{epochs} -----')

            avg_training_loss = 0
            avg_testing_loss = 0

            model.train()

            for x, Y in tqdm(train_loader, desc='Training', unit=' batch'):
                # Transfer images to GPU
                x = x.to(device)
                Y = Y.to(device)

                # Zero out gradients
                optimizer.zero_grad()

                # Send images to model
                x_pred = model(x)

                # Calc loss
                loss = loss_function(x_pred, Y, pos_weight_tensor)

                # Calc gradient and update weights
                loss.backward()
                optimizer.step()

                with torch.no_grad():
                    avg_training_loss += loss.item()

            # Switch to eval mode
            model.eval()

            with torch.no_grad():
                for x, Y in tqdm(test_loader, desc='Testing', unit=' batches'):
                    # Move the images to the GPU
                    x = x.to(device)
                    Y = Y.to(device)

                    # Get logits and sum up total loss
                    x_pred = model(x)
                    avg_testing_loss += loss_function(x_pred, Y, pos_weight_tensor).item()

            # Get training loss
            avg_training_loss /= len(train_loader)

             # Get testing loss
            avg_testing_loss /= len(test_loader)

            scheduler.step(avg_testing_loss)

            # Switch model back to training mode
            model.train()

            epoch_over_training.append({
                "epoch": epoch,
                "training_loss": avg_training_loss
                })

            epoch_over_testing.append({
                "epoch": epoch,
                "testing_loss": avg_testing_loss
                })


            print("")

            print(f'   -> Training Loss: {avg_training_loss: .4f}\n')
            print(f'   -> Testing Loss: {avg_testing_loss: .4f}\n')


        return epoch_over_training, epoch_over_testing


    def critic_and_actor_training_loop(self, hp_c, hp_a, num_of_games = 5, epochs=10, difficulty: str = 'medium'):

        old_actor: nn.Module | None = None
        old_critic: nn.Module | None = None

        for epoch in range(epochs):
            # Get new models
            critic = Critic()
            actor = Actor()

            # Generate Initial Critic data based off of logic bot
            x, y = self.generate_critic_training_data(num_games=num_of_games, difficulty=difficulty, actor=old_actor)

            # Convert to dataloader
            critic_train_loader, critic_test_loader = get_dataloader(x, y, 64)

            # Train Critic
            self.train_critc(critic=critic, train_loader=critic_train_loader, test_loader=critic_test_loader, hp=hp_c)

            # Get actor data
            x, y = self.generate_actor_training_data('medium', num_of_games, critic)

            # Convert to dataloader
            actor_train_loader, actor_test_loader = get_dataloader(x, y, 64)

            # Train the actor
            self.train_actor(actor=actor, train_loader=actor_train_loader, test_loader=actor_test_loader, hp=hp_a)

            old_actor = actor


        return old_actor, old_critic

In [103]:
def start_nn_game(manager, model, difficulty: str, device: torch.device):
    """
    Runs a complete game simulation using the trained Neural Network model.
    """
    manager.initialize_board(difficulty)
    manager.is_game_over = False

    # Put the model in evaluation mode
    model.eval()

    with torch.no_grad():
        while not manager.is_game_over:
            # 1. Get the current state (Input: X)
            X_current = manager.get_nn_input_state().to(device)
            # Add batch dimension: (4, 22, 22) -> (1, 4, 22, 22)
            X_batch = X_current.unsqueeze(0)

            # 2. Get the model's prediction (Output: Y_pred)
            logits = model(X_batch)  # Shape: (1, 484)

            # Convert logits to probabilities using sigmoid
            probs = torch.sigmoid(logits)  # Range [0, 1]

            # Reshape to 22x22 safety map
            safety_map = probs.view(manager.size, manager.size)  # Shape: (22, 22)

            # 3. Choose the move (The NN's decision logic)
            r, c = get_nn_move_choice(safety_map, manager.opened_board.to(device), threshold=0.5)

            if r is None:
                # No safe moves found, game likely won or stuck
                manager.is_game_over = True
                break

            # 4. Execute the move
            success, game_over = manager.make_move(r, c)

    # Return the results for metric calculation
    return {
        "success": manager.check_win_condition(),
        "moves_taken": manager.moves_taken,
        "mines_triggered": manager.mines_triggered
    }

def get_nn_move_choice(safety_map, opened_board, threshold=0.5):
    """
    Randomly selects from unopened cells predicted to be safe.
    Args:
        safety_map: (22, 22) tensor of safety predictions [0, 1]
        opened_board: (22, 22) tensor marking opened cells
        threshold: minimum score to consider a cell "safe"
    """
    # Create mask for unopened cells
    unopened_mask = (opened_board == 0)

    # Find cells that are both unopened AND predicted safe
    safe_unopened = (safety_map >= threshold) & unopened_mask

    # Get coordinates of all safe unopened cells
    safe_coords = torch.where(safe_unopened)

    if len(safe_coords[0]) == 0:
        # No cells predicted as safe, fall back to highest score
        # (This handles situations where all remaining cells look risky)
        masked_safety_map = safety_map.clone()
        masked_safety_map[opened_board == 1] = -1e9

        max_score = masked_safety_map.max()
        if max_score < -1e8:
            return None, None

        r_idx, c_idx = torch.where(masked_safety_map == max_score)
        choice = torch.randint(0, len(r_idx), (1,)).item() if len(r_idx) > 1 else 0
        return r_idx[choice].item(), c_idx[choice].item()

    # Randomly select from safe cells
    choice = torch.randint(0, len(safe_coords[0]), (1,)).item()
    return safe_coords[0][choice].item(), safe_coords[1][choice].item()

In [109]:
def start_nn_game(manager, model, difficulty: str, device: torch.device):
    """
    Runs a complete game simulation using the trained Neural Network model.
    """
    manager.initialize_board(difficulty)
    manager.is_game_over = False

    # Put the model in evaluation mode
    model.eval()

    with torch.no_grad():
        while not manager.is_game_over:
            # 1. Get the current state (Input: X)
            X_current = manager.get_nn_input_state().to(device)
            # Add batch dimension: (4, 22, 22) -> (1, 4, 22, 22)
            X_batch = X_current.unsqueeze(0)

            # 2. Get the model's prediction (Output: Y_pred)
            logits = model(X_batch)  # Shape: (1, 484)

            # Convert logits to probabilities using sigmoid
            probs = torch.sigmoid(logits)  # Range [0, 1]

            # Reshape to 22x22 safety map
            safety_map = probs.view(manager.size, manager.size)  # Shape: (22, 22)

            # 3. Choose the move (The NN's decision logic)
            r, c = get_nn_move_choice(safety_map, manager.opened_board.to(device))

            if r is None:
                # No safe moves found, game likely won or stuck
                manager.is_game_over = True
                break

            # 4. Execute the move
            success, game_over = manager.make_move(r, c)

    # Return the results for metric calculation
    return {
        "success": manager.check_win_condition(),
        "moves_taken": manager.moves_taken,
        "mines_triggered": manager.mines_triggered
    }


def get_nn_move_choice(safety_map, opened_board):
    """
    Finds the unopened cell with the highest predicted safety score.
    Args:
        safety_map: (22, 22) tensor of safety predictions [0, 1]
        opened_board: (22, 22) tensor marking opened cells
    """
    # Mask out opened cells by setting their scores to a very low value
    masked_safety_map = safety_map.clone()
    masked_safety_map[opened_board == 1] = -1e9

    # Find the maximum score
    max_score = masked_safety_map.max()

    # Check if all cells are opened/invalid
    if max_score < -1e8:
        return None, None

    # Find coordinates of the max score
    r_idx, c_idx = torch.where(masked_safety_map == max_score)

    # If there are ties, pick one randomly
    if len(r_idx) > 1:
        choice = torch.randint(0, len(r_idx), (1,)).item()
        return r_idx[choice].item(), c_idx[choice].item()
    else:
        return r_idx[0].item(), c_idx[0].item()

Generating Critic data for MEDIUM (5-channel approach)...


In [256]:
from torch.utils.data import Dataset

class MSDataset(Dataset):

    def __init__(self, X: torch.Tensor, y: torch.Tensor):

        self.X = X.float()

        # CRITICAL FIX: Only reshape Y if it is not scalar data (i.e., not Critic data).
        # We need to preserve the shape of scalar Critic data (N,)
        if y.ndim > 1:
            self.y = y.float().view(y.size(0), -1) # Reshape the label to match model output (N x 484)
        else:
            # If y is already a 1D tensor of scalar values (Critic data), keep it as is.
            self.y = y.float()

        self.size = X.size(-1)

    def __len__(self):
        return self.X.size(0)

    def __getitem__(self, idx):
        # 1. Retrieve the original sample
        # X is (C, 22, 22). Y is either a scalar (Critic) or (484) vector (Actor).

        # Use .clone().contiguous() to ensure data is independently owned
        X = self.X[idx].clone().contiguous()
        Y = self.y[idx].clone().contiguous()

        # --- NO AUGMENTATION LOGIC ---
        # All original steps 2, 3, 4, and the intermediate reshaping of Y are removed.

        # 2. Final Step: Ensure contiguity (though less critical without flips/rotations)
        X = X.contiguous()
        Y = Y.contiguous()

        # Y is already in the final required shape (e.g., (484) for Actor, (1) for Critic)
        # from the __init__ method, so no final reshape is needed here.

        return X, Y

In [225]:
class MSSafeSquares(nn.Module):

    def __init__(self, lin_dropout=0, cnn_dropout=0):

        super().__init__()

        self.BOARD_SIZE = 22

        cnn_layer_1_size = 32
        cnn_layer_2_size = 64
        cnn_layer_3_size = 128
        cnn_layer_4_size = 256

        lin_layer_1_size = 2000
        lin_layer_2_size = 1500
        lin_layer_3_size = 1000
        lin_layer_4_size = 500

        self.CNN = nn.Sequential(
            nn.Conv2d(in_channels=4, out_channels=cnn_layer_1_size, kernel_size=5, stride=1, padding=1),
            nn.BatchNorm2d(cnn_layer_1_size),
            nn.ReLU(),
            nn.Dropout2d(p=cnn_dropout),


            nn.Conv2d(in_channels=cnn_layer_1_size, out_channels=cnn_layer_2_size, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(cnn_layer_2_size),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Dropout2d(p=cnn_dropout),

            nn.Conv2d(in_channels=cnn_layer_2_size, out_channels=cnn_layer_3_size, kernel_size=3, stride=1),
            nn.BatchNorm2d(cnn_layer_3_size),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True),
            nn.Dropout2d(p=cnn_dropout),

            nn.Conv2d(in_channels=cnn_layer_3_size, out_channels=cnn_layer_4_size, kernel_size=3, stride=1),
            nn.BatchNorm2d(cnn_layer_4_size),
            nn.ReLU(),
            nn.Dropout2d(p=cnn_dropout),
        )

        total_count = 0

        with torch.no_grad():
            test_input = torch.zeros(1, 4, self.BOARD_SIZE, self.BOARD_SIZE)

            test_input.to(device)

            features = self.CNN(test_input)

            total_count = features.view(1, -1).size(1)

        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(in_features=total_count, out_features=lin_layer_1_size),
            nn.BatchNorm1d(lin_layer_1_size),
            nn.ReLU(),
            nn.Dropout(lin_dropout),

            nn.Linear(in_features=lin_layer_1_size, out_features=lin_layer_2_size),
            nn.BatchNorm1d(lin_layer_2_size),
            nn.ReLU(),
            nn.Dropout(lin_dropout),

            nn.Linear(in_features=lin_layer_2_size, out_features=lin_layer_3_size),
            nn.BatchNorm1d(lin_layer_3_size),
            nn.ReLU(),
            nn.Dropout(lin_dropout),

            nn.Linear(in_features=lin_layer_3_size, out_features=lin_layer_4_size),
            nn.BatchNorm1d(lin_layer_4_size),
            nn.ReLU(),
            nn.Dropout(lin_dropout),
        )

        self.output_layer = nn.Linear(in_features=lin_layer_4_size, out_features=self.BOARD_SIZE * self.BOARD_SIZE)

    def forward(self, x):

        logits = self.CNN(x)
        logits = self.classifier(logits)

        logits = self.output_layer(logits)

        return logits

def masked_bce_loss(predictions, targets: torch.Tensor, pos_weight):
    """
    BCE loss that ignores masked cells (where target == -1)
    """
    # Create mask: only compute loss where target != -1
    mask = (targets != -1.0).float()

    # Compute BCE loss
    bce = nn.BCEWithLogitsLoss(pos_weight=pos_weight, reduction='none')
    loss = bce(predictions, targets.clamp(0, 1))  # Clamp to convert -1 to 0 for loss calculation

    # Apply mask and average only over valid cells
    masked_loss = (loss * mask).sum() / mask.sum()

    return masked_loss

In [None]:
MS_X_easy = torch.load('./data/X_train_easy.pt')
MS_X_med = torch.load('./data/X_train_medium.pt')
MS_X_hard = torch.load('./data/X_train_hard.pt')

MS_Y_easy = torch.load('./data/Y_train_easy.pt')
MS_Y_med = torch.load('./data/Y_train_medium.pt')
MS_Y_hard = torch.load('./data/Y_train_hard.pt')


In [235]:
from torch.utils.data import DataLoader

def get_dataloader(X_dataset, Y_dataset, batch_size):
    total_size = len(X_dataset)

    train_size = int(0.8 * total_size)

    label_size = total_size - train_size

    X_train, X_test = torch.split(X_dataset, [train_size, label_size])
    Y_train, Y_test = torch.split(Y_dataset, [train_size, label_size])

    train_dataset = MSDataset(X_train, Y_train)
    test_dataset = MSDataset(X_test, Y_test)

    train_loader = DataLoader(
        dataset=train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=5,
        pin_memory=True
    )

    test_loader = DataLoader(
        dataset=test_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=5,
        pin_memory=True
    )

    return train_loader, test_loader

MS_train_loader_easy, MS_test_loader_easy = get_dataloader(X, y, 64)
MS_train_loader_medium, MS_test_loader_medium = get_dataloader(MS_X_med, MS_Y_med, 64)
MS_train_loader_hard,MS_test_loader_hard = get_dataloader(MS_X_hard, MS_Y_hard, 64)


MS_X_all = torch.concat([MS_X_easy, MS_X_med, MS_X_hard], dim=0)
MS_Y_all = torch.concat([MS_Y_easy, MS_Y_med, MS_Y_hard], dim=0)

MS_train_loader_all, MS_test_loader_all = get_dataloader(MS_X_all, MS_Y_all, 64)

In [226]:
def training_loop(train_loader, test_loader, hp: dict):

    epoch_over_training = []
    epoch_over_testing = []

    # Hyperparameter setup
    epochs = hp['epochs']
    learning_rate = hp['learning_rate']
    decay_rate = hp['decay_rate']

    c_dropout = hp['cnn_dropout']
    f_dropout = hp['linear_dropout']

    print('######## Beginning training for MS Safe Square Predictor ##########')

    model = MSSafeSquares(lin_dropout=f_dropout, cnn_dropout=c_dropout)
    model.to(device)

    pos_weight_tensor = torch.tensor(30.0, device=device)

    loss_function = masked_bce_loss

    optimizer = optim.AdamW(model.parameters(),
                           lr=learning_rate,
                           weight_decay=decay_rate
                           )

    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        mode='min',
        factor=0.5,
        patience=5,
        min_lr=1e-6
    )

    # Have references to variables outside of the epoch loop

    avg_training_loss = 0
    avg_testing_loss = 0


    # Epoch Loop
    for epoch in range(epochs):
        print(f'----- Epoch: {epoch + 1}/{epochs} -----')

        avg_training_loss = 0
        avg_testing_loss = 0

        model.train()

        for x, Y in tqdm(train_loader, desc='Training', unit=' batch'):
            # Transfer images to GPU
            x = x.to(device)
            Y = Y.to(device)

            # Zero out gradients
            optimizer.zero_grad()

            # Send images to model
            x_pred = model(x)

            # Calc loss
            loss = loss_function(x_pred, Y, pos_weight_tensor)

            # Calc gradient and update weights
            loss.backward()
            optimizer.step()

            with torch.no_grad():
                avg_training_loss += loss.item()

        # Switch to eval mode
        model.eval()

        with torch.no_grad():
            for x, Y in tqdm(test_loader, desc='Testing', unit=' batches'):
                # Move the images to the GPU
                x = x.to(device)
                Y = Y.to(device)

                # Get logits and sum up total loss
                x_pred = model(x)
                avg_testing_loss += loss_function(x_pred, Y, pos_weight_tensor).item()

        # Get training loss
        avg_training_loss /= len(train_loader)

         # Get testing loss
        avg_testing_loss /= len(test_loader)

        scheduler.step(avg_testing_loss)

        # Switch model back to training mode
        model.train()

        epoch_over_training.append({
            "epoch": epoch,
            "training_loss": avg_training_loss
            })

        epoch_over_testing.append({
            "epoch": epoch,
            "testing_loss": avg_testing_loss
            })


        print("")

        print(f'   -> Training Loss: {avg_training_loss: .4f}\n')
        print(f'   -> Testing Loss: {avg_testing_loss: .4f}\n')

    return model, epoch_over_training, epoch_over_testing

In [236]:
hyperparameters = {
    'epochs': 100,
    'learning_rate': 1e-2,
    'decay_rate': 1e-3,
    'cnn_dropout': 0.0,
    'linear_dropout': 0.25
}

model_easy, train, test = training_loop(MS_train_loader_easy, MS_test_loader_easy, hyperparameters)

######## Beginning training for MS Safe Square Predictor ##########
----- Epoch: 1/100 -----


Training: 100%|██████████| 1201/1201 [00:05<00:00, 224.18 batch/s]
Testing: 100%|██████████| 301/301 [00:00<00:00, 352.89 batches/s]



   -> Training Loss:  1.4672

   -> Testing Loss:  1.3955

----- Epoch: 2/100 -----


Training: 100%|██████████| 1201/1201 [00:05<00:00, 227.02 batch/s]
Testing: 100%|██████████| 301/301 [00:00<00:00, 355.77 batches/s]



   -> Training Loss:  1.3096

   -> Testing Loss:  1.2867

----- Epoch: 3/100 -----


Training: 100%|██████████| 1201/1201 [00:05<00:00, 230.50 batch/s]
Testing: 100%|██████████| 301/301 [00:00<00:00, 349.30 batches/s]



   -> Training Loss:  1.2461

   -> Testing Loss:  1.2281

----- Epoch: 4/100 -----


Training: 100%|██████████| 1201/1201 [00:05<00:00, 230.67 batch/s]
Testing: 100%|██████████| 301/301 [00:00<00:00, 350.02 batches/s]



   -> Training Loss:  1.2073

   -> Testing Loss:  1.2117

----- Epoch: 5/100 -----


Training: 100%|██████████| 1201/1201 [00:05<00:00, 226.77 batch/s]
Testing: 100%|██████████| 301/301 [00:00<00:00, 365.88 batches/s]



   -> Training Loss:  1.1920

   -> Testing Loss:  1.1985

----- Epoch: 6/100 -----


Training: 100%|██████████| 1201/1201 [00:05<00:00, 220.37 batch/s]
Testing: 100%|██████████| 301/301 [00:00<00:00, 343.16 batches/s]



   -> Training Loss:  1.1765

   -> Testing Loss:  1.1815

----- Epoch: 7/100 -----


Training: 100%|██████████| 1201/1201 [00:05<00:00, 229.31 batch/s]
Testing: 100%|██████████| 301/301 [00:00<00:00, 355.69 batches/s]



   -> Training Loss:  1.1532

   -> Testing Loss:  1.1500

----- Epoch: 8/100 -----


Training: 100%|██████████| 1201/1201 [00:05<00:00, 226.60 batch/s]
Testing: 100%|██████████| 301/301 [00:00<00:00, 358.17 batches/s]



   -> Training Loss:  1.1279

   -> Testing Loss:  1.1318

----- Epoch: 9/100 -----


Training: 100%|██████████| 1201/1201 [00:05<00:00, 222.66 batch/s]
Testing: 100%|██████████| 301/301 [00:00<00:00, 352.23 batches/s]



   -> Training Loss:  1.1120

   -> Testing Loss:  1.1223

----- Epoch: 10/100 -----


Training: 100%|██████████| 1201/1201 [00:05<00:00, 226.23 batch/s]
Testing: 100%|██████████| 301/301 [00:00<00:00, 353.23 batches/s]



   -> Training Loss:  1.1013

   -> Testing Loss:  1.1115

----- Epoch: 11/100 -----


Training: 100%|██████████| 1201/1201 [00:05<00:00, 224.58 batch/s]
Testing: 100%|██████████| 301/301 [00:00<00:00, 344.59 batches/s]



   -> Training Loss:  1.0890

   -> Testing Loss:  1.1006

----- Epoch: 12/100 -----


Training: 100%|██████████| 1201/1201 [00:05<00:00, 223.77 batch/s]
Testing: 100%|██████████| 301/301 [00:00<00:00, 353.26 batches/s]



   -> Training Loss:  1.0791

   -> Testing Loss:  1.0977

----- Epoch: 13/100 -----


Training: 100%|██████████| 1201/1201 [00:05<00:00, 226.41 batch/s]
Testing: 100%|██████████| 301/301 [00:00<00:00, 351.35 batches/s]



   -> Training Loss:  1.0730

   -> Testing Loss:  1.0956

----- Epoch: 14/100 -----


Training: 100%|██████████| 1201/1201 [00:05<00:00, 220.26 batch/s]
Testing: 100%|██████████| 301/301 [00:00<00:00, 357.66 batches/s]



   -> Training Loss:  1.0696

   -> Testing Loss:  1.0849

----- Epoch: 15/100 -----


Training: 100%|██████████| 1201/1201 [00:05<00:00, 222.95 batch/s]
Testing: 100%|██████████| 301/301 [00:00<00:00, 347.51 batches/s]



   -> Training Loss:  1.0594

   -> Testing Loss:  1.0798

----- Epoch: 16/100 -----


Training: 100%|██████████| 1201/1201 [00:05<00:00, 221.46 batch/s]
Testing: 100%|██████████| 301/301 [00:00<00:00, 353.32 batches/s]



   -> Training Loss:  1.0554

   -> Testing Loss:  1.0795

----- Epoch: 17/100 -----


Training: 100%|██████████| 1201/1201 [00:05<00:00, 225.36 batch/s]
Testing: 100%|██████████| 301/301 [00:00<00:00, 351.99 batches/s]



   -> Training Loss:  1.0525

   -> Testing Loss:  1.0815

----- Epoch: 18/100 -----


Training: 100%|██████████| 1201/1201 [00:05<00:00, 223.32 batch/s]
Testing: 100%|██████████| 301/301 [00:00<00:00, 346.95 batches/s]



   -> Training Loss:  1.0485

   -> Testing Loss:  1.0672

----- Epoch: 19/100 -----


Training: 100%|██████████| 1201/1201 [00:05<00:00, 229.31 batch/s]
Testing: 100%|██████████| 301/301 [00:00<00:00, 350.62 batches/s]



   -> Training Loss:  1.0435

   -> Testing Loss:  1.0645

----- Epoch: 20/100 -----


Training: 100%|██████████| 1201/1201 [00:05<00:00, 228.81 batch/s]
Testing: 100%|██████████| 301/301 [00:00<00:00, 355.70 batches/s]



   -> Training Loss:  1.0379

   -> Testing Loss:  1.0721

----- Epoch: 21/100 -----


Training: 100%|██████████| 1201/1201 [00:05<00:00, 227.33 batch/s]
Testing: 100%|██████████| 301/301 [00:00<00:00, 345.03 batches/s]



   -> Training Loss:  1.0364

   -> Testing Loss:  1.0568

----- Epoch: 22/100 -----


Training: 100%|██████████| 1201/1201 [00:05<00:00, 223.28 batch/s]
Testing: 100%|██████████| 301/301 [00:00<00:00, 349.49 batches/s]



   -> Training Loss:  1.0325

   -> Testing Loss:  1.0519

----- Epoch: 23/100 -----


Training: 100%|██████████| 1201/1201 [00:05<00:00, 226.05 batch/s]
Testing: 100%|██████████| 301/301 [00:00<00:00, 343.52 batches/s]



   -> Training Loss:  1.0278

   -> Testing Loss:  1.0515

----- Epoch: 24/100 -----


Training: 100%|██████████| 1201/1201 [00:05<00:00, 223.04 batch/s]
Testing: 100%|██████████| 301/301 [00:00<00:00, 355.89 batches/s]



   -> Training Loss:  1.0237

   -> Testing Loss:  1.0501

----- Epoch: 25/100 -----


Training: 100%|██████████| 1201/1201 [00:05<00:00, 223.49 batch/s]
Testing: 100%|██████████| 301/301 [00:00<00:00, 349.98 batches/s]



   -> Training Loss:  1.0212

   -> Testing Loss:  1.0501

----- Epoch: 26/100 -----


Training: 100%|██████████| 1201/1201 [00:05<00:00, 222.67 batch/s]
Testing: 100%|██████████| 301/301 [00:00<00:00, 348.64 batches/s]



   -> Training Loss:  1.0190

   -> Testing Loss:  1.0421

----- Epoch: 27/100 -----


Training: 100%|██████████| 1201/1201 [00:05<00:00, 224.20 batch/s]
Testing: 100%|██████████| 301/301 [00:00<00:00, 343.62 batches/s]



   -> Training Loss:  1.0152

   -> Testing Loss:  1.0413

----- Epoch: 28/100 -----


Training: 100%|██████████| 1201/1201 [00:05<00:00, 222.97 batch/s]
Testing: 100%|██████████| 301/301 [00:00<00:00, 356.34 batches/s]



   -> Training Loss:  1.0108

   -> Testing Loss:  1.0346

----- Epoch: 29/100 -----


Training: 100%|██████████| 1201/1201 [00:05<00:00, 225.40 batch/s]
Testing: 100%|██████████| 301/301 [00:00<00:00, 352.64 batches/s]



   -> Training Loss:  1.0094

   -> Testing Loss:  1.0325

----- Epoch: 30/100 -----


Training: 100%|██████████| 1201/1201 [00:05<00:00, 221.12 batch/s]
Testing: 100%|██████████| 301/301 [00:00<00:00, 348.75 batches/s]



   -> Training Loss:  1.0068

   -> Testing Loss:  1.0353

----- Epoch: 31/100 -----


Training: 100%|██████████| 1201/1201 [00:05<00:00, 224.52 batch/s]
Testing: 100%|██████████| 301/301 [00:00<00:00, 346.71 batches/s]



   -> Training Loss:  1.0042

   -> Testing Loss:  1.0279

----- Epoch: 32/100 -----


Training: 100%|██████████| 1201/1201 [00:05<00:00, 226.72 batch/s]
Testing: 100%|██████████| 301/301 [00:00<00:00, 352.37 batches/s]



   -> Training Loss:  1.0019

   -> Testing Loss:  1.0280

----- Epoch: 33/100 -----


Training: 100%|██████████| 1201/1201 [00:05<00:00, 224.47 batch/s]
Testing: 100%|██████████| 301/301 [00:00<00:00, 353.26 batches/s]



   -> Training Loss:  0.9996

   -> Testing Loss:  1.0282

----- Epoch: 34/100 -----


Training: 100%|██████████| 1201/1201 [00:05<00:00, 224.19 batch/s]
Testing: 100%|██████████| 301/301 [00:00<00:00, 350.54 batches/s]



   -> Training Loss:  0.9986

   -> Testing Loss:  1.0226

----- Epoch: 35/100 -----


Training: 100%|██████████| 1201/1201 [00:05<00:00, 222.20 batch/s]
Testing: 100%|██████████| 301/301 [00:00<00:00, 340.92 batches/s]



   -> Training Loss:  0.9962

   -> Testing Loss:  1.0219

----- Epoch: 36/100 -----


Training: 100%|██████████| 1201/1201 [00:05<00:00, 219.38 batch/s]
Testing: 100%|██████████| 301/301 [00:00<00:00, 351.87 batches/s]



   -> Training Loss:  0.9948

   -> Testing Loss:  1.0207

----- Epoch: 37/100 -----


Training: 100%|██████████| 1201/1201 [00:05<00:00, 223.22 batch/s]
Testing: 100%|██████████| 301/301 [00:00<00:00, 348.18 batches/s]



   -> Training Loss:  0.9927

   -> Testing Loss:  1.0137

----- Epoch: 38/100 -----


Training: 100%|██████████| 1201/1201 [00:05<00:00, 224.14 batch/s]
Testing: 100%|██████████| 301/301 [00:00<00:00, 347.63 batches/s]



   -> Training Loss:  0.9906

   -> Testing Loss:  1.0122

----- Epoch: 39/100 -----


Training: 100%|██████████| 1201/1201 [00:05<00:00, 222.36 batch/s]
Testing: 100%|██████████| 301/301 [00:00<00:00, 351.54 batches/s]



   -> Training Loss:  0.9900

   -> Testing Loss:  1.0113

----- Epoch: 40/100 -----


Training: 100%|██████████| 1201/1201 [00:05<00:00, 226.41 batch/s]
Testing: 100%|██████████| 301/301 [00:00<00:00, 356.97 batches/s]



   -> Training Loss:  0.9871

   -> Testing Loss:  1.0122

----- Epoch: 41/100 -----


Training: 100%|██████████| 1201/1201 [00:05<00:00, 223.60 batch/s]
Testing: 100%|██████████| 301/301 [00:00<00:00, 350.21 batches/s]



   -> Training Loss:  0.9886

   -> Testing Loss:  1.0110

----- Epoch: 42/100 -----


Training: 100%|██████████| 1201/1201 [00:05<00:00, 223.83 batch/s]
Testing: 100%|██████████| 301/301 [00:00<00:00, 349.72 batches/s]



   -> Training Loss:  0.9851

   -> Testing Loss:  1.0089

----- Epoch: 43/100 -----


Training: 100%|██████████| 1201/1201 [00:05<00:00, 227.82 batch/s]
Testing: 100%|██████████| 301/301 [00:00<00:00, 351.72 batches/s]



   -> Training Loss:  0.9831

   -> Testing Loss:  1.0017

----- Epoch: 44/100 -----


Training: 100%|██████████| 1201/1201 [00:05<00:00, 225.42 batch/s]
Testing: 100%|██████████| 301/301 [00:00<00:00, 350.40 batches/s]



   -> Training Loss:  0.9830

   -> Testing Loss:  1.0117

----- Epoch: 45/100 -----


Training: 100%|██████████| 1201/1201 [00:05<00:00, 224.09 batch/s]
Testing: 100%|██████████| 301/301 [00:00<00:00, 347.87 batches/s]



   -> Training Loss:  0.9813

   -> Testing Loss:  1.0081

----- Epoch: 46/100 -----


Training: 100%|██████████| 1201/1201 [00:05<00:00, 221.27 batch/s]
Testing: 100%|██████████| 301/301 [00:00<00:00, 347.50 batches/s]



   -> Training Loss:  0.9804

   -> Testing Loss:  0.9996

----- Epoch: 47/100 -----


Training: 100%|██████████| 1201/1201 [00:05<00:00, 221.67 batch/s]
Testing: 100%|██████████| 301/301 [00:00<00:00, 354.27 batches/s]



   -> Training Loss:  0.9779

   -> Testing Loss:  0.9979

----- Epoch: 48/100 -----


Training: 100%|██████████| 1201/1201 [00:05<00:00, 228.43 batch/s]
Testing: 100%|██████████| 301/301 [00:00<00:00, 347.57 batches/s]



   -> Training Loss:  0.9776

   -> Testing Loss:  0.9905

----- Epoch: 49/100 -----


Training: 100%|██████████| 1201/1201 [00:05<00:00, 222.44 batch/s]
Testing: 100%|██████████| 301/301 [00:00<00:00, 349.72 batches/s]



   -> Training Loss:  0.9755

   -> Testing Loss:  0.9889

----- Epoch: 50/100 -----


Training: 100%|██████████| 1201/1201 [00:05<00:00, 228.59 batch/s]
Testing: 100%|██████████| 301/301 [00:00<00:00, 350.65 batches/s]



   -> Training Loss:  0.9748

   -> Testing Loss:  0.9905

----- Epoch: 51/100 -----


Training: 100%|██████████| 1201/1201 [00:05<00:00, 221.81 batch/s]
Testing: 100%|██████████| 301/301 [00:00<00:00, 340.36 batches/s]



   -> Training Loss:  0.9728

   -> Testing Loss:  0.9865

----- Epoch: 52/100 -----


Training: 100%|██████████| 1201/1201 [00:05<00:00, 221.57 batch/s]
Testing: 100%|██████████| 301/301 [00:00<00:00, 348.92 batches/s]



   -> Training Loss:  0.9705

   -> Testing Loss:  0.9766

----- Epoch: 53/100 -----


Training: 100%|██████████| 1201/1201 [00:05<00:00, 223.15 batch/s]
Testing: 100%|██████████| 301/301 [00:00<00:00, 351.16 batches/s]



   -> Training Loss:  0.9701

   -> Testing Loss:  0.9886

----- Epoch: 54/100 -----


Training: 100%|██████████| 1201/1201 [00:05<00:00, 221.17 batch/s]
Testing: 100%|██████████| 301/301 [00:00<00:00, 353.62 batches/s]



   -> Training Loss:  0.9697

   -> Testing Loss:  0.9885

----- Epoch: 55/100 -----


Training: 100%|██████████| 1201/1201 [00:05<00:00, 219.66 batch/s]
Testing: 100%|██████████| 301/301 [00:00<00:00, 347.52 batches/s]



   -> Training Loss:  0.9683

   -> Testing Loss:  0.9806

----- Epoch: 56/100 -----


Training: 100%|██████████| 1201/1201 [00:05<00:00, 225.26 batch/s]
Testing: 100%|██████████| 301/301 [00:00<00:00, 355.36 batches/s]



   -> Training Loss:  0.9694

   -> Testing Loss:  0.9799

----- Epoch: 57/100 -----


Training: 100%|██████████| 1201/1201 [00:05<00:00, 224.00 batch/s]
Testing: 100%|██████████| 301/301 [00:00<00:00, 346.45 batches/s]



   -> Training Loss:  0.9649

   -> Testing Loss:  0.9862

----- Epoch: 58/100 -----


Training: 100%|██████████| 1201/1201 [00:05<00:00, 222.20 batch/s]
Testing: 100%|██████████| 301/301 [00:00<00:00, 350.15 batches/s]



   -> Training Loss:  0.9650

   -> Testing Loss:  0.9759

----- Epoch: 59/100 -----


Training: 100%|██████████| 1201/1201 [00:05<00:00, 221.82 batch/s]
Testing: 100%|██████████| 301/301 [00:00<00:00, 348.20 batches/s]



   -> Training Loss:  0.9665

   -> Testing Loss:  0.9771

----- Epoch: 60/100 -----


Training: 100%|██████████| 1201/1201 [00:05<00:00, 223.20 batch/s]
Testing: 100%|██████████| 301/301 [00:00<00:00, 349.67 batches/s]



   -> Training Loss:  0.9657

   -> Testing Loss:  0.9750

----- Epoch: 61/100 -----


Training: 100%|██████████| 1201/1201 [00:05<00:00, 222.48 batch/s]
Testing: 100%|██████████| 301/301 [00:00<00:00, 353.16 batches/s]



   -> Training Loss:  0.9635

   -> Testing Loss:  0.9732

----- Epoch: 62/100 -----


Training: 100%|██████████| 1201/1201 [00:05<00:00, 221.71 batch/s]
Testing: 100%|██████████| 301/301 [00:00<00:00, 338.28 batches/s]



   -> Training Loss:  0.9623

   -> Testing Loss:  0.9738

----- Epoch: 63/100 -----


Training: 100%|██████████| 1201/1201 [00:05<00:00, 222.34 batch/s]
Testing: 100%|██████████| 301/301 [00:00<00:00, 346.83 batches/s]



   -> Training Loss:  0.9625

   -> Testing Loss:  0.9735

----- Epoch: 64/100 -----


Training: 100%|██████████| 1201/1201 [00:05<00:00, 223.03 batch/s]
Testing: 100%|██████████| 301/301 [00:00<00:00, 343.59 batches/s]



   -> Training Loss:  0.9614

   -> Testing Loss:  0.9750

----- Epoch: 65/100 -----


Training: 100%|██████████| 1201/1201 [00:05<00:00, 221.20 batch/s]
Testing: 100%|██████████| 301/301 [00:00<00:00, 350.40 batches/s]



   -> Training Loss:  0.9620

   -> Testing Loss:  0.9644

----- Epoch: 66/100 -----


Training: 100%|██████████| 1201/1201 [00:05<00:00, 221.49 batch/s]
Testing: 100%|██████████| 301/301 [00:00<00:00, 355.48 batches/s]



   -> Training Loss:  0.9606

   -> Testing Loss:  0.9625

----- Epoch: 67/100 -----


Training: 100%|██████████| 1201/1201 [00:05<00:00, 220.70 batch/s]
Testing: 100%|██████████| 301/301 [00:00<00:00, 349.50 batches/s]



   -> Training Loss:  0.9585

   -> Testing Loss:  0.9675

----- Epoch: 68/100 -----


Training: 100%|██████████| 1201/1201 [00:05<00:00, 219.33 batch/s]
Testing: 100%|██████████| 301/301 [00:00<00:00, 349.77 batches/s]



   -> Training Loss:  0.9590

   -> Testing Loss:  0.9684

----- Epoch: 69/100 -----


Training: 100%|██████████| 1201/1201 [00:05<00:00, 220.24 batch/s]
Testing: 100%|██████████| 301/301 [00:00<00:00, 353.90 batches/s]



   -> Training Loss:  0.9598

   -> Testing Loss:  0.9624

----- Epoch: 70/100 -----


Training: 100%|██████████| 1201/1201 [00:05<00:00, 218.78 batch/s]
Testing: 100%|██████████| 301/301 [00:00<00:00, 352.30 batches/s]



   -> Training Loss:  0.9577

   -> Testing Loss:  0.9670

----- Epoch: 71/100 -----


Training: 100%|██████████| 1201/1201 [00:05<00:00, 223.23 batch/s]
Testing: 100%|██████████| 301/301 [00:00<00:00, 345.32 batches/s]



   -> Training Loss:  0.9564

   -> Testing Loss:  0.9595

----- Epoch: 72/100 -----


Training: 100%|██████████| 1201/1201 [00:05<00:00, 220.17 batch/s]
Testing: 100%|██████████| 301/301 [00:00<00:00, 342.34 batches/s]



   -> Training Loss:  0.9578

   -> Testing Loss:  0.9666

----- Epoch: 73/100 -----


Training: 100%|██████████| 1201/1201 [00:05<00:00, 224.15 batch/s]
Testing: 100%|██████████| 301/301 [00:00<00:00, 352.18 batches/s]



   -> Training Loss:  0.9553

   -> Testing Loss:  0.9639

----- Epoch: 74/100 -----


Training: 100%|██████████| 1201/1201 [00:05<00:00, 222.24 batch/s]
Testing: 100%|██████████| 301/301 [00:00<00:00, 356.28 batches/s]



   -> Training Loss:  0.9554

   -> Testing Loss:  0.9678

----- Epoch: 75/100 -----


Training: 100%|██████████| 1201/1201 [00:05<00:00, 222.76 batch/s]
Testing: 100%|██████████| 301/301 [00:00<00:00, 340.15 batches/s]



   -> Training Loss:  0.9556

   -> Testing Loss:  0.9610

----- Epoch: 76/100 -----


Training: 100%|██████████| 1201/1201 [00:05<00:00, 219.40 batch/s]
Testing: 100%|██████████| 301/301 [00:00<00:00, 352.97 batches/s]



   -> Training Loss:  0.9527

   -> Testing Loss:  0.9681

----- Epoch: 77/100 -----


Training: 100%|██████████| 1201/1201 [00:05<00:00, 226.65 batch/s]
Testing: 100%|██████████| 301/301 [00:00<00:00, 351.85 batches/s]



   -> Training Loss:  0.9534

   -> Testing Loss:  0.9639

----- Epoch: 78/100 -----


Training: 100%|██████████| 1201/1201 [00:05<00:00, 221.86 batch/s]
Testing: 100%|██████████| 301/301 [00:00<00:00, 340.15 batches/s]



   -> Training Loss:  0.9335

   -> Testing Loss:  0.9434

----- Epoch: 79/100 -----


Training: 100%|██████████| 1201/1201 [00:05<00:00, 219.81 batch/s]
Testing: 100%|██████████| 301/301 [00:00<00:00, 350.70 batches/s]



   -> Training Loss:  0.9283

   -> Testing Loss:  0.9385

----- Epoch: 80/100 -----


Training: 100%|██████████| 1201/1201 [00:05<00:00, 221.41 batch/s]
Testing: 100%|██████████| 301/301 [00:00<00:00, 342.73 batches/s]



   -> Training Loss:  0.9262

   -> Testing Loss:  0.9398

----- Epoch: 81/100 -----


Training: 100%|██████████| 1201/1201 [00:05<00:00, 225.37 batch/s]
Testing: 100%|██████████| 301/301 [00:00<00:00, 348.57 batches/s]



   -> Training Loss:  0.9240

   -> Testing Loss:  0.9387

----- Epoch: 82/100 -----


Training: 100%|██████████| 1201/1201 [00:05<00:00, 226.33 batch/s]
Testing: 100%|██████████| 301/301 [00:00<00:00, 345.77 batches/s]



   -> Training Loss:  0.9228

   -> Testing Loss:  0.9355

----- Epoch: 83/100 -----


Training: 100%|██████████| 1201/1201 [00:05<00:00, 226.06 batch/s]
Testing: 100%|██████████| 301/301 [00:00<00:00, 344.70 batches/s]



   -> Training Loss:  0.9211

   -> Testing Loss:  0.9341

----- Epoch: 84/100 -----


Training: 100%|██████████| 1201/1201 [00:05<00:00, 218.10 batch/s]
Testing: 100%|██████████| 301/301 [00:00<00:00, 339.68 batches/s]



   -> Training Loss:  0.9199

   -> Testing Loss:  0.9344

----- Epoch: 85/100 -----


Training: 100%|██████████| 1201/1201 [00:05<00:00, 223.54 batch/s]
Testing: 100%|██████████| 301/301 [00:00<00:00, 329.38 batches/s]



   -> Training Loss:  0.9181

   -> Testing Loss:  0.9357

----- Epoch: 86/100 -----


Training: 100%|██████████| 1201/1201 [00:05<00:00, 221.57 batch/s]
Testing: 100%|██████████| 301/301 [00:00<00:00, 342.83 batches/s]



   -> Training Loss:  0.9177

   -> Testing Loss:  0.9335

----- Epoch: 87/100 -----


Training: 100%|██████████| 1201/1201 [00:05<00:00, 221.61 batch/s]
Testing: 100%|██████████| 301/301 [00:00<00:00, 344.01 batches/s]



   -> Training Loss:  0.9167

   -> Testing Loss:  0.9307

----- Epoch: 88/100 -----


Training: 100%|██████████| 1201/1201 [00:05<00:00, 223.24 batch/s]
Testing: 100%|██████████| 301/301 [00:00<00:00, 354.57 batches/s]



   -> Training Loss:  0.9160

   -> Testing Loss:  0.9312

----- Epoch: 89/100 -----


Training: 100%|██████████| 1201/1201 [00:05<00:00, 224.15 batch/s]
Testing: 100%|██████████| 301/301 [00:00<00:00, 345.15 batches/s]



   -> Training Loss:  0.9151

   -> Testing Loss:  0.9283

----- Epoch: 90/100 -----


Training: 100%|██████████| 1201/1201 [00:05<00:00, 221.96 batch/s]
Testing: 100%|██████████| 301/301 [00:00<00:00, 358.24 batches/s]



   -> Training Loss:  0.9145

   -> Testing Loss:  0.9312

----- Epoch: 91/100 -----


Training: 100%|██████████| 1201/1201 [00:05<00:00, 224.59 batch/s]
Testing: 100%|██████████| 301/301 [00:00<00:00, 345.64 batches/s]



   -> Training Loss:  0.9133

   -> Testing Loss:  0.9342

----- Epoch: 92/100 -----


Training: 100%|██████████| 1201/1201 [00:05<00:00, 220.48 batch/s]
Testing: 100%|██████████| 301/301 [00:00<00:00, 348.79 batches/s]



   -> Training Loss:  0.9125

   -> Testing Loss:  0.9317

----- Epoch: 93/100 -----


Training: 100%|██████████| 1201/1201 [00:05<00:00, 221.80 batch/s]
Testing: 100%|██████████| 301/301 [00:00<00:00, 348.42 batches/s]



   -> Training Loss:  0.9117

   -> Testing Loss:  0.9331

----- Epoch: 94/100 -----


Training: 100%|██████████| 1201/1201 [00:05<00:00, 219.90 batch/s]
Testing: 100%|██████████| 301/301 [00:00<00:00, 350.95 batches/s]



   -> Training Loss:  0.9111

   -> Testing Loss:  0.9292

----- Epoch: 95/100 -----


Training: 100%|██████████| 1201/1201 [00:05<00:00, 223.34 batch/s]
Testing: 100%|██████████| 301/301 [00:00<00:00, 358.50 batches/s]



   -> Training Loss:  0.9109

   -> Testing Loss:  0.9293

----- Epoch: 96/100 -----


Training: 100%|██████████| 1201/1201 [00:05<00:00, 224.81 batch/s]
Testing: 100%|██████████| 301/301 [00:00<00:00, 340.56 batches/s]



   -> Training Loss:  0.9017

   -> Testing Loss:  0.9194

----- Epoch: 97/100 -----


Training: 100%|██████████| 1201/1201 [00:05<00:00, 223.34 batch/s]
Testing: 100%|██████████| 301/301 [00:00<00:00, 354.73 batches/s]



   -> Training Loss:  0.8985

   -> Testing Loss:  0.9182

----- Epoch: 98/100 -----


Training: 100%|██████████| 1201/1201 [00:05<00:00, 219.02 batch/s]
Testing: 100%|██████████| 301/301 [00:00<00:00, 351.94 batches/s]



   -> Training Loss:  0.8973

   -> Testing Loss:  0.9192

----- Epoch: 99/100 -----


Training: 100%|██████████| 1201/1201 [00:05<00:00, 225.45 batch/s]
Testing: 100%|██████████| 301/301 [00:00<00:00, 354.73 batches/s]



   -> Training Loss:  0.8964

   -> Testing Loss:  0.9167

----- Epoch: 100/100 -----


Training: 100%|██████████| 1201/1201 [00:05<00:00, 222.18 batch/s]
Testing: 100%|██████████| 301/301 [00:00<00:00, 345.78 batches/s]


   -> Training Loss:  0.8954

   -> Testing Loss:  0.9151






In [241]:
avg_moves = 0

num_of_games = 100

for i in range(num_of_games):
    results = start_nn_game(MSGameManager(), model_easy, 'easy', device)
    avg_moves += results['moves_taken']

avg_moves /= num_of_games

print(avg_moves)

6.16


In [279]:
class Critic(nn.Module):

    def __init__(self, linear_dropout=0):
        super().__init__()

        self.BOARD_SIZE = 22

        cnn_layer_1_size = 32
        cnn_layer_2_size = 32
        cnn_layer_3_size = 64
        cnn_layer_4_size = 64

        self.CNN = nn.Sequential(
            nn.Conv2d(in_channels=5, out_channels=cnn_layer_1_size, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(num_features=cnn_layer_1_size),
            nn.ReLU(),

            nn.Conv2d(in_channels=cnn_layer_1_size, out_channels=cnn_layer_2_size, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(num_features=cnn_layer_2_size),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),

            nn.Conv2d(in_channels=cnn_layer_2_size, out_channels=cnn_layer_3_size, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(num_features=cnn_layer_3_size),
            nn.ReLU(),

            nn.Conv2d(in_channels=cnn_layer_3_size, out_channels=cnn_layer_4_size, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(num_features=cnn_layer_4_size),
            nn.ReLU(),
        )

        total_count = 0

        with torch.no_grad():
            test_input = torch.zeros(1, 5, self.BOARD_SIZE, self.BOARD_SIZE)

            test_input.to(device)

            features = self.CNN(test_input)

            total_count = features.view(1, -1).size(1)

        lin_layer_1_size = 2000
        lin_layer_2_size = 2000
        lin_layer_3_size = 1500
        lin_layer_4_size = 500


        self.Classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(in_features=total_count, out_features=lin_layer_1_size),
            nn.BatchNorm1d(num_features=lin_layer_1_size),
            nn.ReLU(inplace=True),
            nn.Dropout(p=linear_dropout),

            nn.Linear(in_features=lin_layer_1_size, out_features=lin_layer_2_size),
            nn.BatchNorm1d(num_features=lin_layer_2_size),
            nn.ReLU(inplace=True),
            nn.Dropout(p=linear_dropout),

            nn.Linear(in_features=lin_layer_2_size, out_features=lin_layer_3_size),
            nn.BatchNorm1d(num_features=lin_layer_3_size),
            nn.ReLU(inplace=True),
            nn.Dropout(p=linear_dropout),

            nn.Linear(in_features=lin_layer_3_size, out_features=lin_layer_4_size),
            nn.BatchNorm1d(num_features=lin_layer_4_size),
            nn.ReLU(inplace=True),
            nn.Dropout(p=linear_dropout),
        )

        self.output_layer = nn.Linear(in_features=lin_layer_4_size, out_features=1)

    def forward(self, x):

        logits = self.CNN(x)
        logits = self.Classifier(logits)

        logits = self.output_layer(logits)

        return logits

class Actor(nn.Module):

    def __init__(self, linear_dropout=0):
        super().__init__()

        self.BOARD_SIZE = 22

        cnn_layer_1_size = 32
        cnn_layer_2_size = 32
        cnn_layer_3_size = 64
        cnn_layer_4_size = 64

        self.CNN = nn.Sequential(
            nn.Conv2d(in_channels=4, out_channels=cnn_layer_1_size, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(num_features=cnn_layer_1_size),
            nn.ReLU(),

            nn.Conv2d(in_channels=cnn_layer_1_size, out_channels=cnn_layer_2_size, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(num_features=cnn_layer_2_size),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),

            nn.Conv2d(in_channels=cnn_layer_2_size, out_channels=cnn_layer_3_size, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(num_features=cnn_layer_3_size),
            nn.ReLU(),

            nn.Conv2d(in_channels=cnn_layer_3_size, out_channels=cnn_layer_4_size, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(num_features=cnn_layer_4_size),
            nn.ReLU(),
        )

        total_count = 0

        with torch.no_grad():
            test_input = torch.zeros(1, 4, self.BOARD_SIZE, self.BOARD_SIZE)

            test_input.to(device)

            features = self.CNN(test_input)

            total_count = features.view(1, -1).size(1)

        lin_layer_1_size = 2000
        lin_layer_2_size = 2000
        lin_layer_3_size = 1500
        lin_layer_4_size = 500


        self.Classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(in_features=total_count, out_features=lin_layer_1_size),
            nn.ReLU(inplace=True),
            nn.Dropout(p=linear_dropout),

            nn.Linear(in_features=lin_layer_1_size, out_features=lin_layer_2_size),
            nn.ReLU(inplace=True),
            nn.Dropout(p=linear_dropout),

            nn.Linear(in_features=lin_layer_2_size, out_features=lin_layer_3_size),
            nn.ReLU(inplace=True),
            nn.Dropout(p=linear_dropout),

            nn.Linear(in_features=lin_layer_3_size, out_features=lin_layer_4_size),
            nn.ReLU(inplace=True),
            nn.Dropout(p=linear_dropout),
        )

        self.output_layer = nn.Linear(in_features=lin_layer_4_size, out_features=self.BOARD_SIZE * self.BOARD_SIZE)

    def forward(self, x):

        logits = self.CNN(x)
        logits = self.Classifier(logits)

        logits = self.output_layer(logits)

        return logits


In [280]:
hp_a = {
    'epochs': 5,
    'learning_rate': 1e-2,
    'decay_rate': 1e-3,
    'cnn_dropout': 0.0,
    'linear_dropout': 0.25
}

hp_c = {
    'epochs': 5,
    'learning_rate': 1e-4,
    'decay_rate': 1e-3,
    'cnn_dropout': 0.0,
    'linear_dropout': 0.25
}

manager = MSGameManager()

critic, actor = manager.critic_and_actor_training_loop(hp_a=hp_a, hp_c=hp_c)

Generating Critic data for MEDIUM (5-channel approach)...
######## Beginning training for MS Critic ##########
----- Epoch: 1/5 -----


Training: 100%|██████████| 12/12 [00:00<00:00, 27.34 batch/s]
Testing: 100%|██████████| 3/3 [00:00<00:00,  8.28 batches/s]



   -> Training Loss:  12417.3310

   -> Testing Loss:  9302.8196

----- Epoch: 2/5 -----


Training: 100%|██████████| 12/12 [00:00<00:00, 26.65 batch/s]
Testing: 100%|██████████| 3/3 [00:00<00:00,  8.06 batches/s]



   -> Training Loss:  12090.2078

   -> Testing Loss:  8511.6816

----- Epoch: 3/5 -----


Training: 100%|██████████| 12/12 [00:00<00:00, 27.06 batch/s]
Testing: 100%|██████████| 3/3 [00:00<00:00,  8.26 batches/s]



   -> Training Loss:  12030.1906

   -> Testing Loss:  8510.4564

----- Epoch: 4/5 -----


Training: 100%|██████████| 12/12 [00:00<00:00, 27.41 batch/s]
Testing: 100%|██████████| 3/3 [00:00<00:00,  8.28 batches/s]



   -> Training Loss:  11873.5893

   -> Testing Loss:  8632.2128

----- Epoch: 5/5 -----


Training: 100%|██████████| 12/12 [00:00<00:00, 27.09 batch/s]
Testing: 100%|██████████| 3/3 [00:00<00:00,  8.28 batches/s]



   -> Training Loss:  11785.3199

   -> Testing Loss:  8640.1350

Generating Actor training data for MEDIUM...
######## Beginning training for MS Actor ##########
----- Epoch: 1/5 -----


Training: 100%|██████████| 1/1 [00:00<00:00,  2.73 batch/s]
Testing: 100%|██████████| 1/1 [00:00<00:00,  2.75 batches/s]



   -> Training Loss:  20.7987

   -> Testing Loss:  0.0000

----- Epoch: 2/5 -----


Training: 100%|██████████| 1/1 [00:00<00:00,  2.74 batch/s]
Testing: 100%|██████████| 1/1 [00:00<00:00,  2.78 batches/s]



   -> Training Loss:  0.0000

   -> Testing Loss:  0.0000

----- Epoch: 3/5 -----


Training: 100%|██████████| 1/1 [00:00<00:00,  2.71 batch/s]
Testing: 100%|██████████| 1/1 [00:00<00:00,  2.76 batches/s]



   -> Training Loss:  0.0000

   -> Testing Loss:  0.0000

----- Epoch: 4/5 -----


Training: 100%|██████████| 1/1 [00:00<00:00,  2.72 batch/s]
Testing: 100%|██████████| 1/1 [00:00<00:00,  2.79 batches/s]



   -> Training Loss:  0.0000

   -> Testing Loss:  0.0000

----- Epoch: 5/5 -----


Training: 100%|██████████| 1/1 [00:00<00:00,  2.70 batch/s]
Testing: 100%|██████████| 1/1 [00:00<00:00,  2.72 batches/s]



   -> Training Loss:  0.0000

   -> Testing Loss:  0.0000

Generating Critic data for MEDIUM (5-channel approach)...
######## Beginning training for MS Critic ##########
----- Epoch: 1/5 -----


Training: 100%|██████████| 12/12 [00:00<00:00, 26.98 batch/s]
Testing: 100%|██████████| 3/3 [00:00<00:00,  8.24 batches/s]



   -> Training Loss:  14639.4238

   -> Testing Loss:  3744.6535

----- Epoch: 2/5 -----


Training: 100%|██████████| 12/12 [00:00<00:00, 26.74 batch/s]
Testing: 100%|██████████| 3/3 [00:00<00:00,  8.06 batches/s]



   -> Training Loss:  14423.3140

   -> Testing Loss:  3391.7636

----- Epoch: 3/5 -----


Training: 100%|██████████| 12/12 [00:00<00:00, 27.20 batch/s]
Testing: 100%|██████████| 3/3 [00:00<00:00,  8.14 batches/s]



   -> Training Loss:  14335.0847

   -> Testing Loss:  3420.0146

----- Epoch: 4/5 -----


Training: 100%|██████████| 12/12 [00:00<00:00, 27.02 batch/s]
Testing: 100%|██████████| 3/3 [00:00<00:00,  8.28 batches/s]



   -> Training Loss:  14340.2143

   -> Testing Loss:  3495.1217

----- Epoch: 5/5 -----


Training: 100%|██████████| 12/12 [00:00<00:00, 26.85 batch/s]
Testing: 100%|██████████| 3/3 [00:00<00:00,  8.20 batches/s]



   -> Training Loss:  14326.2770

   -> Testing Loss:  3512.9125

Generating Actor training data for MEDIUM...
######## Beginning training for MS Actor ##########
----- Epoch: 1/5 -----


Training: 100%|██████████| 1/1 [00:00<00:00,  2.52 batch/s]
Testing: 100%|██████████| 1/1 [00:00<00:00,  2.75 batches/s]



   -> Training Loss:  20.7902

   -> Testing Loss:  0.0000

----- Epoch: 2/5 -----


Training: 100%|██████████| 1/1 [00:00<00:00,  2.68 batch/s]
Testing: 100%|██████████| 1/1 [00:00<00:00,  2.75 batches/s]



   -> Training Loss:  0.0000

   -> Testing Loss:  0.0000

----- Epoch: 3/5 -----


Training: 100%|██████████| 1/1 [00:00<00:00,  2.72 batch/s]
Testing: 100%|██████████| 1/1 [00:00<00:00,  2.73 batches/s]



   -> Training Loss:  0.0000

   -> Testing Loss:  0.0000

----- Epoch: 4/5 -----


Training: 100%|██████████| 1/1 [00:00<00:00,  2.71 batch/s]
Testing: 100%|██████████| 1/1 [00:00<00:00,  2.72 batches/s]



   -> Training Loss:  0.0000

   -> Testing Loss:  0.0000

----- Epoch: 5/5 -----


Training: 100%|██████████| 1/1 [00:00<00:00,  2.71 batch/s]
Testing: 100%|██████████| 1/1 [00:00<00:00,  2.66 batches/s]



   -> Training Loss:  0.0000

   -> Testing Loss:  0.0000

Generating Critic data for MEDIUM (5-channel approach)...


KeyboardInterrupt: 