# Chapter 6: Deep Q-Networks

Initially, improving on Q-iteration by being cleverer about which states to focus on, and how to update values. Then combining these techniques with neural nets for a whole new level of power...

## Q-learning

In [1]:
import gym
import collections
from torch.utils.tensorboard import SummaryWriter

ENV_NAME = "FrozenLake-v1"
GAMMA = 0.9
ALPHA = 0.2 # learning rate - "blending" parameter in the value update step later
TEST_EPISODES = 20

class Agent:
    def __init__(self):
        self.env = gym.make(ENV_NAME)
        self.state = self.env.reset()
        self.values = collections.defaultdict(float) # only storing values, not rewards or transitions any more!

    # Obtain next transition from environment
    def sample_env(self):
        action = self.env.action_space.sample()
        old_state = self.state
        new_state, reward, is_done, _ = self.env.step(action)
        self.state = self.env.reset() if is_done else new_state
        return old_state, action, reward, new_state

    # Given state, find "best" action according to our tabulated values
    # If we don't have a value for a state-action pair, assume it's 0
    def best_value_and_action(self, state):
        best_value, best_action = None, None
        for action in range(self.env.action_space.n):
            action_value = self.values[(state, action)]
            if best_value is None or best_value < action_value:
                best_value, best_action = action_value, action
        return best_value, best_action
    
    # Update values table using one step from environment
    def value_update(self, s, a, r, next_s):
        best_v, _ = self.best_value_and_action(next_s)
        new_v = r + GAMMA*best_v # Bellman approximation
        old_v = self.values[(s, a)]
        self.values[(s, a)] = (1-ALPHA)*old_v + ALPHA*new_v # blend old and new values
    
    # Play one full episode using provided test environment
    def play_episode(self, env):
        total_reward = 0.0
        state = env.reset()
        while True:
            _, action = self.best_value_and_action(state)
            new_state, reward, is_done, _ = env.step(action)
            total_reward += reward
            if is_done:
                break
            state = new_state
        return total_reward


In [4]:
test_env = gym.make(ENV_NAME)
agent = Agent()
writer = SummaryWriter(comment="-q-learning")

iter_no = 0
best_reward = 0.0
while True:
    iter_no += 1
    s, a, r, next_s = agent.sample_env()
    agent.value_update(s, a, r, next_s)

    reward = 0.0
    for _ in range(TEST_EPISODES):
        reward += agent.play_episode(test_env)
    reward /= TEST_EPISODES
    writer.add_scalar("reward", reward, iter_no)
    if reward > best_reward:
        print(f"Best reward updated {best_reward:.3f} -> {reward:.3f}")
        best_reward = reward
    if reward > 0.80:
        print(f"Solved in {iter_no} iterations!")
        writer.close()
        break


Best reward updated 0.000 -> 0.050
Best reward updated 0.050 -> 0.100
Best reward updated 0.100 -> 0.150
Best reward updated 0.150 -> 0.200
Best reward updated 0.200 -> 0.250
Best reward updated 0.250 -> 0.300
Best reward updated 0.300 -> 0.350
Best reward updated 0.350 -> 0.450
Best reward updated 0.450 -> 0.500
Best reward updated 0.500 -> 0.600
Best reward updated 0.600 -> 0.650
Best reward updated 0.650 -> 0.700
Best reward updated 0.700 -> 0.750
Best reward updated 0.750 -> 0.800
Best reward updated 0.800 -> 0.850
Solved in 4961 iterations!


## A more challenging problem: Atari games

In [1]:
# Wrappers

import cv2
import gym
import gym.spaces
import numpy as np
import collections

# Press FIRE button in games where that triggers game start
class FireResetEnv(gym.Wrapper):
    def __init__(self, env=None):
        super(FireResetEnv, self).__init__(env)
        # Check at least 3 actions, and that the second action is the "FIRE" button
        assert env.unwrapped.get_action_meanings()[1] == "FIRE"
        assert len(env.unwrapped.get_action_meanings()) >= 3

    def step(self, action):
        return self.env.step(action)

    # Also check a couple of corner cases for certain games, relating to actions 1 and 2
    def reset(self):
        self.env.reset()
        obs, _, done, _ = self.env.step(1)
        if done:
            self.env.reset()
        obs, _, done, _ = self.env.step(2)
        if done:
            self.env.reset()
        return obs


