<a href="https://colab.research.google.com/github/mooithub/pyemotion_rl/blob/master/my_examples/kms_examples/DeepRLBook/Chapter06/my_dqn_pytorch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# DQN

In [1]:
!git clone https://github.com/psygrammer/pyemotion_rl.git

Cloning into 'pyemotion_rl'...
remote: Enumerating objects: 336, done.[K
remote: Counting objects: 100% (336/336), done.[K
remote: Compressing objects: 100% (265/265), done.[K
remote: Total 336 (delta 125), reused 132 (delta 26), pack-reused 0[K
Receiving objects: 100% (336/336), 16.59 MiB | 6.62 MiB/s, done.
Resolving deltas: 100% (125/125), done.


In [32]:
import cv2
import gym
import gym.spaces
import numpy as np
import collections


class FireResetEnv(gym.Wrapper):
    def __init__(self, env=None):
        """For environments where the user need to press FIRE for the game to start."""
        super(FireResetEnv, self).__init__(env)
        assert env.unwrapped.get_action_meanings()[1] == 'FIRE'
        assert len(env.unwrapped.get_action_meanings()) >= 3

    def step(self, action):
        return self.env.step(action)

    def reset(self):
        self.env.reset()
        obs, _, done, _ = self.env.step(1)
        if done:
            self.env.reset()
        obs, _, done, _ = self.env.step(2)
        if done:
            self.env.reset()
        return obs


class MaxAndSkipEnv(gym.Wrapper):
    def __init__(self, env=None, skip=4):
        """Return only every `skip`-th frame"""
        super(MaxAndSkipEnv, self).__init__(env)
        # most recent raw observations (for max pooling across time steps)
        self._obs_buffer = collections.deque(maxlen=2)
        self._skip = skip

    def step(self, action):
        total_reward = 0.0
        done = None
        for _ in range(self._skip):
            obs, reward, done, info = self.env.step(action)
            self._obs_buffer.append(obs)
            total_reward += reward
            if done:
                break
        max_frame = np.max(np.stack(self._obs_buffer), axis=0)
        return max_frame, total_reward, done, info

    def reset(self):
        """Clear past frame buffer and init. to first obs. from inner env."""
        self._obs_buffer.clear()
        obs = self.env.reset()
        self._obs_buffer.append(obs)
        return obs


class ProcessFrame84(gym.ObservationWrapper):
    def __init__(self, env=None):
        super(ProcessFrame84, self).__init__(env)
        self.observation_space = gym.spaces.Box(
            low=0, high=255, shape=(84, 84, 1), dtype=np.uint8)

    def observation(self, obs):
        return ProcessFrame84.process(obs)

    @staticmethod
    def process(frame):
        if frame.size == 210 * 160 * 3:
            img = np.reshape(frame, [210, 160, 3]).astype(
                np.float32)
        elif frame.size == 250 * 160 * 3:
            img = np.reshape(frame, [250, 160, 3]).astype(
                np.float32)
        else:
            assert False, "Unknown resolution."
        img = img[:, :, 0] * 0.299 + img[:, :, 1] * 0.587 + \
              img[:, :, 2] * 0.114
        resized_screen = cv2.resize(
            img, (84, 110), interpolation=cv2.INTER_AREA)
        x_t = resized_screen[18:102, :]
        x_t = np.reshape(x_t, [84, 84, 1])
        return x_t.astype(np.uint8)


class ImageToPyTorch(gym.ObservationWrapper):
    def __init__(self, env):
        super(ImageToPyTorch, self).__init__(env)
        old_shape = self.observation_space.shape
        new_shape = (old_shape[-1], old_shape[0], old_shape[1])
        self.observation_space = gym.spaces.Box(
            low=0.0, high=1.0, shape=new_shape, dtype=np.float32)

    def observation(self, observation):
        return np.moveaxis(observation, 2, 0)


class ScaledFloatFrame(gym.ObservationWrapper):
    def observation(self, obs):
        return np.array(obs).astype(np.float32) / 255.0


class BufferWrapper(gym.ObservationWrapper):
    def __init__(self, env, n_steps, dtype=np.float32):
        super(BufferWrapper, self).__init__(env)
        self.dtype = dtype
        old_space = env.observation_space
        self.observation_space = gym.spaces.Box(
            old_space.low.repeat(n_steps, axis=0),
            old_space.high.repeat(n_steps, axis=0), dtype=dtype)

    def reset(self):
        self.buffer = np.zeros_like(
            self.observation_space.low, dtype=self.dtype)
        return self.observation(self.env.reset())

    def observation(self, observation):
        self.buffer[:-1] = self.buffer[1:]
        self.buffer[-1] = observation
        return self.buffer


def make_env(env_name):
    env = gym.make(env_name)
    env = MaxAndSkipEnv(env)
    env = FireResetEnv(env)
    env = ProcessFrame84(env)
    env = ImageToPyTorch(env)
    env = BufferWrapper(env, 4)
    return ScaledFloatFrame(env)

In [14]:
class DQN:
  def __init__(self):
    pass

  

In [15]:
net = DQN()

In [37]:
class Agent:
  def __init__(self, env):
    self.env = env

  def play_step(self, net, epsilon, device):
    reward = 0
    print("action:")
    action = self.env.action_space.sample()
    print(action)
    print("obs:")
    obs = self.env.observation_space.sample()
    print(obs)

    return reward


In [38]:
env = make_env("PongNoFrameskip-v4")

In [39]:
agent = Agent(env)

In [None]:
epsilon=0
device=0

In [19]:
reward = 0
epsilon = 0
device = 0

In [40]:
reward += agent.play_step(net, epsilon, device=device)

action:
0
obs:
[[[0.7812051  0.46217105 0.8005907  ... 0.97476923 0.32244882 0.8649571 ]
  [0.03037664 0.83315784 0.7733953  ... 0.40451014 0.94008327 0.22292797]
  [0.5756634  0.8324113  0.8930624  ... 0.32795784 0.6076059  0.76291233]
  ...
  [0.5160297  0.7955999  0.18357807 ... 0.09747858 0.9735003  0.4396744 ]
  [0.49294585 0.2666546  0.10881287 ... 0.62490076 0.3491353  0.9185548 ]
  [0.33584905 0.374239   0.72649926 ... 0.89930695 0.61525065 0.84411216]]

 [[0.66039854 0.0204379  0.64591    ... 0.9523724  0.32150406 0.71167195]
  [0.81265    0.53409725 0.74865955 ... 0.938989   0.63491154 0.5634564 ]
  [0.5699733  0.16118458 0.8824068  ... 0.29843187 0.14388335 0.39418966]
  ...
  [0.75718826 0.06654327 0.07573684 ... 0.9610781  0.6548166  0.328281  ]
  [0.6346106  0.97894937 0.18874396 ... 0.60629845 0.7081806  0.46559456]
  [0.57174975 0.80954623 0.25579658 ... 0.78857106 0.34897763 0.128554  ]]

 [[0.6346261  0.34278873 0.11005142 ... 0.8135684  0.9158217  0.57338256]
  [0.46

In [6]:
reward

0