Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 17 additions & 10 deletions test/test_rb.py
Original file line number Diff line number Diff line change
Expand Up @@ -2555,12 +2555,19 @@ def test_rb_indexing(self, explicit):

def _rbtype(datatype):
if datatype in ("pytree", "tensorclass"):
return [ReplayBuffer, PrioritizedReplayBuffer]
return [
(ReplayBuffer, RandomSampler),
(PrioritizedReplayBuffer, RandomSampler),
(ReplayBuffer, SamplerWithoutReplacement),
(PrioritizedReplayBuffer, SamplerWithoutReplacement),
]
return [
ReplayBuffer,
PrioritizedReplayBuffer,
TensorDictReplayBuffer,
TensorDictPrioritizedReplayBuffer,
(ReplayBuffer, RandomSampler),
(ReplayBuffer, SamplerWithoutReplacement),
(PrioritizedReplayBuffer, None),
(TensorDictReplayBuffer, RandomSampler),
(TensorDictReplayBuffer, SamplerWithoutReplacement),
(TensorDictPrioritizedReplayBuffer, None),
]


Expand Down Expand Up @@ -2598,19 +2605,19 @@ def _make_data(self, datatype, datadim):
batch_size=shape,
)

datatype_rb_pairs = [
[datatype, rbtype]
datatype_rb_tuples = [
[datatype, *rbtype]
for datatype in ["pytree", "tensordict", "tensorclass"]
for rbtype in _rbtype(datatype)
]

@pytest.mark.parametrize("datatype,rbtype", datatype_rb_pairs)
@pytest.mark.parametrize("datatype,rbtype,sampler_cls", datatype_rb_tuples)
@pytest.mark.parametrize("datadim", [1, 2])
@pytest.mark.parametrize("storage_cls", [LazyMemmapStorage, LazyTensorStorage])
def test_rb_multidim(self, datatype, datadim, rbtype, storage_cls):
def test_rb_multidim(self, datatype, datadim, rbtype, storage_cls, sampler_cls):
data = self._make_data(datatype, datadim)
if rbtype not in (PrioritizedReplayBuffer, TensorDictPrioritizedReplayBuffer):
rbtype = functools.partial(rbtype, sampler=RandomSampler())
rbtype = functools.partial(rbtype, sampler=sampler_cls())
else:
rbtype = functools.partial(rbtype, alpha=0.9, beta=1.1)

Expand Down
32 changes: 22 additions & 10 deletions torchrl/data/replay_buffers/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,18 +177,20 @@ def _get_sample_list(self, storage: Storage, len_storage: int):
device = self._sample_list.device
else:
device = storage.device if hasattr(storage, "device") else None

if self.shuffle:
self._sample_list = torch.randperm(len_storage, device=device)
_sample_list = torch.randperm(len_storage, device=device)
else:
self._sample_list = torch.arange(len_storage, device=device)
_sample_list = torch.arange(len_storage, device=device)
self._sample_list = _sample_list

def _single_sample(self, len_storage, batch_size):
index = self._sample_list[:batch_size]
self._sample_list = self._sample_list[batch_size:]

# check if we have enough elements for one more batch, assuming same batch size
# will be used each time sample is called
if self._sample_list.numel() == 0 or (
if self._sample_list.shape[0] == 0 or (
self.drop_last and len(self._sample_list) < batch_size
):
self.ran_out = True
Expand All @@ -201,7 +203,6 @@ def _storage_len(self, storage):
return len(storage)

def sample(self, storage: Storage, batch_size: int) -> Tuple[Any, dict]:
storage = storage.flatten()
len_storage = self._storage_len(storage)
if len_storage == 0:
raise RuntimeError(_EMPTY_STORAGE_ERROR)
Expand All @@ -217,6 +218,8 @@ def sample(self, storage: Storage, batch_size: int) -> Tuple[Any, dict]:
)
self.len_storage = len_storage
index = self._single_sample(len_storage, batch_size)
if storage.ndim > 1:
index = torch.unravel_index(index, storage.shape)
# we 'always' return the indices. The 'drop_last' just instructs the
# sampler to turn to 'ran_out = True` whenever the next sample
# will be too short. This will be read by the replay buffer
Expand Down Expand Up @@ -834,10 +837,7 @@ def _find_start_stop_traj(*, trajectory=None, end=None, at_capacity: bool):
)
# Using transpose ensures the start and stop are sorted the same way
stop_idx = end.transpose(0, -1).nonzero()
# beginnings = torch.cat([end[-1:], end[:-1]], 0)
# start_idx = beginnings.transpose(0, -1).nonzero()
# start_idx = torch.cat([start_idx[:, -1:], start_idx[:, :-1]], -1)
stop_idx = torch.cat([stop_idx[:, -1:], stop_idx[:, :-1]], -1)
stop_idx[:, [0, -1]] = stop_idx[:, [-1, 0]].clone()
# First build the start indices as the stop + 1, we'll shift it later
start_idx = stop_idx.clone()
start_idx[:, 0] += 1
Expand Down Expand Up @@ -991,6 +991,8 @@ def _sample_slices(
storage_length: int,
traj_idx: torch.Tensor | None = None,
) -> Tuple[Tuple[torch.Tensor, ...], Dict[str, Any]]:
# start_idx and stop_idx are 2d tensors organized like a non-zero

