# DQN-based Algorithm Tutorial

Notebook to construct pipeline to train and test DQN-based algorithms. There are advanced algorithms from DQN, and you can choose one of the options below. 

- PER (Prioritized Experience Replay)
- Double DQN
- Dueling DQN
- Multi-step learning
- Distributional RL
- Noisy Networks
- Rainbow (Apply all of the above)

The agent have been trained and tested on gymnasium graphical environment - input will be pixels. 

## Tasks
- ALE/Pong-v5
- ALE/Breakout-v5
- ALE/Enduro-v5
- ALE/DemonAttack-v5

## Environment

In this notebook, we use Gymnasium Atari environments that provides graphical observation. 

The agent gets F x H x W state (when grayscale applied) as an input and returns $ Q(s, a) $ as an output. (F: # of frames, H: height, W: width)

In [None]:
import gymnasium as gym
import random
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import os
import json

from tqdm import tqdm
from datetime import datetime
from collections import namedtuple, deque
from itertools import count
from gymnasium.wrappers.atari_preprocessing import AtariPreprocessing
from gymnasium.wrappers.frame_stack import FrameStack

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

# define environment (editable)
envname = "ALE/Pong-v5"
#envname = "ALE/Breakout-v5"
#envname = "ALE/Enduro-v5"
#envname = "ALE/DemonAttack-v5"

def preprocess_env(env: gym.Env):
  """
  Preprocess atari environment. 
  """
  env = AtariPreprocessing(
    env,
    frame_skip=1,
    screen_size=84,
    terminal_on_life_loss=False,
    grayscale_obs=True,
    grayscale_newaxis=False,
    scale_obs=True
  )
  env = FrameStack(env, num_stack=4)
  return env

env = gym.make(envname)
env = preprocess_env(env)

# plot settings
is_ipython = "inline" in matplotlib.get_backend()
if is_ipython:
  from IPython import display

# device setting (cpu or cuda!)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device", device)

## Choose Training Options

In [None]:
Option = namedtuple("Option", (
  "PER", 
  "DDQN", 
  "DUELING", 
  "MULTISTEP", 
  "DISTRIBUTIONAL", 
  "NOISY"
))

# editable
op = Option(
  PER=True, 
  DDQN=True, 
  DUELING=True, 
  MULTISTEP=True, 
  DISTRIBUTIONAL=True, 
  NOISY=True
)

## Hyperparameters

- `NUM_EPISODES`: number of training episodes. 
- `TEST_EPISODES`: number of test episodes during training. 
- `TEST_FREQ`: testing frequency. 
- `MEMORY_CAPACITY`: maximum capacity of replay memory. 
- `BATCH_SIZE`: size of sampled transitions for training. 
- `GAMMA`: discount factor when calculating estimated goal. 
- `EPS_START`: starting epsilon value in e-greedy policy. 
- `EPS_END`: final epsilon value in e-greedy policy. 
- `EXPLORE_FRAME`: frame number that epsilon starts to decay. 
- `GREEDY_FRAME`: frame number that epsilon ends to decay
- `POLICY_UPDATE_FREQ`: frequency that policy net trains. 
- `TARGET_UPDATE_FREQ`: frequency that target net synchronized to policy net. 
- `LR`: learning rate. 
- `PER_EPS`: very small number that makes priority positive. 
- `PER_ALPHA`: between 0(uniform sampling) and 1(priority-based)
- `PER_BETA_INIT`: between 0(no correction) and 1(correct distribution to uniform)
- `PER_BETA_INC`: increasement of beta while training. 
- `MULTISTEP`: multi-step hyperparemeter. 
- `NOISE_SIGMA`: sigma for noisy network. 
- `DIST_N`: number of quantiles in distributional rl (QR-DQN)

In [None]:
HyperParameter = namedtuple("HyperParameter", (
  "NUM_EPISODES",
  "TEST_EPISODES",
  "TEST_FREQ",
  "MEMORY_CAPACITY",
  "BATCH_SIZE",
  "GAMMA",
  "EPS_START",
  "EPS_END",
  "EXPLORE_FRAME",
  "GREEDY_FRAME",
  "POLICY_UPDATE_FREQ",
  "TARGET_UPDATE_FREQ",
  "LR",
  "PER_EPS",
  "PER_ALPHA",
  "PER_BETA_INIT",
  "PER_BETA_INC", 
  "MULTISTEP", 
  "NOISE_SIGMA", 
  "DIST_N"
))

# editable
hp = HyperParameter(
  NUM_EPISODES=1500,
  TEST_EPISODES=5, 
  TEST_FREQ=100,
  MEMORY_CAPACITY=50000,
  BATCH_SIZE=64,
  GAMMA=0.99,
  EPS_START=0.9,
  EPS_END=0.1,
  EXPLORE_FRAME=1000,
  GREEDY_FRAME=100000,
  POLICY_UPDATE_FREQ=4,
  TARGET_UPDATE_FREQ=1000,
  LR=0.0001, 
  PER_EPS=0.00000001,
  PER_ALPHA=0.6, 
  PER_BETA_INIT=0.4, 
  PER_BETA_INC=0.0001, 
  MULTISTEP=5, 
  NOISE_SIGMA=0.4, 
  DIST_N=32
)

## Replay Memory

Replay memory stores transitions that agent observed. Stored transitions can be used in training agents later. 

Transition consists of (`state`, `action`, `reward`, `next_state`, `done`). 
- `state`: current state. agent act at this state. 
- `action`: action acted at current state. 
- `reward`: reward from the action. (not equal to goal)
- `next_state`: state after the action. zero state if terminated. 
- `done`: indicates whether the episode ended after the action. 

In this notebook, we supports two strategies to sample transitions. Both of the strategies can avoid data correlation problem. 

`UniformReplayMemory`: do uniform sampling - all transitions have same priority, so uniformly sample transitions. 

In [None]:
Transition = namedtuple("Transition", ("state", "action", "reward", "next_state", "done"))

class UniformReplayMemory(object):
  """
  Replay Memory that uses uniform sampling strategy. 
  """

  def __init__(self, capacity: int):
    """
    Init memory which removes oldest data when overflows. 
    """
    self.memory = deque([], maxlen=capacity)

  def push(self, *args):
    """
    Push new transition data into memory. 
    - Transition: (state, action, reward, next_state, done)
    """
    
    self.memory.append(Transition(*args))

  def sample(self, batch_size: int):
    """
    Return random samples according to uniform sampling. 
    Return form is Transition of samples. 
    ex) return.state is list of states. 
    """
    sample = random.sample(self.memory, batch_size)
    return Transition(*zip(*sample))

  def __len__(self):
    """
    Return len of memory. 
    """
    return len(self.memory)


`PriorityReplayMemory`: prioritized experience sampling - each transition has its own priority, so do weighted sampling according to it. 

Priority of each transition is calculated as below. 
$$ P(i) = \frac{p_i^\alpha}{\sum_k p_k^\alpha} $$

Importance Sampling weight is calculated to resolve bias problem.
$$ w_i = \left( \frac{1}{N} \frac{1}{P(i)} \right)^\beta $$

After calculating the loss, priority updated. 
$$ p_i \leftarrow \text{calculated TD error} $$

According to IS weights, parameters are updated according to $ \Delta $ value. 
$$ \Delta \leftarrow \Delta + w_i \nabla_\theta L_{\theta,i} $$

In [None]:
class SumTree(object):
  """
  Sumtree class for PER memory. 
  SumTree is a binary tree and each node stores sum of values of the child nodes. 
  Each leaf node shows data stored in the memory, and contains priority of the data. 
  """
  
  def __init__(self, capacity: int):
    """
    Initialize SumTree with given capacity. 
    - tree size is (2 * capacity - 1), which has (capacity) leaf nodes. 
    - data size is (capacity)
    """
    self.capacity = capacity
    
    # define tree and data
    self.tree = np.zeros(2 * self.capacity - 1)
    self.data = np.zeros(self.capacity, dtype=object)
    
    # data count and pointer
    self.data_cnt = 0
    self.data_idx = 0
    
  def add(self, value: float, data: Transition):
    """
    Add new data to SumTree.
    """
    # update data array
    self.data[self.data_idx] = data
    
    # update leaf and tree
    tree_idx = self.data_idx + self.capacity - 1
    self.update(tree_idx, value)
    
    # update data count and pointer
    self.data_idx += 1
    if self.data_idx >= self.capacity:
      self.data_idx = 0
    if self.data_cnt < self.capacity:
      self.data_cnt += 1
  
  def update(self, tree_idx: int, value: float):
    """
    Update value of leaf node and spread to its parent recursively. 
    """
    # update leaf
    diff = value - self.tree[tree_idx]
    self.tree[tree_idx] = value
    
    # update parent (until root)
    while tree_idx > 0:
      parent = (tree_idx - 1) // 2
      self.tree[parent] += diff
      tree_idx = parent
  
  def _retrieve(self, value: float):
    """
    (Wrapped by get_leaf())
    Find tree index that satisfies given cumulative sum. 
    Make sure that value is in [0, sum(priority)]
    """
    tree_idx = 0
    while True:
      left_idx = tree_idx * 2 + 1
      right_idx = left_idx + 1
      
      # if tree_idx is leaf, return it
      if left_idx >= len(self.tree):
        return tree_idx
      
      # choose left or right child
      if value <= self.tree[left_idx]:
        tree_idx = left_idx
      else:
        value -= self.tree[left_idx]
        tree_idx = right_idx
        
  def get_leaf(self, value: float):
    """
    Find tree index that satisfies given cumulative sum. 
    Make sure that value is in [0, sum(priority)]
    - Return: (tree_idx, value, data) tuple
    """
    tree_idx = self._retrieve(value)
    return tree_idx, self.tree[tree_idx], self.data[tree_idx - self.capacity + 1]

  def total_priority(self):
    """
    Return total priority (same to root node's value)
    """
    return self.tree[0]
  
class PriorityReplayMemory(object):
  """
  Memory that supports Priority Experience Replay. 
  store (state, action, reward, next state, done) in memory 
  and use random samples by priority. 
  """
  
  def __init__(self, capacity: int, eps: float, alpha: float, beta_init: float, beta_inc: float):
    """
    Initialize PER memory
    
    - eps: very small value that makes all values positive. 
    - alpha: hyperparameter between 0(uniform) and 1(priority-based)
    - beta: hyperparameter between 0(no correnction) and 1(correct to uniform-like)
    """
    # hyperparameters
    self.eps = eps
    self.alpha = alpha
    self.beta = beta_init
    self.beta_inc = beta_inc
    
    # sumtree
    self.capacity = capacity
    self.sumtree = SumTree(self.capacity)
    
  def push(self, error: float, *args):
    """
    Push new transition data into sumtree. 
    """
    # translate priority to weight value
    priority = (error + self.eps) ** self.alpha
    self.sumtree.add(priority, Transition(*args))
    
  def sample(self, batch_size: int):
    """
    Sample datas by priorities. 
    - Return: (batch, tree_idxs, IS_weights)
      - tree_idxs: tree indexes of transitions in sampled batch
      - IS_weights: weights to calculate loss
    """
    sample = []
    tree_idxs = []
    priorities = []
    
    segment = self.sumtree.total_priority() / batch_size
    for i in range(batch_size):
      left = segment * i
      right = segment * (i + 1)
      
      # get random leaf in the segment
      value = np.random.uniform(left, right)
      tree_idx, priority, transition = self.sumtree.get_leaf(value)
      
      # add to list
      sample.append(transition)
      tree_idxs.append(tree_idx)
      priorities.append(priority)
    
    # calc importance sampling weights (IS weights)
    priorities = np.array(priorities) / self.sumtree.total_priority()
    is_weights = np.power(self.sumtree.data_cnt * priorities, -self.beta)
    is_weights /= is_weights.max()
    
    # update beta
    if self.beta < 1.0:
      self.beta = np.min([1., self.beta + self.beta_inc])
    
    return Transition(*zip(*sample)), tree_idxs, is_weights
  
  def update(self, tree_idx: int, error: float):
    """
    Update existing priority to newly-calculated error. 
    """
    priority = (error + self.eps) ** self.alpha
    self.sumtree.update(tree_idx, priority)
    
  def __len__(self):
    """
    Return length of memory
    """
    return self.sumtree.data_cnt

## Q Network

In this notebook, we use graphical environment that provides graphical observations and discrete actions. 

So, we use simple network containing few CNN filters and fully-connected layers. 

### Dueling Network
Dueling network has feature extraction network, state head, and advantage head. 

Q(s, a) is calculated according to equation below. 
- V(s): scalar that shows state value. 
- Advantage(s, a): (action_num) vector that shows advantage of each action that others. 
$$ Q(s, a_i) = V(s) + Advantage(s, a_i) - E_j[Advantage(s, a_j)] $$

### Noisy Network
Noisy network add noise to network to perturb weight and bias in Linear layer. 

$$
\begin{gather}
y = (\mu^w + \sigma^w \odot \epsilon^w)x + (\mu^b + \sigma^b \odot \epsilon^b) \notag \\

\mu^w, \sigma^w \in R^{(p \times q)}, \mu^b, \sigma^b \in R^{p} \text{: learnable parameters} \notag \\

\epsilon^w \in R^{(p \times q)}, \epsilon^b \in R^{p} \text{: random variables for each forward} \notag \\
\end{gather}
$$

There are two randomize strategies. 
1. Independent Gaussian noise: randomize $ pq + p $ random variables independently. 
2. Factorised Gaussian noise: randomize $ p + q $ random variables to make two vectors (size of p and q), and calculate $ pq + p $ random variables. 
  $$ 
  \begin{aligned}
  \epsilon_{i,j}^w &= f(\epsilon_i)f(\epsilon_j) \\ 
  \epsilon_{i}^b &= f(\epsilon_i) \\ 
  \text{where } f(x) &= sgn(x) \sqrt{|x|}
  \end{aligned}
  $$

### Distributional RL (QR-DQN)
QR-DQN's output is (BATCH_SIZE, N, NUM_ACTIONS) tensor that shows supports for quantiles for each action. 

In [None]:
class NoisyLinear(nn.Linear):
  """
  Linear layer with parameter perturbation. 
  """
  
  def __init__(self, in_features: int, out_features: int):
    """
    Define new linear layer with noise. 
    """
    super(NoisyLinear, self).__init__(in_features, out_features)

    # init sigma for weight and bias
    sigma = hp.NOISE_SIGMA / np.sqrt(in_features)
    self.sigma_weight = nn.Parameter(torch.Tensor(out_features, in_features).fill_(sigma))
    self.sigma_bias = nn.Parameter(torch.Tensor(out_features).fill_(sigma))
    
    # random variables
    self.register_buffer("eps_in", torch.zeros((1, in_features)))
    self.register_buffer("eps_out", torch.zeros((out_features, 1)))
    self.register_buffer("eps_weight", torch.zeros(self.sigma_weight.shape))
    self.register_buffer("eps_bias", torch.zeros(self.sigma_bias.shape))
  
  def f(self, x: torch.Tensor):
    """
    f(x) = sign(x) * sqrt(abs(x))
    """
    return torch.sign(x) * torch.sqrt(torch.abs(x))

  def forward(self, x: torch.Tensor):
    """
    Forward with randomized variables. 
    """
    # x shape is (B, in_features)
    assert len(x.shape) == 2
    
    # generate random variables
    self.eps_in = torch.randn(self.eps_in.shape, device=device)
    self.eps_out = torch.randn(self.eps_out.shape, device=device)
    
    self.eps_weight = torch.mul(self.f(self.eps_out), self.f(self.eps_in))
    self.eps_bias = self.eps_out.squeeze(1)
    
    # return with noisy weight and bias
    weight = self.weight + self.sigma_weight * self.eps_weight
    bias = self.bias + self.sigma_bias * self.eps_bias
    return F.linear(x, weight, bias)

In [None]:
class QNetwork(nn.Module):
  """
  Q network that gets graphical input and discrete output. 
  """

  def __init__(
    self, 
    dim_observation: tuple, 
    n_actions: int, 
    dueling: bool, 
    noisy: bool, 
    distributional: bool
  ):
    """
    dim_observation input shape and n_actions output channels. 
    """
    super(QNetwork, self).__init__()
    C, H, W = dim_observation
    assert H == 84
    assert W == 84

    self.dueling = dueling
    self.noisy = noisy
    self.distributional = distributional
    
    linear = NoisyLinear if self.noisy else nn.Linear
    
    if self.dueling:
      # dueling net: feature, state, and advantage
      self.feature = nn.Sequential(
        nn.Conv2d(in_channels=C, out_channels=32, kernel_size=8, stride=4),
        nn.ReLU(),
        nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2),
        nn.ReLU(),
        nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1),
        nn.ReLU(),
        nn.Flatten()
      )
      self.state = nn.Sequential(
        linear(3136, 256), 
        nn.ReLU()
      )
      self.advantage = nn.Sequential(
        linear(3136, 256), 
        nn.ReLU()
      )
      if self.distributional:
        # state output: [BATCH_SIZE, N]
        self.state += nn.Sequential(
          linear(256, hp.DIST_N), 
          nn.Unflatten(1, (hp.DIST_N, 1))
        )
        # advantage output: [BATCH_SIZE, N, N_ACTIONS]
        self.advantage += nn.Sequential(
          linear(256, hp.DIST_N * n_actions), 
          nn.Unflatten(1, (hp.DIST_N, n_actions))
        )
      else:
        # state output: [BATCH_SIZE, 1]
        self.state += nn.Sequential(
          linear(256, 1)
        )
        # advantage output: [BATCH_SIZE, N_ACTIONS]
        self.advantage += nn.Sequential(
          linear(256, n_actions)
        )
    else:
      # q network
      self.seqmodel = nn.Sequential(
        nn.Conv2d(in_channels=C, out_channels=32, kernel_size=8, stride=4),
        nn.ReLU(),
        nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2),
        nn.ReLU(),
        nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1),
        nn.ReLU(),
        nn.Flatten(),
        linear(3136, 512),
        nn.ReLU()
      )
      if self.distributional:
        # output: [BATCH_SIZE, N, N_ACTIONS]
        self.seqmodel += nn.Sequential(
          linear(512, hp.DIST_N * n_actions), 
          nn.Unflatten(1, (hp.DIST_N, n_actions))
        )
      else:
        # output: [BATCH_SIZE, N_ACTIONS]
        self.seqmodel += nn.Sequential(
          linear(512, n_actions)
        )
      
  def forward(self, x: torch.Tensor):
    """
    Forward.
    """
    if self.dueling:
      # get feature
      feature = self.feature(x)

      # get state value and advantage
      state = self.state(feature)
      advantage = self.advantage(feature)

      # calc output and return
      average = torch.mean(advantage, dim=-1, keepdim=True)
      return state + advantage - average
    else:
      return self.seqmodel(x)
    
  def get_noise(self):
    """
    If this is noisy network, returns scalar value that indicates noise. 
    |sigma_weight| ** 2 + |sigma_bias| ** 2
    """
    noise = 0.0
    for module in self.modules():
      if isinstance(module, NoisyLinear):
        noise += torch.sqrt(torch.sum(module.sigma_weight ** 2)).item()
        noise += torch.sqrt(torch.sum(module.sigma_bias ** 2)).item()
    return noise

