From 09b3e16b40167c3a0765d9a47147d0be3cc9181f Mon Sep 17 00:00:00 2001 From: Meghan Lele Date: Tue, 29 Sep 2020 10:20:00 -0700 Subject: [PATCH] [JIT] Enable @unused syntax for ignoring properties (#45261) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/45261 **Summary** This commit enables `unused` syntax for ignoring properties. Inoring properties is more intuitive with this feature enabled. `ignore` is not supported because class type properties cannot be executed in Python (because they exist only as TorchScript types) like an `ignored` function and module properties that cannot be scripted are not added to the `ScriptModule` wrapper so that they may execute in Python. **Test Plan** This commit updates the existing unit tests for class type and module properties to test properties ignored using `unused`. Test Plan: Imported from OSS Reviewed By: navahgar, Krovatkin, mannatsingh Differential Revision: D23971881 Pulled By: SplitInfinity fbshipit-source-id: 8d3cc1bbede7753d6b6f416619e4660c56311d33 --- test/jit/test_class_type.py | 11 ++++++++++- test/test_jit_py3.py | 11 ++++++++++- torch/_jit_internal.py | 9 +++++++++ torch/fx/graph_module.py | 2 +- torch/jit/_script.py | 2 +- torch/jit/frontend.py | 4 ++-- torch/nn/modules/rnn.py | 2 +- 7 files changed, 34 insertions(+), 7 deletions(-) diff --git a/test/jit/test_class_type.py b/test/jit/test_class_type.py index dda6916b5591..7c9e323163e6 100644 --- a/test/jit/test_class_type.py +++ b/test/jit/test_class_type.py @@ -1167,7 +1167,7 @@ def free_function(x: int) -> int: @torch.jit.script class Properties(object): - __ignored_properties__ = ["unsupported"] + __jit_unused_properties__ = ["unsupported"] def __init__(self, a: int): self.a = a @@ -1180,6 +1180,15 @@ def attr(self) -> int: def unsupported(self) -> int: return sum([self.a]) + @torch.jit.unused + @property + def unsupported_2(self) -> int: + return sum([self.a]) + + @unsupported_2.setter + def unsupported_2(self, value): + self.a = sum([self.a]) + @attr.setter def attr(self, value: int): self.a = value + 3 diff --git a/test/test_jit_py3.py b/test/test_jit_py3.py index 4de5db884035..212b03d9658b 100644 --- a/test/test_jit_py3.py +++ b/test/test_jit_py3.py @@ -621,7 +621,7 @@ def if_function(inp: torch.Tensor) -> Any: def test_module_properties(self): class ModuleWithProperties(torch.nn.Module): - __ignored_properties__ = ["ignored_attr"] + __jit_unused_properties__ = ["ignored_attr"] def __init__(self, a: int): super().__init__() @@ -639,6 +639,15 @@ def attr(self): def ignored_attr(self): return sum([self.a]) + @torch.jit.unused + @property + def ignored_attr_2(self): + return sum([self.a]) + + @ignored_attr_2.setter + def ignored_attr_2(self, value): + self.a = sum([self.a]) + @attr.setter def attr(self, a: int): if a > 0: diff --git a/torch/_jit_internal.py b/torch/_jit_internal.py index 5fa2ee639a9f..e9fb21c5e854 100644 --- a/torch/_jit_internal.py +++ b/torch/_jit_internal.py @@ -390,6 +390,15 @@ def forward(self, x): # exception raised m(torch.rand(100)) """ + if isinstance(fn, property): + prop = fn + setattr(prop.fget, "_torchscript_modifier", FunctionModifiers.UNUSED) # noqa: B010 + + if prop.fset: + setattr(prop.fset, "_torchscript_modifier", FunctionModifiers.UNUSED) # noqa: B010 + + return prop + fn._torchscript_modifier = FunctionModifiers.UNUSED return fn diff --git a/torch/fx/graph_module.py b/torch/fx/graph_module.py index e635819550ad..9c7d50b1d9dc 100644 --- a/torch/fx/graph_module.py +++ b/torch/fx/graph_module.py @@ -164,7 +164,7 @@ def __init__(self, root: Union[torch.nn.Module, Dict[str, Any]], graph: Graph): # continued string literal. Issue here: https://github.com/pytorch/pytorch/issues/44842 # # Shouldn't be an issue since these methods shouldn't be used in TorchScript anyway - __ignored_properties__ = ['graph'] + __jit_unused_properties__ = ['graph'] @property def graph(self): diff --git a/torch/jit/_script.py b/torch/jit/_script.py index 4d28a5f2ad13..0adbefc02cee 100644 --- a/torch/jit/_script.py +++ b/torch/jit/_script.py @@ -276,7 +276,7 @@ class ScriptModule(with_metaclass(ScriptMeta, Module)): # type: ignore contain methods, attributes, parameters, and constants. These can be accessed the same as on a normal ``nn.Module``. """ - __ignored_properties__ = ['code', 'code_with_constants', 'graph', 'inlined_graph', 'original_name'] + __jit_unused_properties__ = ['code', 'code_with_constants', 'graph', 'inlined_graph', 'original_name'] def __init__(self): super(ScriptModule, self).__init__() diff --git a/torch/jit/frontend.py b/torch/jit/frontend.py index 4cfba50d0466..fdf1e613461e 100644 --- a/torch/jit/frontend.py +++ b/torch/jit/frontend.py @@ -142,12 +142,12 @@ def get_class_properties(cls, self_name): props = inspect.getmembers( cls, predicate=lambda m: isinstance(m, property)) # Any property that should not compiled must be in this list on the Module. - ignored_properties = getattr(cls, "__ignored_properties__", []) + unused_properties = getattr(cls, "__jit_unused_properties__", []) # Create Property TreeView objects from inspected property objects. properties = [] for prop in props: - if prop[0] not in ignored_properties: + if prop[0] not in unused_properties and not should_drop(prop[1].fget): getter = get_jit_def(prop[1].fget, f"__{prop[0]}_getter", self_name=self_name) setter = get_jit_def(prop[1].fset, f"__{prop[0]}_setter", self_name=self_name) if prop[1].fset else None properties.append(Property(getter.range(), Ident(getter.range(), prop[0]), getter, setter)) diff --git a/torch/nn/modules/rnn.py b/torch/nn/modules/rnn.py index e6589b9ef1d9..b8da2a877dd9 100644 --- a/torch/nn/modules/rnn.py +++ b/torch/nn/modules/rnn.py @@ -24,7 +24,7 @@ def apply_permutation(tensor: Tensor, permutation: Tensor, dim: int = 1) -> Tens class RNNBase(Module): __constants__ = ['mode', 'input_size', 'hidden_size', 'num_layers', 'bias', 'batch_first', 'dropout', 'bidirectional'] - __ignored_properties__ = ['all_weights'] + __jit_unused_properties__ = ['all_weights'] mode: str input_size: int