Skip to content

Commit

Permalink
[JIT] Enable @unused syntax for ignoring properties (#45261)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #45261

**Summary**
This commit enables `unused` syntax for ignoring
properties. Inoring properties is more intuitive with this feature enabled.
`ignore` is not supported because class type properties cannot be
executed in Python (because they exist only as TorchScript types) like
an `ignored` function and module properties that cannot be scripted
are not added to the `ScriptModule` wrapper so that they
may execute in Python.

**Test Plan**
This commit updates the existing unit tests for class type and module
properties to test properties ignored using `unused`.

Test Plan: Imported from OSS

Reviewed By: navahgar, Krovatkin, mannatsingh

Differential Revision: D23971881

Pulled By: SplitInfinity

fbshipit-source-id: 8d3cc1bbede7753d6b6f416619e4660c56311d33
  • Loading branch information
Meghan Lele authored and facebook-github-bot committed Sep 29, 2020
1 parent 5f49d14 commit 09b3e16
Show file tree
Hide file tree
Showing 7 changed files with 34 additions and 7 deletions.
11 changes: 10 additions & 1 deletion test/jit/test_class_type.py
Expand Up @@ -1167,7 +1167,7 @@ def free_function(x: int) -> int:

@torch.jit.script
class Properties(object):
__ignored_properties__ = ["unsupported"]
__jit_unused_properties__ = ["unsupported"]

def __init__(self, a: int):
self.a = a
Expand All @@ -1180,6 +1180,15 @@ def attr(self) -> int:
def unsupported(self) -> int:
return sum([self.a])

@torch.jit.unused
@property
def unsupported_2(self) -> int:
return sum([self.a])

@unsupported_2.setter
def unsupported_2(self, value):
self.a = sum([self.a])

@attr.setter
def attr(self, value: int):
self.a = value + 3
Expand Down
11 changes: 10 additions & 1 deletion test/test_jit_py3.py
Expand Up @@ -621,7 +621,7 @@ def if_function(inp: torch.Tensor) -> Any:

def test_module_properties(self):
class ModuleWithProperties(torch.nn.Module):
__ignored_properties__ = ["ignored_attr"]
__jit_unused_properties__ = ["ignored_attr"]

def __init__(self, a: int):
super().__init__()
Expand All @@ -639,6 +639,15 @@ def attr(self):
def ignored_attr(self):
return sum([self.a])

@torch.jit.unused
@property
def ignored_attr_2(self):
return sum([self.a])

@ignored_attr_2.setter
def ignored_attr_2(self, value):
self.a = sum([self.a])

@attr.setter
def attr(self, a: int):
if a > 0:
Expand Down
9 changes: 9 additions & 0 deletions torch/_jit_internal.py
Expand Up @@ -390,6 +390,15 @@ def forward(self, x):
# exception raised
m(torch.rand(100))
"""
if isinstance(fn, property):
prop = fn
setattr(prop.fget, "_torchscript_modifier", FunctionModifiers.UNUSED) # noqa: B010

if prop.fset:
setattr(prop.fset, "_torchscript_modifier", FunctionModifiers.UNUSED) # noqa: B010

return prop

fn._torchscript_modifier = FunctionModifiers.UNUSED
return fn

Expand Down
2 changes: 1 addition & 1 deletion torch/fx/graph_module.py
Expand Up @@ -164,7 +164,7 @@ def __init__(self, root: Union[torch.nn.Module, Dict[str, Any]], graph: Graph):
# continued string literal. Issue here: https://github.com/pytorch/pytorch/issues/44842
#
# Shouldn't be an issue since these methods shouldn't be used in TorchScript anyway
__ignored_properties__ = ['graph']
__jit_unused_properties__ = ['graph']

@property
def graph(self):
Expand Down
2 changes: 1 addition & 1 deletion torch/jit/_script.py
Expand Up @@ -276,7 +276,7 @@ class ScriptModule(with_metaclass(ScriptMeta, Module)): # type: ignore
contain methods, attributes, parameters, and
constants. These can be accessed the same as on a normal ``nn.Module``.
"""
__ignored_properties__ = ['code', 'code_with_constants', 'graph', 'inlined_graph', 'original_name']
__jit_unused_properties__ = ['code', 'code_with_constants', 'graph', 'inlined_graph', 'original_name']

def __init__(self):
super(ScriptModule, self).__init__()
Expand Down
4 changes: 2 additions & 2 deletions torch/jit/frontend.py
Expand Up @@ -142,12 +142,12 @@ def get_class_properties(cls, self_name):
props = inspect.getmembers(
cls, predicate=lambda m: isinstance(m, property))
# Any property that should not compiled must be in this list on the Module.
ignored_properties = getattr(cls, "__ignored_properties__", [])
unused_properties = getattr(cls, "__jit_unused_properties__", [])

# Create Property TreeView objects from inspected property objects.
properties = []
for prop in props:
if prop[0] not in ignored_properties:
if prop[0] not in unused_properties and not should_drop(prop[1].fget):
getter = get_jit_def(prop[1].fget, f"__{prop[0]}_getter", self_name=self_name)
setter = get_jit_def(prop[1].fset, f"__{prop[0]}_setter", self_name=self_name) if prop[1].fset else None
properties.append(Property(getter.range(), Ident(getter.range(), prop[0]), getter, setter))
Expand Down
2 changes: 1 addition & 1 deletion torch/nn/modules/rnn.py
Expand Up @@ -24,7 +24,7 @@ def apply_permutation(tensor: Tensor, permutation: Tensor, dim: int = 1) -> Tens
class RNNBase(Module):
__constants__ = ['mode', 'input_size', 'hidden_size', 'num_layers', 'bias',
'batch_first', 'dropout', 'bidirectional']
__ignored_properties__ = ['all_weights']
__jit_unused_properties__ = ['all_weights']

mode: str
input_size: int
Expand Down

0 comments on commit 09b3e16

Please sign in to comment.