In [None]:
import gym_battleship
import gymnasium as gym
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from stable_baselines3 import DQN
from stable_baselines3.common.callbacks import BaseCallback
import torch as th
import torch.nn as nn
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
from stable_baselines3.dqn import CnnPolicy
from gym_battleship.environments.battleship import CHANNEL_MAP
from stable_baselines3.common.results_plotter import load_results, ts2xy
from stable_baselines3.common.monitor import Monitor
import os
from typing import Callable
import imageio

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
hyperparameters = [
    {
        'features_dim': 128,
        'hidden_layer1': 32,
        'hidden_layer2': 64,
        'net_arch': [64, 64],
        'learning_rate': 1e-4,
        'exploration_fraction': 0.1,
        'exploration_initial_eps': 1,
        'exploration_final_eps': 0.1,
        'max_moves': 110,
        'rewards': {
            'win': 100,
            'lose': -30,
            'missed': -0.2,
            'hit': 5,
            'proximal_hit': 10,
            'repeat_missed': -50,
            'repeat_hit': -50,
        }
    },
    {
        'features_dim': 128,
        'hidden_layer1': 32,
        'hidden_layer2': 64,
        'net_arch': [64, 64],
        'learning_rate': 1e-4,
        'exploration_fraction': 0.5,
        'exploration_initial_eps': 1,
        'exploration_final_eps': 0.1,
        'max_moves': 110,
        'rewards': {
            'win': 100,
            'lose': -30,
            'missed': -0.2,
            'hit': 5,
            'proximal_hit': 10,
            'repeat_missed': -50,
            'repeat_hit': -50,
        }
    },
    {
        'features_dim': 128,
        'hidden_layer1': 32,
        'hidden_layer2': 64,
        'net_arch': [64, 64],
        'learning_rate': 1e-4,
        'exploration_fraction': 0.5,
        'exploration_initial_eps': 1,
        'exploration_final_eps': 0.1,
        'max_moves': 110,
        'rewards': {
            'win': 100,
            'lose': -30,
            'missed': 0,
            'hit': 5,
            'proximal_hit': 10,
            'repeat_missed': -50,
            'repeat_hit': -50,
        }
    },
    {
        'features_dim': 128,
        'hidden_layer1': 32,
        'hidden_layer2': 64,
        'net_arch': [64, 64],
        'learning_rate': 1e-4,
        'exploration_fraction': 0.8,
        'exploration_initial_eps': 1,
        'exploration_final_eps': 0.1,
        'max_moves': 110,
        'rewards': {
            'win': 100,
            'lose': -30,
            'missed': 0,
            'hit': 5,
            'proximal_hit': 10,
            'repeat_missed': -20,
            'repeat_hit': -20,
        }
    },
    {
        'features_dim': 128,
        'hidden_layer1': 32,
        'hidden_layer2': 64,
        'net_arch': [64, 64],
        'learning_rate': 1e-4,
        'exploration_fraction': 0.8,
        'exploration_initial_eps': 1,
        'exploration_final_eps': 0.1,
        'max_moves': 110,
        'rewards': {
            'win': 100,
            'lose': -30,
            'missed': -0.2,
            'hit': 5,
            'proximal_hit': 20,
            'repeat_missed': -20,
            'repeat_hit': -20,
        }
    },
    # final reward schema below
    {
        'features_dim': 128,
        'hidden_layer1': 32,
        'hidden_layer2': 64,
        'net_arch': [64, 64],
        'learning_rate': 1e-4,
        'exploration_fraction': 0.8,
        'exploration_initial_eps': 1,
        'exploration_final_eps': 0.1,
        'max_moves': 110,
        'rewards': {
            'win': 100,
            'lose': -30,
            'missed': -0.2,
            'hit': 5,
            'proximal_hit': 20,
            'repeat_missed': -20,
            'repeat_hit': -3, 
        }
    },
    {
        'features_dim': 128,
        'hidden_layer1': 64,
        'hidden_layer2': 128,
        'net_arch': [128, 128],
        'learning_rate': 1e-4,
        'exploration_fraction': 0.8,
        'exploration_initial_eps': 1,
        'exploration_final_eps': 0.1,
        'max_moves': 110,
        'rewards': {
            'win': 100,
            'lose': -30,
            'missed': -0.2,
            'hit': 5,
            'proximal_hit': 20,
            'repeat_missed': -20,
            'repeat_hit': -3, 
        }
    },
    {
        'features_dim': 256,
        'hidden_layer1': 64,
        'hidden_layer2': 128,
        'net_arch': [256, 256],
        'learning_rate': 1e-4,
        'exploration_fraction': 0.8,
        'exploration_initial_eps': 1,
        'exploration_final_eps': 0.1,
        'max_moves': 110,
        'rewards': {
            'win': 100,
            'lose': -30,
            'missed': -0.2,
            'hit': 5,
            'proximal_hit': 20,
            'repeat_missed': -20,
            'repeat_hit': -3, 
        }
    },
    {
        'features_dim': 256,
        'hidden_layer1': 128,
        'hidden_layer2': 256,
        'net_arch': [256, 256],
        'learning_rate': 1e-4,
        'exploration_fraction': 0.8,
        'exploration_initial_eps': 1,
        'exploration_final_eps': 0.1,
        'max_moves': 110,
        'rewards': {
            'win': 100,
            'lose': -30,
            'missed': -0.2,
            'hit': 5,
            'proximal_hit': 20,
            'repeat_missed': -20,
            'repeat_hit': -3, 
        }
    },
    {
        'features_dim': 256,
        'hidden_layer1': 128,
        'hidden_layer2': 256,
        'net_arch': [256, 256],
        'learning_rate': 1e-4,
        'exploration_fraction': 0.4,
        'exploration_initial_eps': 1,
        'exploration_final_eps': 0.1,
        'max_moves': 110,
        'rewards': {
            'win': 100,
            'lose': -30,
            'missed': -0.2,
            'hit': 5,
            'proximal_hit': 20,
            'repeat_missed': -20,
            'repeat_hit': -3, 
        }
    },
]

