diff --git a/torchrl/data/replay_buffers/rb_prototype.py b/torchrl/data/replay_buffers/rb_prototype.py index ab37a4a4b1a..860f5ffeb3e 100644 --- a/torchrl/data/replay_buffers/rb_prototype.py +++ b/torchrl/data/replay_buffers/rb_prototype.py @@ -226,7 +226,7 @@ def extend(self, tensordicts: Union[List, TensorDictBase]) -> torch.Tensor: else: stacked_td = tensordicts - index = super().extend(tensordicts) + index = super().extend(stacked_td) stacked_td.set( "index", torch.tensor(index, dtype=torch.int, device=stacked_td.device),