In this notebook we will implement the priority replay buffer as described in Schaul et. al.'s [paper](https://arxiv.org/abs/1511.05952), where experiences are assigned a priority based on their TD error ($\delta$): experiences with higher $\delta$ are given a higher priority. New experiences for which we have not yet evaluated a TD error will be assigned the maximimal priority seen so far. In the paper they described two variants of prioritized replay:

- Proportional prioritization : ```p_i = delta_i + \epsilon ```, " where ```epsilon``` is a small positive constant that prevents the edge-case of transitions not being revisited once their error is zero "

- Rank based prioritization: ```p_i = 1/rank_i``` where the rank is assigned to an experience based on its position if the memories were sorted in order of decreasing TD error.  

Schaul et. al. report that on most of the games in the Atari 2600 suite both the Proportional and the Rank based prioritization have similar performance (although, " there are games where one of them remains close to the Double DQN baseline while the other one leads to a big boost, for example Double Dunk or Surround for the rank-based variant, and Alien, Asterix, Enduro, Phoenix or Space
Invaders for the proportional variant. ")

In this notebook, we  will work on implementing "Proportional prioritization" based on a sum-tree.

A sum-tree is  a binary tree where the value of the parent node is equal to the value of its two children. In our case, the leaves of the tree will correspond to the indices of the the memory buffer. The values of the leaves will correspond to the priority of the experience being stored at corresponding index in the buffer. 

In [1]:
import numpy as np

In [2]:
class sum_tree():
    def __init__(self, num_memories):
        self.num_memories = num_memories
        self.num_levels = int(np.ceil(np.log2(num_memories)+1))
        self.num_vertices = int(2**self.num_levels-1)
        self.tree_array = np.zeros(shape = self.num_vertices) # the 0-th entry is the root; children of i: 2i+1, 2i+2
        
    def get_children(self, i):
        # get the values of the children of the i-th node
        try:
            return [self.tree_array[2*i+1], self.tree_array[2*i+2]]
        except:
            return [None, None]
    
    def get_parent(self, i):
        # get the value of the parent of the i-th node
        return self.tree_array[(i-1)//2]
    
    def get_value(self, i):
        # get the value of i-th node
        return self.tree_array[i]
    
    def update_val(self, i, val):
        # update the value of the i-th node
        # This will also require us to update the value of its parent in order to maintian the sum-tree property
        self.tree_array[i] = val
        if not i==0:
            parent = (i-1)//2
            new_parent_val = sum(self.get_children(parent))
            self.update_val(parent, new_parent_val)        
            
    def get_sample_id(self, priority, current_index = 0):
        # get a sample corresponding to the input priority by traversing the tree starting from node at current_index
        
        print(priority, current_index)
        if priority > self.get_value(current_index):
            raise ValueError("priority should be less than value of current index")
            
        left_c, right_c = self.get_children(current_index)
        
        print
        if left_c == None:
            # we have reached the leaf node
            return current_index
        
        if priority <= left_c:
            sample_id = self.get_sample_id(priority, 2*current_index+1)
            return sample_id
        else:    
            sample_id = self.get_sample_id(priority-left_c, 2*current_index+2)
            return sample_id
        
    def max_leaf_value(self):
        # get the maximum value amongsts all the leaves
        return max(self.tree_array[2**(self.num_levels-2)+1:])
        

In [3]:
tree = sum_tree(4)
print(tree.num_levels, tree.num_vertices, tree.tree_array)

3 7 [0. 0. 0. 0. 0. 0. 0.]


In [4]:
leaf_indices = range(3,7)
priorities = list(map(lambda x: 2*x, leaf_indices))
priorities

[6, 8, 10, 12]

In [5]:
for ind, val in zip(leaf_indices, priorities):
    tree.update_val(ind, val)

In [6]:
tree.tree_array

array([36., 14., 22.,  6.,  8., 10., 12.])

In [7]:
tree.get_children(3)

[None, None]

In [8]:
sample_id = tree.get_sample_id(20)
sample_id

20 0
6.0 2
6.0 5


5

In [9]:
tree.max_leaf_value()

12.0

In [10]:
tree.tree_array[2**(tree.num_levels-2)+1:]

array([ 6.,  8., 10., 12.])