In [None]:
%matplotlib inline
import gym
from gym.wrappers import Monitor
import math
import random
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from collections import namedtuple
from itertools import count
from PIL import Image
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as T
from time import sleep

In [None]:
is_ipython = 'inline' in matplotlib.get_backend()
if is_ipython: from IPython import display

plt.ion()

In [None]:
class DQN(nn.Module):
    def __init__(self, img_height, img_width):
        super().__init__()
        
        # Just a few simple hidden layers
        self.fc1 = nn.Linear(in_features=img_height*img_width*3, out_features=24)
        self.fc2 = nn.Linear(in_features=24, out_features=32)
        self.out = nn.Linear(in_features=32, out_features=3)
        
    # will implement a forward pass to the network
    def forward(self, t):
        t = t.flatten(start_dim=1)
        t = F.relu(self.fc1(t))
        t = F.relu(self.fc2(t))
        t = self.out(t)
        return t

In [None]:
Experience = namedtuple(
    'Experience',
    ('state', 'action', 'next_state', 'reward')
)

In [None]:
class ReplayMemory():
    # This capacity is the only parameter that needs to be specified when creating a replay memory object
    # memory [] will be the structure that actually holds the stored experiences
    # push_count is how many experiences we've added to memory
    def __init__(self, capacity):
        self.capacity = capacity
        self.memory = []
        self.push_count = 0
    
    def push(self, experience):
        # accepts an experience. We have to first check that the memory is less than capacity
        # otherwise we begin to push experiences on to the front of memory. Overriding the oldest 
        # experience first 
        if len(self.memory) < self.capacity:
            self.memory.append(experience)
        else:
            self.memory[self.push_count % self.capacity] = experience
        self.push_count += 1
    
    # These sampled experiences are what we are going to use to train the DQN
    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)
    
    def can_provide_sample(self, batch_size):
        return len(self.memory) >= batch_size

In [None]:
class EpsilonGreedyStrategy():
    def __init__(self, start, end, decay):
        self.start = start
        self.end = end
        self.decay = decay
    
    # Returns the calculated exploration rate 
    # Our agent is going to then use this exploration rate to determine how it should select its actions
    # either by exploring or exploiting the environment
    def get_exploration_rate(self, current_step):
        return self.end + (self.start - self.end) * \
            math.exp(-1. * current_step * self.decay)

In [None]:
class Agent():
    # num_actions is the number of possible actions an agent can take from a given state
    def __init__(self, strategy, num_actions, device):
        # corresponds to the current step number in the environment
        self.current_step = 0
        self.strategy = strategy
        self.num_actions = num_actions
        self.device = device
    
    # policy_net is from DQN
    def select_action(self, state, policy_net):
        # exploration rate returned from the epsilon greedy strategy that was passed in when we created our agent
        rate = strategy.get_exploration_rate(self.current_step)
        # increment the agents current step by 1
        self.current_step += 1
    
        # check to see if the exploration rate is greater than a randomly generated number between 0 and 1
        if rate > random.random():
            # Then we explore the environment by randomly selecting an action, 0, 1, or 2
            return torch.tensor([[random.randrange(self.num_actions)]], device=device, dtype=torch.long)
        else:
            # Otherwise we exploit the environment by selecting the action that corresponds to the highest
            # q value output from our policy network for the given state
            # no_grad will turn off gradient tracking since we are currently using the model just for inference
            # and not for training
            with torch.no_grad():
                return policy_net(state).max(1)[1].view(1, 1) # exploit

In [None]:
class FreewayManager():
    def __init__(self, device, movie):
        self.device = device
#         self.env = Monitor(gym.make('Freeway-v0').unwrapped, movie, force=True)
        self.env = gym.make('Freeway-v0').unwrapped
        self.env.reset()
        # will keep track of the screen at any given time. None means we are at the start
        self.current_screen = None
        self.done = False
    
    # returns an initial observation of the environment
    def reset(self):
        self.env.reset()
        self.current_screen = None
    
    def close(self):
        self.env.close()
    
    def render(self, mode='human'):
        return self.env.render(mode)
    
    # returns the number of actions available to an agent in the environment
    def num_actions_available(self):
        return self.env.action_space.n
    
    # requires an action to be passed in
    # calls step on the environment
    # we only care about the reward and whether or not it is done
    # The action that will be passed to our main function will be a tensor
    # Item just returns the value of this tensor as a standard python number which is what step expects
    # Returns the reward wrapped in pytorch tensor. So we have a tensor coming in and a tensor coming out
    def take_action(self, action):
        state, reward, self.done, _ = self.env.step(action.item())
        if action.item() == 1:
            reward += 0.01
        
