In [0]:
!apt-get install -y xvfb python-opengl > /dev/null 2>&1

In [0]:
!pip install gym pyvirtualdisplay > /dev/null 2>&1

In [0]:
from IPython import display as ipythondisplay
from pyvirtualdisplay import Display
display = Display(visible=0, size=(400, 300))
display.start()

In [0]:
%matplotlib inline

In [0]:
from google.colab import drive
drive.mount('/content/gdrive')

In [0]:
import gym
import math
import random
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from collections import namedtuple, deque
from itertools import count
from PIL import Image
import time
import imageio

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as T
from collections import deque

env = gym.make('CartPole-v1').unwrapped

# set up matplotlib
is_ipython = 'inline' in matplotlib.get_backend()
if is_ipython:
    from IPython import display

plt.ion()

# if gpu is to be used
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [0]:
device

In [0]:
class ReplayMemory(object):

    def __init__(self, size, frame_height=84, frame_width=84, channels=3, batch_size=32):
        self.size = size
        
        self.frame_height = frame_height
        self.frame_width = frame_width
        self.channels = channels
        self.batch_size = batch_size


        self.rewards = torch.empty(self.size, device=device, dtype=torch.uint8)
        self.actions = torch.empty(self.size, device=device, dtype=torch.uint8)
        self.states = torch.empty((self.size, self.channels, self.frame_height, self.frame_width), device=device,dtype=torch.uint8)
        self.next_states = torch.empty((self.size, self.channels, self.frame_height, self.frame_width), device=device,dtype=torch.uint8)
        self.done_flags = torch.empty(self.size, device=device, dtype=torch.uint8)

        self.batch_states = torch.empty((self.batch_size, self.channels, self.frame_height, self.frame_width), device=device, dtype=torch.uint8)
        self.batch_next_states = torch.empty((self.batch_size, self.channels, self.frame_height, self.frame_width), device=device, dtype=torch.uint8)
        self.indices = torch.empty(self.batch_size, dtype=torch.int16)

        self.current_position = 0
        self.current_size = 0

    def add_experience(self, state, action, next_state, reward, done):
        """Saves a transition."""
        if state.shape != (self.channels, self.frame_height, self.frame_width):
          print (state.shape)
          raise ValueError("Frame dimension is of wrong size")

        self.actions[self.current_position] = action
        self.rewards[self.current_position] = reward
        self.done_flags[self.current_position] = done
        # import pdb; pdb.set_trace()
        self.states[self.current_position, ...] = 255*state
        self.next_states[self.current_position, ...] = 255*next_state

        self.current_size = max(self.current_size, self.current_position)
        self.current_position = (self.current_position + 1) % self.size

    def get_MiniBatch(self):
        if self.current_size < self.batch_size:
            raise ValueError("Not enough experience in memory")
        
        for i in range(self.batch_size):
            self.indices[i] = random.randint(0, self.current_size-1)

        # import pdb; pdb.set_trace();
        for i, idx in enumerate(self.indices):
            self.batch_states[i] = self.states[idx]
            self.batch_next_states[i] = self.next_states[idx+1]
    #   plt.imshow(np.array(self.states[0,0,:,:].cpu()))
        self.indices = self.indices.type(torch.LongTensor)

        return self.batch_states.type(torch.FloatTensor).to(device), self.actions[self.indices].unsqueeze(1).type(torch.LongTensor).to(device), self.rewards[self.indices].to(device), \
        self.batch_next_states.type(torch.FloatTensor).to(device), self.done_flags[self.indices].to(device)
    
    def __len__(self):
        return self.current_size

In [0]:
class DQN(nn.Module):

    def __init__(self, env):
        super(DQN, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, kernel_size=5, stride=2)
        self.bn1 = nn.BatchNorm2d(16)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=5, stride=2)
        self.bn2 = nn.BatchNorm2d(32)
        self.conv3 = nn.Conv2d(32, 32, kernel_size=5, stride=2)
        self.bn3 = nn.BatchNorm2d(32)
        self.head = nn.Linear(448, env.action_space.n)

    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.relu(self.bn3(self.conv3(x)))
        return self.head(x.view(x.size(0), -1))

