# Reinforcement Learning Assignment: Gravitar



Recurrent Replay Non-Distributed Deeper Denser DQN (R2ND4) is a Reinforcement Learning Agent that was initially designed as a non-distributed version of R2D2 [1] and was later developed further. R2ND4 uses: 
* Double Q-Learning [4];
* Prioritized Replay Buffer [6] using transition sequences of length 120 [1];
* n-step Bellman Rewards and Targets [4];
* Invertible Value Function Rescaling (forward and inverse) [7];
* A Duelling Network Architecture [3];
* A CNN followed by an LSTM for state encodings (with burnin as per R2D2 [1]);
* A Novel Deeper Denser Architecture using Skip-Connections, inspired by D2RL's [2] findings;
* Gradient Clipping as recommended in [3];
* Frame Stacking;
* Observations resized to 84x84 and turned to greyscale using OpenCV.

For training on a less complex environment or if prioritising faster convergence, the 2nd much wider linear layer from the Value and Advantage networks containing the skip connection to the CNN output can be removed entirely to allow the model to converge to a lower 3,000 rolling mean score after 180k episodes (and reaching 3,500 after 500k episodes).


## References
The findings of the following papers were relied upon for the design of R2ND4:
* Recurrent Experience Replay in Distributed Reinforcement Learning (R2D2) [1]: https://openreview.net/forum?id=r1lyTjAqYX;
* D2RL: Deep Dense Architectures in Reinforcement Learning  [2]: https://arxiv.org/abs/2010.09163;
* Dueling Network Architectures for Deep Reinforcement Learning [3]: https://arxiv.org/abs/1511.06581;
* Rainbow: Combining Improvements in Deep Reinforcement Learning [4]: https://arxiv.org/abs/1710.02298;
* Distributed Prioritized Experience Replay [5]: https://arxiv.org/abs/1803.00933;
* Prioritized Experience Replay [6]: https://arxiv.org/abs/1511.05952;
* Observe and Look Further: Achieving Consistent Performance on Atari [7]: https://arxiv.org/abs/1805.11593.

The codebase below is based upon and borrows from the following sources:
* SEED RL: Scalable and Efficient Deep-RL with Accelerated Central Inference [8]: https://github.com/google-research/seed_rl;
* Reinforcement Learning Assembly (rela) [9]: https://github.com/facebookresearch/rela;
* RL Adventure [10]: https://github.com/higgsfield/RL-Adventure;
* OpenAI Baselines [11]: https://github.com/openai/baselines;
* Distributed Reinforcement Learning [12]: https://github.com/chagmgang/distributed_reinforcement_learning.


## Imports and TensorBoard + CUDA setup

In [1]:
import gym
from gym import spaces
import collections
import random
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import cv2
import operator
from torch.utils.tensorboard import SummaryWriter

# Set the tag for saving and TensorBoard.
tag = "R2ND4"

# Whether to restart training from a checkpoint located at training/{tag}.
restart = False
if not os.path.exists("training"):
    os.makedirs("training")

# Setup TensorBoard to write to runs/{tag}.
writer = SummaryWriter("runs//{}".format(tag))

use_cuda = torch.cuda.is_available()
print(use_cuda)
device   = torch.device("cuda" if use_cuda else "cpu")

False


## Hyperparameters

In [2]:
# Adam optimizer hyperparameters.
learning_rate                = 0.0001
adam_eps                     = 0.001

# Sync rate between target and online networks for Double Q-Learning.
sync_target_every       = 250

# Batch size of length seq_len_with_burn_in sequences of transitions to use for the networks and Replay Buffer.
batch_size              = 64

# Hyperparameters used by the Prioritized Replay Buffer
eta                          = 0.9  # Weighting of max (eta) and mean (1-eta) of TD errors when calculating the priority of a batch post sampling to update priorities within buffer.
priority_exponent            = 0.9  # alpha
importance_sampling_exponent = 0.6  # beta
buffer_limit            = 15000     # Size of buffer.

# Epsilon value used in value function scaling and rescaling (taken from R2D2).
vfEpsilon   = 0.001

# Number of steps used for n-step Bellman rewards and targets.
n_steps                 = 5

# Discount factor for future rewards.
gamma                   = 0.99

# Lengths used for storing sequences of experience into the Replay Buffer (See R2D2 for details).
seq_len                 = 80 + n_steps
l_burnin                = 40
seq_len_with_burn_in    = seq_len + l_burnin
overlap = int(seq_len / 2) # 40
seq_len_with_burn_in_minus_overlap    = seq_len_with_burn_in - overlap

# Minimum length of sequence to submit to the buffer (otherwise will be discarded). This is a personal contribution.
minimumLen = 20

# For handling statistics and videos.
video_every             = 100
print_every             = 25

## Gym environment wrappers

