Prioritized Experience Replay is a type of experience replay in reinforcement learning where we In more frequently replay transitions with high expected learning progress, as measured by the magnitude of their temporal-difference (TD) error.

https://arxiv.org/abs/1312.5602

In [1]:
import numpy as np
import random
from collections import deque

In [None]:
class ReplayMemory:
    def __init__(self,max_memory_size):
        self.replay_memory = deque(maxlen=max_memory_size)
        
    def add(self, experience):
        '''Adds tuple(s,a,r,ns,d) to the replay memory deque'''
        self.replay_memory.append(experience)
        
    def sample(self, batch_size):
        if len(self.replay_memory) > batch_size:
            batch = random.choices(self.replay_memory, k=batch_size)
        else:
            batch = self.replay_memory     
        return batch

    
    
    
class PrioritizedReplayMemory:
    def __init__(self,max_memory_size):
        self.replay_memory = deque(maxlen=max_memory_size)
        self.priorities = deque(maxlen=max_memory_size)
        
        self.epsilon = 1e-8
        self.alpha = 0.5 #(if 0 - random sampling)
        
        
    def get_probabilities(self):
        scaled_priorities = np.array(self.priorities) ** self.alpha
        probability_batch = scaled_priorities / sum(scaled_priorities)
        return probability_batch
    
        
    def add(self, experience):
        '''Adds tuple(s,a,r,ns,d) to the replay memory deque'''
        self.replay_memory.append(experience)
        
    def sample(self, batch_size):
        if len(self.replay_memory) > batch_size:
            batch = random.choices(self.replay_memory, k=batch_size)
        else:
            batch = self.replay_memory     
        return batch    

In [6]:
def prioritized_experience_replay(Q_predicted,Q_target,priority_list):
    epsilon = 1e-8
    alpha = 0.5 #(if 0 - random sanmpling)
    
    error  = Q_predicted - Q_target
    p_i = error + epsilon
    prob_i = (error + epsilon)**alpha/sum(np.array(priority_list)**alpha)