Skip to content

Commit

Permalink
[JIT] Enable @ignore and @unused syntax for ignoring properties
Browse files Browse the repository at this point in the history
**Summary**
This commit enables `@ignore` and `@unused` syntax for ignoring
properties, but works only if these decorators are applied to a function
before the @Property decorator. Despite this, ignoring properties is
more intuitive with this feature enabled.

**Test Plan**
This commit updates the existing unit tests for class type and module
properties to test properties ignored using `@ignore` and `@unused`.
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 7a545729ca5efff280658a00e32e5053d7619f9d
Pull Request resolved: #45261
  • Loading branch information
Meghan Lele committed Sep 24, 2020
1 parent a25c202 commit 6449867
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 2 deletions.
9 changes: 9 additions & 0 deletions test/jit/test_class_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -1180,6 +1180,15 @@ def attr(self) -> int:
def unsupported(self) -> int:
return sum([self.a])

@property
@torch.jit.unused
def unsupported_2(self) -> int:
return sum([self.a])

@unsupported_2.setter
def unsupported_2(self, value):
self.a = sum([self.a])

@attr.setter
def attr(self, value: int):
self.a = value + 3
Expand Down
9 changes: 9 additions & 0 deletions test/test_jit_py3.py
Original file line number Diff line number Diff line change
Expand Up @@ -639,6 +639,15 @@ def attr(self):
def ignored_attr(self):
return sum([self.a])

@property
@torch.jit.unused
def ignored_attr_2(self):
return sum([self.a])

@ignored_attr_2.setter
def ignored_attr_2(self, value):
self.a = sum([self.a])

@attr.setter
def attr(self, a: int):
if a > 0:
Expand Down
4 changes: 2 additions & 2 deletions torch/jit/frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
)
from torch._utils_internal import get_source_lines_and_file

from torch._jit_internal import SourceContext, should_drop, is_static_fn
from torch._jit_internal import SourceContext, should_drop, is_static_fn, is_ignored_fn
import torch.jit.annotations

# Borrowed from cPython implementation
Expand Down Expand Up @@ -147,7 +147,7 @@ def get_class_properties(cls, self_name):
# Create Property TreeView objects from inspected property objects.
properties = []
for prop in props:
if prop[0] not in ignored_properties:
if prop[0] not in ignored_properties and not is_ignored_fn(prop[1].fget):
getter = get_jit_def(prop[1].fget, f"__{prop[0]}_getter", self_name=self_name)
setter = get_jit_def(prop[1].fset, f"__{prop[0]}_setter", self_name=self_name) if prop[1].fset else None
properties.append(Property(getter.range(), Ident(getter.range(), prop[0]), getter, setter))
Expand Down

0 comments on commit 6449867

Please sign in to comment.