In [None]:
%matplotlib inline

In [None]:
# Initialize hyperparams to paper's numbers
N = int(1e7)
epsilon_generator = (1 - (i * .9/1e6) for i in range(int(1e6)))
# mini-batch size = 32
# optimizer RMSProp

In [None]:
# initialize replay memory to capacity N
replay_memory = ReplayMemory(N)
# initialize action-value function with random weights
Q = make_q_net()

for episode in range(M): # loop over episodes
    s_seq = []
    phi_seq = []
    
    for t in range(T): # loop over steps in episode
        # with probability epsilon select random action a_t
        delta = uniform_random(0, 1)
        if delta < epsilon:
            a_t = sample(action_space)
        # otherwise select argmax(a, Q(s,a))
        else:
            a_t = argmax(Q(phi[t]))
        
        # execute action and record transition
        r_t, x_tp1 = env.step(a_t)
        
        # add action and observation to trajectory
        tau.append((a_t, x_tp1))
        # prepare next input to the Q network
        phi.append(preprocess(tau))
        
        # store transition
        replay_memory.append((phi[t], a_t, rk_t, phi[t+1]))
        
        # sample minibatch of transitions
        mini_batch = sample(replay_memory, K)
        
        grad = 0
        for transition in mini_batch:
            phi_j, a_j, r_j, phi_jp1 = transition
            
            # calculate TD error for each transition
            if phi_jp1 is terminal:
                y_j = r_j
            else:
                y_j = r_j + lam * argmax(Q(phi_jp1))
            # calculate the gradient from transition
            grad += get_grad(y_j, Q, phi_j, a_j)
            
        grad = 1/K * grad
        Q.gradient_descent(grad)

### Make DQN agent

In [None]:
class DQNAgent:
    
    def __init__(self):
        self.s_seq = []
        self.phi_seq = []
        self.replay_mem = ReplayMemory()

        self.q_net = QNet()
        self.epsilon = .9
        
        self.env = gym.envs.make('PongNoFrameskip-v4')
        
    
    
    def act(self):
        pass
    
    def reset(self):
        pass
    
    def _train(self):
        pass
    
    def _update_epsilon(self):
        pass

In [None]:
import numpy as np
from collections import deque

class ReplayMemory:
    
    def __init__(self, N, sample_size):
        self.N = N
        self.sample_size = sample_size
        self.transitions = deque()
        
    def sample(self, k=None):
        if k is None:
            k = self.sample_size
        return np.random.choice(self.transitions, k)

    def add(e):
        if len(self.transitions >= self.N):
            self.transitions.popleft()
        self.transitions.append(e)

### Test random agent on Pong

In [None]:
import gym
import numpy as np
env = gym.envs.make('PongNoFrameskip-v4')

In [None]:
obss = []
obss.append(env.reset())

for _ in range(1000):
    a = np.random.choice(env.action_space.n)
    obs, rew, done, _ = env.step(a)
    obss.append(obs)
    
    if done:
        obs = env.reset()

### Preprocess frames

1. Make images grayscale.
2. Downsample images by a factor of 4.
3. Crop images into squares.

In [None]:
from torchvision import transforms

def process_frame(frame):
    """Turn game frame into small, square, grayscale image."""
    pipeline = transforms.Compose([
        transforms.ToPILImage(), # turn numpy ndarray into PIL image
        transforms.Grayscale(), # convert image to grayscale
        transforms.Resize((110, 84)), # resize image to 110 x 84
        transforms.CenterCrop(84), # crop at the center into 84 x 84 image
        transforms.ToTensor() # convert PIL image to torch tensor
    ])
    
    return pipeline(frame)


def make_input(frames):
    assert len(frames) == 4 # must have 4 frames
    
    x = [process_frame(fr) for fr in frames]
    return torch.cat(x)

### Build the Q network

In [None]:
import torch
from torch import nn
from torch.nn import functional as F

class QNet(nn.Module):
    
    def __init__(self):
        super(QNet, self).__init__()
        self.conv1 = nn.Conv2d(4, 16, 
                               kernel_size=8, stride=4)
        self.conv2 = nn.Conv2d(16, 32, 
                               kernel_size=4, stride=2)
        self.fc1 = nn.Linear(2592, 256)
        self.fc2 = nn.Linear(256, env.action_space.n)
        
    
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = x.view((x.shape[0], -1))
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

### Test Q network on 1000 frames of Pong

In [None]:
q_net = QNet()

x = [process_frame(fr) for fr in obss[0:4]]
x = torch.cat(x)
x = torch.unsqueeze(x,0)

In [None]:
x.size()

In [None]:
y = q_net(x)
y

In [None]:
y.argmax(d

In [None]:
def preprocess(tau):
    pass

In [None]:
def get_grad(Q, y, a, phi):
    pass