In [None]:
! pip install gym[atari,accept-rom-license]

In [None]:
%%writefile dqn_pong.py

import wrappers
import dqn_model

import argparse
import time
import numpy as np
import collections

import torch
import torch.nn as nn
import torch.optim as optim


DEFAULT_ENV_NAME = "PongNoFrameskip-v4"
MEAN_REWARD_BOUND = 19

GAMMA = 0.99
BATCH_SIZE = 32
REPLAY_SIZE = 10000
LEARNING_RATE = 1e-4
SYNC_TARGET_FRAMES = 1000
REPLAY_START_SIZE = 10000

EPSILON_DECAY_LAST_FRAME = 150000
EPSILON_START = 1.0
EPSILON_FINAL = 0.01

Experience = collections.namedtuple("Experience", field_names = ["state", "action", "reward", "done", "new_sate"])

class ExperienceBuffer():
  def __init__(self, capacity):
    self.buffer = collections.deque(maxlen = capacity)

  def __len__(self):
    return len(self.buffer)

  def append(self, experience):
    self.buffer.append(experience)

  def sample(self, batch_size):
    indices = np.random.choice(len(self.buffer), batch_size, replace = False)

    states, actions, rewards, dones, next_states = \
    zip(*[self.buffer[index] for index in indices])

    return np.array(states), np.array(actions), \
               np.array(rewards, dtype=np.float32), \
               np.array(dones, dtype=np.uint8), \
               np.array(next_states)


class Agent():
  def __init__(self, env, exp_buffer):
    self.env = env
    self.exp_buffer = exp_buffer
    self._reset()

  def _reset(self):
    self.state = env.reset()
    self.total_reward = 0.0

  @torch.no_grad()
  def play_step(self, net, epsilon = 0.0, device = "cpu"):
    done_reward = None

    
    if np.random.random() < epsilon:
      # choose random action according to epsilon prob
      # epsilon will decay from 1.0 to 0.01 in 150000 episode linearly
      action = self.env.action_space.sample()

    else:
      state_a = np.array([self.state], copy = False)
      state_v = torch.tensor(state_a).to(device)
      q_vals_v = net(state_v)
      _, act_v = torch.max(q_vals_v, dim = 1)
      action = act_v.item()

      # take a step in the env

      new_state, reward, is_done = env.step(action)
      self.total_reward += reward

      exp = Experience(self.state, action, reward, is_done, new_state)

      self.exp_buffer.append(exp)

      self.state = new_state

      if is_done:
        done_reward = total_reward
        self._reset()

      return done_reward
 
  def calc_loss(batch, net, tgt_net, device = "cpu"):
    # get batch of random experiences
    states, actions, rewards, dones, next_states = batch

    # convert all of them to torch tensor  
    states_v = torch.tensor(np.array(states, copy = False)).to(device)
    actions_v = torch.tensor(actions).to(device)
    next_states_v = torch.tensor(np.array(next_states, copy = False)).to(device)
    rewards_v = torch.tensor(rewards).to(device)
    done_mask = torch.BoolTensor(dones).to(device)

    # get Q(s, a) in vector
    state_action_values = net(states_v).gather(1, actions_v.unsqueeze(dim = -1)).squeeze(dim = -1)

    # loss between Q(s, a) - (r + GAMMA * Q'(s', a'))
    # Q(s, a) is calculated from the current network
    # Q'(s', a') is calculated from the target network
    # every 1000 episodes, target network is synchronized with the current network
    with torch.no_grad():
      next_state_values = tgt_net(next_states_v).max(1)[0]
      next_state_values[done_mask] = 0.0 # when it's a done-state, the value is just reward 
      next_state_values = next_state_values.detach()

    expected_state_action_values = (GAMMA * next_state_values) + rewards_v

    return nn.MSELoss()(state_action_values, expected_state_action_values)


