From 51413e5e3efb459a2ac22395df6b9715a2b275fe Mon Sep 17 00:00:00 2001 From: Tom Begley Date: Thu, 24 Nov 2022 11:10:54 +0000 Subject: [PATCH 1/2] functorch fixes for old-deps --- tensordict/nn/common.py | 13 ++++++++++--- tensordict/nn/sequence.py | 3 ++- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/tensordict/nn/common.py b/tensordict/nn/common.py index 47cfaf63c..ab4adfa17 100644 --- a/tensordict/nn/common.py +++ b/tensordict/nn/common.py @@ -29,6 +29,13 @@ except ImportError: _has_functorch = False + class FunctionalModule: + pass + + class FunctionalModuleWithBuffers: + pass + + __all__ = [ "TensorDictModule", "TensorDictModuleWrapper", @@ -153,7 +160,7 @@ def __init__( def is_functional(self): return isinstance( self.module, - (functorch.FunctionalModule, functorch.FunctionalModuleWithBuffers), + (FunctionalModule, FunctionalModuleWithBuffers), ) def _write_to_tensordict( @@ -408,7 +415,7 @@ def make_functional_with_buffers(self, clone: bool = True, native: bool = False) def num_params(self): if isinstance( self.module, - (functorch.FunctionalModule, functorch.FunctionalModuleWithBuffers), + (FunctionalModule, FunctionalModuleWithBuffers), ): return len(self.module.param_names) else: @@ -416,7 +423,7 @@ def num_params(self): @property def num_buffers(self): - if isinstance(self.module, (functorch.FunctionalModuleWithBuffers,)): + if isinstance(self.module, FunctionalModuleWithBuffers): return len(self.module.buffer_names) else: return 0 diff --git a/tensordict/nn/sequence.py b/tensordict/nn/sequence.py index 922f31479..8aa1ae1ed 100644 --- a/tensordict/nn/sequence.py +++ b/tensordict/nn/sequence.py @@ -394,6 +394,7 @@ def make_functional_with_buffers(self, clone: bool = True, native: bool = False) is_shared=False) """ + native = native or not _has_functorch if clone: self_copy = deepcopy(self) self_copy.module = copy(self_copy.module) @@ -406,7 +407,7 @@ def make_functional_with_buffers(self, clone: bool = True, native: bool = False) _params, _buffers, ) = module.make_functional_with_buffers(clone=True, native=native) - if native or not _has_functorch: + if native: params[str(i)] = _params buffers[str(i)] = _buffers else: From 1f9051835fb34fe523818e376281c2a5c19725b3 Mon Sep 17 00:00:00 2001 From: Tom Begley Date: Thu, 24 Nov 2022 11:42:37 +0000 Subject: [PATCH 2/2] is_functional bug --- tensordict/nn/common.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/tensordict/nn/common.py b/tensordict/nn/common.py index ab4adfa17..9fb35cbda 100644 --- a/tensordict/nn/common.py +++ b/tensordict/nn/common.py @@ -160,7 +160,12 @@ def __init__( def is_functional(self): return isinstance( self.module, - (FunctionalModule, FunctionalModuleWithBuffers), + ( + FunctionalModule, + FunctionalModuleWithBuffers, + rlFunctionalModule, + rlFunctionalModuleWithBuffers, + ), ) def _write_to_tensordict(