In [None]:
def make_battleship_cnn(features_dim: int, hidden_layer1: int, hidden_layer2: int):
    class BattleshipCNN(BaseFeaturesExtractor):
        def __init__(self, observation_space, features_dim=features_dim):
            super().__init__(observation_space, features_dim)

            n_channels = observation_space.shape[0]  

            self.cnn = nn.Sequential(
                nn.Conv2d(n_channels, hidden_layer1, kernel_size=3, stride=1, padding=1),
                nn.ReLU(),
                nn.Conv2d(hidden_layer1, hidden_layer2, kernel_size=3, stride=1, padding=1),
                nn.ReLU(),
                nn.Flatten()
            )

            with th.no_grad():
                sample = th.zeros((1, n_channels, 10, 10))
                sample_out = self.cnn(sample)
                conv_output_dim = sample_out.shape[1]

            self.linear = nn.Sequential(
                nn.Linear(conv_output_dim, features_dim),
                nn.ReLU()
            )

        def forward(self, obs):
            return self.linear(self.cnn(obs))
    
    return BattleshipCNN


In [None]:
def plot_q_value_heatmap(q_values, title="Q-Value Heatmap", image=False):
    q_vals = q_values[0].cpu().numpy()
    
    grid = np.zeros((10, 10))
    for i in range(100):
        row = i // 10
        col = i % 10
        grid[row, col] = q_vals[i]
    
    fig, ax = plt.subplots(figsize=(10, 8))
    im = ax.imshow(grid, cmap='coolwarm', aspect='auto')
    
    cbar = plt.colorbar(im, ax=ax)
    cbar.set_label('Q-Value', rotation=270, labelpad=20)
    
    ax.set_xticks(np.arange(10))
    ax.set_yticks(np.arange(10))
    ax.set_xticklabels(np.arange(10))
    ax.set_yticklabels(np.arange(10))
    
    ax.set_xticks(np.arange(10) - 0.5, minor=True)
    ax.set_yticks(np.arange(10) - 0.5, minor=True)
    ax.grid(which='minor', color='gray', linestyle='-', linewidth=0.5)
    
    ax.set_xlabel('Column')
    ax.set_ylabel('Row')
    ax.set_title(title)
    
    for i in range(10):
        for j in range(10):
            text = ax.text(j, i, f'{grid[i, j]:.2f}',
                          ha="center", va="center", color="black", fontsize=8)
    
    plt.tight_layout()
    
    if image:
        # Convert plot to image array
        fig.canvas.draw()
        img = np.frombuffer(fig.canvas.buffer_rgba(), dtype=np.uint8)
        img = img.reshape(fig.canvas.get_width_height()[::-1] + (4,))
        # Convert RGBA to RGB
        img = img[:, :, :3]
        plt.close(fig)
        return img
    else:
        plt.show()
        return grid

In [None]:
def masked_predict(model, obs):
    """Choose action with highest Q-value among legal actions"""
    obs_tensor = th.as_tensor(obs).unsqueeze(0).float()
    
    with th.no_grad():
        q_values = model.q_net(obs_tensor)
        
    legal_mask = (obs[CHANNEL_MAP.LEGAL_MOVE.value, :, :] == 0).flatten()
    
    # set illegal actions to very negative Q-value
    q_values_masked = q_values.clone()
    q_values_masked[0, ~legal_mask] = -float('inf')
    action = q_values_masked.argmax(dim=1).item()

    return action, q_values