if __name__ == "__main__":
  parser = argparse.ArgumentParser()
  parser.add_argument("--cuda", default = False, action = "store_true", help ="Use Cuda")
  parser.add_argument("--env", default = DEFAULT_ENV_NAME, help = "name of the environment")

  args = parser.parse_args()

  device = torch.device("cuda" if args.cuda else "cpu")

  env = wrappers.make_env(args.env)

  net = dqn_model.DQN(env.observation_space, env.action_space.n).to(device)

  # tgt_net is used to calculate Q'(s', a')
  # this should not flow into gradient calculation
  tgt_net = dqn_model.DQN(env.observation_space, env.action_space.n).to(device)

  print(net)

  buffer = ExperienceBuffer(REPLAY_SIZE)
  agent = Agent(env, buffer)
  epsilon = EPSILON_START

  optimizer = optim.Adam(net.parameters(), lr = LEARNING_RATE)

  total_rewards = []
  frame_idx = 0
  ts_frame = 0
  ts = time.time()
  best_m_reward = None

  while True:
    frame_idx += 1

    # epsilon decreases linearly
    epsilon = max(EPSILON_FINAL, EPSILON_START - (frame_idx / EPSILON_DECAY_LAST_FRAME))

    reward = agent.play_step(net, epsilon, device = device)

    if reward is not None:
      # when an episode is finished
      # check the progress of RL

      total_rewards.append(reward)
      speed = (frame_idx - ts_frame) / (time.time() - ts)
      ts_frame = frame_idx
      ts = time.time()

      m_reward = np.mean(total_rewards[-100:])

      print("%d: done %d games, reward %.3f, " "eps %.2f, speed %.2f f/s" % (
                frame_idx, len(total_rewards), m_reward, epsilon,
                speed
            ))

      if best_m_reward is None or best_m_reward < m_reward:
        torch.save(net.state_dict(), args.env + "-best_%.0f.dat" % m_reward)

        if best_m_reward is not None:
          print("Best reward updated %.3f -> %.3f" % (best_m_reward, m_reward))
        
        # update the best mean reward for the last 100 episodes
        best_m_reward = m_reward

      # the case where Pong is solved
      if m_reward > MEAN_REWARD_BOUND:
        print("Solved in %d frames!" % frame_idx)
        break

    # wait to until enough data is collected
    if len(buffer) < REPLAY_START_SIZE:
      continue 

    # synchronize parameters from main network to target network
    # for every 1000 frames
    if frame_idx % SYNC_TARGET_FRAMES == 0:
      tgt_net.load_state_dict(net.state_dict())
    
    optimizer.zero_grad()
    batch = buffer.sample(BATCH_SIZE)
    loss = calc_loss(batch, net, tgt_net, device = device)
    loss.backward()
    optimizer.step()

In [None]:
%%writefile dqn_model.py
import torch
import torch.nn as nn
import numpy as np


class DQN(nn.Module):
    def __init__(self, input_shape, n_actions):
        super(DQN, self).__init__()

        self.conv = nn.Sequential(
            nn.Conv2d(input_shape[0], 32, kernel_size=8, stride=4),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=1),
            nn.ReLU()
        )

        conv_out_size = self._get_conv_out(input_shape)
        self.fc = nn.Sequential(
            nn.Linear(conv_out_size, 512),
            nn.ReLU(),
            nn.Linear(512, n_actions)
        )

    def _get_conv_out(self, shape):
        o = self.conv(torch.zeros(1, *shape))
        return int(np.prod(o.size()))

    def forward(self, x):
        conv_out = self.conv(x).view(x.size()[0], -1)
        return self.fc(conv_out)

In [None]:
%%writefile wrappers.py 

import cv2
import gym
import gym.spaces
import numpy as np
import collections