## Training: Utility Functions

- `select_action()`: select agent's action in e-greedy policy. 
- `save_plot()`: plot epsilon, loss, average frames and score. 
- `save_model()`: store model, hyperparameters, and training info. 

In [None]:
# policy net and target net
dim_observation = env.observation_space.shape
n_actions = int(env.action_space.n)

policy_net = QNetwork(
  dim_observation, 
  n_actions, 
  dueling=op.DUELING, 
  noisy=op.NOISY, 
  distributional=op.DISTRIBUTIONAL
).to(device)
target_net = QNetwork(
  dim_observation, 
  n_actions, 
  dueling=op.DUELING, 
  noisy=op.NOISY, 
  distributional=op.DISTRIBUTIONAL
).to(device)
target_net.load_state_dict(policy_net.state_dict())

# replay memory
if op.PER:
  # PER memory
  memory = PriorityReplayMemory(
    capacity=hp.MEMORY_CAPACITY, 
    eps=hp.PER_EPS, 
    alpha=hp.PER_ALPHA, 
    beta_init=hp.PER_BETA_INIT, 
    beta_inc=hp.PER_BETA_INC
  )
else:
  # simple uniform memory
  memory = UniformReplayMemory(hp.MEMORY_CAPACITY)
  
# quantiles for distributional RL
if op.DISTRIBUTIONAL:
  tau_hats = torch.tensor([[[(2 * i + 1) / (2 * hp.DIST_N)] for i in range(hp.DIST_N)]], device=device)
  tau_hats = tau_hats.expand(hp.BATCH_SIZE, hp.DIST_N, 1)

