File tree Expand file tree Collapse file tree 2 files changed +21
-1
lines changed Expand file tree Collapse file tree 2 files changed +21
-1
lines changed Original file line number Diff line number Diff line change @@ -1169,7 +1169,9 @@ def forward(
11691169 if self ._kwargs is not None :
11701170 kwargs .update (
11711171 {
1172- kwarg : tensordict .get (in_key , default = default )
1172+ kwarg : tensordict ._get_tuple_maybe_non_tensor (
1173+ _unravel_key_to_tuple (in_key ), default = default
1174+ )
11731175 for kwarg , in_key in _zip_strict (self ._kwargs , self .in_keys )
11741176 }
11751177 )
Original file line number Diff line number Diff line change @@ -689,6 +689,24 @@ def test_nontensor(self):
689689 out_keys = ["out" ],
690690 )
691691 assert tdm (TensorDict ())["out" ] == "a string!"
692+ tdm = TensorDictModule (
693+ lambda a_string : a_string + " is a string!" ,
694+ in_keys = ["string" ],
695+ out_keys = ["another string" ],
696+ )
697+ assert (
698+ tdm (TensorDict (string = "a string" ))["another string" ]
699+ == "a string is a string!"
700+ )
701+ tdm = TensorDictModule (
702+ lambda string : string + " is a string!" ,
703+ in_keys = {"string" : "key" },
704+ out_keys = ["another string" ],
705+ out_to_in_map = True ,
706+ )
707+ assert (
708+ tdm (TensorDict (key = "a string" ))["another string" ] == "a string is a string!"
709+ )
692710
693711 @pytest .mark .parametrize (
694712 "out_keys" ,
You can’t perform that action at this time.
0 commit comments