In [None]:
# Remember to install tianshou first
!pip install tianshou==0.4.8
!pip install gym

# Overview
Replay Buffer is a very common module in DRL implementations. In Tianshou, you can consider Buffer module as  as a specialized form of Batch, which helps you track all data trajectories and provide utilities such as sampling method besides the basic storage.

There are many kinds of Buffer modules in Tianshou, two most basic ones are ReplayBuffer and VectorReplayBuffer. The later one is specially designed for parallelized environments (will introduce in tutorial L3).

# Usages

## Basic usages as a batch
Usually a buffer stores all the data in a batch with circular-queue style.

In [None]:
from tianshou.data import Batch, ReplayBuffer
# a buffer is initialised with its maxsize set to 10 (older data will be discarded if more data flow in).
print("========================================")
buf = ReplayBuffer(size=10)
print(buf)
print("maxsize: {}, data length: {}".format(buf.maxsize, len(buf)))

# add 3 steps of data into ReplayBuffer sequentially
print("========================================")
for i in range(3):
  buf.add(Batch(obs=i, act=i, rew=i, done=0, obs_next=i + 1, info={}))
print(buf)
print("maxsize: {}, data length: {}".format(buf.maxsize, len(buf)))

# add another 10 steps of data into ReplayBuffer sequentially
print("========================================")
for i in range(3, 13):
  buf.add(Batch(obs=i, act=i, rew=i, done=0, obs_next=i + 1, info={}))
print(buf)
print("maxsize: {}, data length: {}".format(buf.maxsize, len(buf)))

ReplayBuffer()
maxsize: 10, data length: 0
ReplayBuffer(
    info: Batch(),
    act: array([0, 1, 2, 0, 0, 0, 0, 0, 0, 0]),
    obs: array([0, 1, 2, 0, 0, 0, 0, 0, 0, 0]),
    done: array([False, False, False, False, False, False, False, False, False,
                 False]),
    rew: array([0., 1., 2., 0., 0., 0., 0., 0., 0., 0.]),
    obs_next: array([1, 2, 3, 0, 0, 0, 0, 0, 0, 0]),
)
maxsize: 10, data length: 3
ReplayBuffer(
    info: Batch(),
    act: array([10, 11, 12,  3,  4,  5,  6,  7,  8,  9]),
    obs: array([10, 11, 12,  3,  4,  5,  6,  7,  8,  9]),
    done: array([False, False, False, False, False, False, False, False, False,
                 False]),
    rew: array([10., 11., 12.,  3.,  4.,  5.,  6.,  7.,  8.,  9.]),
    obs_next: array([11, 12, 13,  4,  5,  6,  7,  8,  9, 10]),
)
maxsize: 10, data length: 10


Just like Batch, ReplayBuffer supports concatenation, splitting, advanced slicing and indexing, etc.

In [None]:
print(buf[-1])
print(buf[-3:])
# Try more methods you find useful in Batch yourself.

Batch(
    obs: array(9),
    act: array(9),
    rew: array(9.),
    done: array(False),
    obs_next: array(10),
    info: Batch(),
    policy: Batch(),
)
Batch(
    obs: array([7, 8, 9]),
    act: array([7, 8, 9]),
    rew: array([7., 8., 9.]),
    done: array([False, False, False]),
    obs_next: array([ 8,  9, 10]),
    info: Batch(),
    policy: Batch(),
)


ReplayBuffer can also be saved into local disk, still keeping track of the trajectories. This is extremely helpful in offline DRL settings.

In [None]:
import pickle
_buf = pickle.loads(pickle.dumps(buf))

## Understanding reserved keys for buffer
As I have explained, ReplayBuffer is specially designed to utilize the implementations of DRL algorithms. So, for convenience, we reserve certain seven reserved keys in Batch.

*   `obs`
*   `act`
*   `rew`
*   `done`
*   `obs_next`
*   `info`
*   `policy`