In [None]:
def get_q_values(model, obs):
    """Get Q-values for all actions given a state"""
    obs_tensor = th.as_tensor(obs).unsqueeze(0).float()
    
    with th.no_grad():
        q_values = model.q_net(obs_tensor)

    return q_values

In [None]:

class TrainingMetricsCallback(BaseCallback):
    
    def __init__(self, check_freq: int = 100, verbose=0):
        super(TrainingMetricsCallback, self).__init__(verbose)
        self.check_freq = check_freq # number of episodes we're averaging over for mean reward and game length
        self.ep_rew_means = []
        self.ep_len_means = []
        self.episodes = []
        self.current_episode_count = 0
        
    def _on_step(self) -> bool:
        # max length of ep_info_buffer is 100 by default
        if len(self.model.ep_info_buffer) == self.check_freq:
            self.current_episode_count += self.check_freq
            
            new_episodes = list(self.model.ep_info_buffer)[:self.check_freq]
            for _ in range(self.check_freq):
                self.model.ep_info_buffer.popleft() # remove episodes we've seen
                
            ep_means = [ep['r'] for ep in new_episodes]
            ep_lens = [ep['l'] for ep in new_episodes]
            
            ep_rew_mean = np.mean(ep_means)
            ep_len_mean = np.mean(ep_lens)
            
            self.ep_rew_means.append(ep_rew_mean)
            self.ep_len_means.append(ep_len_mean)
            self.episodes.append(self.current_episode_count)
            
            if self.verbose > 0:
                print(f"Current episode count: {self.current_episode_count}: "
                        f"Mean reward = {ep_rew_mean:.2f}, Mean length = {ep_len_mean:.2f} for past {self.check_freq} episodes.")
                    
        return True


def plot_training_metrics(callback, hyperparameter_index):
    fig1, ax1 = plt.subplots(figsize=(10, 6))
    ax1.plot(callback.episodes, callback.ep_rew_means, 'b-')
    ax1.set_xlabel('Episodes')
    ax1.set_ylabel('Episode Reward Mean')
    ax1.set_title(f'Model {hyperparameter_index}: Episode Reward Mean')
    ax1.grid(True)
    plt.tight_layout()
    fig1.savefig(f'./dqn_models/model-{hyperparameter_index}-reward.png', dpi=150)
    plt.show()
    plt.close(fig1)
    
    fig2, ax2 = plt.subplots(figsize=(10, 6))
    ax2.plot(callback.episodes, callback.ep_len_means, 'r-')
    ax2.set_xlabel('Episodes')
    ax2.set_ylabel('Mean Episode Length')
    ax2.set_title(f'Model {hyperparameter_index}: Episode Length Mean')
    ax2.grid(True)
    plt.tight_layout()
    fig2.savefig(f'./dqn_models/model-{hyperparameter_index}-length.png', dpi=150)
    plt.show()
    plt.close(fig2)
    


In [None]:
with open("dqn-training.txt", "a") as f:
    for index, hyperparameter in enumerate(hyperparameters):
        env = gym.make('Battleship-v0', board_size=(10, 10), reward_dictionary=hyperparameter['rewards'], episode_steps=hyperparameter['max_moves'])
        env.reset()

        CustomCNN = make_battleship_cnn(
            hyperparameter['features_dim'],
            hyperparameter['hidden_layer1'],
            hyperparameter['hidden_layer2']
        )

        policy_kwargs = dict(
            features_extractor_class=CustomCNN,
            features_extractor_kwargs=dict(features_dim=hyperparameter['features_dim']),
            net_arch=hyperparameter['net_arch'],
            normalize_images=False
        )

        model = DQN(
            "CnnPolicy",
            env,
            policy_kwargs=policy_kwargs,
            learning_rate=hyperparameter['learning_rate'],
            # adjustable parameters
            exploration_fraction=hyperparameter['exploration_fraction'],
            exploration_initial_eps=hyperparameter['exploration_initial_eps'],
            exploration_final_eps=hyperparameter['exploration_final_eps'],
            verbose=0,
        )
        
        # Create callback to track training metrics
        metrics_callback = TrainingMetricsCallback(check_freq=10)
        
        model.learn(total_timesteps=200000, log_interval=10, callback=metrics_callback)
        model.save(f"./dqn_models-test1/dqn-{index + 1}")
        
        # Plot training metrics
        plot_training_metrics(metrics_callback, index + 1)
        

In [None]:
models_dir = "./dqn_models"

models = [f for f in os.listdir(models_dir) if f.startswith("dqn-")]

