In [1]:
from sklearn.datasets import fetch_openml

mnist = fetch_openml('mnist_784', version=1, as_frame=False)

In [2]:
# create test and train sets

# randomly shuffle the data into 10% test and 90% train
import numpy as np
from sklearn.model_selection import train_test_split

print(mnist.data[0].shape)

(784,)


In [3]:
import gymnasium as gym
from typing import Literal, Any

from numpy.typing import NDArray

# create custom environment for MNIST
# environment is binary classifier for the digit 3. If the digit
# is the number 3, should get action 1. Otherwise, action 0.
# reward is +1 for correct classification, -1 for incorrect.
class MNISTEnv(gym.Env):

    index: int
    train_type: Literal["train", "test"]
    X_train: NDArray[np.float32]
    X_test: NDArray[np.float32]
    y_train: NDArray[np.int64]
    y_test: NDArray[np.int64]

    def __init__(self, train_type: Literal["train", "test"], seed: int = 1337):
        super().__init__()

        # each observation is a flattened 28x28 image
        self.observation_space = gym.spaces.Box(low=0, high=255, shape=(784,), dtype=np.float32)

        # either a yes or no choice
        self.action_space = gym.spaces.Discrete(2)

        self.X_train, self.X_test, self.y_train, self.y_test = train_test_split(mnist.data, mnist.target, test_size=0.1, random_state=seed)
        self.X_train = self.X_train.astype(np.float32)
        self.X_test = self.X_test.astype(np.float32)
        self.y_train = self.y_train.astype(np.int64)
        self.y_test = self.y_test.astype(np.int64)
        self.train_type = train_type

        self.index = 0

    def reset(
        self,
        *,
        seed: int | None = None,
        options: dict[str, Any] | None = None,
    ) -> tuple[NDArray[np.float32], dict[str, Any]]:
        super().reset(seed=seed)
        if seed is None:
            seed = 1337
        self.index = 0
        if self.train_type == "train":
            return self.X_train[self.index], {}
        else:
            return self.X_test[self.index], {}
    
    def step(self, action: int) -> tuple[NDArray[np.float32], float, bool, bool, dict[str, Any]]:
        
        if self.train_type == "train":
            current_index = self.index
            X = self.X_train[self.index]
            y = self.y_train
        else:
            current_index = self.index
            X = self.X_test[self.index]
            y = self.y_test
        
        self.index += 1
        if self.train_type == "train":
            done = self.index >= len(self.X_train) - 1
        else:
            done = self.index >= len(self.X_test) - 1
        
        if done:
            next_state = np.zeros_like(X[0])
        else:
            next_state = X
        
        # print(f"Index: {current_index}")
        # print(f"Action taken: {action}")
        # print(f"X: {X.shape}, y: {y[current_index]}")

        if (action == 1 and y[current_index] == 3) or (action == 0 and y[current_index] != 3):
            reward = 1.0
        else:
            reward = -1.0

        return next_state, reward, done, False, {}
        

In [9]:
from stable_baselines3 import DQN

train_env = MNISTEnv(train_type="train")

model = DQN("MlpPolicy", train_env, verbose=1)

model.learn(total_timesteps=630000, progress_bar=True)

# for episodes in range(4):
#     done = False
#     obs, info = train_env.reset()
#     while not done:
#         # print(obs.shape)
#         action: int = int(model.predict(obs, deterministic=True)[0])
#         # print(f"Output: {action}")
#         obs, reward, done, truncated, info = train_env.step(action)
#     print("Completed episode", episodes)

<stable_baselines3.dqn.dqn.DQN at 0x7b94581ae390>

In [10]:
model.save("dqn-mnist")

In [12]:
from stable_baselines3.common.evaluation import evaluate_policy

test_env = MNISTEnv(train_type="test")
model = DQN.load("dqn-mnist")

mean_reward, std_reward = evaluate_policy(model, test_env, n_eval_episodes=10)

print(f"Mean reward: {mean_reward} +/- {std_reward}")