In this notebook we will implement the priority replay buffer as described in Schaul et. al.'s [paper](https://arxiv.org/abs/1511.05952), where experiences are assigned a priority based on their TD error ($\delta$): experiences with higher $\delta$ are given a higher priority. New experiences for which we have not yet evaluated a TD error will be assigned the maximimal priority seen so far. In the paper they described two variants of prioritized replay:

- Proportional prioritization : ```p_i = delta_i + \epsilon ```, " where ```epsilon``` is a small positive constant that prevents the edge-case of transitions not being revisited once their error is zero "

- Rank based prioritization: ```p_i = 1/rank_i``` where the rank is assigned to an experience based on its position if the memories were sorted in order of decreasing TD error.  

Schaul et. al. report that on most of the games in the Atari 2600 suite both the Proportional and the Rank based prioritization have similar performance (although, " there are games where one of them remains close to the Double DQN baseline while the other one leads to a big boost, for example Double Dunk or Surround for the rank-based variant, and Alien, Asterix, Enduro, Phoenix or Space
Invaders for the proportional variant. ")

In this notebook, we  will work on implementing "Proportional prioritization" based on a sum-tree.

A sum-tree is  a binary tree where the value of the parent node is equal to the value of its two children. In our case, the leaves of the tree will correspond to the indices of the the memory buffer. The values of the leaves will correspond to the priority of the experience being stored at corresponding index in the buffer. 

In [1]:
import numpy as np
from operator import itemgetter

In [2]:
class sum_tree():
    def __init__(self, num_memories):
        self.num_memories = num_memories
        self.num_levels = int(np.ceil(np.log2(num_memories)+1))
        self.num_vertices = int(2**self.num_levels-1)
        self.tree_array = np.zeros(shape = self.num_vertices) # the 0-th entry is the root; children of i: 2i+1, 2i+2
        
    def get_children(self, i):
        # get the values of the children of the i-th node
        try:
            return [self.tree_array[2*i+1], self.tree_array[2*i+2]]
        except:
            return [None, None]
    
    def get_parent(self, i):
        # get the value of the parent of the i-th node
        return self.tree_array[(i-1)//2]
    
    def get_value(self, i):
        # get the value of i-th node
        return self.tree_array[i]
    
    def update_val(self, i, val):
        # update the value of the i-th node
        # This will also require us to update the value of its parent in order to maintian the sum-tree property
        self.tree_array[i] = val
        if not i==0:
            parent = (i-1)//2
            new_parent_val = sum(self.get_children(parent))
            self.update_val(parent, new_parent_val)        
            
    def get_sample_id(self, priority, current_index = 0):
        # get a sample corresponding to the input priority by traversing the tree starting from node at current_index
        
        # print(priority, current_index)
        if priority > self.get_value(current_index):
            raise ValueError("priority should be less than value of current index")
            
        left_c, right_c = self.get_children(current_index)
        
        if left_c == None:
            # we have reached the leaf node
            return current_index
        
        if priority <= left_c:
            sample_id = self.get_sample_id(priority, 2*current_index+1)
            return sample_id
        else:    
            sample_id = self.get_sample_id(priority-left_c, 2*current_index+2)
            return sample_id
        
    def max_leaf_value(self):
        # get the maximum value amongsts all the leaves
        return max(self.tree_array[2**(self.num_levels-2)+1:])
        

In [3]:
tree = sum_tree(4)
print(tree.num_levels, tree.num_vertices, tree.tree_array)

3 7 [0. 0. 0. 0. 0. 0. 0.]


In [4]:
leaf_indices = range(3,7)
priorities = list(map(lambda x: 2*x, leaf_indices))
priorities

[6, 8, 10, 12]

In [5]:
for ind, val in zip(leaf_indices, priorities):
    tree.update_val(ind, val)

In [6]:
tree.tree_array

array([36., 14., 22.,  6.,  8., 10., 12.])

In [7]:
tree.get_children(3)

[None, None]

In [8]:
sample_id = tree.get_sample_id(20)
sample_id

5

In [9]:
tree.max_leaf_value()

12.0

In [10]:
tree.tree_array[2**(tree.num_levels-2)+1:]

array([ 6.,  8., 10., 12.])

In [11]:
priority_sum = tree.get_value(0)
num_samples = 100000
sample_priorities = (priority_sum)*np.random.random(num_samples)
#print(sample_priorities)
sample_ids = list(map(lambda val: tree.get_sample_id(val), sample_priorities))
#print(sample_ids)

In [12]:
from collections import Counter
sample_counts = Counter(sample_ids)
print('sample counts: {}'.format(sample_counts))
probs = dict(map(lambda KeyValue: (KeyValue[0], KeyValue[1]/num_samples), sample_counts.items()))
print('probability of samples: {}'.format(probs))
expected_probs = dict(map(lambda KeyValue: (KeyValue[0], tree.get_value(KeyValue[0])/priority_sum), sample_counts.items()))
print('expected probability of samples: {}'.format(expected_probs))

sample counts: Counter({6: 33040, 5: 27974, 4: 22362, 3: 16624})
probability of samples: {6: 0.3304, 4: 0.22362, 3: 0.16624, 5: 0.27974}
expected probability of samples: {6: 0.3333333333333333, 4: 0.2222222222222222, 3: 0.16666666666666666, 5: 0.2777777777777778}


In [13]:
class PrioritizedReplay():
    def __init__(self, buffer_size, n_states, n_actions, roll_out, n_agents, alpha = 0.6, epsilon = 0.0001):
        self.memory = [None]*buffer_size
        self.buffer_size = buffer_size
        self.n_agents = n_agents
        self.n_states = n_states
        self.n_actions = n_actions
        self.roll_out = roll_out # roll_out = 1 corresponds to a single step
        self.alpha = alpha # this is the exponent \alpha in eq.1 of the Prioritized Replay paper 
        self.epsilon = epsilon # this the epsilon  that is added to priorities to avoid edge-case issues
                               # see the discussion below eq. 1 in Prioritized Replay paper
        
        # length of an array containg a single memory of any one player
        self.experience_length = 2*n_states+n_actions+roll_out+1 
        
        # index the in memory where the next experience is to be added
        # runs from 0 to buffer_size-1 after which it resets to zero
        self.new_exp_idx = 0
        
        # sum tree for the priorities
        self.priority_tree = sum_tree(buffer_size)
        self.first_leaf_idx = 2**(self.priority_tree.num_levels-1)-1 # index of the first leaf in priority_tree.tree_array 
        
    def add(self, experience_tuple):
        # add a new experience to the memory
        # each tuple consists of (n-1)-steps of state, action, reward, done and the n-state
        # here n is the roll_out length
        self.memory[self.new_exp_idx] = experience_tuple
        
        # get the maximal priority of all experiences in the buffer
        priority = max([*self.priority_tree.tree_array[self.first_leaf_idx:], self.epsilon])
        
        # now update the priority in the priority_tree
        tree_index = self.new_exp_idx + self.first_leaf_idx
        self.priority_tree.update_val(tree_index, priority)
        
        # move the new experience index to the next position
        self.new_exp_idx = (self.new_exp_idx+1)%self.buffer_size
    
    def update_priority(self, TDErrors, experience_idxs):
        # *********** To Implement **************
        # update the priorities of the given experiences
        # priority = abs(TDError)**self.alpha + self.epsilon
        # experience_idxs are the indices of the experiences in self.memory whose priority has to be updated
        # TDErrors are the latest TDErrors of those experiences
        for TDError, idx in zip(TDErrors, experience_idxs):
            priority = abs(TDError)**self.alpha + self.epsilon
            # now update the priority in the priority_tree
            tree_index = idx + self.first_leaf_idx
            self.priority_tree.update_val(tree_index, priority)
            
    def sample(self, batch_size):
        
        # get sum of all priorities 
        priority_sum = self.priority_tree.get_value(0)
        sample_priorities = (priority_sum)*np.random.random(batch_size)
        
        
        batch_idxs = np.array(list(map(lambda val: 
                                      self.priority_tree.get_sample_id(val) - self.first_leaf_idx, sample_priorities)))
        
        batch = np.stack(list(itemgetter(*batch_idxs)(self.memory)))
        print(batch)
        expected_batch_shape = (batch_size, self.n_agents, self.experience_length)
        
        assert batch.shape == expected_batch_shape, 'Shape of the batch is not same as expected. Got: {}, expected: {}!'.format(batch.shape, expected_batch_shape)
        
        states0_batch = batch[:,:,:self.n_states] # shape = (batch_size, n_agents, n_states)
        actions0_batch = batch[:,:, self.n_states: self.n_states+self.n_actions].reshape(batch_size, -1)
        # shape = (batch_size, n_agents*n_actions)
        assert actions0_batch.shape == (batch_size, self.n_agents*self.n_actions),  'actions0 shape is incorrect'
        
        rewards_batch = batch[:,:,self.n_states+self.n_actions:self.n_states+self.n_actions+self.roll_out] # shape = (batch_size, n_agents, roll_out)
        dones = batch[:,0,self.n_states+self.n_actions+self.roll_out:self.n_states+self.n_actions+self.roll_out+1]
        # shape = (batch_size, 1)
        states_fin_batch = batch[:,:,self.n_states+self.n_actions+self.roll_out+1:] # shape = (batch_size, n_agents, n_states)
        
        return  states0_batch, actions0_batch, rewards_batch, dones, states_fin_batch, batch_idxs    

In [14]:
buffer_size = 2**3
n_states = 2
n_actions = 2
roll_out = 1
n_agents = 1
replay_buffer = PrioritizedReplay(buffer_size, n_states, n_actions, roll_out, n_agents)

In [15]:
exp = np.random.random(size = (buffer_size, n_agents, n_states + n_actions + roll_out + 1 + n_states ))

In [16]:
exp[0,0].shape

(8,)

In [17]:
for experience in exp:
    replay_buffer.add(experience)

In [18]:
replay_buffer.priority_tree.tree_array

array([0.0008, 0.0004, 0.0004, 0.0002, 0.0002, 0.0002, 0.0002, 0.0001,
       0.0001, 0.0001, 0.0001, 0.0001, 0.0001, 0.0001, 0.0001])

In [19]:
priority_sum = replay_buffer.priority_tree.get_value(0)
num_samples = 100000
sample_priorities = (priority_sum)*np.random.random(num_samples)
#print(sample_priorities)
sample_ids = list(map(lambda val: replay_buffer.priority_tree.get_sample_id(val), sample_priorities))
#print(sample_ids)

from collections import Counter
sample_counts = Counter(sample_ids)
print('sample counts: {}'.format(sample_counts))
probs = dict(map(lambda KeyValue: (KeyValue[0], KeyValue[1]/num_samples), sample_counts.items()))
print('probability of samples: {}'.format(probs))
expected_probs = dict(map(lambda KeyValue: (KeyValue[0], 
                                            replay_buffer.priority_tree.get_value(KeyValue[0])/priority_sum), sample_counts.items()))
print('expected probability of samples: {}'.format(expected_probs))

sample counts: Counter({9: 12569, 13: 12564, 14: 12539, 10: 12504, 7: 12498, 8: 12495, 11: 12433, 12: 12398})
probability of samples: {8: 0.12495, 11: 0.12433, 7: 0.12498, 12: 0.12398, 14: 0.12539, 9: 0.12569, 10: 0.12504, 13: 0.12564}
expected probability of samples: {8: 0.125, 11: 0.125, 7: 0.125, 12: 0.125, 14: 0.125, 9: 0.125, 10: 0.125, 13: 0.125}


In [20]:
*mems, indxs = replay_buffer.sample(2)

[[[0.56128528 0.36002206 0.63071026 0.47259625 0.82141372 0.47588885
   0.43940773 0.37293861]]

 [[0.73405148 0.11983133 0.90191481 0.67229982 0.40811169 0.11024123
   0.75844642 0.34592707]]]


In [21]:
indxs

array([6, 5])

In [22]:
tderrors = np.random.random(2)
print(tderrors)
replay_buffer.update_priority(tderrors, indxs)

[0.33908088 0.09650215]


In [23]:
replay_buffer.priority_tree.tree_array

array([7.69294561e-01, 4.00000000e-04, 7.68894561e-01, 2.00000000e-04,
       2.00000000e-04, 2.46079433e-01, 5.22815128e-01, 1.00000000e-04,
       1.00000000e-04, 1.00000000e-04, 1.00000000e-04, 1.00000000e-04,
       2.45979433e-01, 5.22715128e-01, 1.00000000e-04])

In [24]:
tderrors**0.6+0.0001

array([0.52271513, 0.24597943])