In [1]:
%%capture
!pip install --upgrade pip
!pip install gymnasium
!pip install pfrl

In [2]:
import pfrl
import torch
from torch import nn
import gymnasium as gym
from gymnasium import spaces
import numpy as np
import tensorflow as tf
import random

In [3]:
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
assert x_train.shape == (50000, 32, 32, 3)
assert x_test.shape == (10000, 32, 32, 3)
assert y_train.shape == (50000, 1)
assert y_test.shape == (10000, 1)

Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz


  and should_run_async(code)




In [4]:
class CifarEnv(gym.Env):

    def __init__(self,):

        self.observation_space = spaces.Box(low=0, high=255, shape=(32, 32, 3), dtype=np.uint8)

        self.action_space = spaces.Discrete(10)
        self.expected_action = 0

        # assert render_mode is None or render_mode in self.metadata["render_modes"]
        self.render_mode = None
        self.x, self.y = (x_train, y_train)
        self.random = True
        self.images_per_episode = 1
        self.dataset_idx = 0

    def _get_info(self):
      return 0

    def step(self, action):
        done = False
        reward = int(action == self.expected_action)

        obs = self._next_obs()

        self.step_count += 1
        if self.step_count >= self.images_per_episode:
            done = True

        return obs, reward, done, {}

    def reset(self, seed=None, options=None):
        self.step_count = 0
        obs = self._next_obs()
        return obs

    def _next_obs(self):
        if self.random:
            next_obs_idx = random.randint(0, len(self.x) - 1)
            self.expected_action = int(self.y[next_obs_idx])
            obs = self.x[next_obs_idx]

        else:
            obs = self.x[self.dataset_idx]
            self.expected_action = int(self.y[self.dataset_idx])

            self.dataset_idx += 1
            if self.dataset_idx >= len(self.x):
                raise StopIteration()

        return obs

class CifarEnvTest(gym.Env):

    def __init__(self,):

        self.observation_space = spaces.Box(low=0, high=255, shape=(32, 32, 3), dtype=np.uint8)

        self.action_space = spaces.Discrete(10)
        self.expected_action = 0

        # assert render_mode is None or render_mode in self.metadata["render_modes"]
        self.render_mode = None
        self.x, self.y = (x_test, y_test)
        self.random = True
        self.images_per_episode = 10000
        self.dataset_idx = 0

    def _get_info(self):
      return 0

    def step(self, action):
        done = False
        reward = int(action == self.expected_action)

        obs = self._next_obs()

        self.step_count += 1
        if self.step_count >= self.images_per_episode:
            done = True

        return obs, reward, done, {}

    def reset(self, seed=None, options=None):
        self.step_count = 0
        obs = self._next_obs()
        return obs

    def _next_obs(self):
        if self.random:
            next_obs_idx = random.randint(0, len(self.x) - 1)
            self.expected_action = int(self.y[next_obs_idx])
            obs = self.x[next_obs_idx]

        else:
            obs = self.x[self.dataset_idx]
            self.expected_action = int(self.y[self.dataset_idx])

            self.dataset_idx += 1
            if self.dataset_idx >= len(self.x):
                raise StopIteration()

        return obs

In [5]:
env = CifarEnv()
test_env = CifarEnvTest()

In [16]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class DistributionalDuelingDQN(nn.Module, pfrl.q_function.StateQFunction):
    """Distributional dueling fully-connected Q-function with discrete actions."""

    def __init__(
        self,
        n_actions,
        n_atoms=41,
        v_min = 0,
        v_max = 1, 
        n_input_channels=3,
        activation=torch.relu,
        bias=0.1,
    ):
        assert n_atoms >= 2
        assert v_min < v_max

        self.n_actions = n_actions
        self.n_input_channels = n_input_channels
        self.activation = activation
        self.n_atoms = n_atoms

        super().__init__()
        self.z_values = torch.linspace(v_min, v_max, n_atoms, dtype=torch.float32)

        self.network = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2), # output: 64 x 16 x 16
            nn.BatchNorm2d(64),

            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2), # output: 128 x 8 x 8
            nn.BatchNorm2d(128),

            nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2), # output: 256 x 4 x 4
            nn.BatchNorm2d(256),
            )

        self.main_stream = nn.Linear(4096, 1024)
        self.a_stream = nn.Linear(512, n_actions * n_atoms)
        self.v_stream = nn.Linear(512, n_atoms)

    def forward(self, x):

        h = self.network(x)

        # Advantage
        batch_size = x.shape[0]

        h = self.activation(self.main_stream(h.view(batch_size, -1)))
        h_a, h_v = torch.chunk(h, 2, dim=1)
        ya = self.a_stream(h_a).reshape((batch_size, self.n_actions, self.n_atoms))

        mean = ya.sum(dim=1, keepdim=True) / self.n_actions

        ya, mean = torch.broadcast_tensors(ya, mean)
        ya -= mean

        # State value
        ys = self.v_stream(h_v).reshape((batch_size, 1, self.n_atoms))
        ya, ys = torch.broadcast_tensors(ya, ys)
        q = F.softmax(ya + ys, dim=2)

        self.z_values = self.z_values.to(x.device)
        return pfrl.action_value.DistributionalDiscreteActionValue(q, self.z_values)

