# Load Modules and Configure Display

In [26]:
import gym
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
%matplotlib inline

In [27]:
import math
import random
from collections import namedtuple
from itertools import count
from PIL import Image

In [28]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.functional as F
import torchvision.transforms as T

In [32]:
if ('inline' in matplotlib.get_backend()):
    from IPython import display

# Deep Q-Network

In [5]:
class DQN(nn.Module):
    
    def __init__(self, img_height, img_width):
        super(DQN, self).__init__()
        n_pixels = img_height * img_width * 3 # 3 Channels for RGB Image
        self.fc1 = nn.Linear(n_pixels, 24)
        self.fc2 = nn.Linear(24, 32)
        self.fc3 = nn.Linear(32, 2)
    
    def forward(self, x):
        x = torch.flatten(x, start_dim=1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.out(x)
        return x

# Replay Memory

In [8]:
Experience = namedtuple('Experience', ['state', 'action', 'next_state', 'reward'])

In [23]:
class ReplayMemory():
    
    def __init__(self, limit):
        self.limit = limit
        self.memory = []
        self.push_count = 0
    
    def push(self, experience):
        if len(self.memory) < self.capacity:
            self.memory.append(experience)
        else:
            self.memory[self.push_count % self.capacity] = experience
        self.push_count += 1
    
    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)
    
    def can_provide_sample(self, batch_size):
        return len(self.memory) >= batch_size

# Epsilon-Greedy Algorithm

$ r_f = r_{min} + (r_{max} - r_{min}) * e^{-\lambda t}$

In [22]:
class EpsilonGreedy():
    
    def __init__(self, r_max, r_min, r_decay):
        self.max = r_max
        self.min = r_min
        self.decay = r_decay
    
    def get_exploration_rate(self, current_step):
        return self.min + (self.max - self.min) * np.exp(-self.decay * current_step)

# Reinforcement Learning Agent

In [24]:
class Agent():
    def __init__(self, strategy, num_actions, device):
        self.current_step = 0
        self.strategy = strategy
        self.num_actions = num_actions
        self.device = device
        
    def select_action(self, state, policy_net):
        rate = strategy.get_exploration_rate(self.current_step)
        self.current_step += 1
        if (rate > random.random()):
            action = random.randrange(self.num_actions)
            return torch.tensor([action]).to(device) # Explore
        else:
            with torch.no_grad():
                return policy_net(state).argmax(dim=1).to(device) # Exploit

# Environment Manager

In [33]:
class CartPoleEnvManager():
    
    def __init__(self, device):
        self.device = device
        self.env = gym.make('CartPole-v0').unwrapped
        self.env.reset()
        self.current_screen = None
        self.done = False
    
    def reset(self):
        self.env.reset()
        self.current_screen = None
    
    def close(self):
        self.env.close()
    
    def render(self, mode='human'):
        return self.env.render(mode)
    
    def num_actions_available(self):
        return self.env.action_space.n
    
    def take_action(self, action):
        _, reward, self.done, _ = self.env.step(action.item())
        return torch.tensor([reward], device=self.device)
    
    def just_starting(self):
        return self.current_screen is None
    
    def get_state(self):
        if (self.just_starting() or self.done):
            self.current_screen = self.get_processed_screen()
            black_screen = torch.zeros_like(self.current_screen)
            return black_screen
        else:
            s1 = self.current_screen
            s2 = self.get_processed_screen()
            self.current_screen = s2
            return (s2 - s1)
        
    def get_screen_height(self):
        screen = self.get_processed_screen()
        return screen.shape[2]
    
    def get_screen_width(self):
        screen = self.get_processed_screen()
        return screen.shape[3]
    
    def get_processed_screen(self):
        screen = self.render('rgb_array').transpose((2, 0, 1))
        screen = self.crop(screen)
        return self.transform_screen_data(screen)
    
    def crop_screen(self, screen):
        screen_height = screen.shape[1]
        top = 0.4 * int(screen_height)
        bottom = 0.8 * int(screen_height)
        screen = screen[:, top:bottom, :]
        return screen
    
    def transform_screen_data(self, screen):
        screen = np.ascontiguousarray(screen, dtype=np.float32) / 255 # Rescale
        screen = torch.from_numpy(screen)
        resize = T.Compose([T.ToPILImage(), T.Resize((40, 90)), T.ToTensor()])
        return resize(screen).unsqueeze(0).to(self.device) # Add batch dimension