In [96]:
import torch

test_size  = 2
obs_dim = 3
act_dim = 1
obs1_buf = torch.randn([test_size, obs_dim], dtype=torch.float32)
obs2_buf = torch.randn([test_size, obs_dim], dtype=torch.float32)
acts_buf = torch.randn([test_size, act_dim], dtype=torch.float32)
rews_buf = torch.randn([test_size, 1], dtype=torch.float32)
done_buf = torch.randn([test_size, 1], dtype=torch.float32)


In [97]:
class ReplayBuffer:
    """
    A simple FIFO experience replay buffer (modifed from from https://github.com/openai/spinningup/blob/master/spinup/algos/sac/sac.py)
    """

    def __init__(self, obs_dim, act_dim, size):
        self.obs1_buf = torch.zeros([size, obs_dim], dtype=torch.float32)
        self.obs2_buf = torch.zeros([size, obs_dim], dtype=torch.float32)
        self.acts_buf = torch.zeros([size, act_dim], dtype=torch.float32)
        self.rews_buf = torch.zeros([size, 1], dtype=torch.float32)
        self.done_buf = torch.zeros([size, 1], dtype=torch.float32)
        self.ptr, self.size, self.max_size = 0, 0, size

    def store(self, obs, next_obs, act, rew, done):
        self.obs1_buf[self.ptr] = obs
        self.obs2_buf[self.ptr] = next_obs
        self.acts_buf[self.ptr] = act
        self.rews_buf[self.ptr] = rew
        self.done_buf[self.ptr] = done
        self.ptr = (self.ptr + 1) % self.max_size
        self.size = min(self.size + 1, self.max_size)

    def sample_batch(self, batch_size=32):
        idxs = np.random.randint(0, self.size, size=batch_size)
        return (self.obs1_buf[idxs], self.obs2_buf[idxs], self.acts_buf[idxs], self.rews_buf[idxs], self.done_buf[idxs])


def update_mean(data, cur_mean, cur_steps):
    new_steps = data.shape[0]
    return (torch.mean(data, 0) * new_steps + cur_mean * cur_steps) / (cur_steps + new_steps)


def update_var(data, cur_var, cur_steps):
    new_steps = data.shape[0]
    return (torch.var(data, 0) * new_steps + cur_var * cur_steps) / (cur_steps + new_steps)



In [98]:
class ReplayBuffer2:
    """
    A simple FIFO experience replay buffer (modifed from from https://github.com/openai/spinningup/blob/master/spinup/algos/sac/sac.py)
    """

    def __init__(self, obs_dim, act_dim, max_size):
        self.obs1_buf = torch.zeros([max_size, obs_dim], dtype=torch.float32)
        self.obs2_buf = torch.zeros([max_size, obs_dim], dtype=torch.float32)
        self.acts_buf = torch.zeros([max_size, act_dim], dtype=torch.float32)
        self.rews_buf = torch.zeros([max_size, 1], dtype=torch.float32)
        self.done_buf = torch.zeros([max_size, 1], dtype=torch.float32)
        self.ptr, self.size, self.max_size = 0, 0, max_size

    def store(self, obs, next_obs, act, rew, done):
        
        insert_size = obs.shape[0]
        space_left = self.max_size - (self.ptr + insert_size)
        
        if(space_left >= 0):
            self.obs1_buf[self.ptr:self.ptr + insert_size,:] = obs
            self.obs2_buf[self.ptr:self.ptr + insert_size,:] = next_obs
            self.acts_buf[self.ptr:self.ptr + insert_size,:] = act
            self.rews_buf[self.ptr:self.ptr + insert_size,:] = rew
            self.done_buf[self.ptr:self.ptr + insert_size,:] = done
            
            self.ptr = self.ptr + insert_size
            self.size = max(self.size, self.ptr)
            
        else:
            
            space_at_end = self.max_size - self.ptr
            space_at_beg = (self.ptr + insert_size) - self.max_size
            
            self.obs1_buf[self.ptr:,:] = obs[:space_at_end,:]
            self.obs2_buf[self.ptr:,:] = next_obs[:space_at_end,:]
            self.acts_buf[self.ptr:,:] = act[:space_at_end,:]
            self.rews_buf[self.ptr:,:] = rew[:space_at_end,:]
            self.done_buf[self.ptr:,:] = done[:space_at_end,:]
            
            self.obs1_buf[:space_at_beg,:] = obs[space_at_end:,:]
            self.obs2_buf[:space_at_beg,:] = next_obs[space_at_end:,:]
            self.acts_buf[:space_at_beg,:] = act[space_at_end:,:]
            self.rews_buf[:space_at_beg,:] = rew[space_at_end:,:]
            self.done_buf[:space_at_beg,:] = done[space_at_end:,:]
            
            self.ptr = (self.ptr + insert_size) % self.max_size
            self.size = self.max_size 
            

    def sample_batch(self, batch_size=32):
        idxs = np.random.randint(0, self.size, size=batch_size)
        return (self.obs1_buf[idxs], self.obs2_buf[idxs], self.acts_buf[idxs], self.rews_buf[idxs], self.done_buf[idxs])


