## Imports

In [None]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

In [None]:
import sys
sys.path.append('..')

## DQN Algo Outline

In [None]:
from collections import namedtuple
import random
import time

import torch
from torch import nn
from torch import optim
import numpy as np

import gym
from src import agent, train, utils

In [None]:
random.seed(0)

## Explore

### Collect 10 random frames

In [None]:
replay_mem_size = int(1e6)
batch_size = 32
num_episodes = int(1e3)

In [None]:
def frames_to_phi(frames):
    frames = [utils.process_frame(s) 
              for s in frames]
    phi = torch.cat(frames)
    
    return phi

def get_rand_phis(k, n):
    frames = []
    env = gym.envs.make('Pong-v4')
    env.reset()
    
    for i in range(n):
        a = random.randrange(env.action_space.n)
        s_t1, r_t, done, _ = env.step(a)
        frames.append(s_t1)

        if done:
            s_t = env.reset()
    
    idxs = random.sample(range(3, n), k)
    phis = [frames_to_phi(frames[i-3:i+1]) for i in idxs]
    
    return phis

def get_frames_avg_qval(phis, agt):
    x = torch.stack(phis)
    qvals = agt.get_best_values(x)
    
    return torch.mean(qvals).item()

In [None]:
def initialize(replay_mem_size, batch_size):
    agt = agent.DQNAgent()
    replay_mem = utils.ReplayMemory(replay_mem_size, batch_size)
    obs_history = utils.ObsHistory()
    env = gym.envs.make('Pong-v4')
    train_stats = TrainingStats()
    
    return agt, replay_mem, obs_history, env, train_stats

def act_step(obs_history, agt, env):
    phi_t = obs_history.get_phi()
    a_t = agt.act(phi_t)
    s_t1, r_t, done, _ = env.step(a_t)
    
    return a_t, s_t1, r_t, done

def store_step(s_t, a_t, r_t, done, s_t1, obs_history, replay_mem):
    obs_history.store(s_t1)
    replay_mem.store(s_t, a_t, r_t, done)

def gradient_step(replay_mem, agt):
    if replay_mem.size() > replay_mem.sample_size + 3:
        mini_batch = replay_mem.sample()

        agt.optimizer.zero_grad()
        loss = train.mini_batch_loss(mini_batch, agt)
        loss.backward()
        agt.optimizer.step()
        
        return loss.item()

def save_params(agt, episodes, save_path):
    torch.save({
        'model_state_dict': agt.qnet.state_dict(),
        'optimizer_state_dict': agt.optimizer.state_dict(),
        'episodes': episodes
    }, save_path)

def load_params(agt, load_path):
    checkpoint = torch.load(checkpoint_path)
    
    agt.qnet.load_state_dict(checkpoint['model_state_dict'])
    agt.opimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    
    return chekpoint['episodes']

def reset_episode(env, obs_history):
    s_t = env.reset()
    obs_history.reset(s_t)
    done = False
    
    return s_t, done

In [None]:
def frames_to_phi(frames):
    frames = [utils.process_frame(s) for s in frames]
    phi = torch.cat(frames)

    return phi

def get_rand_phis(k, n):
    frames = []
    env = gym.envs.make('Pong-v4')
    env.reset()

    for i in range(n):
        a = random.randrange(env.action_space.n)
        s_t1, r_t, done, _ = env.step(a)
        frames.append(s_t1)

        if done:
            s_t = env.reset()

    idxs = random.sample(range(3, n), k)
    phis = [frames_to_phi(frames[i-3:i+1]) for i in idxs]

    return phis

In [None]:
class TrainingStats:
    
    def __init__(self):
        self.ep_rewards = []
        self.ep_avg_train_losses = []
        self.steps_per_ep = []
        self.benchmark_qvals = []
        self.benchmark_frames = torch.stack(get_rand_phis(10, 10000))
        
    def store(self, agt, ep_reward, ep_steps, ep_loss, episode_num):
        self.ep_rewards.append(ep_reward)
        self.steps_per_ep.append(ep_steps)
        avg_ep_loss = ep_loss / ep_steps
        self.ep_avg_train_losses.append(avg_ep_loss)
        avg_qvals = get_frames_avg_qvals(agt)
        
        print('Episode {}:'.format(episode_num))
        print('Reward: {}'.format(ep_reward))
        print('Steps: {}'.format(ep_steps))
        print('Avg loss: {:.5f}'.format(avg_ep_loss))
        print('===========================================')
        

    def get_frames_avg_qval(self, agt):
        qvals = agt.get_best_values(self.benchmark_frames)

        return torch.mean(qvals).item()

In [None]:
num_episodes = 1000

In [None]:
ep_rewards = []
ep_avg_train_losses = []
steps_per_ep = []
benchmark_qvals = []

In [None]:
agt, replay_mem, obs_history, env, train_stats = \
    initialize(replay_mem_size, batch_size)

