# Lab 01 : Deep Q-Learning (DQN) - demo

Author: Adam Paszke<br>
https://pytorch.org/tutorials/intermediate/reinforcement_q_learning.html

Cart pole dataset:<br>
https://github.com/openai/gym/wiki/CartPole-v0

In [1]:
import gym # OpenAI library
import math
import random
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from collections import namedtuple
from itertools import count
from PIL import Image

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as T

# Load the environment CartPole from OpenAI
env = gym.make('CartPole-v0').unwrapped

# if gpu available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# set up matplotlib
is_ipython = 'inline' in matplotlib.get_backend()
if is_ipython:
    from IPython import display
    

In [2]:
# A memory that store training data (s, a, s', r)

Transition = namedtuple('Transition',
                        ('state', 'action', 'next_state', 'reward'))

# Class of experience replay
class ReplayMemory(object):

    def __init__(self, capacity):
        self.capacity = capacity
        self.memory = []
        self.position = 0

    #  Store the observed transitions (s, a, s', r)
    def push(self, *args):
        """Saves a transition."""
        if len(self.memory) < self.capacity:
            self.memory.append(None)
        self.memory[self.position] = Transition(*args)
        self.position = (self.position + 1) % self.capacity

    #  Select a random batch of transitions for training 
    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def __len__(self):
        return len(self.memory)


In [3]:
# Class of Deep Q-Networks (DQN)

class DQN(nn.Module):

    def __init__(self):
        super(DQN, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, kernel_size=5, stride=2)
        self.bn1 = nn.BatchNorm2d(16)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=5, stride=2)
        self.bn2 = nn.BatchNorm2d(32)
        self.conv3 = nn.Conv2d(32, 32, kernel_size=5, stride=2)
        self.bn3 = nn.BatchNorm2d(32)
        self.head = nn.Linear(448, 2) # output has 2 dimensions = Left or Right move

    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.relu(self.bn3(self.conv3(x)))
        return self.head(x.view(x.size(0), -1))
    

In [4]:
# Extract images=states=s generated by the environment. 

screen_width = 600 # from gym
def get_cart_location():
    
    world_width = env.x_threshold * 2
    scale = screen_width / world_width
    return int(env.state[0] * scale + screen_width / 2.0)  # middle of cart

# Images will be transformed by torchvision from [3, 160, 320] to [3, 40, 80].
resize = T.Compose([T.ToPILImage(),
                    T.Resize(40, interpolation=Image.CUBIC),
                    T.ToTensor()])

