# Priority Replay

[Prioritized Experience Replay](https://arxiv.org/pdf/1511.05952.pdf) Tom Schaul, John Quan, Ioannis Antonoglou and David Silver


In [5]:
import numpy as np

## ProbabilityBag

We need to prioritize memory replay with the following requirements:
 * Memories are stored with a relative priority
 * There is an upper limit to the priorities stored in memory, and things of lower priority fall off the buttom when the memory is full
 * We can sample the memories, and the probability of any given item being selected is proportional to its priority
 * Things with lower priority still have a chance of being selected
 * Memories will get pulled out of the memory pile, then put back in based on their new updated priorities


The interface will be something like this:

In [None]:
class ProbabilityBag:
    def __init__(self, max_size):
        pass
    
    def pop_batch(self, n):
        """Remove a randomly selected batch of n items from the bag.  The probability
        of an item being selected is proportional to its priority.
        """
        
    def push_batch(self, items):
        """Insert a group of items.  The priority of each item is the first member of the tuple.
        Example:
            item = (priority, state, action, reward)
            items = [item]
            probability_bag.push_batch(items)
        """

## Sampling with probability
Numpy has a function for selecting items with a given probability:

In [3]:
a = [1,2,3,4,5,6,7,8]
weights = [3, 3, 3, 2, 2, 1, 1, 1]
sum_w = sum(weights)
p_a = [w / sum_w for w in weights]

In [4]:
sum_w, p_a

(16, [0.1875, 0.1875, 0.1875, 0.125, 0.125, 0.0625, 0.0625, 0.0625])

In [10]:
np.random.choice(a, 3, p = p_a)

array([5, 2, 5])

Notice that we got 5 twice.  This is what the *replacement* value controls; set it to false to keep the item from going back into the choice pool once selected.

In [15]:
[np.random.choice(a, 3, p=p_a, replace=False) for _ in range(5)]

[array([1, 2, 8]),
 array([2, 3, 4]),
 array([4, 2, 5]),
 array([4, 1, 6]),
 array([2, 1, 4])]

Ok, we never got duplicates.

## Accessing the priority

In [16]:
items = np.array([[1,2,3],[4,5,6],[7,8,9]])

In [20]:
items[:,0]

array([1, 4, 7])