In [1]:
%%capture
!pip install --upgrade pip
!pip install gymnasium
!pip install pfrl

In [2]:
import pfrl
import torch
from torch import nn
import gymnasium as gym
from gymnasium import spaces
import numpy as np
import tensorflow as tf
import random

In [3]:
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
assert x_train.shape == (50000, 32, 32, 3)
assert x_test.shape == (10000, 32, 32, 3)
assert y_train.shape == (50000, 1)
assert y_test.shape == (10000, 1)

Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
     8192/170498071 [..............................] - ETA: 0s

  and should_run_async(code)




In [4]:
class CifarEnv(gym.Env):

    def __init__(self,):

        self.observation_space = spaces.Box(low=0, high=255, shape=(32, 32, 3), dtype=np.uint8)

        self.action_space = spaces.Discrete(10)
        self.expected_action = 0

        # assert render_mode is None or render_mode in self.metadata["render_modes"]
        self.render_mode = None
        self.x, self.y = (x_train, y_train)
        self.random = True
        self.images_per_episode = 1
        self.dataset_idx = 0

    def _get_info(self):
      return 0

    def step(self, action):
        done = False
        reward = int(action == self.expected_action)

        obs = self._next_obs()

        self.step_count += 1
        if self.step_count >= self.images_per_episode:
            done = True

        return obs, reward, done, {}

    def reset(self, seed=None, options=None):
        self.step_count = 0
        obs = self._next_obs()
        return obs

    def _next_obs(self):
        if self.random:
            next_obs_idx = random.randint(0, len(self.x) - 1)
            self.expected_action = int(self.y[next_obs_idx])
            obs = self.x[next_obs_idx]

        else:
            obs = self.x[self.dataset_idx]
            self.expected_action = int(self.y[self.dataset_idx])

            self.dataset_idx += 1
            if self.dataset_idx >= len(self.x):
                raise StopIteration()

        return obs

class CifarEnvTest(gym.Env):

    def __init__(self,):

        self.observation_space = spaces.Box(low=0, high=255, shape=(32, 32, 3), dtype=np.uint8)

        self.action_space = spaces.Discrete(10)
        self.expected_action = 0

        # assert render_mode is None or render_mode in self.metadata["render_modes"]
        self.render_mode = None
        self.x, self.y = (x_test, y_test)
        self.random = True
        self.images_per_episode = 10000
        self.dataset_idx = 0

    def _get_info(self):
      return 0

    def step(self, action):
        done = False
        reward = int(action == self.expected_action)

        obs = self._next_obs()

        self.step_count += 1
        if self.step_count >= self.images_per_episode:
            done = True

        return obs, reward, done, {}

    def reset(self, seed=None, options=None):
        self.step_count = 0
        obs = self._next_obs()
        return obs

    def _next_obs(self):
        if self.random:
            next_obs_idx = random.randint(0, len(self.x) - 1)
            self.expected_action = int(self.y[next_obs_idx])
            obs = self.x[next_obs_idx]

        else:
            obs = self.x[self.dataset_idx]
            self.expected_action = int(self.y[self.dataset_idx])

            self.dataset_idx += 1
            if self.dataset_idx >= len(self.x):
                raise StopIteration()

        return obs

In [5]:
env = CifarEnv()
test_env = CifarEnvTest()

In [6]:
class QFunction(torch.nn.Module):

    def __init__(self,):
        super().__init__()
        self.network = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2), # output: 64 x 16 x 16
            nn.BatchNorm2d(64),

            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2), # output: 128 x 8 x 8
            nn.BatchNorm2d(128),

            nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2), # output: 256 x 4 x 4
            nn.BatchNorm2d(256),

            nn.Flatten(),
            nn.Linear(256*4*4, 1024),
            nn.ReLU(),
            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.Linear(512, 10))

    def forward(self, x):
        h = self.network(x)
        return pfrl.action_value.DiscreteActionValue(h)

obs_size = env.observation_space.low.size
n_actions = env.action_space.n
q_func = QFunction()

In [7]:
# Use Adam to optimize q_func. eps=1e-2 is for stability.
optimizer = torch.optim.Adam(q_func.parameters(), eps=1e-2)
# Set the discount factor that discounts future rewards.
gamma = 1