def get_screen():
    
    screen = env.render(mode='rgb_array').transpose((2, 0, 1)) # transpose into torch order (CHW)
    screen = screen[:, 160:320] # Strip off the top and bottom of the screen
    view_width = 320
    cart_location = get_cart_location()
    if cart_location < view_width // 2:
        slice_range = slice(view_width)
    elif cart_location > (screen_width - view_width // 2):
        slice_range = slice(-view_width, None)
    else:
        slice_range = slice(cart_location - view_width // 2,
                            cart_location + view_width // 2)
    screen = screen[:, :, slice_range] # Strip off the edges so that we have
                                       # a square image centered on a cart.
    # Convert to float, re-scale, convert to torch tensor
    screen = np.ascontiguousarray(screen, dtype=np.float32) / 255
    screen = torch.from_numpy(screen)
    screen = resize(screen).unsqueeze(0).to(device) # Resize, and add a batch dimension (BCHW)
    return screen



In [8]:
# Hyper-parameters
BATCH_SIZE = 128
GAMMA = 0.999
EPS_START = 0.9
EPS_END = 0.05
EPS_DECAY = 200
TARGET_UPDATE = 10

# Instantiate DQN networks
policy_net = DQN().to(device) # policy_network
target_net = DQN().to(device) # policy_network_frozen : it is used to deal with non-stationarity
target_net.load_state_dict(policy_net.state_dict()) # policy_network_frozen = policy network
target_net.eval() # policy_network_frozen not updated (no backpropagation)

# Optimizer
optimizer = optim.RMSprop(policy_net.parameters())

# Initialize the replay memory
memory = ReplayMemory(10000) 

# Function that samples an action (left or right move) from policy network
# The sampling function that use either
# 1. A random sampling of the action
# 2. A Bernoulli sampling from the policy network
# The chance of random sampling decreases over time
steps_done = 0
def select_action(state):
    
    global steps_done
    sample = random.random() # generates a random float uniformly in [0.0, 1.0)
    eps_threshold = EPS_END + (EPS_START - EPS_END) * \
        math.exp(-1. * steps_done / EPS_DECAY)
    steps_done += 1
    if sample > eps_threshold:
        with torch.no_grad():
            return policy_net(state).max(1)[1].view(1, 1) # sample one action from the policy network
    else:
        return torch.tensor([[random.randrange(2)]], device=device, dtype=torch.long) # random sampling of one action

# Plot the average of the episode duration over 100 episodes.
# Note: The plotting (orange) starts from the 100th episode.
episode_durations = []
def plot_durations():
    
    plt.figure(2)
    #plt.clf()
    durations_t = torch.tensor(episode_durations, dtype=torch.float)
    plt.title('Training...')
    plt.xlabel('Episode')
    plt.ylabel('Duration')
    plt.plot(durations_t.numpy())
    # Take 100 episode averages and plot them 
    if len(durations_t) >= 100:
        means = durations_t.unfold(0, 100, 1).mean(1).view(-1)
        means = torch.cat((torch.zeros(99), means))
        plt.plot(means.numpy())
    plt.pause(0.001)  # pause a bit so that plots are updated
    #plt.pause(1) # pause a bit so that plots are updated
    if is_ipython:
        display.clear_output(wait=True)
        #display.display(plt.gcf())
        

In [9]:
# Training function
def optimize_model():
    
    # nothing to train if memory is empty or less than the batch size
    if len(memory) < BATCH_SIZE:
        return
    
    # sample a batch of transitions (s, a, s', r) from the replay memory 
    transitions = memory.sample(BATCH_SIZE) # 
    batch = Transition(*zip(*transitions))

    # compute next states=s' that are not final states
    non_final_mask = torch.tensor(tuple(map(lambda s: s is not None,
                                          batch.next_state)), device=device, dtype=torch.uint8)
    # concatenate next states=s'
    non_final_next_states = torch.cat([s for s in batch.next_state
                                                if s is not None])
    # concatenate states=s, actions=a, and reward=r
    state_batch = torch.cat(batch.state) # s
    action_batch = torch.cat(batch.action) # a
    reward_batch = torch.cat(batch.reward) # r

    # compute Q_w(a|s)
    state_action_values = policy_net(state_batch).gather(1, action_batch) 

    # compute max_a' Q_w#(a'|s'), w# = policy_network_frozen
    next_state_values = torch.zeros(BATCH_SIZE, device=device)
    next_state_values[non_final_mask] = target_net(non_final_next_states).max(1)[0].detach() 
    
    # compute expected Q value : r + gamma . max_a' Q_w#(a'|s')
    expected_state_action_values = reward_batch + GAMMA * next_state_values

    # Compute Huber loss : 
    # sum_batch=(s,a,s',r) huber_loss( Q_w(a|s) - (r+g.max_a' Q_w#(a'|s')) )
    loss = F.smooth_l1_loss(state_action_values, expected_state_action_values.unsqueeze(1))
    
    # backpropagate
    optimizer.zero_grad()
    loss.backward()
    
    # clip the gradient value to prevent the exploding gradient problem
    for param in policy_net.parameters():
        param.grad.data.clamp_(-1, 1)
    
    optimizer.step()
    

In [None]:
# loop over the episodes
num_episodes = 10000
for i_episode in range(num_episodes):
    
    # Initialize the environment and state
    env.reset()
    last_screen = get_screen()
    current_screen = get_screen()
    state = current_screen - last_screen
    
    # Roll-out the episode
    for t in count():
        
        # select and perform an action
        action = select_action(state)
        _, reward, done, _ = env.step(action.item())
        reward = torch.tensor([reward], device=device)

        # observe new state
        last_screen = current_screen
        current_screen = get_screen()
        if not done:
            next_state = current_screen - last_screen
        else:
            next_state = None

        # store the transition in memory
        memory.push(state, action, next_state, reward)

        # move to the next state
        state = next_state

        # perform one step of the optimization (on the target network)
        optimize_model()
        if done:
            episode_durations.append(t + 1)
            plot_durations()
            break
            
    # update the target network = policy_network_frozen with the current policy_network
    if i_episode % TARGET_UPDATE == 0:
        target_net.load_state_dict(policy_net.state_dict())

print('Complete')
env.render()
env.close() # close the render window
plt.show()

In [None]:
env.close() # close the render window