obs_size = env.observation_space.low.size
n_actions = env.action_space.n
q_func = DistributionalDuelingDQN(10)

In [17]:
# Use Adam to optimize q_func. eps=1e-2 is for stability.
optimizer = torch.optim.Adam(q_func.parameters(), eps=1e-2)
# Set the discount factor that discounts future rewards.
gamma = 1

# Use epsilon-greedy for exploration
explorer = pfrl.explorers.ConstantEpsilonGreedy(
    epsilon=0.3, random_action_func=env.action_space.sample)
explorer_2 = pfrl.explorers.LinearDecayEpsilonGreedy(start_epsilon=0.9, end_epsilon=0.01, decay_steps=50000, random_action_func = env.action_space.sample)
# DQN uses Experience Replay.
# Specify a replay buffer and its capacity.
replay_buffer = pfrl.replay_buffers.PrioritizedReplayBuffer(capacity=10 ** 6)

# Since observations from CartPole-v0 is numpy.float64 while
# As PyTorch only accepts numpy.float32 by default, specify
# a converter as a feature extractor function phi.
phi = lambda x:np.resize(x/255.0,(3,32,32)).astype(np.float32, copy=False)

# Set the device id to use GPU. To use CPU only, set it to -1.
gpu = 0

agent = pfrl.agents.CategoricalDoubleDQN(
        q_func,
        optimizer,
        replay_buffer,
        gpu= gpu,
        gamma=gamma,
        explorer=explorer_2,
        minibatch_size=32,
        replay_start_size=10000,
        target_update_interval=10000,
        update_interval=1,
        batch_accumulator="mean",
        phi=phi,
    )

In [18]:
# Set up the logger to print info messages for understandability.
import logging
import sys
import time
start_time = time.time()
logging.basicConfig(level=logging.INFO, stream=sys.stdout, format='')
eval_ep = 10000
pfrl.experiments.train_agent_with_evaluation(
    agent,
    env,
    steps=100000,           # Train the agent for 2000 steps
    eval_n_steps=None,       # We evaluate for episodes, not time
    eval_n_episodes=1,       # 10 episodes are sampled for each evaluation
    train_max_episode_len=1,  # Maximum length of each episode
    eval_max_episode_len=10000,
    eval_interval=20000,   # Evaluate the agent after every 1000 steps
    outdir='result',      # Save everything to 'result' directory
    eval_env = test_env,
)

(<pfrl.agents.categorical_double_dqn.CategoricalDoubleDQN at 0x794f099066e0>,
 [{'average_q': 0.18282156,
   'average_loss': 0.04435203477100003,
   'cumulative_steps': 20000,
   'n_updates': 10001,
   'rlen': 20000,
   'eval_score': 3327.0},
  {'average_q': 0.2665403,
   'average_loss': 0.03305000337830279,
   'cumulative_steps': 40000,
   'n_updates': 30001,
   'rlen': 40000,
   'eval_score': 3878.0},
  {'average_q': 0.33086088,
   'average_loss': 0.032989617517450824,
   'cumulative_steps': 60000,
   'n_updates': 50001,
   'rlen': 60000,
   'eval_score': 4119.0},
  {'average_q': 0.40977734,
   'average_loss': 0.03238331888627727,
   'cumulative_steps': 80000,
   'n_updates': 70001,
   'rlen': 80000,
   'eval_score': 4268.0},
  {'average_q': 0.41221794,
   'average_loss': 0.02787562897079624,
   'cumulative_steps': 100000,
   'n_updates': 90001,
   'rlen': 100000,
   'eval_score': 4347.0}])

In [19]:
print("DQN Training Time:", time.time() - start_time)

DQN Training Time: 1967.3128566741943