The meaning of these seven reserved keys are consistent with the meaning in [OPENAI Gym](https://gym.openai.com/). We would recommend you simply use these seven keys when adding batched data into ReplayBuffer, because
some of them are tracked in ReplayBuffer (e.g. "done" value is tracked to help us determine a trajectory's start index and end index, together with its total reward and episode length.)

```
buf.add(Batch(......, extro_info=0)) # This is okay but not recommended.
buf.add(Batch(......, info={"extro_info":0})) # Recommended.
```


## Data sampling
We keep a replay buffer in DRL for one purpose:"sample data from it for training". `ReplayBuffer.sample()` and `ReplayBuffer.split(..., shuffle=True)` can both fullfill this need.

In [None]:
buf.sample(5)

(Batch(
     obs: array([10, 11,  4,  3,  8]),
     act: array([10, 11,  4,  3,  8]),
     rew: array([10., 11.,  4.,  3.,  8.]),
     done: array([False, False, False, False, False]),
     obs_next: array([11, 12,  5,  4,  9]),
     info: Batch(),
     policy: Batch(),
 ), array([0, 1, 4, 3, 8]))

## Trajectory tracking
Compared to Batch, a unique feature of ReplayBuffer is that it can help you track the environment trajectories.

First, let us simulate a situation, where we add three trajectories into the buffer. The last trajectory is still not finished yet.

In [None]:
from numpy import False_
buf = ReplayBuffer(size=10)
# Add the first trajectory (length is 3) into ReplayBuffer
print("========================================")
for i in range(3):
  result = buf.add(Batch(obs=i, act=i, rew=i, done=True if i==2 else False, obs_next=i + 1, info={}))
  print(result)
print(buf)
print("maxsize: {}, data length: {}".format(buf.maxsize, len(buf)))
# Add the second trajectory (length is 5) into ReplayBuffer
print("========================================")
for i in range(3, 8):
  result = buf.add(Batch(obs=i, act=i, rew=i, done=True if i==7 else False, obs_next=i + 1, info={}))
  print(result)
print(buf)
print("maxsize: {}, data length: {}".format(buf.maxsize, len(buf)))
# Add the third trajectory (length is 5, still not finished) into ReplayBuffer
print("========================================")
for i in range(8, 13):
  result = buf.add(Batch(obs=i, act=i, rew=i, done=False, obs_next=i + 1, info={}))
  print(result)
print(buf)
print("maxsize: {}, data length: {}".format(buf.maxsize, len(buf)))

(array([0]), array([0.]), array([0]), array([0]))
(array([1]), array([0.]), array([0]), array([0]))
(array([2]), array([3.]), array([3]), array([0]))
ReplayBuffer(
    info: Batch(),
    act: array([0, 1, 2, 0, 0, 0, 0, 0, 0, 0]),
    obs: array([0, 1, 2, 0, 0, 0, 0, 0, 0, 0]),
    done: array([False, False,  True, False, False, False, False, False, False,
                 False]),
    rew: array([0., 1., 2., 0., 0., 0., 0., 0., 0., 0.]),
    obs_next: array([1, 2, 3, 0, 0, 0, 0, 0, 0, 0]),
)
maxsize: 10, data length: 3
(array([3]), array([0.]), array([0]), array([3]))
(array([4]), array([0.]), array([0]), array([3]))
(array([5]), array([0.]), array([0]), array([3]))
(array([6]), array([0.]), array([0]), array([3]))
(array([7]), array([25.]), array([5]), array([3]))
ReplayBuffer(
    info: Batch(),
    act: array([0, 1, 2, 3, 4, 5, 6, 7, 0, 0]),
    obs: array([0, 1, 2, 3, 4, 5, 6, 7, 0, 0]),
    done: array([False, False,  True, False, False, False, False,  True, False,
              

### episode length and rewards tracking
Notice that `ReplayBuffer.add()` returns a tuple of 4 numbers every time it returns, meaning `(current_index, episode_reward, episode_length, episode_start_index)`. `episode_reward` and `episode_length` are valid only when a trajectory is finished. This might save developers some trouble.



### Episode index management
In the ReplayBuffer above, we can get access to any data step by indexing.


In [None]:
print(buf)
data = buf[6]
print(data)

ReplayBuffer(
    info: Batch(),
    act: array([10, 11, 12,  3,  4,  5,  6,  7,  8,  9]),
    obs: array([10, 11, 12,  3,  4,  5,  6,  7,  8,  9]),
    done: array([False, False, False, False, False, False, False,  True, False,
                 False]),
    rew: array([10., 11., 12.,  3.,  4.,  5.,  6.,  7.,  8.,  9.]),
    obs_next: array([11, 12, 13,  4,  5,  6,  7,  8,  9, 10]),
)
Batch(
    obs: array(6),
    act: array(6),
    rew: array(6.),
    done: array(False),
    obs_next: array(7),
    info: Batch(),
    policy: Batch(),
)


Now we know that step "6" is not the start of an episode (it should be step 4, 4-7 is the second trajectory we add into the ReplayBuffer), we wonder what is the earliest index of the that episode.

This may seem easy but actually it is not. We cannot simply look at the "done" flag now, because we can see that since the third-added trajectory is not finished yet, step "4" is surrounded by flag "False". There are many things to consider. Things could get more nasty if you are using more advanced ReplayBuffer like VectorReplayBuffer, because now the data is not stored in a simple circular-queue.

Luckily, all ReplayBuffer instances help you identify step indexes through a unified API.

In [None]:
# Search for the previous index of index "6"
now_index = 6
while True:
  prev_index = buf.prev(now_index)
  print(prev_index)
  if prev_index == now_index:
    break
  else:
    now_index = prev_index

5
4
3
3


Using `ReplayBuffer.prev()`, we know that the earliest step of that episode is step "3". Similarly, `ReplayBuffer.next()` helps us indentify the last index of an episode regardless of which kind of ReplayBuffer we are using.

In [None]:
# next step of indexes [4,5,6,7,8,9] are:
print(buf.next([4,5,6,7,8,9]))

[5 6 7 7 9 0]


We can also search for the indexes which are labeled "done: False", but are the last step in a trajectory.

In [None]:
print(buf.unfinished_index())

[2]


Aforementioned APIs will be helpful when we calculate quantities like GAE and n-step-returns in DRL algorithms ([Example usage in Tianshou](https://github.com/thu-ml/tianshou/blob/6fc68578127387522424460790cbcb32a2bd43c4/tianshou/policy/base.py#L384)). The unified APIs ensure a modular design and a flexible interface.

# Further Reading
## Other Buffer Module

*   PrioritizedReplayBuffer, which helps you implement [prioritized experience replay](https://arxiv.org/abs/1511.05952)
*   CachedReplayBuffer, one main buffer with several cached buffers (higher sample efficiency in some scenarios)
*   ReplayBufferManager, A base class that can be inherited (may help you manage multiple buffers).

Check the documentation and the source code for more details.

## Support for steps stacking to use RNN in DRL.
There is an option called `stack_num` (default to 1) when initialising the ReplayBuffer, which may help you use RNN in your algorithm. Check the documentation for details.