In [1]:
from timeit import timeit
from collections import deque
import random
import numpy as np


In [2]:
class Memory:
    def __init__(self, maxlen=1000000):
        self.buffer = deque(maxlen=maxlen)
        
    def add(self, data):
        self.buffer.append(data)
       
    def _atleast2d(self, sample): # ensure at least shape (size, 1) instead of (size,)
        sample = np.asarray(sample, dtype=np.float64)
        return np.expand_dims(sample, axis=1) if sample.ndim < 2 else sample
    
    def replay(self, size):
        samples = random.sample(self.buffer, size)
        samples = [self._atleast2d(sample) for sample in zip(*samples)]
        return samples

if 1:
    m = Memory()
    for i in range(10000):
        item = (np.random.random_sample(2), 
                np.random.random_sample((2, 2)), 
                np.random.randint(10))
        m.add(item)
    print(m.replay(3))
    print(timeit(lambda: m.replay(100), number=10000))

[array([[0.37235918, 0.06632025],
       [0.745442  , 0.16088088],
       [0.91000152, 0.61211501]]), array([[[0.06057086, 0.11134378],
        [0.58673288, 0.41892727]],

       [[0.59526653, 0.68742928],
        [0.01005504, 0.17475121]],

       [[0.22998924, 0.8636551 ],
        [0.81916981, 0.63225377]]]), array([[7.],
       [8.],
       [9.]])]
2.0559278


In [3]:
class CycleMemory:
    def __init__(self, shapes, maxlen=1000000):
        self.buffers = []
        self.begin = 0
        self.size = 0
        self.maxlen = maxlen
        for shape in shapes:
            self.buffers.append(np.zeros((maxlen,) + shape, dtype=np.float64))
            
    def add(self, data):
        for i, buffer in enumerate(self.buffers):
            buffer[self.begin] = data[i]
        self.begin = (self.begin + 1) % self.maxlen
        self.size = min(self.maxlen, self.size + 1)
    
    def replay(self, size):
        idx = np.random.choice(np.arange(self.size), size, replace=False)
        samples = [buffer[idx] for buffer in self.buffers]
        return samples
    
if 1:
    m = CycleMemory(((2,), (2, 2), (1,)))
    for i in range(10000):
        item = (np.random.random_sample(2), 
                np.random.random_sample((2, 2)), 
                np.random.randint(10))
        m.add(item)
    print(m.replay(3))
    print(timeit(lambda: m.replay(100), number=10000))

[array([[0.61193401, 0.33160072],
       [0.10319404, 0.12240532],
       [0.87117798, 0.16350619]]), array([[[0.36142822, 0.37810365],
        [0.10004135, 0.51572993]],

       [[0.02454501, 0.96184037],
        [0.61052739, 0.01314555]],

       [[0.64731077, 0.43732689],
        [0.17671534, 0.17246948]]]), array([[6.],
       [4.],
       [8.]])]
2.5074254000000002


In [20]:
class SumTree:
    def __init__(self, capacity=1000000):
        self.capacity = capacity
        self.tree = np.zeros(2 * capacity - 1)
        self.data = np.zeros(capacity, dtype=object)
        self.size = 0
        self.begin = 0
    
    def update(self, i, p):
        change = p - self.tree[i]
        while i != 0:
            self.tree[i] += change
            i = (i - 1) // 2
        self.tree[0] += change
    
    def add(self, data, p):
        self.data[self.begin] = data
        self.update(self.begin + self.capacity - 1, p)
        self.begin = (self.begin + 1) % self.capacity
        
    def get(self, p):
        i = 0
        while 1:
            left = 2 * i + 1
            if left >= len(self.tree):
                break
            if p <= self.tree[left]:
                # one of left children
                i = left
            else:
                # one of right children
                p -= self.tree[left]
                i = left + 1
        return i, self.tree[i], self.data[i - self.capacity + 1]
    
    def total(self):
        return self.tree[0]
    
if 1:
    m = SumTree(10000)
    for i in range(10000):
        item = (np.random.random_sample(),)
        m.add(item, np.random.random_sample())
    p = np.random.random_sample()
    print(m.get(p))

(16387, 0.9572278784560172, (0.7687477743251352,))


In [21]:
m = SumTree(8)
m.add(0, 1)
m.add(0, 2)
m.add(0, 3)
m.add(0, 4)
m.add(0, 5)
m.add(0, 6)
m.add(0, 7)
m.add(0, 8)

In [22]:
m.tree

array([36., 10., 26.,  3.,  7., 11., 15.,  1.,  2.,  3.,  4.,  5.,  6.,
        7.,  8.])

In [58]:
class PrioritizedMemory:
    def __init__(self, capacity=1000000):
        self.tree = SumTree(capacity)
        # proportional prioritization
        '''pi = |δi| + e'''
        self.e = 0.001
        self.a = 0.6
        self.b = 0.4
        self.decay = 0.001
        
    def _priority(self, error):
        p = (np.abs(error) + self.e) ** self.a
        return p
        
    def add(self, data, error):
        p = self._priority(error)
        self.tree.add(data, p)
        
    def update(self, i, error):
        p = self._priority(error)
        self.tree.update(i, p)
        
    def replay(self, size):
        indices = []
        priorities = []
        samples = []
        segment = self.tree.total() / size
        for i in range(size):
            num = np.random.uniform(segment * i, segment * (i + 1))
            index, priority, sample = self.tree.get(num)
            indices.append(index)
            priorities.append(priority)
            samples.append(sample)
            
        '''importance-sampling (IS) weights: wi = (N * P(i))**-β / max(w)'''
        probs = priorities / self.tree.total()
        IS_weights = (size * probs) ** -self.b
        IS_weights = IS_weights / np.amax(IS_weights)
        
        self.b += self.decay * (1 - self.b)
        
        return samples, IS_weights, indices
        
if 1:
    m = PrioritizedMemory()
    for i in range(10000):
        item = (np.random.random(), np.random.random())
        m.add(item, np.random.random())
    print(m.replay(3))

([(0.051340898331771245, 0.08944069082966277), (0.4859057206552321, 0.3589452722354254), (0.7982979104012583, 0.8808265911141573)], array([0.85149094, 1.        , 0.8967107 ]), [1001693, 1003548, 1007212])


In [64]:
print(m.b)
m.replay(3)

0.40359101199100356


([(0.5161801708296251, 0.8131248616494949),
  (0.1988314523075394, 0.7893820313656182),
  (0.16880882565731004, 0.8670205680389228)],
 array([1.        , 0.90036991, 0.95638131]),
 [1000636, 1006625, 1006979])