# save directory
start_datetime = datetime.now()
dirname = start_datetime.strftime("%Y%m%d-%H%M%S")
path = os.path.join(os.getcwd(), "dqn", dirname)

# adamw optimizer
optimizer = optim.AdamW(policy_net.parameters(), lr=hp.LR, amsgrad=True)

# training variables
train_epsilons = []
train_losses = []
train_frames = []
train_scores = []
test_frames = []
test_scores = []
test_noises = []
epsilon = hp.EPS_START
epsilon_decay = (hp.EPS_START - hp.EPS_END) / (hp.GREEDY_FRAME - hp.EXPLORE_FRAME)
steps = 0

# if we use multi-step option, define transition buffer
if op.MULTISTEP:
  multistep_buffer = Transition(
    state=deque(maxlen=hp.MULTISTEP),
    action=deque(maxlen=hp.MULTISTEP),
    reward=deque(maxlen=hp.MULTISTEP),
    next_state=deque(maxlen=hp.MULTISTEP),
    done=deque(maxlen=hp.MULTISTEP)
  )

def select_greedy(state: np.ndarray):
  """
  Select agent's action by greedy policy. 
  """
  if op.DISTRIBUTIONAL:
    with torch.no_grad():
      return torch.argmax(torch.mean(policy_net(state), dim=1, keepdim=False), dim=1, keepdim=True)
  else:
    with torch.no_grad():
      return torch.argmax(policy_net(state), dim=1, keepdim=True)