In [3]:
# Taken from RL Adventure's [10] and OpenAI's Baselines [11] Repos.
class EpisodicLifeEnv(gym.Wrapper):
    def __init__(self, env):
        """Make end-of-life == end-of-episode, but only reset on true game over.
        Done by DeepMind for the DQN and co. since it helps value estimation.
        """
        gym.Wrapper.__init__(self, env)
        self.lives = 0
        self.was_real_done  = True

    def step(self, action):
        obs, reward, done, info = self.env.step(action)
        self.was_real_done = done
        # check current lives, make loss of life terminal,
        # then update lives to handle bonus lives
        lives = self.env.unwrapped.ale.lives()
        if lives < self.lives and lives > 0:
            # for Qbert sometimes we stay in lives == 0 condtion for a few frames
            # so its important to keep lives > 0, so that we only reset once
            # the environment advertises done.
            done = True
        self.lives = lives
        return obs, reward, done, info

    def reset(self, **kwargs):
        """Reset only when lives are exhausted.
        This way all states are still reachable even though lives are episodic,
        and the learner need not know about any of this behind-the-scenes.
        """
        if self.was_real_done:
            obs = self.env.reset(**kwargs)
        else:
            # no-op step to advance from terminal/lost life state
            obs, _, _, _ = self.env.step(0)
        self.lives = self.env.unwrapped.ale.lives()
        return obs

# Taken from RL Adventure's [10] and OpenAI's Baselines [11] Repos.
class WarpFrame(gym.ObservationWrapper):
    def __init__(self, env):
        """Warp frames to 84x84 as done in the Nature paper and later work."""
        gym.ObservationWrapper.__init__(self, env)
        self.width = 84
        self.height = 84
        self.observation_space = spaces.Box(low=0, high=255,
            shape=(self.height, self.width, 1), dtype=np.uint8)

    def observation(self, frame):
        frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
        frame = cv2.resize(frame, (self.width, self.height), interpolation=cv2.INTER_AREA)
        return frame[:, :, None]

# Edited from RL Adventure's [10] and OpenAI's Baselines [11] Repos.
class FrameStack(gym.Wrapper):
    def __init__(self, env, k):
        """Stack k last frames.
        Returns lazy array, which is much more memory efficient.
        See Also
        --------
        baselines.common.atari_wrappers.LazyFrames
        """
        gym.Wrapper.__init__(self, env)
        self.k = k
        self.frames = collections.deque([], maxlen=k)
        shp = env.observation_space.shape
        self.observation_space = spaces.Box(low=0, high=255, shape=(k * shp[2], shp[0], shp[1]), dtype=np.uint8)

    def reset(self):
        ob = self.env.reset().transpose(2,0,1)
        for _ in range(self.k):
            self.frames.append(ob)
        return self._get_ob()

    def step(self, action):
        ob, reward, done, info = self.env.step(action)
        self.frames.append(ob.transpose(2,0,1))
        return self._get_ob(), reward, done, info

    def _get_ob(self):
        assert len(self.frames) == self.k
        return LazyFrames(list(self.frames))

# Taken from RL Adventure's [10] and OpenAI's Baselines [11] Repos.
class LazyFrames(object):
    def __init__(self, frames):
        """This object ensures that common frames between the observations are only stored once.
        It exists purely to optimize memory usage which can be huge for DQN's 1M frames replay
        buffers.
        This object should only be converted to numpy array before being passed to the model.
        You'd not believe how complex the previous solution was."""
        self._frames = frames
        self._out = None

    def _force(self):
        if self._out is None:
            self._out = np.concatenate(self._frames, axis=0)
            self._frames = None
        return self._out

    def __array__(self, dtype=None):
        out = self._force()
        if dtype is not None:
            out = out.astype(dtype)
        return out

    def __len__(self):
        return len(self._force())

    def __getitem__(self, i):
        return self._force()[i]

## Gym environment setup

In [4]:
# Make environment
env = gym.make('Gravitar-v0')

# Unlike R2D2's reported results [1], using episodic life (with roll-over LSTM states) resulted in worse performance, so this wrapper is disabled for our model (NOTE: requires tweaking in the acting loop to enable roll-over)
# env = EpisodicLifeEnv(env)

# Resize frame to 84x84 and make grayscale.
env = WarpFrame(env)

# Makes env observation into last 4 frames instead.
env = FrameStack(env, 4)

# Enables video recording.
env = gym.wrappers.Monitor(env, "./video", video_callable=lambda episode_id: (episode_id%video_every)==0,force=True)

num_actions = env.action_space.n

## Setup Reproducible Environment and Action Spaces

In [5]:
seed = 742
torch.manual_seed(seed)
env.seed(seed)
random.seed(seed)
np.random.seed(seed)
env.action_space.seed(seed)

[742]

## Prioritized Replay Buffer Prerequisites