def run_games(model, env, num_games=1000, mask=True): 
    env.reset()
    num_lost = 0
    num_wins = 0
    num_steps = 0
    total_reward = 0
    num_hits = 0
    num_misses = 0
    action_distribution = []
    hit_to_miss_ratios = []
    sample_freq = num_games / 5

    for game in range(num_games):
        obs, info = env.reset()
        sampling_action_distribution = True if (game % sample_freq == 0 and not mask) else False
        action_distribution.append({}) if sampling_action_distribution else None
        game_hits = 0
        game_misses = 0
        
        for _ in range(110):
            if mask:
                action, q_values = masked_predict(model, obs)
            else:
                action, _states = model.predict(obs)
                action = int(action)
                if sampling_action_distribution:
                    action_distribution[-1][action] = action_distribution[-1].get(action, 0) + 1
                
            obs, reward, terminated, truncated, info = env.step(action)
            total_reward += reward
            num_steps += 1
            
            if reward == env.unwrapped.reward_dictionary['hit'] or reward == env.unwrapped.reward_dictionary['proximal_hit'] or reward == env.unwrapped.reward_dictionary['win']:
                num_hits += 1
                game_hits += 1
            else: 
                num_misses += 1
                game_misses += 1
            
            if truncated: 
                num_lost += 1
                hit_to_miss_ratios.append(game_hits / game_misses if game_misses > 0 else float('inf'))
                break
                
            if terminated:
                num_wins += 1
                hit_to_miss_ratios.append(game_hits / game_misses if game_misses > 0 else float('inf'))
                break
    
    return {
        "num_wins": num_wins,
        "num_lost": num_lost,
        "num_hits": num_hits,
        "num_misses": num_misses,
        "hit-to-miss-ratios": hit_to_miss_ratios,
        "average-hit-to-miss-ratio": num_hits / num_misses if num_misses > 0 else float('inf'),
        "win_rate": num_wins / num_games,
        "avg_steps": num_steps / num_games,
        "avg_reward": total_reward / num_games,
        "action_distribution_samples": action_distribution
    }
    
with open("dqn-training.txt", "a") as f:
    for index, model_name in enumerate(models): 
        num_truncated = 0
        num_wins = 0
        num_steps = 0
        total_reward = 0
        
        model = DQN.load(os.path.join(models_dir, model_name))
        NUM_GAMES = 1000
        
        # 1000 games with action masking and 1000 without action masking
        action_masking_results = run_games(model, env, num_games=NUM_GAMES, mask=True)
        non_action_masking_results = run_games(model, env, num_games=NUM_GAMES, mask=False)
        
        print("-----------------------------")
        print(f"Model {index + 1} results: ")
        print(f"Action masking results ({NUM_GAMES} games):")
        print(action_masking_results)
        print(f"Non-action masking results ({NUM_GAMES}):")
        print(non_action_masking_results)
        
        f.write("-----------------------------\n")
        f.write(f"Model {index + 1} results:\n")
        f.write("\tAction masking results: \n")
        f.write("\t" + str(action_masking_results) + "\n")
        f.write("\tNon-action masking results: \n")
        f.write("\t" + str(non_action_masking_results) + "\n")
        f.flush()  # Ensure data is written after each hyperparameter set

In [None]:
def record_video(env, model, out_directory, fps=30):
  images = []
  obs, info = env.reset()
  
  #get initial board image and q-values
  img = env.unwrapped.render(mode="image")
  initial_q_values = get_q_values(model, obs)
  heatmap_img = plot_q_value_heatmap(initial_q_values, title="Q-Value Heatmap for next move", image=True)
  combined_img = np.concatenate((img, heatmap_img), axis=1)
  images.append(combined_img)
  
  while True:
    action, q_values = masked_predict(model, obs)
    obs, reward, terminated, truncated, info = env.step(int(action))
    img = env.unwrapped.render(mode="image")
    new_q_values = get_q_values(model, obs)
    heatmap_img = plot_q_value_heatmap(new_q_values, title="Q-Value Heatmap for next move", image=True)
    
    # combine the board image and q-value heatmap side by side
    combined_img = np.concatenate((img, heatmap_img), axis=1)
    images.append(combined_img)
    
    if terminated or truncated: 
      break
  
  imageio.mimsave(out_directory, images, fps=fps)

In [None]:
model = DQN.load("./dqn_models-test1/dqn-10")
env = gym.make('Battleship-v0', board_size=(10, 10), reward_dictionary=hyperparameters[9]['rewards'], episode_steps=hyperparameters[9]['max_moves'])

record_video(env, model, "./dqn_models-test1/dqn-14-video.gif", fps=0.5)