def get_traj_idx(lengths=lengths):
return torch.randint(lengths.shape[0], (num_slices,), device=lengths.device)

Expand All @@ -999,7 +1001,7 @@ def get_traj_idx(lengths=lengths):
idx = lengths == seq_length
if not idx.any():
raise RuntimeError(
"Did not find a single trajectory with sufficient length."
f"Did not find a single trajectory with sufficient length (length range: {lengths.min()} - {lengths.max()} / required={seq_length}))."
)
if (
isinstance(seq_length, torch.Tensor)
Expand Down Expand Up @@ -1307,14 +1309,24 @@ def sample(
seq_length, num_slices = self._adjusted_batch_size(batch_size)
indices, _ = SamplerWithoutReplacement.sample(self, storage, num_slices)
storage_length = storage.shape[0]

# traj_idx will either be a single tensor or a tuple that can be reorganized
# like a non-zero through stacking.
def tuple_to_tensor(traj_idx, lengths=lengths):
if isinstance(traj_idx, tuple):
traj_idx = torch.arange(len(storage), device=lengths.device).view(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I compared this with np.ravel_multi_index using

torch.as_tensor(np.ravel_multi_index(tuple(idx.numpy() for idx in unravelled), shape))

Rumtimes are roughly equivalent, with a slight advantage for the numpy version

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is way slower, about 2.5x the numpy solution

def ravel_multi_index(x, shape):
    out = 0
    shape_modif = np.cumprod(list(reversed((*shape, 1))))
    for i, idx in enumerate(reversed(x)):
        out += idx * shape_modif[i]
    return out

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A more vectorized version still underperforms arange and numpy

def ravel_multi_index(x, shape):
    out = 0
    shape_modif = torch.flipud(
        torch.cumprod(torch.tensor(list(reversed((*shape[1:], 1)))), 0)
    ).unsqueeze(0)
    return (torch.stack(x, -1) * shape_modif).sum(-1)

storage.shape
)[traj_idx]
return traj_idx

idx, info = self._sample_slices(
lengths,
start_idx,
stop_idx,
seq_length,
num_slices,
storage_length,
traj_idx=indices,
traj_idx=tuple_to_tensor(indices),
)
return idx, info

Expand Down
8 changes: 7 additions & 1 deletion torchrl/data/replay_buffers/writers.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,13 @@ def _replicate_index(self, index):
if self._storage.ndim == 1:
return index
mesh = torch.stack(
torch.meshgrid(*(torch.arange(dim) for dim in self._storage.shape[1:])), -1
torch.meshgrid(
*(
torch.arange(dim, device=index.device)
for dim in self._storage.shape[1:]
)
),
-1,
).flatten(0, -2)
if _is_int(index):
index0 = torch.as_tensor(int(index)).expand(mesh.shape[0], 1)
Expand Down