Skip to content

Commit feb736f

Browse files
author
Vincent Moens
committed
[BugFix] Fix tensordict.get in TDModule tensor retrieval for NonTensorData
ghstack-source-id: 240b54a Pull Request resolved: #1249
1 parent bf9007c commit feb736f

File tree

2 files changed

+21
-1
lines changed

2 files changed

+21
-1
lines changed

tensordict/nn/common.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1169,7 +1169,9 @@ def forward(
11691169
if self._kwargs is not None:
11701170
kwargs.update(
11711171
{
1172-
kwarg: tensordict.get(in_key, default=default)
1172+
kwarg: tensordict._get_tuple_maybe_non_tensor(
1173+
_unravel_key_to_tuple(in_key), default=default
1174+
)
11731175
for kwarg, in_key in _zip_strict(self._kwargs, self.in_keys)
11741176
}
11751177
)

test/test_nn.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -689,6 +689,24 @@ def test_nontensor(self):
689689
out_keys=["out"],
690690
)
691691
assert tdm(TensorDict())["out"] == "a string!"
692+
tdm = TensorDictModule(
693+
lambda a_string: a_string + " is a string!",
694+
in_keys=["string"],
695+
out_keys=["another string"],
696+
)
697+
assert (
698+
tdm(TensorDict(string="a string"))["another string"]
699+
== "a string is a string!"
700+
)
701+
tdm = TensorDictModule(
702+
lambda string: string + " is a string!",
703+
in_keys={"string": "key"},
704+
out_keys=["another string"],
705+
out_to_in_map=True,
706+
)
707+
assert (
708+
tdm(TensorDict(key="a string"))["another string"] == "a string is a string!"
709+
)
692710

693711
@pytest.mark.parametrize(
694712
"out_keys",

0 commit comments

Comments
 (0)