In [6]:
# Taken from RL Adventure's [10] and OpenAI's Baselines [11] Repos.
class SegmentTree(object):
    def __init__(self, capacity, operation, neutral_element):
        """Build a Segment Tree data structure.
        https://en.wikipedia.org/wiki/Segment_tree
        Can be used as regular array, but with two
        important differences:
            a) setting item's value is slightly slower.
               It is O(lg capacity) instead of O(1).
            b) user has access to an efficient `reduce`
               operation which reduces `operation` over
               a contiguous subsequence of items in the
               array.
        Paramters
        ---------
        capacity: int
            Total size of the array - must be a power of two.
        operation: lambda obj, obj -> obj
            and operation for combining elements (eg. sum, max)
            must for a mathematical group together with the set of
            possible values for array elements.
        neutral_element: obj
            neutral element for the operation above. eg. float('-inf')
            for max and 0 for sum.
        """
        assert capacity > 0 and capacity & (capacity - 1) == 0, "capacity must be positive and a power of 2."
        self._capacity = capacity
        self._value = [neutral_element for _ in range(2 * capacity)]
        self._operation = operation

    def _reduce_helper(self, start, end, node, node_start, node_end):
        if start == node_start and end == node_end:
            return self._value[node]
        mid = (node_start + node_end) // 2
        if end <= mid:
            return self._reduce_helper(start, end, 2 * node, node_start, mid)
        else:
            if mid + 1 <= start:
                return self._reduce_helper(start, end, 2 * node + 1, mid + 1, node_end)
            else:
                return self._operation(
                    self._reduce_helper(start, mid, 2 * node, node_start, mid),
                    self._reduce_helper(mid + 1, end, 2 * node + 1, mid + 1, node_end)
                )

    def reduce(self, start=0, end=None):
        """Returns result of applying `self.operation`
        to a contiguous subsequence of the array.
            self.operation(arr[start], operation(arr[start+1], operation(... arr[end])))
        Parameters
        ----------
        start: int
            beginning of the subsequence
        end: int
            end of the subsequences
        Returns
        -------
        reduced: obj
            result of reducing self.operation over the specified range of array elements.
        """
        if end is None:
            end = self._capacity
        if end < 0:
            end += self._capacity
        end -= 1
        return self._reduce_helper(start, end, 1, 0, self._capacity - 1)

    def __setitem__(self, idx, val):
        # index of the leaf
        idx += self._capacity
        self._value[idx] = val
        idx //= 2
        while idx >= 1:
            self._value[idx] = self._operation(
                self._value[2 * idx],
                self._value[2 * idx + 1]
            )
            idx //= 2

    def __getitem__(self, idx):
        assert 0 <= idx < self._capacity
        return self._value[self._capacity + idx]

# Taken from RL Adventure's [10] and OpenAI's Baselines [11] Repos.
class SumSegmentTree(SegmentTree):
    def __init__(self, capacity):
        super(SumSegmentTree, self).__init__(
            capacity=capacity,
            operation=operator.add,
            neutral_element=0.0
        )

    def sum(self, start=0, end=None):
        """Returns arr[start] + ... + arr[end]"""
        return super(SumSegmentTree, self).reduce(start, end)

    def find_prefixsum_idx(self, prefixsum):
        """Find the highest index `i` in the array such that
            sum(arr[0] + arr[1] + ... + arr[i - i]) <= prefixsum
        if array values are probabilities, this function
        allows to sample indexes according to the discrete
        probability efficiently.
        Parameters
        ----------
        perfixsum: float
            upperbound on the sum of array prefix
        Returns
        -------
        idx: int
            highest index satisfying the prefixsum constraint
        """
        assert 0 <= prefixsum <= self.sum() + 1e-5
        idx = 1
        while idx < self._capacity:  # while non-leaf
            if self._value[2 * idx] > prefixsum:
                idx = 2 * idx
            else:
                prefixsum -= self._value[2 * idx]
                idx = 2 * idx + 1
        return idx - self._capacity

# Taken from RL Adventure's [10] and OpenAI's Baselines [11] Repos.
class MinSegmentTree(SegmentTree):
    def __init__(self, capacity):
        super(MinSegmentTree, self).__init__(
            capacity=capacity,
            operation=min,
            neutral_element=float('inf')
        )

    def min(self, start=0, end=None):
        """Returns min(arr[start], ...,  arr[end])"""

        return super(MinSegmentTree, self).reduce(start, end)

## Prioritized Replay Buffer

In [7]:
# Edited from RL Adventure's [10] and OpenAI's Baselines [11] Repos.
class ReplayBuffer(object):
    def __init__(self, size):
        """Create Replay buffer.
        Parameters
        ----------
        size: int
            Max number of transitions to store in the buffer. When the buffer
            overflows the old memories are dropped.
        """
        self._storage = []
        self._maxsize = size
        self._next_idx = 0

    def __len__(self):
        return len(self._storage)

    def push(self, seq):
        if self._next_idx >= len(self._storage):
            self._storage.append(seq)
        else:
            self._storage[self._next_idx] = seq
        self._next_idx = (self._next_idx + 1) % self._maxsize

    def _encode_sample(self, idxes):
        s_lst = torch.from_numpy(np.array([[self._storage[idx][0][time] for idx in idxes] for time in range(seq_len_with_burn_in+1)])).to(device).float()/255.0
        a_lst = torch.tensor([[[self._storage[idx][1][time]] for idx in idxes] for time in range(seq_len_with_burn_in+1)]).to(device)
        r_lst = torch.tensor([[[self._storage[idx][2][time]] for idx in idxes] for time in range(seq_len_with_burn_in+1)]).to(device)
        done_mask_lst = torch.tensor([[[self._storage[idx][3][time]] for idx in idxes] for time in range(seq_len_with_burn_in+1)]).to(device)
        h0, c0 = torch.cat([self._storage[idx][4]["h0"] for idx in idxes]).squeeze(1).unsqueeze(0).to(device), torch.cat([self._storage[idx][4]["c0"] for idx in idxes]).squeeze(1).unsqueeze(0).to(device)
        return s_lst, a_lst, r_lst, done_mask_lst, {"h0": h0, "c0": c0}

    def sample(self, batch_size):
        """Sample a batch of experiences.
        Parameters
        ----------
        batch_size: int
            How many transitions to sample.
        Returns
        -------
        obs_batch: np.array
            batch of observations
        act_batch: np.array
            batch of actions executed given obs_batch
        rew_batch: np.array
            rewards received as results of executing act_batch
        next_obs_batch: np.array
            next set of observations seen after executing act_batch
        done_mask: np.array
            done_mask[i] = 1 if executing act_batch[i] resulted in
            the end of an episode and 0 otherwise.
        """
        idxes = [random.randint(0, len(self._storage) - 1) for _ in range(batch_size)]
        return self._encode_sample(idxes)