# Use epsilon-greedy for exploration
explorer = pfrl.explorers.ConstantEpsilonGreedy(
    epsilon=0.3, random_action_func=env.action_space.sample)
explorer_2 = pfrl.explorers.LinearDecayEpsilonGreedy(start_epsilon=0.9, end_epsilon=0.01, decay_steps=50000, random_action_func = env.action_space.sample)
# DQN uses Experience Replay.
# Specify a replay buffer and its capacity.
replay_buffer = pfrl.replay_buffers.PrioritizedReplayBuffer(capacity=10 ** 6)

# Since observations from CartPole-v0 is numpy.float64 while
# As PyTorch only accepts numpy.float32 by default, specify
# a converter as a feature extractor function phi.
phi = lambda x:np.resize(x,(3,32,32)).astype(np.float32, copy=False)

# Set the device id to use GPU. To use CPU only, set it to -1.
gpu = 0

# Now create an agent that will interact with the environment.
agent = pfrl.agents.DoubleDQN(
    q_func,
    optimizer,
    replay_buffer,
    gamma,
    explorer_2,
    replay_start_size=10000,
    update_interval=1,
    target_update_interval=10000,
    phi = phi,
    gpu=gpu,
)

In [8]:
# Set up the logger to print info messages for understandability.
import logging
import sys
import time
start_time = time.time()
logging.basicConfig(level=logging.INFO, stream=sys.stdout, format='')
eval_ep = 10000
pfrl.experiments.train_agent_with_evaluation(
    agent,
    env,
    steps=100000,           # Train the agent for 2000 steps
    eval_n_steps=None,       # We evaluate for episodes, not time
    eval_n_episodes=1,       # 10 episodes are sampled for each evaluation
    train_max_episode_len=1,  # Maximum length of each episode
    eval_max_episode_len=10000,
    eval_interval=20000,   # Evaluate the agent after every 1000 steps
    outdir='result',      # Save everything to 'result' directory
    eval_env = test_env,
)

(<pfrl.agents.double_dqn.DoubleDQN at 0x79286776be50>,
 [{'average_q': 0.21237372,
   'average_loss': 0.0036870955361519007,
   'cumulative_steps': 20000,
   'n_updates': 10001,
   'rlen': 20000,
   'eval_score': 3700.0},
  {'average_q': 0.34901989,
   'average_loss': 0.003966008767019957,
   'cumulative_steps': 40000,
   'n_updates': 30001,
   'rlen': 40000,
   'eval_score': 4345.0},
  {'average_q': 0.44182006,
   'average_loss': 0.003636181562906131,
   'cumulative_steps': 60000,
   'n_updates': 50001,
   'rlen': 60000,
   'eval_score': 4661.0},
  {'average_q': 0.5190564,
   'average_loss': 0.0026982996793230994,
   'cumulative_steps': 80000,
   'n_updates': 70001,
   'rlen': 80000,
   'eval_score': 4890.0},
  {'average_q': 0.5582751,
   'average_loss': 0.001845650130417198,
   'cumulative_steps': 100000,
   'n_updates': 90001,
   'rlen': 100000,
   'eval_score': 4983.0}])

In [9]:
print("DQN Training Time:", time.time() - start_time)

DQN Training Time: 1795.4009108543396


In [14]:
!zip -r /content/result.zip /content/result

  and should_run_async(code)


  adding: content/result/ (stored 0%)
  adding: content/result/scores.txt (deflated 53%)
  adding: content/result/100000_finish/ (stored 0%)
  adding: content/result/100000_finish/optimizer.pt (deflated 10%)
  adding: content/result/100000_finish/model.pt (deflated 8%)
  adding: content/result/100000_finish/target_model.pt (deflated 8%)
  adding: content/result/best/ (stored 0%)
  adding: content/result/best/optimizer.pt (deflated 10%)
  adding: content/result/best/model.pt (deflated 8%)
  adding: content/result/best/target_model.pt (deflated 8%)


In [15]:
from google.colab import files
files.download("/content/result.zip")

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>