def update_mean(data, cur_mean, cur_steps):
    new_steps = data.shape[0]
    return (torch.mean(data, 0) * new_steps + cur_mean * cur_steps) / (cur_steps + new_steps)


def update_var(data, cur_var, cur_steps):
    new_steps = data.shape[0]
    return (torch.var(data, 0) * new_steps + cur_var * cur_steps) / (cur_steps + new_steps)

replay_buf = ReplayBuffer(obs_dim, act_dim, int(1e6))

In [99]:
replay_buf = ReplayBuffer(obs_dim, act_dim, 10)
replay_buf2 = ReplayBuffer2(obs_dim, act_dim, 10)

In [94]:
%%timeit

for i in range(1):
    for obs1, obs2, acts, rews, done in zip(obs1_buf, obs2_buf, acts_buf, rews_buf, done_buf):
        replay_buf.store(obs1, obs2, acts, rews, done)

49.3 ms ± 57.3 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [95]:
%%timeit
for i in range(1):
    replay_buf2.store(obs1_buf, obs2_buf, acts_buf, rews_buf, done_buf)

31.3 µs ± 340 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)


In [28]:
%%timeit


replay_buf.obs1_buf[:test_size] = obs1_buf
replay_buf.obs2_buf[:test_size] = obs2_buf
replay_buf.acts_buf[:test_size] = acts_buf
replay_buf.rews_buf[:test_size] = rews_buf
replay_buf.done_buf[:test_size] = done_buf

16.3 µs ± 119 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)


In [140]:
obs1_buf = torch.randn([test_size, obs_dim], dtype=torch.float32)
obs2_buf = torch.randn([test_size, obs_dim], dtype=torch.float32)
acts_buf = torch.randn([test_size, act_dim], dtype=torch.float32)
rews_buf = torch.randn([test_size, 1], dtype=torch.float32)
done_buf = torch.randn([test_size, 1], dtype=torch.float32)


replay_buf2.store(obs1_buf, obs2_buf, acts_buf, rews_buf, done_buf)

In [141]:
replay_buf2.obs1_buf

tensor([[-1.1975,  1.6623,  0.1507],
        [ 1.5804, -1.8119,  1.0720],
        [ 0.9737,  1.9037, -0.2180],
        [ 0.3617,  1.4736, -0.0153],
        [-0.5894, -1.0121, -1.8025],
        [ 0.6768,  1.0882, -1.2854],
        [-0.3231, -0.2042,  0.3877],
        [ 0.4944, -2.0150, -0.2871],
        [-0.0999, -0.4783, -0.0402],
        [ 0.2463, -1.0972, -0.6492]])

In [142]:
replay_buf2.acts_buf

tensor([[ 0.5947],
        [-1.5061],
        [ 0.1282],
        [ 0.6314],
        [ 0.7548],
        [ 0.9051],
        [-0.6399],
        [-0.1565],
        [ 1.6891],
        [-0.0991]])

In [143]:
replay_buf2.size

10