Skip to content

Commit 1d642b0

Browse files
author
Vincent Moens
committed
[BugFix] Fix improper name setting in __setitem__
ghstack-source-id: 08d0bcf Pull-Request-resolved: #1313
1 parent 5f26a8b commit 1d642b0

File tree

4 files changed

+36
-2
lines changed

4 files changed

+36
-2
lines changed

tensordict/_td.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1882,6 +1882,7 @@ def rep(leaf):
18821882
batch_size=new_batch_size,
18831883
call_on_nested=True,
18841884
propagate_lock=True,
1885+
names=self._maybe_names(),
18851886
)
18861887

18871888
def _repeat(self, *repeats: int) -> TensorDictBase:
@@ -1895,6 +1896,7 @@ def rep(leaf):
18951896
batch_size=new_batch_size,
18961897
call_on_nested=True,
18971898
propagate_lock=True,
1899+
names=self._maybe_names(),
18981900
)
18991901

19001902
def _transpose(self, dim0, dim1):

tensordict/base.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11712,7 +11712,9 @@ def _validate_value_devicefree(
1171211712
if has_names:
1171311713
if value.names[: self.batch_dims] != self.names:
1171411714
# we clone not to corrupt the value
11715-
value = value.clone(False).refine_names(*self.names)
11715+
value = value.clone(False).refine_names(
11716+
*(self.names + value.names[self.batch_dims :])
11717+
)
1171611718
else:
1171711719
if value._has_names():
1171811720
self.names = value.names[: self.batch_dims]

tensordict/utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1360,7 +1360,9 @@ def assert_close(
13601360
if not _is_tensor_collection(type(actual)) or not _is_tensor_collection(
13611361
type(expected)
13621362
):
1363-
raise TypeError("assert_allclose inputs must be of TensorDict type")
1363+
raise TypeError(
1364+
f"assert_allclose inputs must be of TensorDict type, got {type(actual)} and {type(expected)}"
1365+
)
13641366

13651367
from tensordict._lazy import LazyStackedTensorDict
13661368

test/test_tensordict.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2232,6 +2232,34 @@ def test_requires_grad(self, device):
22322232
# First stacked tensor has requires_grad == True
22332233
assert list(stacked_td.values())[0].requires_grad is True
22342234

2235+
def test_refine_names_setitem_subtd(self):
2236+
batch_size = 1
2237+
seq_len = 2
2238+
n_agents = 3
2239+
td = TensorDict(
2240+
{
2241+
"agents": TensorDict(
2242+
{
2243+
"obs": torch.zeros((batch_size, seq_len, n_agents, 5)),
2244+
"dones": torch.zeros((batch_size, seq_len, n_agents, 1)),
2245+
},
2246+
batch_size=(batch_size, seq_len, n_agents),
2247+
names=[None, "time", "other"],
2248+
),
2249+
"dones": torch.zeros((batch_size, seq_len)),
2250+
},
2251+
batch_size=(batch_size, seq_len),
2252+
names=[None, "time"],
2253+
)
2254+
#
2255+
td["agents"] = td["agents"].repeat_interleave(2, dim=-1)
2256+
assert len(td["agents"].names) == 3
2257+
assert td["agents"].names[-1] == "other"
2258+
td["agents"] = td["agents"].repeat(1, 1, 2)
2259+
assert td["agents"].names[-1] == "other"
2260+
td["agents"] = torch.cat((td["agents"], td["agents"]), dim=2)
2261+
assert td["agents"].names[-1] == "other"
2262+
22352263
def test_rename_key_nested(self):
22362264
td = TensorDict(a={"b": {"c": 0}})
22372265
td.rename_key_(("a", "b", "c"), ("a", "b"))

0 commit comments

Comments
 (0)