In [0]:
resize_screen = T.Compose([T.ToPILImage(),
                    T.Resize(40, interpolation=Image.CUBIC),
                    T.ToTensor()])

# This is based on the code from gym.
screen_width = 600


def get_cart_location():
    world_width = env.x_threshold * 2
    scale = screen_width / world_width
    return int(env.state[0] * scale + screen_width / 2.0)  # MIDDLE OF CART


def get_screen():
    screen = env.render(mode='rgb_array').transpose(
        (2, 0, 1))  # transpose into torch order (CHW)
    # Strip off the top and bottom of the screen
    screen = screen[:, 160:320]
    view_width = 320
    cart_location = get_cart_location()
    if cart_location < view_width // 2:
        slice_range = slice(view_width)
    elif cart_location > (screen_width - view_width // 2):
        slice_range = slice(-view_width, None)
    else:
        slice_range = slice(cart_location - view_width // 2,
                            cart_location + view_width // 2)
    # Strip off the edges, so that we have a square image centered on a cart
    screen = screen[:, :, slice_range]
    # Convert to float, rescare, convert to torch tensor
    # (this doesn't require a copy)
    screen = np.ascontiguousarray(screen, dtype=np.float32) / 255
    screen = torch.from_numpy(screen)
    # Resize, and add a batch dimension (BCHW)
    return resize_screen(screen).to(device)


# print (get_screen().shape)
env.reset()
plt.figure()
plt.imshow(get_screen().cpu().permute(1, 2, 0).numpy(),
           interpolation='none')
plt.title('Example extracted screen')
plt.show()
print (get_screen().shape)

In [0]:
BATCH_SIZE = 128
GAMMA = 0.999
EPS_START = 0.9
EPS_END = 0.05
EPS_DECAY = 200
TARGET_UPDATE = 10

policy_net = DQN(env).to(device)
target_net = DQN(env).to(device)
target_net.load_state_dict(policy_net.state_dict())
# target_net.eval()

optimizer = optim.RMSprop(policy_net.parameters())
replay_memory = ReplayMemory(10000, frame_height=40, frame_width=80, batch_size=BATCH_SIZE)


steps_done = 0


def select_action(state, num_actions):
    global steps_done
    sample = random.random()
    eps_threshold = EPS_END + (EPS_START - EPS_END) * \
        math.exp(-1. * steps_done / EPS_DECAY)
    steps_done += 1
    if sample > eps_threshold:
        with torch.no_grad():
            return policy_net(state.unsqueeze(0)).max(1)[1].view(1, 1)
    else:
        return torch.tensor([[random.randrange(num_actions)]], device=device, dtype=torch.long)


episode_durations = []

def plot_durations(clear_screen=True):
    plt.figure(2)
    plt.clf()
    durations_t = torch.tensor(episode_durations, dtype=torch.float)
    plt.title('Training...')
    plt.xlabel('Episode')
    plt.ylabel('Duration')
    plt.plot(durations_t.numpy())
    # Take 100 episode averages and plot them too
    if len(durations_t) >= 100:
        means = durations_t.unfold(0, 100, 1).mean(1).view(-1)
        means = torch.cat((torch.zeros(99), means))
        plt.plot(means.numpy())

    plt.pause(0.001)  # pause a bit so that plots are updated
    if clear_screen:
        display.clear_output(wait=True)
        display.display(plt.gcf())

In [0]:
def epoch_time(start_time, end_time):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs

In [0]:
from skimage.transform import resize
def generate_gif(frame_number, frames_for_gif, reward, path):
    """
        Args:
            frame_number: Integer, determining the number of the current frame
            frames_for_gif: A sequence of (210, 160, 3) frames of an Atari game in RGB
            reward: Integer, Total reward of the episode that es ouputted as a gif
            path: String, path where gif is saved
    """
    for idx, frame_idx in enumerate(frames_for_gif): 
        frames_for_gif[idx] = resize(frame_idx, (320, 320, 3), 
                                     preserve_range=True, order=0)
        
    imageio.mimsave(f'{path}{"ATARI_CARTPOLE_DDQN_frame_{0}_reward_{1}.gif".format(frame_number, reward)}', 
                    frames_for_gif, duration=1/30)

In [0]:
def optimize_model():
    states, actions, rewards, new_states, terminal_flags = replay_memory.get_MiniBatch()

    #DQN
    # q_value = policy_net(states).gather(1, actions)
    # next_q_values = target_net(new_states).max(1)[0].detach()
    # expected_state_action_values = rewards + (GAMMA * next_q_values * (1 - terminal_flags)) 
    

    q_values = policy_net(states)
    next_q_values = policy_net(new_states)
    next_q_state_values = target_net(new_states) 

    q_value = q_values.gather(1, actions)
    next_q_value = next_q_state_values.gather(1, next_q_values.max(1)[1].unsqueeze(1))
    expected_state_action_values = reward + GAMMA * next_q_value * (1 - terminal_flags)

    loss = F.smooth_l1_loss(q_value, expected_state_action_values.unsqueeze(1))
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    return loss.item()

In [0]:
num_episodes = 2001
EVAL_FREQ = 200
file_counter = 0
num_actions = env.action_space.n

all_rewards = []
losses = []
start_time = time.time()

for i_episode in range(1, num_episodes):
    # Initialize the environment and state
    env.reset()
    last_screen = get_screen()
    current_screen = get_screen()
    state = current_screen - last_screen
    loss = []
    episode_reward = 0
    for t in count():
        # Select and perform an action
        action = select_action(state, num_actions)
        _, reward, done, _ = env.step(action.item())
        reward = torch.tensor([reward], device=device)

        # Observe new state
        last_screen = current_screen
        current_screen = get_screen()
        # frames_for_gif.append(current_screen.cpu().squeeze(0).permute(1, 2, 0).numpy())
        # if not done:
        next_state = current_screen - last_screen
        # else:
            # next_state = torch.FloatTensor([0])

        # Store the transition in memory
        replay_memory.add_experience(state=state, action=action, next_state=next_state, reward=reward, done=torch.FloatTensor([done]))

        # Move to the next state
        state = next_state

        # Perform one step of the optimization (on the target network)
        if len(replay_memory) > BATCH_SIZE:
            loss.append(optimize_model())

        episode_reward += reward.data
        if done:
            episode_durations.append(t + 1)
            plot_durations()
            break

    losses.append(np.mean(loss))
    all_rewards.append(episode_reward)
    # Update the target network
    if i_episode % TARGET_UPDATE == 0:
        target_net.load_state_dict(policy_net.state_dict())
    
    if i_episode % EVAL_FREQ == 0:
        env.reset()
        last_screen = get_screen()
        current_screen = get_screen()
        state = current_screen - last_screen
        episode_reward = 0
        frames_for_gif = []
        frames_for_gif.append(current_screen.cpu().permute(1, 2, 0).numpy())
        for t in count():
            # Select and perform an action
            # action = select_action(policy_net, state, num_actions)
            action = policy_net(state.unsqueeze(0)).max(1)[1].view(1, 1)
            _, reward, done, _ = env.step(action.item())
            reward = torch.tensor([reward], device=device)

            # Observe new state
            last_screen = current_screen
            current_screen = get_screen()
            # plt.imshow(current_screen.cpu().squeeze(0).permute(1, 2, 0).numpy(),
            #         interpolation='none')
            # plt.show()

            frames_for_gif.append(current_screen.cpu().permute(1, 2, 0).numpy())
            next_state = current_screen - last_screen

            # Move to the next state
            state = next_state

            episode_reward += reward.data
            if done:
                # print ("Timesteps: ", t + 1)
                break

        generate_gif(i_episode, frames_for_gif, episode_reward, path = f"/content/gdrive/My Drive/ModelFiles/")
        if i_episode % 2*EVAL_FREQ == 0:

            file_counter += 1
            path = f"/content/gdrive/My Drive/ModelFiles/RLP2CartPole_DDQN1-{file_counter}.pth"

            torch.save({'state_dict': policy_net.state_dict(),
                        'all_rewards': all_rewards,
                        'loss': losses,
                        'optimizer': optimizer.state_dict()}, path)
            
end_time = time.time()

In [0]:
epoch_mins, epoch_secs = epoch_time(start_time, end_time)
print(f'Epoch: {num_episodes+1:02} | Epoch Time: {epoch_mins}m {epoch_secs}s')
path = f"/content/gdrive/My Drive/ModelFiles/RLP2CartPole-{num_episodes}-Time-{epoch_mins}m{epoch_secs}s.pth"

torch.save({'state_dict': policy_net.state_dict(),
            'all_rewards': all_rewards,
            'all_timesteps': all_timesteps,
            'optimizer': optimizer.state_dict()}, path)

print('Model Saved')
env.render()
env.close()
plt.ioff()
plt.show()

In [0]:
path = f"/content/gdrive/My Drive/ModelFiles/"
path += f"RLP2CartPole-1000-Time-85m34s.pth"
# path += f"RLP2CartPole-200-Time-15m17s.pth"
model = torch.load(path)
policy_net = DQN(env).to(device)
# optimizer = optim.RMSprop(policy_net.parameters())
# target_net = DQN(env).to(device)
policy_net.load_state_dict(model['state_dict'])
# optimizer.load_state_dict(model['optimizer'])
# target_net.load_state_dict(model['state_dict'])
policy_net.eval()

In [0]:
model.keys()

In [0]:
plt.plot(model['all_rewards'])

In [0]:
# GAMMA = 0.999
# EPS_START = 0.05
# EPS_END = 0.05
# EPS_DECAY = 200
num_actions = env.action_space.n
rewards = []

for i in range(200):
  env.reset()
  last_screen = get_screen()
  current_screen = get_screen()
  state = current_screen - last_screen
  episode_reward = 0
  frames_for_gif = []
  for t in count():
    with torch.no_grad():

      # Select and perform an action
      # action = select_action(policy_net, state, num_actions)
      action = policy_net(state).max(1)[1].view(1, 1)
      _, reward, done, _ = env.step(action.item())
      reward = torch.tensor([reward], device=device)

      # Observe new state
      last_screen = current_screen
      current_screen = get_screen()
      # plt.imshow(current_screen.cpu().squeeze(0).permute(1, 2, 0).numpy(),
      #         interpolation='none')
      # plt.show()

      frames_for_gif.append(current_screen.cpu().squeeze(0).permute(1, 2, 0).numpy())
      next_state = current_screen - last_screen

      # Move to the next state
      state = next_state

      episode_reward += reward.data
      if done:
        # print ("Timesteps: ", t + 1)
        break
  rewards.append(episode_reward)


In [0]:
plt.plot(rewards)

In [0]:
generate_gif(10000, frames_for_gif, episode_reward, path = f"/content/gdrive/My Drive/ModelFiles/")

In [0]:
window = 1
plt.xlabel('Episode')
plt.ylabel('Total Reward (SMA 10)')
plt.plot([np.mean(all_rewards[tr:tr+window]) for tr in range(window, len(all_rewards))])
# plt.legend(['Exponential Decay Rate - 0.95'])
plt.title('Rewards vs Episodes')
print("reward", max(all_rewards))

In [0]:
plt.plot(all_timesteps)