Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support scripting classmethod called with object instances #49967

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
26 changes: 26 additions & 0 deletions test/jit/test_class_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -1212,6 +1212,32 @@ def test_function(a: int, b: int) -> 'ClassWithStaticMethod':

self.checkScript(test_function, (1, 2))

def test_classmethod(self):
"""
Test classmethods on class types.
"""
global ClassWithClassMethod

@torch.jit.script
class ClassWithClassMethod:
def __init__(self, a: int):
self.a: int = a

def __eq__(self, other: 'ClassWithClassMethod'):
return self.a == other.a

@classmethod
def create(cls, a: int) -> 'ClassWithClassMethod':
return cls(a)

def test_function(a: int) -> 'ClassWithClassMethod':
x = ClassWithClassMethod(a)
# Support calling classmethod with an instance
# Calling with the class is not supported.
return x.create(a)

self.checkScript(test_function, (1,))

def test_properties(self):
"""
Test that a scripted class can make use of the @property decorator.
Expand Down
15 changes: 13 additions & 2 deletions torch/jit/frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,9 +166,14 @@ def get_jit_class_def(cls, self_name):
and not is_static_fn(cls, m.__name__)
and m.__name__ in cls.__dict__
)

def is_classmethod(fn):
return inspect.ismethod(fn) and getattr(fn, "__self__", None) == cls

methods = [get_jit_def(method[1],
method[0],
self_name=self_name) for method in methods]
self_name=self_name,
is_classmethod=is_classmethod(method[1])) for method in methods]

properties = get_class_properties(cls, self_name)

Expand Down Expand Up @@ -217,7 +222,7 @@ def remove_prefix(text, prefix):
return aligned_prefix + aligned_suffix


def get_jit_def(fn, def_name, self_name=None):
def get_jit_def(fn, def_name, self_name=None, is_classmethod=False):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to determine if fn is a classmethod without passing an extra parameter? I think it is:

>>> class A:
...   @classmethod
...   def a(cls):
...     pass
...   def b(self):
...     pass
...   @staticmethod
...   def c(s, t):
...     pass
... 
>>> A.a.__self__
<class '__main__.A'>
>>> A.b.__self__
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
AttributeError: 'function' object has no attribute '__self__'
>>> A.c.__self__
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
AttributeError: 'function' object has no attribute '__self__'

So it seems that __self__ is only present on classmethods? So I think we can do that check in get_jit_def directly without the extra parameter.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Regular methods that are bounded to instances, e.g. A().b in your example will have __self__. Could get_jit_def be called with such methods?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From my knowledge, no. Can you try it to see if it's possible (to get rid of the extra function argument)?

Copy link
Contributor Author

@ppwwyyxx ppwwyyxx Jan 7, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This fails many tests. I added print(fn) when __self__ is present, and saw cases it is called with bound regular methods:

$python3 test/test_jit.py                                                                                                                                                          
Fail to import hypothesis in common_utils, tests are not derandomized                                                                                                                 
.................................<bound method BasicModule.forward of BasicModule()>                                                                                                  
E<bound method BasicModule.forward of BasicModule()>                                                                                                                                  
E<bound method BasicModule.forward of BasicModule()>                                                                                                                                  
E./private/home/yuxinwu/DL/pytorch/test/jit/test_builtins.py:127: DeprecationWarning: Please use assertEqual instead.                                                                 
  self.assertEquals(py_out, jit_out)                                                                                                                                                  
.<bound method TestBuiltins.test_has_attr.<locals>.Mod.forward of Mod(                                                                                                                
  (mods): ModuleList(                                                                                                                                                                 
    (0): HasA()                                                                                                                                                                       
    (1): HasB()                                                                                                                                                                       
  )                                                                                                                                                                                   
)>                                                                                                                                                                                    
<bound method ModuleList.forward of ModuleList(                                                                                                                                       
  (0): HasA()                                                                                                                                                                         
  (1): HasB()                                                                                                                                                                         
)>                                                                                                                                                                                    
E<bound method TestBuiltins.test_has_attr_invalid_args.<locals>.Mod.forward of Mod(                                                                                                   
  (mod): Linear(in_features=1, out_features=1, bias=True)                                                                                                                             
)>                                                                    

It seems that's because scripted modules are always created from instance, not class - so regular methods are bound.

"""
Build a JIT AST (TreeView) from the given function.

Expand All @@ -244,6 +249,12 @@ def _forward(self):
ctx = SourceContext(source, filename, file_lineno, leading_whitespace_len, True)
fn_def = py_ast.body[0]

if is_classmethod:
arg_name = fn_def.args.args[0].arg
# Insert a statement that assigns the first argument to the class
assign_stmt = ast.parse(f"{arg_name} = {self_name}").body[0]
fn_def.body.insert(0, assign_stmt)

# Swap out the function signature and body if it is unused
if should_drop(fn):
unused_fn_def = ast.parse("def unused_fn(self: Any):\n\traise RuntimeError(\"Cannot call @unused methods\")")
Expand Down