# Archive buffer

Define the spaces and the cell factory.

In [1]:
from gym import spaces
from go_explore.cells import CellIsObs

obs_space = spaces.Box(0, 10, (2,))
observation_space = spaces.Dict({"observation": obs_space, "goal": obs_space})
action_space = spaces.Box(0, 0, (1,))

cell_factory = CellIsObs(obs_space)


Define the buffer.

In [2]:
from go_explore.archive import ArchiveBuffer

archive = ArchiveBuffer(
    buffer_size=100,
    observation_space=observation_space,
    action_space=action_space,
    cell_factory=cell_factory,
    n_envs=2,
)


Feed the buffer.

In [3]:
import numpy as np

trajectories = np.array(
    [
        [[0, 0], [0, 1], [0, 2], [0, 3], [0, 4], [0, 5], [0, 6]],
        [[0, 0], [1, 0], [1, 1], [1, 2], [1, 3], [0, 3], [0, 4]],
    ]
)

for i in range(6):
    archive.add(
        obs={"observation": trajectories[:, i], "goal": np.array([[0], [0]])},
        next_obs={"observation": trajectories[:, i + 1], "goal": np.array([[0], [0]])},
        action=np.array([[0], [0]]),
        reward=np.array([0, 0]),
        done=np.ones(2) * (i == 5),
        infos=[{}, {}],
    )


Try sampling trajectory method.

In [4]:
archive.sample_trajectory()

array([[0., 1.],
       [0., 2.],
       [0., 3.],
       [0., 4.],
       [0., 5.],
       [0., 6.]], dtype=float32)

Here is the set of possible trajectories.

In [5]:
possible_trajectories = [
    [[0, 1]],
    [[0, 1], [0, 2]],
    [[0, 1], [0, 2], [0, 3]],
    [[0, 1], [0, 2], [0, 3], [0, 4]],
    [[0, 1], [0, 2], [0, 3], [0, 4], [0, 5]],
    [[0, 1], [0, 2], [0, 3], [0, 4], [0, 5], [0, 6]],
    [[1, 0]],
    [[1, 0], [1, 1]],
    [[1, 0], [1, 1], [1, 2]],
    [[1, 0], [1, 1], [1, 2], [1, 3]],
    # [[1, 0], [1, 1], [1, 2], [1, 3], [0, 3]],
    # [[1, 0], [1, 1], [1, 2], [1, 3], [0, 3], [0, 4]],
    # [[1, 0], [1, 1], [1, 2], [1, 3], [0, 3], [0, 4], [0, 5]],
]


Check that all possible trajectories are sampled.
Also check that all sampled trajectories are possible.

In [6]:
sampled_trajectories = [archive.sample_trajectory().astype(int).tolist() for _ in range(30)]  # list convinient to compare

assert np.all([trajectory in possible_trajectories for trajectory in sampled_trajectories])
assert np.all([trajectory in sampled_trajectories for trajectory in possible_trajectories])
