In [1]:
import time
from collections import deque

import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributions as dist

import gym

In [2]:
device = 'cuda'

In [3]:
class ReplayBuffer:
    def __init__(self,max_len=1e6):
        self.len = 0
        self.max_len = max_len
        self.data = deque()
        
    def __len__(self):
        return self.len
    
    def append(self,T):
        for i in range(3):
            assert type(T[i]) == np.ndarray

        
        if self.len == self.max_len:
            self.data.popleft()
            self.data.append(T)
        else:
            self.len += 1
            self.data.append(T)
    
    def sample(self, batch_size):
        batch = {}
        
        batch['state'] = np.zeros((batch_size,*self.data[0][0].shape))
        batch['next_state'] = np.zeros((batch_size,*self.data[0][1].shape))
        batch['action'] = np.zeros((batch_size,*self.data[0][2].shape))
        batch['reward'] = np.zeros((batch_size,1))
        batch['done'] = np.zeros((batch_size,1))
        
        idxs = np.random.randint(0,self.len,size = batch_size)
        for b,i in enumerate(idxs):
            batch['state'][b] = self.data[i][0].copy()
            batch['next_state'][b] = self.data[i][1].copy()
            batch['action'][b] = self.data[i][2].copy()
            batch['reward'][b][0] = self.data[i][3]
            batch['done'][b][0] = self.data[i][4]
            
        return batch

In [4]:
buffer = ReplayBuffer(max_len = 5)

In [5]:
for i in range(10):
    T = [np.array([i,i]),np.array([-i,-i]),np.array([2*i,2*i,2*i]),i,i%2]
    buffer.append(T)

In [6]:
buffer.data

deque([[array([5, 5]), array([-5, -5]), array([10, 10, 10]), 5, 1],
       [array([6, 6]), array([-6, -6]), array([12, 12, 12]), 6, 0],
       [array([7, 7]), array([-7, -7]), array([14, 14, 14]), 7, 1],
       [array([8, 8]), array([-8, -8]), array([16, 16, 16]), 8, 0],
       [array([9, 9]), array([-9, -9]), array([18, 18, 18]), 9, 1]])

In [7]:
sample = buffer.sample(3)

In [8]:
sample['state']

array([[6., 6.],
       [7., 7.],
       [5., 5.]])

In [9]:
sample['next_state']

array([[-6., -6.],
       [-7., -7.],
       [-5., -5.]])

In [10]:
sample['action']

array([[12., 12., 12.],
       [14., 14., 14.],
       [10., 10., 10.]])

In [11]:
sample['reward']

array([[6.],
       [7.],
       [5.]])

In [12]:
sample['done']

array([[0.],
       [1.],
       [1.]])