def select_action(state: np.ndarray):
  """
  Select agent's action by e-greedy policy. 
  """
  sample = random.random()
  if sample > epsilon:
    # greedy action
    return select_greedy(state)
  else:
    # random action
    return torch.tensor([[env.action_space.sample()]], device=device, dtype=torch.long)

def save_plot():
  """
  Plot loss, epsilon, average frames, and score and save the figures. 
  """
  plt.figure(figsize=(16, 12))
  plt.clf()
  plt.ion()
  
  plt.subplot(2, 2, 1)
  if op.NOISY:
    plt.title("Noise")
    plt.xlabel("Test Iter")
    plt.ylabel("Noise")
    plt.plot(*zip(*test_noises), label="test")
    plt.legend()
    plt.grid()
  else:
    plt.title("Epsilon")
    plt.xlabel("Frames")
    plt.ylabel("Epsilon")
    plt.plot(*zip(*train_epsilons), label="train")
    plt.legend()
    plt.grid()
  
  plt.subplot(2, 2, 2)
  plt.title("Loss")
  plt.xlabel("Learning Step")
  plt.ylabel("Loss")
  plt.plot(*zip(*train_losses), label="train")
  x, y = zip(*train_losses)
  x = x[99:]
  y = torch.mean(torch.tensor(y).unfold(0, 100, 1), dim=1)
  plt.plot(x, y, label="train-avg100")
  plt.legend()
  plt.grid()
  
  plt.subplot(2, 2, 3)
  plt.title("# of Frames")
  plt.xlabel("Episode")
  plt.ylabel("# of Frames")
  plt.plot(*zip(*train_frames), label="train")
  plt.plot(*zip(*test_frames), label="test")
  plt.legend()
  plt.grid()
  
  plt.subplot(2, 2, 4)
  plt.title("Score")
  plt.xlabel("Episode")
  plt.ylabel("Score")
  plt.plot(*zip(*train_scores), label="train")
  plt.plot(*zip(*test_scores), label="test")
  plt.legend()
  plt.grid()
  
  plt.ioff()
  plt.savefig(os.path.join(path, "plot.png"))
  
  if is_ipython:
    display.clear_output(wait=True)
    display.display(plt.gcf())