# Edited from RL Adventure's [10] and OpenAI's Baselines [11] Repos.
class PrioritizedReplayBuffer(ReplayBuffer):
    def __init__(self, size, alpha):
        """Create Prioritized Replay buffer.
        Parameters
        ----------
        size: int
            Max number of transitions to store in the buffer. When the buffer
            overflows the old memories are dropped.
        alpha: float
            how much prioritization is used
            (0 - no prioritization, 1 - full prioritization)
        See Also
        --------
        ReplayBuffer.__init__
        """
        super(PrioritizedReplayBuffer, self).__init__(size)
        assert alpha > 0
        self._alpha = alpha

        it_capacity = 1
        while it_capacity < size:
            it_capacity *= 2

        self._it_sum = SumSegmentTree(it_capacity)
        self._it_min = MinSegmentTree(it_capacity)
        self._max_priority = 1.0

    def push(self, *args, **kwargs):
        """See ReplayBuffer.store_effect"""
        idx = self._next_idx
        super(PrioritizedReplayBuffer, self).push(*args, **kwargs)
        self._it_sum[idx] = self._max_priority ** self._alpha
        self._it_min[idx] = self._max_priority ** self._alpha

    def _sample_proportional(self, batch_size):
        res = []
        for _ in range(batch_size):
            # TODO(szymon): should we ensure no repeats?
            mass = random.random() * self._it_sum.sum(0, len(self._storage) - 1)
            idx = self._it_sum.find_prefixsum_idx(mass)
            res.append(idx)
        return res

    def sample(self, batch_size, beta=importance_sampling_exponent):
        """Sample a batch of experiences.
        compared to ReplayBuffer.sample
        it also returns importance weights and idxes
        of sampled experiences.
        Parameters
        ----------
        batch_size: int
            How many transitions to sample.
        beta: float
            To what degree to use importance weights
            (0 - no corrections, 1 - full correction)
        Returns
        -------
        obs_batch: np.array
            batch of observations
        act_batch: np.array
            batch of actions executed given obs_batch
        rew_batch: np.array
            rewards received as results of executing act_batch
        next_obs_batch: np.array
            next set of observations seen after executing act_batch
        done_mask: np.array
            done_mask[i] = 1 if executing act_batch[i] resulted in
            the end of an episode and 0 otherwise.
        weights: np.array
            Array of shape (batch_size,) and dtype np.float32
            denoting importance weight of each sampled transition
        idxes: np.array
            Array of shape (batch_size,) and dtype np.int32
            idexes in buffer of sampled experiences
        """
        assert beta > 0

        idxes = self._sample_proportional(batch_size)

        weights = []
        p_min = self._it_min.min() / self._it_sum.sum()
        max_weight = (p_min * len(self._storage)) ** (-beta)

        for idx in idxes:
            p_sample = self._it_sum[idx] / self._it_sum.sum()
            weight = (p_sample * len(self._storage)) ** (-beta)
            weights.append(weight / max_weight)
        weights = torch.tensor(weights, device=device)
        encoded_sample = self._encode_sample(idxes)
        return encoded_sample, [weights, idxes]

    def update_priorities(self, idxes, priorities):
        """Update priorities of sampled transitions.
        sets priority of transition at index idxes[i] in buffer
        to priorities[i].
        Parameters
        ----------
        idxes: [int]
            List of idxes of sampled transitions
        priorities: [float]
            List of updated priorities corresponding to
            transitions at the sampled idxes denoted by
            variable `idxes`.
        """
        assert len(idxes) == len(priorities)
        for idx, priority in zip(idxes, priorities):
            assert priority > 0
            assert 0 <= idx < len(self._storage)
            self._it_sum[idx] = priority ** self._alpha
            self._it_min[idx] = priority ** self._alpha

            self._max_priority = max(self._max_priority, priority)


## Setup Prioritized Replay Buffer

In [8]:
memory = PrioritizedReplayBuffer(buffer_limit, priority_exponent)

## Duelling LSTM Q-Network