#         if action.item() == 2:
#             reward -= 0.1
        
        return torch.tensor([reward], device=self.device)

    
    # returns true if the current screen is none
    def just_starting(self):
        return self.current_screen is None
    
    # return the current state of the environment in the form of a processed image of the screen
    # we represent a single state in the environment as the difference between the current screen 
    # and the previous screen
    # the first screen before you start is represented by a black screen and the screen that corresponds
    # to when the episode ends is also going to be a black screen 
    def get_state(self):
        if self.just_starting() or self.done:
            self.current_screen = self.get_processed_screen()
            black_screen = torch.zeros_like(self.current_screen)
            return black_screen
        else:
            s1 = self.current_screen
            s2 = self.get_processed_screen()
            self.current_screen = s2
            return s2 - s1
    
    def get_screen_height(self):
        screen = self.get_processed_screen()
        return screen.shape[2]
    
    def get_screen_width(self):
        screen = self.get_processed_screen()
        return screen.shape[3]
    
    def get_processed_screen(self):
        screen = self.render('rgb_array').transpose((2, 0, 1))
        screen = self.crop_screen(screen)
        return self.transform_screen_data(screen)
    
    def crop_screen(self, screen):
        screen_height = screen.shape[1]
        
        # Strip off any of the screen not needed
        top = int(screen_height * 0.07)
        bottom = int(screen_height * 0.93)
        screen = screen[:, top:bottom, :]
        return screen
    
    def transform_screen_data(self, screen):
        # Convert to float, rescale, convert to tensor
        # stored sequentially next to eachother in memory
        screen = np.ascontiguousarray(screen, dtype=np.float32) / 255
        screen = torch.from_numpy(screen)
        
        # Use torchvision package to compose image transforms
        # used to chain together several image transformations
        resize = T.Compose([
            T.ToPILImage()
            ,T.Resize((80,90))
            ,T.ToTensor()
        ])
        
        # unsqueeze represents a batch dimension since the processed images will
        # be passed to the DQN in batches
        return resize(screen).unsqueeze(0).to(self.device) 

# Plotting

In [None]:
def plot(values, moving_avg_period, episode, rewards):
    plt.figure(2)
    plt.clf
    plt.title('Training...')
    plt.xlabel('# Of Times Crossing Highway')
    plt.ylabel('Crossing Time')
    plt.plot(values)
    
    moving_avg = get_moving_average(moving_avg_period, values) 
    plt.plot(moving_avg)
    plt.pause(0.001)
    print("Episode:", episode, "\n", \
         "Score:", moving_avg_period, " and Rewards:", rewards, "\naverage crossing time:", moving_avg[-1])
    if is_ipython: display.clear_output(wait=True)
        
def get_moving_average(period, values):
    values = torch.tensor(values, dtype=torch.float)
    # we can't calculate a moving average of a data set when the data set is not at least as large
    # as the period we wan't to calculate the moving average for
    if len(values) >= period:
        # returns a tensor containing all slices with a size equal to the period that was passed in
        # it does this on the 0th dimension of the original values tensor 
        # This gives us a new tensor containing all slices of size X across the original value tensor
        # We then take the mean of each of the slizes and flatten the tensor so that now the moving average is 
        # equal to a tensor containing all X period moving averages from the values that were passed in
        # We then concatenate this resulting tensor to a tensor of 0s with a size equal to period - 1
        # This is to show that the moving average for the first period - 1 values is 0 
        # Then convert the moving average tensor to a numpy array and return the result
        moving_avg = values.unfold(dimension=0, size=period, step=1) \
            .mean(dim=1).flatten(start_dim=0)
        moving_avg = torch.cat((torch.zeros(period-1), moving_avg))
        return moving_avg.numpy()
    else:
        # A numpy of all zeros with a length equal to the values array that was passed in
        moving_avg = torch.zeros(len(values))
        return moving_avg_numpy()

# Main Program

In [None]:
batch_size = 256
# The discount factor used in the bellman equation
gamma = 0.999
# Starting value of epsilon
eps_start = 1
# Ending value of epsilon
eps_end = 0.01
# The decay rate used to decay epsilon over time
eps_decay = 0.001
# How frequently, in terms of episodes, to update the networks weights with the policy networks weights
target_update = 10
# The capacity of the replay memory
memory_size = 100000
# learning rate that is used during training of the policy net
lr = 0.001
# The number of episodes to play
num_episodes = 1000