for episode in range(num_episodes):  # loop over episodes
    s_t, done = reset_episode(env, obs_history)
    
    ep_reward = 0
    ep_train_loss = 0
    ep_steps = 0
    
    while not done:  # loop over steps in episode
        a_t, s_t1, r_t, done = act_step(obs_history, agt, env)
        store_step(s_t, a_t, r_t, done, s_t1, obs_history, replay_mem)
        
        s_t = s_t1
        
        loss_val = gradient_step(replay_mem, agt)
        
        ep_reward += r_t
        ep_steps += 1
        if loss_val is not None:
            ep_train_loss += loss_val 
    
    train_stats.store(agt, ep_reward, ep_steps, ep_train_loss, episode)
    
    if episode % 10 == 9:
        checkpoint_name = 'dqn_agt_{}.pt'.format(episode)
        save_params(agt, episode, checkpoint_name)
        print('Model saved.')
    

## Manually Go Through One Iteration

In [None]:
import numpy as np
import torch

import gym
from src import agent, dqn, train, utils

#### Create Gym Environment

In [None]:
env = gym.envs.make('Pong-v4')

In [None]:
init_obs = env.reset()

for i in range(20000):
    a = random.randrange(env.action_space.n)
    _, _, done, _ = env.step(a)
    time.sleep(.01)
    env.render()
    
    if done:
        env.reset()

In [None]:
init_obs = env.reset()
init_obs.shape

In [None]:
s1, r, done, _ = env.step(0)

In [None]:
s1.shape

#### Initialize DQN objects

In [None]:
replay_mem_size = int(1e6)
mini_batch_size = 32

In [None]:
agt = agent.DQNAgent()
replay_memory = utils.ReplayMemory(replay_mem_size, mini_batch_size)
obs_history = utils.ObsHistory()

#### Begin new episode

In [None]:
obs_init = env.reset()  # reset environment to start new episode
obs_history.reset(obs_init)  # reset observations for new episode
done = False

#### Choose action

In [None]:
phi = obs_history.phi
a = agt.act(phi)
obs, rew, done, _ = env.step(a)
obs_history.store(obs)

#### Store Transition

In [None]:
phi_1 = obs_history.phi
replay_memory.store((phi, a, rew, phi_1, done))

#### Gradient Step

[x] Dummy transitions function for testing.

[x] Make `r` reward vector from transitions.

[x] Make $\max_{a}Q(s', a')$ vector from transitions.

[x] Make `y` target vector from transitions.

In [None]:
from collections import namedtuple
import torch
from torch import nn
import numpy as np

Transition = namedtuple('Transition', 
                        ['phi', 'a', 'r', 'phi_1', 'done'])

In [None]:
def dummy_transitions(n):
    transitions = []
    
    for i in range(n):
        phi = torch.empty(4, 84, 84).random_(0, 255)
        phi_1 = torch.empty(4, 84, 84).random_(0, 255)
        a = np.random.randint(0, 6)
        r = np.random.randint(0, 2)
        done = False if np.random.randint(0, 2) == 0 else True
        
        transitions.append(Transition(phi, a, r, phi_1, done))
    
    return transitions

In [None]:
transitions = dummy_transitions(3)
    
phi, a, r, phi_1, done = zip(*transitions)

In [None]:
def make_y(transitions, agt):
    y = []
    
    for tr in transitions:
        if tr.done:
            y.append(tr.r)
        else:
            x = tr.phi.unsqueeze(0)
            y.append(tr.r + .99 * agt.get_best_values(x).item())
            
    return torch.Tensor(y)

In [None]:
y = make_y(transitions, agt)

In [None]:
def get_max_vals(transitions, agt):
    phis = []

    for tr in transitions:
        phis.append(tr.phi)

    x = torch.stack(phis)
    return agt.get_best_values(x)

In [None]:
qmax = get_max_vals(transitions, agt)

In [None]:
loss = nn.MSELoss(reduction='mean')

In [None]:
loss(y, qmax)

In [None]:
def mini_batch_loss(transitions, agt):
    y = make_y(transitions, agt)
    qmax = get_max_vals(transitions, agt)
    
    loss = nn.MSELoss(reduction='mean')
    return loss(y, qmax)

In [None]:
mb_loss = mini_batch_loss(transitions, agt)

#### Optimizer and gradient step

In [None]:
from torch import optim

optimizer = optim.RMSprop(agt.qnet.parameters())

In [None]:
optimizer.zero_grad()

In [None]:
mb_loss = mini_batch_loss(transitions, agt)
mb_loss

In [None]:
mb_loss.backward()

In [None]:
optimizer.step()

### Test random agent on Pong

In [None]:
s0 = env.reset()

In [None]:
obss = []
obss.append(env.reset())

for _ in range(1000):
    a = np.random.choice(env.action_space.n)
    obs, rew, done, _ = env.step(a)
    obss.append(obs)
    
    if done:
        obs = env.reset()

In [None]:
env.action_space.n