In [9]:
# Based off Facebook's RL Assembly [9] Repo.
class DuellingLSTMNet(nn.Module):
    def __init__(self, device, num_action):
        super().__init__()

        # Parameters
        self.frame_stack = 4
        self.conv_out_dim = 3136
        self.hid_dim = 512
        self.num_lstm_layer = 1
        self.num_action = num_action

        # Convolutional state encoder.
        self.net = nn.Sequential(
            nn.Conv2d(self.frame_stack, 32, 8, stride=4),
            nn.ReLU(),
            nn.Conv2d(32, 64, 4, stride=2),
            nn.ReLU(),
            nn.Conv2d(64, 64, 3, stride=1),
            nn.ReLU(),
        ).to(device)

        # LSTM to take in encoded states from CNN (and the hidden state at beginning of sequence) and output a state representation with longer term dependencies and information.
        self.lstm = nn.LSTM(
            self.conv_out_dim, self.hid_dim, num_layers=self.num_lstm_layer
        ).to(device)

        # If passing actions and rewards too (instead) like R2D2 does.
        # self.lstm = nn.LSTM(
        #     self.conv_out_dim + num_action + 1, self.hid_dim, num_layers=self.num_lstm_layer
        # ).to(device)

        # Value function fully connected layers: {
            
          # This design is inspired by the findings of the D2RL paper, which allows our network to become much deeper than usual.
        # Directly connected to the LSTM output.
        self.fc_v_1 = nn.Sequential(
            nn.Linear(self.hid_dim, self.hid_dim),
            nn.ReLU()
        ).to(device)
        
        # Contains a skip connection to the CNN output.
        self.fc_v_2 = nn.Sequential(
             nn.Linear(self.hid_dim + self.conv_out_dim, self.hid_dim),
             nn.ReLU()
        ).to(device)

        # Contains a skip connection to the LSTM output.
        self.fc_v_3 = nn.Sequential(
            nn.Linear(self.hid_dim + self.hid_dim, self.hid_dim),
            nn.ReLU()
        ).to(device)
        
        # Contains a skip connection to the LSTM output.
        self.fc_v_4 = nn.Linear(self.hid_dim + self.hid_dim, 1).to(device)

        # }

        # Advantage function fully connected layers: {
            
          # This design is inspired by the findings of the D2RL paper, which allows our network to become much deeper than usual.
        # Directly connected to the LSTM output.
        self.fc_a_1 = nn.Sequential(
            nn.Linear(self.hid_dim, self.hid_dim),
            nn.ReLU()
        ).to(device)

        # Contains a skip connection to the CNN output.
        self.fc_a_2 = nn.Sequential(
            nn.Linear(self.hid_dim + self.conv_out_dim, self.hid_dim),
            nn.ReLU()
        ).to(device)
        # Contains a skip connection to the LSTM output.
        self.fc_a_3 = nn.Sequential(
            nn.Linear(self.hid_dim + self.hid_dim, self.hid_dim),
            nn.ReLU()
        ).to(device)

        # Contains a skip connection to the LSTM output.
        self.fc_a_4 = nn.Linear(self.hid_dim + self.hid_dim, self.num_action).to(device)

        # }

        self.lstm.flatten_parameters()
    
    def get_h0(self, batchsize):
        """
        Retrieve initial hidden state of LSTM.
        """
        shape = (self.num_lstm_layer, batchsize, self.hid_dim)
        hid = {"h0": torch.zeros(*shape, device=device), "c0": torch.zeros(*shape, device=device)}
        return hid

    def duel(self, v, a):
        """
        Takes in Q-value outputs from Value and Advantage networks, and produces the Duelling Q-Networks outputs.
        """
        q = v + a - a.mean(2, keepdim=True)
        return q

    def _conv_forward(self, s):
        """
        Send observation through CNN.
        """
        assert s.dim() == 4  # [batch, c, h, w]
        x = self.net(s) #state to representaion
        x = x.view(s.size(0), self.conv_out_dim)
        return x

    def advantage(self, o, s):
        """
        Retrieve Q-values for each action from the current state.
        Taken from the D2RL [2] paper's insights and architecture.
        o: LSTM output
        s: CNN output
        """
        a = self.fc_a_1(o)
        a = torch.cat([a, s], dim=2)
        a = self.fc_a_2(a)
        a = torch.cat([a, o], dim=2)
        a = self.fc_a_3(a)
        a = torch.cat([a, o], dim=2)
        a = self.fc_a_4(a)
        return a

    def value(self, o, s):
        """
        Retrieve Q-value for the value of the current state.
        Taken from the D2RL [2] paper's insights and architecture.
        o: LSTM output
        s: CNN output
        """
        v = self.fc_v_1(o)
        v = torch.cat([v, s], dim=2)
        v = self.fc_v_2(v)
        v = torch.cat([v, o], dim=2)
        v = self.fc_v_3(v)
        v = torch.cat([v, o], dim=2)
        v = self.fc_v_4(v)
        return v

    def act(self, obs, a, r, hid, epsilon_greedy=True):
        """
        Retrieve the action an agent(s) should take according to the Duelling Network (in this case simply the Advantage Network), with epsilon greedy policy support.
        Simultaneously return the next hidden state of the LSTM.
        """
        x = self._conv_forward(obs)
        # x: [batch, hid]
        x = x.unsqueeze(0)
        # x: [1, batch, hid]

        # If passing actions and rewards too (instead) like R2D2 does.
        # x = torch.cat([x,a,r], axis = 2)

        o, (h, c) = self.lstm(x, (hid["h0"], hid["c0"]))
        if epsilon_greedy and random.random() < epsilon:
          greedy_action = random.randrange(self.num_action)
        else:
          a = self.advantage(o, x)
          a = a.squeeze(0)
          # a: [batch, num_action]
          legal_a = (1 + a - a.min())
          greedy_action = legal_a.argmax(1).detach()
        
          # If instead using this network to process batches of actions (in a distributed setting for example):
          # if epsilon_greedy:
            # random_actions = torch.randint(low=0, high=self.num_action, size=(batch_size,))
            # probs = torch.rand(batch_size)
            # greedy_action = torch.where(probs < epsilon, random_actions, greedy_action)

        return int(greedy_action), {"h0": h.detach(), "c0": c.detach()}

    def unroll_rnn(self, obs, a, r, hid):
        """
        Send observation through both CNN and LSTM.
        """
        s = obs
        assert s.dim() == 5  # [seq, batch, c, h, w]
        seq, batch, c, h, w = s.size()
        s = s.view(seq * batch, c, h, w)
        # Send through CNN.
        x = self.net(s)
        x = x.view(seq, batch, self.conv_out_dim)

        # If passing actions and rewards too (instead) like R2D2 does.
        # x = torch.cat([x,a,r], axis = 2)

        o, (h, c) = self.lstm(x, (hid["h0"], hid["c0"]))

        # o: LSTM output, x: CNN output.
        return o, x, {"h0": h, "c0": c}

    def forward(self, obs, a, r, hid):
        """
        Send observation through the network and return Q-values.
        return:
            q(s, a): [seq, batch, num_action]
            hid(s_n): [batch] (n=seq, final hidden state, used for LSTM burnin)
        """
        o, x, hid = self.unroll_rnn(obs, a, r, hid)
        # o: LSTM output, x: CNN output.
        a = self.advantage(o, x)
        v = self.value(o, x)
        q = self.duel(v, a)
        return q, hid

