In [None]:
%pip install gym
%pip install pandas
%pip install stable_baselines3
%pip install shimmy

In [None]:
import gym
import gym_battleship
import pandas as pd
from stable_baselines3 import DQN

In [None]:
env = gym.make('Battleship-v0', board_size=(10, 10))
env.reset()

In [None]:
from stable_baselines3.dqn import CnnPolicy
import stable_baselines3


In [None]:
import torch as th
import torch.nn as nn
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor

class BattleshipCNN(BaseFeaturesExtractor):
    def __init__(self, observation_space, features_dim=128):
        super().__init__(observation_space, features_dim)

        n_channels = observation_space.shape[0]  

        self.cnn = nn.Sequential(
            nn.Conv2d(n_channels, 32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Flatten()
        )

        with th.no_grad():
            sample = th.zeros((1, n_channels, 10, 10))
            sample_out = self.cnn(sample)
            conv_output_dim = sample_out.shape[1]

        self.linear = nn.Sequential(
            nn.Linear(conv_output_dim, features_dim),
            nn.ReLU()
        )

    def forward(self, obs):
        return self.linear(self.cnn(obs))


In [None]:
policy_kwargs = dict(
    features_extractor_class=BattleshipCNN,
    features_extractor_kwargs=dict(features_dim=128),
)

model = DQN(
    "CnnPolicy",
    env,
    policy_kwargs=policy_kwargs,
    learning_rate=1e-4,
    verbose=1,
)


In [None]:
model.learn(total_timesteps=100000, log_interval=10)

In [None]:
print(env.observation_space)
print(env.observation_space.shape[0])
print(env.action_space)

In [None]:
ACTION_SPACE = env.action_space.n
OBSERVATION_SPACE = env.observation_space.shape[0]

In [None]:
print(env.board_generated)

In [None]:
env.render()

In [None]:
action = (0, 0)
env.step(action)

In [None]:
env.render()

In [None]:
env.step((1, 0))
env.render()

In [None]:
env.step((0, 1))
env.render()

In [None]:
env.step((0, 2))
env.step((0, 3))
env.render()

In [None]:
env.step((0, 4))
env.render()

In [None]:
print(dir(env))

In [None]:
observation, reward, done, info = env.step((0, 5))
print(observation)
print(reward)
print(done)
print(info)

In [None]:
env.render()

In [None]:
print(env.board)

In [None]:
observation, reward, done, info = env.step((9, 5))
print(observation)
print(reward)
print(done)
print(info)