Skip to content

Commit d25bd54

Browse files
author
Vincent Moens
committed
[BugFix] _PASSTHROUGH_MEMO for passthrough tensorclass
ghstack-source-id: 0bfbfc9 Pull Request resolved: #1231
1 parent c35d7aa commit d25bd54

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

tensordict/utils.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2582,17 +2582,20 @@ def _is_non_tensor(cls: type):
25822582
return out
25832583

25842584

2585+
_PASSTHROUGH_MEMO = {}
2586+
2587+
25852588
def _pass_through_cls(cls: type):
25862589
out = None
25872590
is_dynamo = is_compiling()
25882591
if not is_dynamo:
2589-
out = _NON_TENSOR_MEMO.get(cls)
2592+
out = _PASSTHROUGH_MEMO.get(cls)
25902593
if out is None:
25912594
out = bool(getattr(cls, "_is_non_tensor", False)) or getattr(
25922595
cls, "_pass_through", False
25932596
)
25942597
if not is_dynamo:
2595-
_NON_TENSOR_MEMO[cls] = out
2598+
_PASSTHROUGH_MEMO[cls] = out
25962599
return out
25972600

25982601

0 commit comments

Comments
 (0)