## Q-Networks Setup for Double Q Learning

In [10]:
q = DuellingLSTMNet(device, num_actions)
q_target = DuellingLSTMNet(device, num_actions)

## Setup Optimizer for Online Network Backpropagation

In [11]:
optimizer = optim.Adam(q.parameters(), lr=learning_rate, eps=adam_eps)

## Recommence Training from Checkpoint and Sync Networks

In [12]:
start_ep = 0
if restart:
    params = torch.load('training/save{}.chkpt'.format(tag))
    q.load_state_dict(params['q'])
    optimizer.load_state_dict(params["optimizer"])
    start_ep = params["n_episode"]

## Sync Online and Target Q-Networks

In [13]:
q_target.load_state_dict(q.state_dict())

<All keys matched successfully>

## Training Helper Functions

In [14]:
# Edited from Google's SEED RL [8] Repo.
def value_function_rescaling(x):
  """Value function rescaling used in R2D2 paper [1], see table 2, or Proposition A.2 in paper "Observe and Look Further" [7]."""
  return torch.sign(x) * (torch.sqrt(torch.abs(x) + 1.) - 1.) + vfEpsilon * x

# Edited from Google's SEED RL [8] Repo.
def inverse_value_function_rescaling(x):
  """Inverse of the above function. See Proposition A.2 in paper "Observe and Look Further" [7]."""
  return torch.sign(x) * (
      torch.square(((torch.sqrt(
          1. + 4. * vfEpsilon * (torch.abs(x) + 1. + vfEpsilon))) - 1.) / (2. * vfEpsilon)) -
      1.)
  
# Edited from Google's SEED RL [8] Repo.
def n_step_bellman_target(rewards, done_mask, q_target, gamma, n_steps):
  r"""Computes n-step Bellman targets.
  See section 2.3 of R2D2 [1] paper (which does not mention the logic around end of
  episode).
  Args:
    rewards: <float32>[time, batch_size] tensor. This is r_t in the equations
      below.
    done_mask: <bool>[time, batch_size] tensor. This is done_mask_t in the equations
      below. done_mask_t should be false if the episode is done just after
      experimenting reward r_t.
    q_target: <float32>[time, batch_size] tensor. This is Q_target(s_{t+1}, a*)
      (where a* is an action chosen by the caller).
    gamma: Exponential RL discounting.
    n_steps: The number of steps to look ahead for computing the Bellman
      targets.
  Returns:
    y_t targets as <float32>[time, batch_size] tensor.
    When n_steps=1, this is just:
    $$r_t + gamma * (1 - done_t) * Q_{target}(s_{t+1}, a^*)$$
    In the general case, this is:
    $$(\sum_{i=0}^{n-1} \gamma ^ {i} * notdone_{t, i-1} * r_{t + i}) +
      \gamma ^ n * notdone_{t, n-1} * Q_{target}(s_{t + n}, a^*) $$
    where notdone_{t,i} is defined as:
    $$notdone_{t,i} = \prod_{k=h0}^{k=i}(1 - done_mask_{t+k})$$
    The last n_step-1 targets cannot be computed with n_step returns, since we
    run out of Q_{target}(s_{t+n}). Instead, they will use n_steps-1, .., 1 step
    returns. For those last targets, the last Q_{target}(s_{t}, a^*) is re-used
    multiple times.
    However, in this implementation the last n_steps-1 are truncated in the
    training function and are not used.
    
  """
  # We append n_steps - 1 times the last q_target. They are divided by gamma **
  # k to correct for the fact that they are at a 'fake' indice, and will
  # therefore end up being multiplied back by gamma ** k in the loop below.
  # We prepend 0s that will be discarded at the first iteration below.
  bellman_target = torch.cat(
      [torch.zeros_like(q_target[0:1]), q_target] +
      [q_target[-1:] / gamma ** k
       for k in range(1, n_steps)],
      axis=0)
  # Pad with n_steps 0s. They will be used to compute the last n_steps-1
  # targets (having 0 values is important).
  done_mask = torch.cat([done_mask] + [torch.ones_like(done_mask[0:1])] * n_steps, axis=0)
  rewards = torch.cat([rewards] + [torch.zeros_like(rewards[0:1])] * n_steps,
                      axis=0)
  # Iteratively build the n_steps targets. After the i-th iteration (1-based),
  # bellman_target is effectively the i-step returns.
  for _ in range(n_steps):
    rewards = rewards[:-1]
    done_mask = done_mask[:-1]
    bellman_target = (
        rewards + gamma * done_mask * bellman_target[1:])

  return bellman_target

