We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 2ad9f95 commit 539651fCopy full SHA for 539651f
tensordict/_lazy.py
@@ -250,9 +250,13 @@ def __init__(
250
td0 = tensordicts[0]
251
device = td0.device
252
if stack_dim < 0:
253
- raise RuntimeError(
254
- f"stack_dim must be non negative, got stack_dim={stack_dim}"
255
- )
+ ndim = td0.ndim
+ try:
+ stack_dim = _maybe_correct_neg_dim(stack_dim, ndim=ndim + 1, shape=None)
256
+ except Exception:
257
+ raise RuntimeError(
258
+ f"Couldn't infer stack dim from negative value, got stack_dim={stack_dim}"
259
+ )
260
_batch_size = td0.batch_size
261
if stack_dim > len(_batch_size):
262
raise RuntimeError(
test/test_tensordict.py
@@ -9395,6 +9395,19 @@ def test_lazy_stacked_insert(self, dim, index, device):
9395
with pytest.raises(ValueError, match="Batch sizes in tensordicts differs"):
9396
lstd.insert(index, TensorDict({"a": torch.ones(17)}, [17], device=device))
9397
9398
+ def test_neg_dim_lazystack(self):
9399
+ td0 = TensorDict(batch_size=(3, 5))
9400
+ td1 = TensorDict(batch_size=(4, 5))
9401
+ assert lazy_stack([td0, td1], -1).shape == (-1, 5, 2)
9402
+ assert lazy_stack([td0, td1], -2).shape == (-1, 2, 5)
9403
+ assert lazy_stack([td0, td1], -3).shape == (2, -1, 5)
9404
+ with pytest.raises(RuntimeError):
9405
+ assert lazy_stack([td0, td1], -4)
9406
9407
+ assert lazy_stack([td0, td1, TensorDict()])
9408
9409
+ assert lazy_stack([TensorDict(), td0, td1])
9410
+
9411
@pytest.mark.parametrize(
9412
"reduction", ["sum", "nansum", "mean", "nanmean", "std", "var", "prod"]
9413
)
0 commit comments