# Replay Buffer
## Usages
### Basic usages as a batch
Usually a buffer stores all the data in a batch with circular-queue style.

In [5]:
from tianshou.data import Batch, ReplayBuffer
# a buffer is initialised with its maxsize set to 10 (older data will be discarded if more data flow in)
print("==========================================")
buf = ReplayBuffer(size=10)
print(buf)
print("maxsize: {}, data length: {}".format(buf.maxsize, len(buf)))

# add 3 steps of data into ReplayBuffer sequentially
print("==========================================")
for i in range(3):
    buf.add(Batch(obs=i, act=i, rew=i, done=0, terminated=0, truncated=0, obs_next=i + 1, info={}))
print(buf)
print("maxsize: {}, data length: {}".format(buf.maxsize, len(buf)))

# add another 10 steps of data into ReplayBuffer sequentially
print("==========================================")
for i in range(3, 13):
    buf.add(Batch(obs=i, act=i, rew=i, done=0, terminated=0, truncated=0, obs_next=i + 1, info={}))
print(buf)
print("maxsize: {}, data length: {}".format(buf.maxsize, len(buf)))    

ReplayBuffer()
maxsize: 10, data length: 0
ReplayBuffer(
    obs_next: array([1, 2, 3, 0, 0, 0, 0, 0, 0, 0]),
    obs: array([0, 1, 2, 0, 0, 0, 0, 0, 0, 0]),
    terminated: array([False, False, False, False, False, False, False, False, False,
                       False]),
    info: Batch(),
    rew: array([0., 1., 2., 0., 0., 0., 0., 0., 0., 0.]),
    act: array([0, 1, 2, 0, 0, 0, 0, 0, 0, 0]),
    truncated: array([False, False, False, False, False, False, False, False, False,
                      False]),
    done: array([False, False, False, False, False, False, False, False, False,
                 False]),
)
maxsize: 10, data length: 3
ReplayBuffer(
    obs_next: array([11, 12, 13,  4,  5,  6,  7,  8,  9, 10]),
    obs: array([10, 11, 12,  3,  4,  5,  6,  7,  8,  9]),
    terminated: array([False, False, False, False, False, False, False, False, False,
                       False]),
    info: Batch(),
    rew: array([10., 11., 12.,  3.,  4.,  5.,  6.,  7.,  8.,  9.]),
    act

Just like Batch, ReplayBuffer supports concatenation, splitting, advanced slicing and indexing, etc.

In [6]:
print(buf[-1])
print(buf[-3:])

Batch(
    obs: array(9),
    act: array(9),
    rew: array(9.),
    terminated: array(False),
    truncated: array(False),
    done: array(False),
    obs_next: array(10),
    info: Batch(),
    policy: Batch(),
)
Batch(
    obs: array([7, 8, 9]),
    act: array([7, 8, 9]),
    rew: array([7., 8., 9.]),
    terminated: array([False, False, False]),
    truncated: array([False, False, False]),
    done: array([False, False, False]),
    obs_next: array([ 8,  9, 10]),
    info: Batch(),
    policy: Batch(),
)


ReplayBuffer can also be saved into local disk, still keeping track of the trajectories. This is extremely helpful in offline DRL settings.

In [7]:
import pickle
_buf = pickle.loads(pickle.dumps(buf))

### Data sampling
We keep a replay buffer in DRL for one purpose: "sample data from it for training". `ReplayBuffer.sample()` and `ReplayBuffer.split(..., shuffle=True)` can both fulfill this need.

In [8]:
buf.sample(5)

(Batch(
     obs: array([5, 8, 3, 3, 3]),
     act: array([5, 8, 3, 3, 3]),
     rew: array([5., 8., 3., 3., 3.]),
     terminated: array([False, False, False, False, False]),
     truncated: array([False, False, False, False, False]),
     done: array([False, False, False, False, False]),
     obs_next: array([6, 9, 4, 4, 4]),
     info: Batch(),
     policy: Batch(),
 ),
 array([5, 8, 3, 3, 3]))

### Trajectory tracking
Compared to a Batch, a unique feature of ReplayBuffer is that it can help you track environment trajectories.

First, let us simulate a situation, where we add three trajectories into the buffer. The last trajectory is still not finished yet.

In [14]:
from numpy import False_
buf  = ReplayBuffer(size=10)
# add the first trajectory (length is 3) into ReplayBuffer
print("=========================================")
for i in range(3):
    result = buf.add(Batch(obs=i, act=i, rew=i, done=True if i==2 else False, terminated=True if i==2 else False, truncated=True if i==2 else False, obs_next=i + 1, info={}))
    print(result)
print(buf)
print("maxsize: {}, data length: {}".format(buf.maxsize, len(buf)))

# add the second trajectory (length is 5) into ReplayBuffer
print("=========================================")
for i in range(3, 8):
    result = buf.add(Batch(obs=i, act=i, rew=i, done=True if i==7 else False, terminated=True if i==7 else False, truncated=True if i==7 else False, obs_next=i + 1, info={}))
    print(result)
print(buf)
print("maxsize: {}, data length: {}".format(buf.maxsize, len(buf)))

# add the third trajectory (length is 5, still not finished) into ReplayBuffer
print("=========================================")
for i in range(8, 13):
    result = buf.add(Batch(obs=i, act=i, rew=i, done=False, terminated=False, truncated=False, obs_next=i + 1, info={}))
    print(result)
print(buf)
print("maxsize: {}, data length: {}".format(buf.maxsize, len(buf)))

(array([0]), array([0.]), array([0]), array([0]))
(array([1]), array([0.]), array([0]), array([0]))
(array([2]), array([3.]), array([3]), array([0]))
ReplayBuffer(
    obs_next: array([1, 2, 3, 0, 0, 0, 0, 0, 0, 0]),
    obs: array([0, 1, 2, 0, 0, 0, 0, 0, 0, 0]),
    terminated: array([False, False,  True, False, False, False, False, False, False,
                       False]),
    info: Batch(),
    rew: array([0., 1., 2., 0., 0., 0., 0., 0., 0., 0.]),
    act: array([0, 1, 2, 0, 0, 0, 0, 0, 0, 0]),
    truncated: array([False, False,  True, False, False, False, False, False, False,
                      False]),
    done: array([False, False,  True, False, False, False, False, False, False,
                 False]),
)
maxsize: 10, data length: 3
(array([3]), array([0.]), array([0]), array([3]))
(array([4]), array([0.]), array([0]), array([3]))
(array([5]), array([0.]), array([0]), array([3]))
(array([6]), array([0.]), array([0]), array([3]))
(array([7]), array([25.]), array([5]), a

#### episode length and rewards tracking
Notice that `ReplayBuffer.add()` returns a tuple of 4 numbers every time it returns, meaning `(current_index, episode_reward, episode_length, episode_start_index)`. `episode_reward` and `episode_length` are valid only when a trajectory is finished.

#### Episode index management
In the ReplayBuffer above, we can get access to any data step by indexing.

In [15]:
print(buf)
data = buf[6]
print(data)

ReplayBuffer(
    obs_next: array([11, 12, 13,  4,  5,  6,  7,  8,  9, 10]),
    obs: array([10, 11, 12,  3,  4,  5,  6,  7,  8,  9]),
    terminated: array([False, False, False, False, False, False, False,  True, False,
                       False]),
    info: Batch(),
    rew: array([10., 11., 12.,  3.,  4.,  5.,  6.,  7.,  8.,  9.]),
    act: array([10, 11, 12,  3,  4,  5,  6,  7,  8,  9]),
    truncated: array([False, False, False, False, False, False, False,  True, False,
                      False]),
    done: array([False, False, False, False, False, False, False,  True, False,
                 False]),
)
Batch(
    obs: array(6),
    act: array(6),
    rew: array(6.),
    terminated: array(False),
    truncated: array(False),
    done: array(False),
    obs_next: array(7),
    info: Batch(),
    policy: Batch(),
)


Now we know that step "6" is not the start of an episode (it should be step4, 4-7 is the second trajectory we add into the ReplayBuffer), we wonder what is the earliest index of that episode.

This may seem easy but actually it is not. We cannot simply look at the "done" flag now, because we can see that since the third-added trajectory is not finished yet, step "4" is surrounded by flag "False". there are many things to consider. things could get more nasty if you are using more advanced ReplayBuffer liek VectorReplayBuffer, because now data is not stored in a simple circular-queue.

Luckily, all ReplayBuffer instances help you identify step indexes through a unified API.

In [16]:
# search for the previous index of index "6"
now_index = 6
while True:
    prev_index = buf.prev(now_index)
    print(prev_index)
    if prev_index == now_index:
        break
    else: now_index = prev_index

5
4
3
3


Using `ReplayBuffer.prev()`, we know that the earliest step of that episode is step "3". Similarly, `ReplayBuffer.next()` helps us identify the last index of an episode regardless of which kind of ReplayBuffer we are using.

In [17]:
# next step of indexes [4, 5, 6, 7, 8, 9] are:
print(buf.next([4, 5, 6, 7, 8, 9]))

[5 6 7 7 9 0]


We can also search for the indexes which are labeled "done: False", but are the lest step in a trajectory.

In [18]:
print(buf.unfinished_index())

[2]