## Training Function

In [15]:
def train(q, q_target, memory, optimizer):
    # If you want to train multiple iterations per episode
    # for _ in range(5):
        (s,a,r,done_mask,hids), (weights, idxes)  = memory.sample(batch_size)
        
        # WARNING: 
        # s also has last s_prime attached!
        # done_mask has 1.0 prepended once!

        # one_hot_a and r_scaled are only used if passing actions and rewards too (instead) like R2D2 does.
        one_hot_a = torch.nn.functional.one_hot(a, num_classes=num_actions).squeeze()
        r_scaled = r/250.0

        # Handle LSTM burnin.
        if l_burnin:
          with torch.no_grad():
              _, online_hid = q(s[:l_burnin],one_hot_a[:l_burnin], r_scaled[:l_burnin], hids)
              _, target_hid = q_target(s[:l_burnin],one_hot_a[:l_burnin], r_scaled[:l_burnin], hids)
        else:
          online_hid = hids 
          target_hid = hids
        s,one_hot_a,r,r_scaled, done_mask, a = s[l_burnin:], one_hot_a[l_burnin:], r[l_burnin:],r_scaled[l_burnin:],  done_mask[l_burnin:], a[l_burnin:]

        # Retrieve online network's Q-values.
        q_out, _ = q(s,one_hot_a,r_scaled, online_hid)

        # Get Q-value for maximum action.
        q_a = q_out[:-1].gather(2,a[1:])*done_mask[:-1] #replay_q
        # Truncate last n_steps-1 Q-values where the n-step Bellman targets are "incomplete".
        q_a = q_a[:-n_steps+1]

        # Compute n-step Bellman targets.
        with torch.no_grad():
          # Double Q-Learning {
          # This means using argmax from online Q-network, but corresponding Q-values from target network to increase stability
          greedy_action = q_out[1:].max(2)[1].unsqueeze(2)
          target_q_prime, _ = q_target(s,one_hot_a,r_scaled, target_hid)
          target_q_prime = target_q_prime[1:]
          max_q_prime = target_q_prime.gather(2, greedy_action) * done_mask[1:]
          # }

          max_q_prime = inverse_value_function_rescaling(max_q_prime)
          
          target = n_step_bellman_target(r[1:], done_mask[1:], max_q_prime, gamma, n_steps)
          
          target = value_function_rescaling(target)[:-n_steps+1]

        # Calculate TD errors and batch-wise weighted loss.
        abs_td_errors = torch.abs(q_a - target.detach()).float()
        loss = 0.5 * abs_td_errors.square().sum(0)
        loss= (loss * weights).mean()

        # Perform an optimizer step using clipped gradients, as per Dueling Network Architectures for Deep Reinforcement Learning.
        optimizer.zero_grad()
        loss.backward()
        nn.utils.clip_grad_norm_(q.parameters(), 10)
        optimizer.step()

        # Update Prioritized Replay Buffer Priorities.
        with torch.no_grad():
          priorities = torch.max(abs_td_errors, axis=0).values * eta + torch.mean(abs_td_errors, axis=0) * (1 - eta)
          memory.update_priorities(idxes, priorities)

        # Clean up memory
        del s,a,r,done_mask,hids, q_out, q_a, greedy_action, max_q_prime, target, abs_td_errors, priorities, loss, weights, idxes

## Setup Temporary and Helper Variables

In [16]:
# Defined for memory efficiency of the Replay Buffer.
zeroFloatList = [0.0]
zeroIntList   = [0]

# Variables for statistic printing.
score         = 0.0
marking       = []

## Acting and Learning Loop