# Combine repetition of actions during K frames, and pixels from two consecutive frames
class MaxAndSkipEnv(gym.Wrapper):
    def __init__(self, env=None, skip=4):
        super(MaxAndSkipEnv, self).__init__(env)
        self._obs_buffer = collections.deque(maxlen=2)
        self._skip = skip # 4 is arbitrary, but canonically works well for Atari games

    def step(self, action):
        total_reward = 0.0
        done = None
        # Perform the same action _skip times (as in, skip the decision-making process)
        for _ in range(self._skip):
            obs, reward, done, info = self.env.step(action)
            self._obs_buffer.append(obs)
            total_reward += reward
            if done:
                break
        max_frame = np.max(np.stack(self._obs_buffer), axis=0) # get brightest pixel from two frames
        return max_frame, total_reward, done, info
    
    def reset(self):
        self._obs_buffer.clear()
        obs = self.env.reset()
        self._obs_buffer.append(obs)
        return obs


# Rescale frames and convert to greyscale
class ProcessFrame84(gym.ObservationWrapper):
    def __init__(self, env=None):
        super(ProcessFrame84, self).__init__(env)
        self.observation_space = gym.spaces.Box(low=0, high=255, shape=(84, 84, 1), dtype=np.uint8) # pixel intensities
    
    def observation(self, obs):
        return ProcessFrame84.process(obs)

    def process(frame):
        if frame.size == 210*160*3:
            img = np.reshape(frame, [210, 160, 3]).astype(np.float32)
        elif frame.size == 250*160*3:
            img = np.reshape(frame, [250, 160, 3]).astype(np.float32)
        else:
            assert False, "Unknown resolution" # throw an error
        img = 0.299*img[:, :, 0] + 0.587*img[:, :, 1] + 0.114*img[:, :, 2] # "colorimetric" greyscale averaging
        resized_screen = cv2.resize(img, (84, 110), interpolation=cv2.INTER_AREA)
        x_t = resized_screen[18:102, :] # crop top and bottom
        x_t = np.reshape(x_t, [84, 84, 1]) # add third dimension
        return x_t.astype(np.uint8)


# Create a stack of consecutive frames and return as a single observation
# (this allows the network to "see" velocity, movement etc which are not apparent from an isolated frame)
class BufferWrapper(gym.ObservationWrapper):
    def __init__(self, env, n_steps, dtype=np.float32):
        super(BufferWrapper, self).__init__(env)
        self.dtype = dtype
        old_space = env.observation_space
        self.observation_space = gym.spaces.Box(
            old_space.low.repeat(n_steps, axis=0),
            old_space.high.repeat(n_steps, axis=0),
            dtype=dtype
        )

    def reset(self):
        self.buffer = np.zeros_like(self.observation_space.low, dtype=self.dtype)
        return self.observation(self.env.reset())
    
    def observation(self, observation):
        self.buffer[:-1] = self.buffer[1:]
        self.buffer[-1] = observation
        return self.buffer


# Rearrange dimensions to PyTorch convention (channel-first)
class ImageToPyTorch(gym.ObservationWrapper):
    def __init__(self, env):
        super(ImageToPyTorch, self).__init__(env)
        old_shape = self.observation_space.shape
        new_shape = (old_shape[-1], old_shape[0], old_shape[1])
        self.observation_space = gym.spaces.Box(low=0.0, high=1.0, shape=new_shape, dtype=np.float32)

    def observation(self, observation):
        return np.moveaxis(observation, 2, 0)


# Convert observation data from bytes to float and rescale to 0-1
class ScaledFloatFrame(gym.ObservationWrapper):
    def observation(self, obs):
        return np.array(obs).astype(np.float32) / 255.0


# Function to create environment and apply all wrappers
def make_env(env_name):
    env = gym.make(env_name)
    env = MaxAndSkipEnv(env)
    env = FireResetEnv(env)
    env = ProcessFrame84(env)
    env = ImageToPyTorch(env)
    env = BufferWrapper(env)
    env = ScaledFloatFrame(env)
    return env

In [None]:
# The model

import torch
import torch.nn as nn
import numpy as np

class DQN(nn.Module):
    def __init__(self, input_shape, n_actions):
        super(DQN, self).__init__()

        self.conv = nn.Sequential(
            nn.Conv2d(input_shape[0], 32, kernel_size=8, stride=4),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=1),
            nn.ReLU()
        )

        # No built-in "flatten" layer in torch, so we have to DIY that bit and then use
        # a separate network for our final couple of fully-connected layers  
        conv_out_size = self._get_conv_out(input_shape)
        self.fc = nn.Sequential(
            nn.Linear(conv_out_size, 512),
            nn.ReLU(),
            nn.Linear(512, n_actions)
        )

        def _get_conv_out(self, shape):
            o = self.conv(torch.zeros(1, *shape))
            return int(np.prod(o.size()))

        def forward(self, x):
            conv_out = self.conv(x).view(x.size()[0], -1) # reshape to just batch number in 0th dimension, then all params in 1st
            return self.fc(conv_out)