Prioritized Experience Replay is a type of experience replay in reinforcement learning where we In more frequently replay transitions with high expected learning progress, as measured by the magnitude of their temporal-difference (TD) error.

https://arxiv.org/abs/1312.5602

In [3]:
import numpy as np
import random
from collections import deque

In [None]:
class ReplayMemory:
    def __init__(self,max_memory_size):
        self.replay_memory = deque(maxlen=max_memory_size)
        
    def add(self, experience):
        '''Adds tuple(s,a,r,ns,d) to the replay memory deque'''
        self.replay_memory.append(experience)
        
    def sample(self, batch_size):
        if len(self.replay_memory) > batch_size:
            batch = random.choices(self.replay_memory, k=batch_size)
        else:
            batch = self.replay_memory     
        return batch

![image.png](attachment:image.png)

In [51]:
class PrioritizedReplayMemory:
    def __init__(self,max_memory_size):
        self.max_memory_size = max_memory_size
        self.replay_memory = deque(maxlen=self.max_memory_size)
        self.priorities = deque(maxlen=self.max_memory_size)
        
        self.epsilon = 0.01
        self.alpha = 0.5 # alpha = 0 corresponding to the uniform case
        self.beta = 0.5  # should steadily increase up to 1 to the end of agent training
        
               
    def add(self, experience):
        '''Adds tuple(s,a,r,ns,d) to the replay memory deque.
        During exploration, the pi terms are not known for brand-new samples 
        because those have not been evaluated with the networks to get a TD error term. 
        To get around this, PER initializes pi according to the maximum priority of any priority thus far, 
        thus favoring those terms during sampling later.'''
        self.replay_memory.append(experience)        
        self.priorities.append(max(self.priorities,default=1))
        
        
    def sample(self, batch_size):        
        if len(self.replay_memory) > batch_size:
            probabilities = self.get_probabilities()
            batch_indices = random.choices(self.replay_memory, k=batch_size, weights=probabilities) 
            batch = np.array(self.replay_memory)[batch_indices]
        
        else:
            batch = self.replay_memory
            
        is_weights = compute_ISw(batch_probabilities)    

        return batch, is_weights
    
    def get_probabilities(self):
        '''Returns numpy array'''
        scaled_priorities = np.array(self.priorities) ** self.alpha        
        return scaled_priorities / sum(scaled_priorities)
    
    def compute_ISw(self, batch_probabilities):
        weights = (len(self.replay_memory) * batch_probabilities)**-self.beta
        # normalize Importance Sampling weights
        weights = weights / max(weights)        
        return weights
    
    def update_priorities(self, batch_indices, errors):
        errors = np.array(errors) + self.epsilon        
        priorities = np.array(self.priorities,dtype=float)
        
        priorities[batch_indices] = errors        
        self.priorities = deque(priorities,maxlen=self.max_memory_size) 