def save_model():
  """
  Save model, hyperparameters, and training info.
  """
  # save model
  torch.save({
    "policy_net": policy_net.state_dict(),
    "target_net": target_net.state_dict()
  }, os.path.join(path, "model.pt"))
  
  # save training options and hyperparameters
  with open(os.path.join(path, "option.json"), "w") as w:
    json.dump(dict([
      ("algorithm", op._asdict()), 
      ("hparam", hp._asdict())
    ]), w, indent=2)
  
  # save training info
  with open(os.path.join(path, "info.json"), "w") as w:
    json.dump(dict([
      ("env", envname), 
      ("test_frames", test_frames), 
      ("test_scores", test_scores), 
      ("test_noises", test_noises), 
      ("steps", steps), 
      ("training_time", (datetime.now() - start_datetime).seconds)
    ]), w, indent=2)

## Training: Optimizing Function

- `optimize_model()`: execute one step of optimizing. We use Huber loss in this notebook.
  
  $$ L(\delta) = \begin{cases} 
    \frac{1}{2}\delta^2 & \text{for } |\delta| \leq 1 \\ 
    |\delta| - \frac{1}{2} & \text{otherwise}
  \end{cases} $$
  
  - Original DQN learns according to equation below. 
  $$ \delta = Q_{policy}(s_t, a_t) - (r_t + \gamma max_a Q_{target}(s_{t+1}, a)) $$
  

  - DDQN learns according to equation below. 
  $$ \delta = Q_{policy}(s_t, a_t) - (r_t + \gamma Q_{target}(s_{t+1}, argmax_{a} Q_{policy}(s_{t+1}, a))) $$
  
  - Multi-step DQN learns according to equation below. 
    - n is hyperparameter for multi-step learning. 
    - $ G_{t:t+n} $ is cumulative rewards for n-steps. 
    - T is timestep when the episode ends. 
  $$ G_{t:t+n} = \sum_{i=0}^{min(n-1, T)} \gamma^{i} r_{t+i} $$
  $$ \delta = Q_{policy}(s_t, a_t) - (G_{t:t+n} + \gamma^{n} max_a Q_{target}(s_{t+n}, a)) $$
  
  - Distributional DQN learns according to equation below. 
    
      ||C51|QR-DQN(used)|
      |-|-|-|
      |Fixes|support|probability value|
      |Learns|quantile fraction|support|
    
    1. Quantile
        - quantile $ \tau_i = \frac{i}{N} \text{ for } i = 1,···,N $ where N is number of quantiles. 
        - in this notebook, we use midpoint $ \hat{\tau_i} = \frac{2(i - 1) + 1}{2N} \text{ for } i = 1,···,N $ which is unique minimizer of Wasserstein Distnace. 
        - quantile regression $ Z_i = F_Z^-1(\tau_i) $
    2. Object
        - objective of QR-DQN is minimizing Wasserstein Distance. In this notebook, we use p = 1. 
        $$ 
        \begin{aligned}
        &W_p(U, Y) = \left( \int_0^1 |F_Y^{-1}(w) - F_U^{-1}(w)|^p dw \right)^{\frac{1}{p}} \\
        &\text{U, Y: probability function} \\
        &\text{F: cumulative distribution function}
        \end{aligned}
        $$
        - Target is claculated as below. 
          $$ T\theta_j \leftarrow r + \gamma * \theta_j(x', a^*) \quad \forall j $$
        - Loss is calculated as below. 
          1. loss increases where target distribution differs to prediction. 
          2. loss increases where right support has small value, or left support has large value. 
          $$ 
          \begin{aligned}
          &Loss = \sum_{i=1}^N E_j\left[ \rho_{\tau_i}(T \theta_j - \theta_i(x, a)) \right] \\
          &\rho_\tau(u) = \begin{cases} 
            Huber(u)(1 - \tau) \quad &\text{for } u < 0 \\ 
            Huber(u)(\tau) &\text{for } u \geq 0
          \end{cases}
          \end{aligned}
          $$
    3. Network
        - output of the network is (BATCH_SIZE, ACTION_NUM, N), which is quantile for each action. 
        - mean Q value is $ Z_\theta(x, a) = \frac{1}{N} \sum_{i=1}^N \delta_{\theta_i}(x, a) $
  
  - Noisy DQN learns according to other options
    - DQN or DDQN
    - 1-step TD or n-step TD
    - General or Distributional q value

In [None]:
def optimize_model():
  """
  Optimize policy model. 
  """
  # if not enough data, quit
  if len(memory) < hp.BATCH_SIZE:
    return
  
  # sample batch
  if op.PER:
    # get from PER memory
    batch, tree_idxs, is_weights = memory.sample(hp.BATCH_SIZE)
  else:
    # get from Simple memory
    batch = memory.sample(hp.BATCH_SIZE)
  
  # get training data from the batch
  state_batch = torch.cat(batch.state)
  action_batch = torch.cat(batch.action)
  reward_batch = torch.cat(batch.reward)
  next_state_batch = torch.cat(batch.next_state)
  done_batch = torch.cat(batch.done)
  
  if op.DISTRIBUTIONAL:
    # calc Z(s_t, a)
    current_z = policy_net(state_batch).gather(2, action_batch.unsqueeze(2).expand(hp.BATCH_SIZE, hp.DIST_N, 1))
    
    if op.DDQN:
      # calc r_t + gamma * Z_target(s_t+1, argmax_a Q(s_t+1, a))
      with torch.no_grad():
        next_action = torch.argmax(torch.mean(policy_net(next_state_batch), dim=1, keepdim=False), dim=1, keepdim=True)
        next_z = target_net(next_state_batch).gather(2, next_action.unsqueeze(2).expand(hp.BATCH_SIZE, hp.DIST_N, 1))
    else:
      # calc r_t + gamma * Z_target(s_t+1, argmax_a Q_target(s_t+1, a))
      with torch.no_grad():
        next_z = target_net(next_state_batch)
        next_action = torch.argmax(torch.mean(next_z, dim=1, keepdim=False), dim=1, keepdim=True)
        next_z = next_z.gather(2, next_action.unsqueeze(2).expand(hp.BATCH_SIZE, hp.DIST_N, 1))
    expected_z = reward_batch.unsqueeze(2) + (hp.GAMMA ** hp.MULTISTEP if op.MULTISTEP else hp.GAMMA) \
      * (1.0 - done_batch.unsqueeze(2)) * next_z
    expected_z = torch.transpose(expected_z, 1, 2)
    
    # calculate loss (dim=1: prediction dim, dim=2: target dim)
    criterion = nn.HuberLoss(reduction="none")
    diff = expected_z - current_z
    loss = criterion(current_z, expected_z)
    loss *= torch.abs(tau_hats - (diff < 0).float())
    loss = torch.mean(torch.sum(loss, dim=1, keepdim=False), dim=1, keepdim=True)
    
    if op.PER:
      # calculate TD error
      error = torch.mean(torch.sum(torch.abs(diff), dim=1, keepdim=False), dim=1, keepdim=True)
      
      # using calculated errors, update priorities of transitions
      for idx in range(hp.BATCH_SIZE):
        memory.update(tree_idxs[idx], error[idx].item())

      # calc loss weighted by is_weight
      is_weights = torch.tensor(is_weights, device=device).unsqueeze(1)
      loss = torch.mul(loss, is_weights)
      loss = torch.mean(loss)
    else:
      # get huber loss
      loss = torch.mean(loss)
  else:
    # calc Q(s_t, a)
    current_q = policy_net(state_batch).gather(1, action_batch)
    
    if op.DDQN:
      # calc r_t + gamma * Q_target(s_t+1, argmax_a Q(s_t+1, a))
      with torch.no_grad():
        next_action = torch.argmax(policy_net(next_state_batch), dim=1, keepdim=True)
        next_q = target_net(next_state_batch).gather(1, next_action)
    else:
      # calc r_t + gamma * max_a Q_target(s_t+1, a)
      with torch.no_grad():
        next_q = target_net(next_state_batch).max(1, keepdim=True).values
    expected_q = reward_batch + (hp.GAMMA ** hp.MULTISTEP if op.MULTISTEP else hp.GAMMA) * (1.0 - done_batch) * next_q
  
    # calculate loss
    if op.PER:
      # calculate TD error
      error = torch.abs(current_q - expected_q)
      
      # using calculated errors, update priorities of transitions
      for idx in range(hp.BATCH_SIZE):
        memory.update(tree_idxs[idx], error[idx].item())

      # calc loss weighted by is_weight
      criterion = nn.HuberLoss(reduction="none")
      is_weights = torch.tensor(is_weights, device=device).unsqueeze(1)
      loss = criterion(current_q, expected_q)
      loss = torch.mul(loss, is_weights)
      loss = torch.mean(loss)
    else:
      # get huber loss
      criterion = nn.HuberLoss()
      loss = criterion(current_q, expected_q)
  
  loss_value = loss.item()
  
  # optimize model
  optimizer.zero_grad()
  loss.backward()
  torch.nn.utils.clip_grad_value_(policy_net.parameters(), 100)
  optimizer.step()
  
  return loss_value

## Training: Testing Function

- `test_model()`: run policy model in the environment with greedy policy. 

In [None]:
def test_model():
  """
  Test policy model and save the result in training variables. 
  """
  # testing variables
  frames = []
  scores = []
  
  # repeat for TEST_EPISODES episodes
  for _ in range(1, hp.TEST_EPISODES + 1):
    # initialize environment and state
    state, _ = env.reset()
    state = torch.tensor(np.array(state), device=device, dtype=torch.float32).unsqueeze(0)
    score = 0
    
    # start an episode
    for frame in count():
      # select greedy action
      action = select_greedy(state)
      
      # act to next state
      observation, reward, terminated, truncated, _ = env.step(action.item())
      score += reward
      done = terminated or truncated
      
      # update state
      state = torch.tensor(np.array(observation), device=device, dtype=torch.float32).unsqueeze(0)

      # check end condition
      if done:
        frames.append(frame)
        scores.append(score)
        break
      
  # return average frames, scores, and noise
  return np.mean(np.array(frames)), np.mean(np.array(scores)), (policy_net.get_noise() if op.NOISY else None)

## Training

In training, simulate agent in the environment to create transitions, and trains the agent using `optimize_model()` function. 

In [None]:
# create training directory
os.makedirs(path)

for episode in tqdm(range(1, hp.NUM_EPISODES + 1)):
  # initialize environment and state
  state, _ = env.reset()
  state = torch.tensor(np.array(state), device=device, dtype=torch.float32).unsqueeze(0)
  score = 0
  
  # start an episode
  for frame in count():
    steps += 1
    
    if op.NOISY:
      # noisy net with greedy policy
      action = select_greedy(state)
    else:
      # select e-greedy action
      action = select_action(state)
    
      # update epsilon
      if steps < hp.EXPLORE_FRAME:
        train_epsilons = [(1, hp.EPS_START), (steps, hp.EPS_START)]
      elif steps < hp.GREEDY_FRAME:
        epsilon -= epsilon_decay
        train_epsilons = [(1, hp.EPS_START), (hp.EXPLORE_FRAME, hp.EPS_START), (steps, epsilon)]
      else:
        train_epsilons = [(1, hp.EPS_START), (hp.EXPLORE_FRAME, hp.EPS_START), (hp.GREEDY_FRAME, epsilon), (steps, epsilon)]
    
    # act to next state
    observation, reward, terminated, truncated, _ = env.step(action.item())
    score += reward
    reward = torch.tensor([[reward]], device=device, dtype=torch.float32)
    done = terminated or truncated
    
    # get next state
    next_state = torch.tensor(np.array(observation), device=device, dtype=torch.float32).unsqueeze(0)

    # multi-step
    if op.MULTISTEP:
      # append transition to buffer
      multistep_buffer.state.append(state)
      multistep_buffer.action.append(action)
      multistep_buffer.reward.append(reward)
      multistep_buffer.next_state.append(next_state)
      multistep_buffer.done.append(done)
            
      # if buffer full and not done, calculate n-step overall transition
      available = len(multistep_buffer.state) == hp.MULTISTEP \
        and (True not in [multistep_buffer.done[i] for i in range(hp.MULTISTEP - 1)])
      if available:
        state_total = multistep_buffer.state[0]
        action_total = multistep_buffer.action[0]
        reward_total = multistep_buffer.reward
        reward_total = sum([(hp.GAMMA ** i) * r for i, r in enumerate(reward_total)])
        next_state_total = multistep_buffer.next_state[-1]
        done_total = multistep_buffer.done[-1]
      else:
        state_total = None
        action_total = None
        reward_total = None
        next_state_total = None
        done_total = None  
    else:
      # no multistep: just use 1-td
      state_total = state
      action_total = action
      reward_total = reward
      next_state_total = next_state
      done_total = done
    
    if state_total is not None:
      # add transition to the memory
      if op.PER:
        if op.DISTRIBUTIONAL:
          # calculate TD error
          with torch.no_grad():
            # calc Z(s_t, a)
            current_z = policy_net(state_total).gather(2, action_total.unsqueeze(2).expand(1, hp.DIST_N, 1))
            
            if done_total:
              expected_z = reward_total.unsqueeze(2)
            else:
              if op.DDQN:
                # calc r_t + gamma * Z_target(s_t+1, argmax_a Q(s_t+1, a))
                next_action = torch.argmax(torch.mean(policy_net(next_state_total), dim=1, keepdim=False), dim=1, keepdim=True)
                next_z = target_net(next_state_total).gather(2, next_action.unsqueeze(2).expand(1, hp.DIST_N, 1))
              else:
                # calc r_t + gamma * Z_target(s_t+1, argmax_a Q_target(s_t+1, a))
                next_z = target_net(next_state_total)
                next_action = torch.argmax(torch.mean(next_z, dim=1, keepdim=False), dim=1, keepdim=True)
                next_z = next_z.gather(2, next_action.unsqueeze(2).expand(1, hp.DIST_N, 1))
              expected_z = reward_total.unsqueeze(2) + (hp.GAMMA ** hp.MULTISTEP if op.MULTISTEP else hp.GAMMA) * next_z
              expected_z = torch.transpose(expected_z, 1, 2)
            
            # calculate TD error (dim=1: prediction dim, dim=2: target dim)
            error = torch.mean(torch.sum(torch.abs(expected_z - current_z), dim=1, keepdim=False), dim=1, keepdim=True).item()
        else:
          # calculate TD error
          with torch.no_grad():
            # current Q_current
            current_q = policy_net(state_total).gather(1, action_total)
            
            # calculate Q_expected = r + gamma * Q_next
            if done_total:
              expected_q = reward_total
            else:
              if op.DDQN:
                next_action = torch.argmax(policy_net(next_state_total), dim=1, keepdim=True)
                next_q = target_net(next_state_total).gather(1, next_action)
              else:
                next_q = target_net(next_state_total).max(1, keepdim=True).values
              expected_q = reward_total + (hp.GAMMA ** hp.MULTISTEP if op.MULTISTEP else hp.GAMMA) * next_q
            
            # calculate td error
            error = torch.abs(current_q - expected_q).item()
          
        # push transition to priority memory (with calculated priority)
        done_tensor = torch.tensor([[1.0 if done_total else 0.0]], device=device)
        memory.push(error, state_total, action_total, reward_total, next_state_total, done_tensor)
      else:
        # push transition to simple memory
        done_tensor = torch.tensor([[1.0 if done_total else 0.0]], device=device)
        memory.push(state_total, action_total, reward_total, next_state_total, done_tensor)
      
    # update state
    state = next_state
    
    # if time to update policy net, optimize
    if steps % hp.POLICY_UPDATE_FREQ == 0:
      loss = optimize_model()
      train_losses.append((steps // hp.POLICY_UPDATE_FREQ, loss))
    
    # if time to update target net, synchronize
    if steps % hp.TARGET_UPDATE_FREQ == 0:
      target_net.load_state_dict(policy_net.state_dict())
    
    # check end condition
    if done:
      train_frames.append((episode, frame))
      train_scores.append((episode, score))
      if episode % hp.TEST_FREQ == 0:
        # add to training variables
        mean_frame, mean_score, noise = test_model()
        test_frames.append((episode, mean_frame))
        test_scores.append((episode, mean_score))
        if noise:
          test_noises.append((episode, noise))
        
        # plot and save model
        save_plot()
        save_model()
      break

env.close()
save_plot()
save_model()

## Test

In this block, trained agent plays in the environment. We can see rendered environment played by the agent. 

In [None]:
env = gym.make(envname, render_mode="human")
#env = gym.make(envname)
env = preprocess_env(env)

scores = []

# if you want to load from trained model, edit this (editable)
load_dirname = None

if load_dirname is not None:
  # load models
  path = os.path.join(os.getcwd(), "dqn", load_dirname)
  checkpoint = torch.load(os.path.join(path, "model.pt"), map_location=device)
  
  policy_net.load_state_dict(checkpoint["policy_net"])
  target_net.load_state_dict(checkpoint["target_net"])

# repeat for TEST_EPISODES episodes
for episode in range(1, hp.TEST_EPISODES + 1):
  # initialize environment and state
  state, _ = env.reset()
  state = torch.tensor(np.array(state), device=device, dtype=torch.float32).unsqueeze(0)
  score = 0
  
  # start an episode
  for _ in count():
    # select greedy action
    action = select_greedy(state)
    
    # act to next state
    observation, reward, terminated, truncated, _ = env.step(action.item())
    score += reward
    done = terminated or truncated
    
    # update state
    state = torch.tensor(np.array(observation), device=device, dtype=torch.float32).unsqueeze(0)

    # check end condition
    if done:
      print(f"Episode {episode}: {score}")
      scores.append(score)
      break

env.close()

print(f"Average: {sum(scores) / hp.TEST_EPISODES}")
print(f"Max: {max(scores)}")
print(f"Min: {min(scores)}")