In [None]:
for n_episode in range(start_ep, int(1e32)):
# 3 stage linear epsilon annealing:
    if n_episode < 200000:
        # Linear annealing from 50% to 2%, then constant from episode 100k -> 200k.
        epsilon = max(0.02, 0.50 - 0.01*(n_episode/2000), 0) 
    else:
        # Linear annealing from 2% to 0% from episode 200k
        epsilon = max(0,  0.02 - 0.02*(n_episode-200000)/200000)
    
    s = env.reset()

    # Perform 1 random action at beginning of episode.
    a = random.randrange(num_actions)
    s, r, done, _ = env.step(a)
    if done:
        continue
    s_list,a_list,r_list, done_mask_list = [], [a], [r], [1.0]
   
    # Get LSTM initial hidden state.
    seq1_initial_episode_hid = q.get_h0(1)
    hid = seq1_initial_episode_hid

    # Reset variables.
    done = False
    score = 0.
    current_seq1_len = 0
    notFirstSeq = False

    with torch.no_grad():
        while True:
            current_seq1_len += 1
            
            # Retrieve action to take from online Q-network.
            a, hid = q.act(torch.from_numpy(np.array(s)).to(device).unsqueeze(0).float()/255.0, torch.nn.functional.one_hot(torch.tensor([a]).view(1,1), num_classes=num_actions).to(device), torch.tensor([r]).view(1,1,1).to(device)/250.0, hid)
            
            # Step environment.
            s_prime, r, done, _ = env.step(a)
            
            done_mask = 0.0 if done else 1.0
            s_list.append(s)
            a_list.append(a)
            r_list.append(r)
            done_mask_list.append(done_mask)

            score += r

            if done:
              deltaEndSeq0 = seq_len_with_burn_in_minus_overlap - current_seq1_len
              if notFirstSeq:
                # Pad end of episode with zeros and dones and build seq0.
                s0_list = s_list + [s_prime] * (deltaEndSeq0+1)
                a0_list = a_list + zeroIntList * deltaEndSeq0
                r0_list = r_list + zeroFloatList * deltaEndSeq0
                done0_mask_list = done_mask_list + zeroFloatList * deltaEndSeq0
                
                # NOTE: seq0 is always long enough if this is not the first sequence (chunk) of the episode.
                # Submit longer sequence seq0 into buffer.
                memory.push((s0_list,a0_list,r0_list, done0_mask_list, seq0_initial_episode_hid))
                
                # If seq1 is long enough.
                if current_seq1_len > l_burnin + minimumLen:
                  # Pad end of episode with zeros and dones and build seq1.
                  s1_list = s0_list[overlap:] + [s0_list[-2]] * overlap
                  a1_list = a0_list[overlap:] + zeroIntList * overlap
                  r1_list = r0_list[overlap:] + zeroFloatList * overlap
                  done1_mask_list = done0_mask_list[overlap:] + zeroFloatList * overlap

                  # Submit shorter sequence seq1 into buffer.
                  memory.push((s1_list,a1_list,r1_list, done1_mask_list, seq1_initial_episode_hid))
              
              # If this is the first sequence (chunk) of the episode.
              elif current_seq1_len > l_burnin + minimumLen:
                deltaEndSeq1 = deltaEndSeq0 + overlap
                # Pad end of episode with zeros and dones and build seq1.
                s1_list = s_list + [s_prime] * (deltaEndSeq1+1)
                a1_list = a_list + zeroIntList * deltaEndSeq1
                r1_list = r_list + zeroFloatList * deltaEndSeq1
                done1_mask_list = done_mask_list + zeroFloatList * deltaEndSeq1

                # Submit the only sequence into buffer.
                memory.push((s1_list,a1_list,r1_list, done1_mask_list, seq1_initial_episode_hid))
              break
            
            # When longer of 2 sequences (seq0) reaches correct length, submit to buffer and handle sequence switching.
            if seq_len_with_burn_in_minus_overlap == current_seq1_len: # If seq0len = seq_len_with_burn_in : handle switch over
              # Unless seq0 does not yet exist.
              if notFirstSeq:
                memory.push((s_list[:seq_len_with_burn_in] + [s_prime],a_list[:seq_len_with_burn_in+1],r_list[:seq_len_with_burn_in+1], done_mask_list[:seq_len_with_burn_in+1], seq0_initial_episode_hid))
                s_list,a_list,r_list, done_mask_list = s_list[overlap:], a_list[overlap:], r_list[overlap:],  done_mask_list[overlap:]

              current_seq1_len -= overlap
              notFirstSeq = True
              seq0_initial_episode_hid = seq1_initial_episode_hid
              seq1_initial_episode_hid = hid

            # Once there is nothing left to do set current state as the previous next state ready for next iteration.
            s = s_prime

    # If the buffer is big enough, begin training at the end of each new episode.
    if len(memory)>4000:  
        train(q, q_target, memory, optimizer)

    marking.append(score)

    # Sync target and online Q-networks periodically.
    if n_episode%sync_target_every==0 and n_episode!=0:
        q_target.load_state_dict(q.state_dict())

    # Save model checkpoint periodically.
    if n_episode%10000 == 0:
        torch.save({'q':q.state_dict(), 'optimizer':optimizer.state_dict(), 'n_episode':n_episode}, 'training/save{0}.chkpt'.format(tag))
    
    # Handle statistics printing and empty cache.
    if n_episode%100 == 0:
        print("marking, episode: {}, score: {:.1f}, mean_score: {:.2f}, std_score: {:.2f}".format(
            n_episode, score, np.array(marking).mean(), np.array(marking).std()))
        
        # TensorBoard integration.
        writer.add_scalar('score',
                           np.array(marking).mean(),
                            n_episode)
        writer.add_scalar("eps", epsilon, n_episode)
        
        marking = []
        torch.cuda.empty_cache()

    if n_episode%print_every==0 and n_episode!=0:
        print("episode: {}, score: {:.1f}, epsilon: {:.2f}".format(n_episode, score, epsilon))

marking, episode: 0, score: 0.0, mean_score: 0.00, std_score: 0.00
