Skip to content

Commit 539651f

Browse files
author
Vincent Moens
committed
[Feature] Allow neg dims during LazyStack
ghstack-source-id: dd2a9bd Pull Request resolved: #1240
1 parent 2ad9f95 commit 539651f

File tree

2 files changed

+20
-3
lines changed

2 files changed

+20
-3
lines changed

tensordict/_lazy.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -250,9 +250,13 @@ def __init__(
250250
td0 = tensordicts[0]
251251
device = td0.device
252252
if stack_dim < 0:
253-
raise RuntimeError(
254-
f"stack_dim must be non negative, got stack_dim={stack_dim}"
255-
)
253+
ndim = td0.ndim
254+
try:
255+
stack_dim = _maybe_correct_neg_dim(stack_dim, ndim=ndim + 1, shape=None)
256+
except Exception:
257+
raise RuntimeError(
258+
f"Couldn't infer stack dim from negative value, got stack_dim={stack_dim}"
259+
)
256260
_batch_size = td0.batch_size
257261
if stack_dim > len(_batch_size):
258262
raise RuntimeError(

test/test_tensordict.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9395,6 +9395,19 @@ def test_lazy_stacked_insert(self, dim, index, device):
93959395
with pytest.raises(ValueError, match="Batch sizes in tensordicts differs"):
93969396
lstd.insert(index, TensorDict({"a": torch.ones(17)}, [17], device=device))
93979397

9398+
def test_neg_dim_lazystack(self):
9399+
td0 = TensorDict(batch_size=(3, 5))
9400+
td1 = TensorDict(batch_size=(4, 5))
9401+
assert lazy_stack([td0, td1], -1).shape == (-1, 5, 2)
9402+
assert lazy_stack([td0, td1], -2).shape == (-1, 2, 5)
9403+
assert lazy_stack([td0, td1], -3).shape == (2, -1, 5)
9404+
with pytest.raises(RuntimeError):
9405+
assert lazy_stack([td0, td1], -4)
9406+
with pytest.raises(RuntimeError):
9407+
assert lazy_stack([td0, td1, TensorDict()])
9408+
with pytest.raises(RuntimeError):
9409+
assert lazy_stack([TensorDict(), td0, td1])
9410+
93989411
@pytest.mark.parametrize(
93999412
"reduction", ["sum", "nansum", "mean", "nanmean", "std", "var", "prod"]
94009413
)

0 commit comments

Comments
 (0)