We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent c35d7aa commit d25bd54Copy full SHA for d25bd54
tensordict/utils.py
@@ -2582,17 +2582,20 @@ def _is_non_tensor(cls: type):
2582
return out
2583
2584
2585
+_PASSTHROUGH_MEMO = {}
2586
+
2587
2588
def _pass_through_cls(cls: type):
2589
out = None
2590
is_dynamo = is_compiling()
2591
if not is_dynamo:
- out = _NON_TENSOR_MEMO.get(cls)
2592
+ out = _PASSTHROUGH_MEMO.get(cls)
2593
if out is None:
2594
out = bool(getattr(cls, "_is_non_tensor", False)) or getattr(
2595
cls, "_pass_through", False
2596
)
2597
- _NON_TENSOR_MEMO[cls] = out
2598
+ _PASSTHROUGH_MEMO[cls] = out
2599
2600
2601
0 commit comments