# Use the gpu if it is available otherwise cpu
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# em -> Environment Manager
mov = './atari/freeway/optimizer/_0'
em = FreewayManager(device, mov)
# An instance of the epsilon greedy strategy class
# Pass in the start, end, and decay values for epsilon
strategy = EpsilonGreedyStrategy(eps_start, eps_end, eps_decay)
# Define an agent using the agent class and pass in the required strategy, number of actions available, and device
agent = Agent(strategy, em.num_actions_available(), device)
# instance of replaymemory with capacity of memory_size
memory = ReplayMemory(memory_size)

policy_net = DQN(em.get_screen_height(), em.get_screen_width()).to(device)
target_net = DQN(em.get_screen_height(), em.get_screen_width()).to(device)
# Set the weights and targets in the target net to be the same as those in the policy net using PyTorches
# state_dict and load_state_dict functions 
target_net.load_state_dict(policy_net.state_dict())
# Put the target net into eval mode which tells PyTorch that this network is not in training mode 
# This network will only be used for inference 
target_net.eval()
# Adam optimizer accepts are policy_net parameters as those for which we will be optimizing 
# and our defined learning rate 
optimizer = optim.Adam(params=policy_net.parameters(), lr=lr)

# ready to start training 
prev_avg = 1000
episode_times = []
for episode in range(num_episodes):
    # for each episode we first reset the environment and then get the inital state
    scored = 0
    total_rewards = 0
    prev_timesteps = 0
#     best_time = 10000
    worst_time = 0
    em.env.reset()
    em.current_screen = None
    
    state = em.get_state()
    
    for timestep in count():
        # action is selected based on the current state
        # the agent will be using the policy_net network to select its action 
        # if it exploits the environment rather than explores it 
        em.env.render()
        action = agent.select_action(state, policy_net)
        reward = em.take_action(action)
        next_state = em.get_state()
        # Create and ecperience and store it in replay memory and set state to the next state
        memory.push(Experience(state, action, next_state, reward))
        state = next_state
        sleep(0.01)
        
        if reward.item() > 1:
            scored += 1
            
            if (timestep - prev_timesteps) > prev_avg:
#                 print("lost")
                reward -= 1
                
            episode_times.append((timestep - prev_timesteps))
            prev_timesteps = timestep
            
        total_rewards += reward.item()
        
        # Now that the agent has had an experience and stored it in replay memory we now check to 
        # see if we can get a sample from replay memory to train our policy net
        # We can get a sample equal to the batch size from replay memory as long as
        # the current size of memory is at least the batch size
        
        if memory.can_provide_sample(batch_size):
            # get a sample equal to batch_size
            experiences = memory.sample(batch_size)
            # extract into their own tensors
            batch = Experience(*zip(*experiences))
            
            non_final_mask = torch.tensor(tuple(map(lambda s: s is not None, \
                            batch.next_state)), device=device, dtype=torch.bool)
            
            non_final_next_states = torch.cat([s for s in batch.next_state if s is not None])
    
            states = torch.cat(batch.state)
            actions = torch.cat(batch.action)
            rewards = torch.cat(batch.reward)
            next_states = torch.cat(batch.next_state)
            
            # get the Q values for the corresponding state action pairs that we've 
            # extracted from our experiences in batch
            current_q_values = policy_net(states).gather(1, actions)
    
            next_q_values = torch.zeros(batch_size, device=device)
            next_q_values[non_final_mask] = target_net(non_final_next_states).max(1)[0].detach()
            # Can calculate the target Q values by multiplying each of the next q values by our discount
            # rate gamma and add this result to the corresponding reward in the rewards tensor to create
            # a new tensor of target q values
            target_q_values = (next_q_values * gamma) + rewards
            
            # We now can calculate the loss between the current q values and the target q values
            # using mean squared error as our loss function and then we zero out the gradients
            # using optimizer. Sets the gradients of all the weights and biases in the policy net to 0 
            # Since PyTorch accumulates the gradients when it does back prop, we need to call zero grad
            # before back prop occurs. Otherwise if we didn't zero out the gradients each time then we would be 
            # accumulating gradients across all backprop runs
            loss = F.mse_loss(current_q_values, target_q_values.unsqueeze(1))
            optimizer.zero_grad()
            # Copmutes the gradient of the loss with respect to all the weights and biases in the policy net
            loss.backward()
            # Updates the weights and biases with the gradients that we computed when we called backward on
            # our loss
            optimizer.step()
        
        if em.done:
            plot(episode_times, scored, episode+1, total_rewards)
            prev_avg = timestep / scored + 20
#             print(prev_avg)
            break
    
    # updating every 10 episodes
    if episode % target_update == 0:
        target_net.load_state_dict(policy_net.state_dict())

    em.close()