class FireResetEnv(gym.Wrapper):
    def __init__(self, env=None):
        """For environments where the user need to press FIRE for the game to start."""
        super(FireResetEnv, self).__init__(env)
        assert env.unwrapped.get_action_meanings()[1] == 'FIRE'
        assert len(env.unwrapped.get_action_meanings()) >= 3

    def step(self, action):
        return self.env.step(action)

    def reset(self):
        self.env.reset()
        obs, _, done, _ = self.env.step(1)
        if done:
            self.env.reset()
        obs, _, done, _ = self.env.step(2)
        if done:
            self.env.reset()
        return obs


class MaxAndSkipEnv(gym.Wrapper):
    def __init__(self, env=None, skip=4):
        """Return only every `skip`-th frame"""
        super(MaxAndSkipEnv, self).__init__(env)
        # most recent raw observations (for max pooling across time steps)
        self._obs_buffer = collections.deque(maxlen=2)
        self._skip = skip

    def step(self, action):
        total_reward = 0.0
        done = None
        for _ in range(self._skip):
            obs, reward, done, info = self.env.step(action)
            self._obs_buffer.append(obs)
            total_reward += reward
            if done:
                break
        max_frame = np.max(np.stack(self._obs_buffer), axis=0)
        return max_frame, total_reward, done, info

    def reset(self):
        """Clear past frame buffer and init. to first obs. from inner env."""
        self._obs_buffer.clear()
        obs = self.env.reset()
        self._obs_buffer.append(obs)
        return obs


class ProcessFrame84(gym.ObservationWrapper):
    def __init__(self, env=None):
        super(ProcessFrame84, self).__init__(env)
        self.observation_space = gym.spaces.Box(
            low=0, high=255, shape=(84, 84, 1), dtype=np.uint8)

    def observation(self, obs):
        return ProcessFrame84.process(obs)

    @staticmethod
    def process(frame):
        if frame.size == 210 * 160 * 3:
            img = np.reshape(frame, [210, 160, 3]).astype(
                np.float32)
        elif frame.size == 250 * 160 * 3:
            img = np.reshape(frame, [250, 160, 3]).astype(
                np.float32)
        else:
            assert False, "Unknown resolution."
        img = img[:, :, 0] * 0.299 + img[:, :, 1] * 0.587 + \
              img[:, :, 2] * 0.114
        resized_screen = cv2.resize(
            img, (84, 110), interpolation=cv2.INTER_AREA)
        x_t = resized_screen[18:102, :]
        x_t = np.reshape(x_t, [84, 84, 1])
        return x_t.astype(np.uint8)


class ImageToPyTorch(gym.ObservationWrapper):
    def __init__(self, env):
        super(ImageToPyTorch, self).__init__(env)
        old_shape = self.observation_space.shape
        new_shape = (old_shape[-1], old_shape[0], old_shape[1])
        self.observation_space = gym.spaces.Box(
            low=0.0, high=1.0, shape=new_shape, dtype=np.float32)

    def observation(self, observation):
        return np.moveaxis(observation, 2, 0)


class ScaledFloatFrame(gym.ObservationWrapper):
    def observation(self, obs):
        return np.array(obs).astype(np.float32) / 255.0


class BufferWrapper(gym.ObservationWrapper):
    def __init__(self, env, n_steps, dtype=np.float32):
        super(BufferWrapper, self).__init__(env)
        self.dtype = dtype
        old_space = env.observation_space
        self.observation_space = gym.spaces.Box(
            old_space.low.repeat(n_steps, axis=0),
            old_space.high.repeat(n_steps, axis=0), dtype=dtype)

    def reset(self):
        self.buffer = np.zeros_like(
            self.observation_space.low, dtype=self.dtype)
        return self.observation(self.env.reset())

    def observation(self, observation):
        self.buffer[:-1] = self.buffer[1:]
        self.buffer[-1] = observation
        return self.buffer


def make_env(env_name):
    env = gym.make(env_name)
    env = MaxAndSkipEnv(env)
    env = FireResetEnv(env)
    env = ProcessFrame84(env)
    env = ImageToPyTorch(env)
    env = BufferWrapper(env, 4)
    return ScaledFloatFrame(env)