In [1]:
%%capture
!pip install --upgrade pip
!pip install gymnasium
!pip install pfrl

In [2]:
import pfrl
import torch
from torch import nn
import gymnasium as gym
from gymnasium import spaces
import numpy as np
import tensorflow as tf
import random

In [3]:
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
assert x_train.shape == (50000, 32, 32, 3)
assert x_test.shape == (10000, 32, 32, 3)
assert y_train.shape == (50000, 1)
assert y_test.shape == (10000, 1)

  and should_run_async(code)


In [4]:
class CifarEnv(gym.Env):

    def __init__(self,):

        self.observation_space = spaces.Box(low=0, high=255, shape=(32, 32, 3), dtype=np.uint8)

        self.action_space = spaces.Discrete(10)
        self.expected_action = 0

        # assert render_mode is None or render_mode in self.metadata["render_modes"]
        self.render_mode = None
        self.x, self.y = (x_train, y_train)
        self.random = True
        self.images_per_episode = 1
        self.dataset_idx = 0

    def _get_info(self):
      return 0

    def step(self, action):
        done = False
        reward = int(action == self.expected_action)

        obs = self._next_obs()

        self.step_count += 1
        if self.step_count >= self.images_per_episode:
            done = True

        return obs, reward, done, {}

    def reset(self, seed=None, options=None):
        self.step_count = 0
        obs = self._next_obs()
        return obs

    def _next_obs(self):
        if self.random:
            next_obs_idx = random.randint(0, len(self.x) - 1)
            self.expected_action = int(self.y[next_obs_idx])
            obs = self.x[next_obs_idx]

        else:
            obs = self.x[self.dataset_idx]
            self.expected_action = int(self.y[self.dataset_idx])

            self.dataset_idx += 1
            if self.dataset_idx >= len(self.x):
                raise StopIteration()

        return obs

class CifarEnvTest(gym.Env):

    def __init__(self,):

        self.observation_space = spaces.Box(low=0, high=255, shape=(32, 32, 3), dtype=np.uint8)

        self.action_space = spaces.Discrete(10)
        self.expected_action = 0

        # assert render_mode is None or render_mode in self.metadata["render_modes"]
        self.render_mode = None
        self.x, self.y = (x_test, y_test)
        self.random = True
        self.images_per_episode = 10000
        self.dataset_idx = 0

    def _get_info(self):
      return 0

    def step(self, action):
        done = False
        reward = int(action == self.expected_action)

        obs = self._next_obs()

        self.step_count += 1
        if self.step_count >= self.images_per_episode:
            done = True

        return obs, reward, done, {}

    def reset(self, seed=None, options=None):
        self.step_count = 0
        obs = self._next_obs()
        return obs

    def _next_obs(self):
        if self.random:
            next_obs_idx = random.randint(0, len(self.x) - 1)
            self.expected_action = int(self.y[next_obs_idx])
            obs = self.x[next_obs_idx]

        else:
            obs = self.x[self.dataset_idx]
            self.expected_action = int(self.y[self.dataset_idx])

            self.dataset_idx += 1
            if self.dataset_idx >= len(self.x):
                raise StopIteration()

        return obs

In [5]:
env = CifarEnv()
test_env = CifarEnvTest()

In [6]:
class Backbone(torch.nn.Module):

    def __init__(self,input_shape=3,output_shape=1):
        super().__init__()
        self.network = nn.Sequential(
            nn.Conv2d(input_shape, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2), # output: 64 x 16 x 16
            nn.BatchNorm2d(64),

            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2), # output: 128 x 8 x 8
            nn.BatchNorm2d(128),

            nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2), # output: 256 x 4 x 4
            nn.BatchNorm2d(256),

            nn.Flatten(),
            nn.Linear(256*4*4, 1024),
            nn.ReLU(),
            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.Linear(512,output_shape)
            )

    def forward(self, x):
        h = self.network(x)
        return h
policy_nn = nn.Sequential(
            Backbone(3,10),
            pfrl.policies.SoftmaxCategoricalHead())

value_nn = Backbone(3,1)

In [7]:
optimizer = torch.optim.Adam(value_nn.parameters(), eps=1e-2)
gamma = 1

# Use epsilon-greedy for exploration
explorer = pfrl.explorers.ConstantEpsilonGreedy(
    epsilon=0.3, random_action_func=env.action_space.sample)
explorer_2 = pfrl.explorers.LinearDecayEpsilonGreedy(start_epsilon=0.9, end_epsilon=0.01, decay_steps=50000, random_action_func = env.action_space.sample)
# DQN uses Experience Replay.
# Specify a replay buffer and its capacity.
replay_buffer = pfrl.replay_buffers.PrioritizedReplayBuffer(capacity=10 ** 5)
phi = lambda x:np.resize(x/255.0,(3,32,32)).astype(np.float32, copy=False)
gpu = 0

agent = pfrl.agents.TRPO(
        policy=policy_nn,
        vf=value_nn,
        vf_optimizer=optimizer,
        gpu=gpu,
        update_interval=1000,
        max_kl=0.01,
        conjugate_gradient_max_iter=20,
        conjugate_gradient_damping=1e-1,
        gamma=gamma,
        lambd=0.97,
        vf_epochs=1,
        vf_batch_size=32,
        entropy_coef=0,
        phi=phi,
    )

In [8]:
# Set up the logger to print info messages for understandability.
import logging
import sys
import time
start_time = time.time()
logging.basicConfig(level=logging.INFO, stream=sys.stdout, format='')
eval_ep = 10000
pfrl.experiments.train_agent_with_evaluation(
    agent,
    env,
    steps=100000,           # Train the agent for 2000 steps
    eval_n_steps=None,       # We evaluate for episodes, not time
    eval_n_episodes=1,       # 10 episodes are sampled for each evaluation
    train_max_episode_len=1,  # Maximum length of each episode
    eval_max_episode_len=10000,
    eval_interval=20000,   # Evaluate the agent after every 1000 steps
    outdir='result',      # Save everything to 'result' directory
    eval_env = test_env,
)

(<pfrl.agents.trpo.TRPO at 0x7d7edc993b80>,
 [{'average_value': 0.08763197,
   'average_entropy': 2.3000557,
   'average_kl': 0.004696746822446585,
   'average_policy_step_size': 0.05,
   'explained_variance': -0.012148978120053444,
   'eval_score': 1021.0},
  {'average_value': 0.13775407,
   'average_entropy': 2.2525432,
   'average_kl': 0.006308911135420203,
   'average_policy_step_size': 0.45,
   'explained_variance': -0.0011874471062405245,
   'eval_score': 1351.0},
  {'average_value': 0.17597134,
   'average_entropy': 2.0902784,
   'average_kl': 0.006900565637471645,
   'average_policy_step_size': 0.5333333333333333,
   'explained_variance': -0.006330721639808967,
   'eval_score': 1813.0},
  {'average_value': 0.2241693,
   'average_entropy': 1.9300092,
   'average_kl': 0.006863759116044846,
   'average_policy_step_size': 0.534375,
   'explained_variance': 0.032589801761501724,
   'eval_score': 2318.0},
  {'average_value': 0.26609084,
   'average_entropy': 1.7850924,
   'average_kl

In [9]:
print("DQN Training Time:", time.time() - start_time)

DQN Training Time: 2647.733216524124
