## DQN Example (Cartpole)

Applied from Algorithm 4.1 in "Foundations of Deep Reinforcement Learning" (without experience replay)

In [None]:
import sys
import numpy as np
import gym
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

### Policy Calculation

In [None]:
def epsilon_greedy_policy(env, q_network, epsilon, state):
    if np.random.rand() > epsilon:
        with torch.no_grad():
            q_values = q_network(torch.tensor(state, dtype=torch.float32))
            action = np.argmax(q_values.numpy())
    else:
        action = np.random.choice(env.action_space.n)
    return action, None

### Value Calculation

In [None]:
class QNetwork(nn.Module):
    def __init__(self, state_dim, h_dim, action_dim):
        super(QNetwork, self).__init__()
        self.q_nn = nn.Sequential(
            nn.Linear(state_dim, h_dim),
            nn.ReLU(),
            nn.Linear(h_dim, action_dim)
        )
        
    def forward(self, state):
        q_values = self.q_nn(state)
        return q_values
    
def dqn_cumulative_discounted_reward(trajectory, q_network, gamma=.99):
    s, a, _, r, s_next, _, d = trajectory

    cdr = []
    if gamma is None:
        gamma = 0.99
    with torch.no_grad():
        q_values_ = q_network(
            torch.tensor(s_next, dtype=torch.float32)
        )
        a_max = torch.argmax(q_values_, 1)
        q_values = q_values_.gather(-1, a_max.unsqueeze(-1)).squeeze(-1)
        for i in reversed(range(len(r))):
            c = r[i] + gamma * (1 - d[i]) * q_values[i].item()
            cdr.append(c)
    cdr = list(reversed(cdr))
    return s, a, None, r, s_next, None, d, cdr

### Trajectory Generation

In [None]:
def sample_trajectory(env, policy_fn, max_steps=200, break_when_done=True):
    s, _ = env.reset()

    break_next = False
    gym_trajectory = []
    for t in range(max_steps):
        a, a_logprob = policy_fn(s)
        s_next, r, done, _, _ = env.step(a)
        d = int(done)
        gym_trajectory.append([list(s), a, a_logprob, r, d])
        if break_next and break_when_done:
            break
        if done:
            break_next = True
        s = s_next

    sarsa_trajectory = []
    for t in range(1, len(gym_trajectory)):
        s, a, a_logprob, r, d = gym_trajectory[t-1]
        s_next, a_next, _, _, _ = gym_trajectory[t]
        sarsa_trajectory.append([
            s, a, a_logprob, r, s_next, a_next, d
        ])

    s, a, a_logprob, r, s_next, a_next, d = list(zip(*sarsa_trajectory))
    trajectory = [
        np.array(s), np.array(a),
        None if a_logprob[0] is None else torch.stack(a_logprob),
        np.array(r),
        np.array(s_next), np.array(a_next),
        np.array(d)
    ]
    return trajectory

### Training

In [None]:
class DQN:
    def __init__(self, state_dim, action_dim, h_dim=64, lr_alpha=.01):
        self.q_network = QNetwork(state_dim, h_dim, action_dim)
        self.optimizer = optim.Adam(self.q_network.parameters(), lr=lr_alpha)
        
    def train(self, batch):
        s, a, _, r, s_next, _, d, cdr = batch
        
        q_preds_ = self.q_network(
            torch.tensor(s, dtype=torch.float32)
        )
        a_max = torch.argmax(q_preds_, 1)
        q_preds = q_preds_.gather(-1, a_max.unsqueeze(-1)).squeeze(-1)
        loss = F.mse_loss(q_preds, torch.tensor(cdr, dtype=torch.float32))
        
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        
        return loss.item()

In [None]:
dqn = DQN(state_dim=4, action_dim=2)

env = gym.make("CartPole-v0")
epsilon, epsilon_decay = 1., .999
count = 0
for episode in range(1, 1001):
    policy_fn = lambda state: epsilon_greedy_policy(env, dqn.q_network, epsilon, state)
    value_fn = lambda trajectory: dqn_cumulative_discounted_reward(trajectory, dqn.q_network)

    trajectory = sample_trajectory(env, policy_fn)
    trajectory = value_fn(trajectory)
    loss = dqn.train(trajectory)
    
    trajectory_len = len(trajectory[0])
    print("%5d %5d %5d %7.2f,  epsilon: %5.3f" % (episode, trajectory_len, count, loss, epsilon))
    
    if trajectory_len == 199:
        count += 1
        if count > 5:
            break
    epsilon *= epsilon_decay

### Cartpole Visualization

In [None]:
%env HV_DOC_HTML=true # required for colab
import holoviews as hv

def cartpole_visualize(states, step=4):
    hv.extension("bokeh") # required for colab
    states = states[::step]
    def cartpole2xy(states):
        xy = []
        positions, angles = states[:,0], states[:,2]
        for i in range(len(positions)):
            position, angle = positions[i], angles[i]
            x0, y0 = position, 0
            x1, y1 = position+np.sin(angle), np.cos(angle)
            xy.append([(x0, y0), (x1, y1)])
        return xy
    
    def cartpole_draw(curve):
        baseline = [(-2.4, 0), (2.4, 0)]
        return hv.Overlay([
            hv.Curve(baseline), hv.Curve(curve)
        ]).redim(
            x=hv.Dimension("x", range=(-2.4, 2.4)),
            y=hv.Dimension("y", range=(-.05, 1.05))
        ).opts(height=150, width=400)
    
    charts = []
    xy = cartpole2xy(states)
    for i in range(len(xy)):
        _ = cartpole_draw(xy[i])
        charts.append(_)
        
    holomap = hv.HoloMap({i*step:charts[i] for i in range(len(charts))})
    return holomap

cartpole_visualize(trajectory[0])