Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[jit] jit._drop fun modifier to allow in jit class non-jit decl funs #93012

Closed

Conversation

IvanKobzarev
Copy link
Contributor

@IvanKobzarev IvanKobzarev commented Jan 25, 2023

@torch.jit.unused and @torch.jit.ignore do not allow to keep in torch scripted class member function, that has non scriptable declaration (e.g. return type)

Adding FunctionModifier _DROP to allow fully skip those functions from scripting and keep them in the code of the scripted class.

E.g. it can be used for:

@torch.jit._drop
def __fx_create_arg__(self, tracer: torch.fx.Tracer) -> torch.fx.node.Argument:
    # torch.fx classes are not scriptable
    return tracer.create_node(
        "call_function",
        CFX,
        args=(tracer.create_arg(self.features),),
        kwargs={},
    )

def __iter__(self) -> Iterator[torch.Tensor]:
    return iter(self.a)

Testing:
Added test case in test/jit/test_types.py with non-scriptable type annotations (fx.* classes) that fails before fix and passes after.

python test/test_jit.py

Stack from ghstack (oldest at bottom):

Differential Revision: D42774830

@pytorch-bot
Copy link

pytorch-bot bot commented Jan 25, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/93012

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 62038df:
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added the release notes: jit release notes category label Jan 25, 2023
IvanKobzarev added a commit that referenced this pull request Jan 25, 2023
ghstack-source-id: 562c225193c41dd49519c41d2824d2bef03b01d7
Pull Request resolved: #93012
IvanKobzarev added a commit that referenced this pull request Jan 26, 2023
ghstack-source-id: 46e9c01f89d9aacc590255e7cab9c94547033da7
Pull Request resolved: #93012
@IvanKobzarev
Copy link
Contributor Author

@IvanKobzarev has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

`torch.jit.unused` is supposed to mark methods to be excluded from jit scripting.

jit creates stub with the same number of arguments to throw an exception if client call `jit.unused` method.

That stub will go through scripting and if it contains annotations/type decorations with non-scriptable classes - jit-scripting will fail compiling those classes.

The return type is preserved to keep calls to unused method.

The same happens for `jit.ignore`, but that does not allow to keep non-scriptable types as return types for jit.ignore methods.

e.g.
```
torch.jit.ignore
def __fx_create_arg__(self, tracer: torch.fx.Tracer) -> torch.fx.node.Argument:
    # torch.fx classes are not scriptable
    return tracer.create_node(
        "call_function",
        CFX,
        args=(tracer.create_arg(self.features),),
        kwargs={},
    )
```

```
FunctionDef(
name='__fx_create_arg__', 
args=arguments(
	posonlyargs=[], 
	args=[
		arg(arg='self', annotation=Name(id='Any', ctx=Load()), type_comment=None),
		arg(arg='tracer', annotation=Name(id='Any', ctx=Load()), type_comment=None)
	],
	vararg=None,
	kwonlyargs=[],
	kw_defaults=[],
	kwarg=None,
	defaults=[]
),
body=[Raise(exc=Call(func=Name(id='RuntimeError', ctx=Load()), args=[Constant(value='Cannot call unused methods', kind=None)], keywords=[]), cause     None)], 
decorator_list=[
	Attribute(value=Attribute(value=Name(id='torch', ctx=Load()), attr='jit', ctx=Load()), attr='unused', ctx=Load())
],
returns=Attribute(value=Attribute(value=Attribute(value=Name(id='torch', ctx=Load()), attr='fx', ctx=Load()), attr='node', ctx=Load()), attr='Argument', ctx=Load()), 
type_comment=None
)
```

Fix:
- Completely  jit.ignore methods for class definitions
- Drop return types for jit.ignore where it is kept for jit.unused and jit.ignore




Testing:
Added test case in `test/jit/test_types.py` with non-scriptable type annotations (fx.* classes) that fails before fix and passes after.

```
python test/test_jit.py
```



Differential Revision: [D42774830](https://our.internmc.facebook.com/intern/diff/D42774830)

[ghstack-poisoned]
IvanKobzarev added a commit that referenced this pull request Jan 26, 2023
ghstack-source-id: b71b6127043480ec3fdb35fe5569580a4d284da5
Pull Request resolved: #93012
@IvanKobzarev IvanKobzarev changed the title [jit] jit.ignore funcs ignore non-jittable decl annotations [jit] jit.ignore funcs ignore non-jit return type annotations Jan 26, 2023
@IvanKobzarev
Copy link
Contributor Author

@IvanKobzarev has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

…ons"

`torch.jit.unused` is supposed to mark methods to be excluded from jit scripting.

jit creates stub with the same number of arguments to throw an exception if client call `jit.unused` method.

That stub will go through scripting and if it contains annotations/type decorations with non-scriptable classes - jit-scripting will fail compiling those classes.

The return type is preserved to keep calls to unused method.

The same happens for `jit.ignore`, but that does not allow to keep non-scriptable types as return types for jit.ignore methods.

e.g.
```
torch.jit.ignore
def __fx_create_arg__(self, tracer: torch.fx.Tracer) -> torch.fx.node.Argument:
    # torch.fx classes are not scriptable
    return tracer.create_node(
        "call_function",
        CFX,
        args=(tracer.create_arg(self.features),),
        kwargs={},
    )
```

```
FunctionDef(
name='__fx_create_arg__', 
args=arguments(
	posonlyargs=[], 
	args=[
		arg(arg='self', annotation=Name(id='Any', ctx=Load()), type_comment=None),
		arg(arg='tracer', annotation=Name(id='Any', ctx=Load()), type_comment=None)
	],
	vararg=None,
	kwonlyargs=[],
	kw_defaults=[],
	kwarg=None,
	defaults=[]
),
body=[Raise(exc=Call(func=Name(id='RuntimeError', ctx=Load()), args=[Constant(value='Cannot call unused methods', kind=None)], keywords=[]), cause     None)], 
decorator_list=[
	Attribute(value=Attribute(value=Name(id='torch', ctx=Load()), attr='jit', ctx=Load()), attr='unused', ctx=Load())
],
returns=Attribute(value=Attribute(value=Attribute(value=Name(id='torch', ctx=Load()), attr='fx', ctx=Load()), attr='node', ctx=Load()), attr='Argument', ctx=Load()), 
type_comment=None
)
```

Fix:
- Completely  jit.ignore methods for class definitions
- Drop return types for jit.ignore where it is kept for jit.unused and jit.ignore




Testing:
Added test case in `test/jit/test_types.py` with non-scriptable type annotations (fx.* classes) that fails before fix and passes after.

```
python test/test_jit.py
```



Differential Revision: [D42774830](https://our.internmc.facebook.com/intern/diff/D42774830)

[ghstack-poisoned]
IvanKobzarev added a commit that referenced this pull request Jan 26, 2023
ghstack-source-id: 47bf3f40fdb96f92f8603d1db90d8919f6b74349
Pull Request resolved: #93012
@IvanKobzarev
Copy link
Contributor Author

@IvanKobzarev has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

…ons"

`torch.jit.unused` is supposed to mark methods to be excluded from jit scripting.

jit creates stub with the same number of arguments to throw an exception if client call `jit.unused` method.

That stub will go through scripting and if it contains annotations/type decorations with non-scriptable classes - jit-scripting will fail compiling those classes.

The return type is preserved to keep calls to unused method.

The same happens for `jit.ignore`, but that does not allow to keep non-scriptable types as return types for jit.ignore methods.

e.g.
```
torch.jit.ignore
def __fx_create_arg__(self, tracer: torch.fx.Tracer) -> torch.fx.node.Argument:
    # torch.fx classes are not scriptable
    return tracer.create_node(
        "call_function",
        CFX,
        args=(tracer.create_arg(self.features),),
        kwargs={},
    )
```

```
FunctionDef(
name='__fx_create_arg__', 
args=arguments(
	posonlyargs=[], 
	args=[
		arg(arg='self', annotation=Name(id='Any', ctx=Load()), type_comment=None),
		arg(arg='tracer', annotation=Name(id='Any', ctx=Load()), type_comment=None)
	],
	vararg=None,
	kwonlyargs=[],
	kw_defaults=[],
	kwarg=None,
	defaults=[]
),
body=[Raise(exc=Call(func=Name(id='RuntimeError', ctx=Load()), args=[Constant(value='Cannot call unused methods', kind=None)], keywords=[]), cause     None)], 
decorator_list=[
	Attribute(value=Attribute(value=Name(id='torch', ctx=Load()), attr='jit', ctx=Load()), attr='unused', ctx=Load())
],
returns=Attribute(value=Attribute(value=Attribute(value=Name(id='torch', ctx=Load()), attr='fx', ctx=Load()), attr='node', ctx=Load()), attr='Argument', ctx=Load()), 
type_comment=None
)
```

Fix:
- Completely  jit.ignore methods for class definitions
- Drop return types for jit.ignore where it is kept for jit.unused and jit.ignore




Testing:
Added test case in `test/jit/test_types.py` with non-scriptable type annotations (fx.* classes) that fails before fix and passes after.

```
python test/test_jit.py
```



Differential Revision: [D42774830](https://our.internmc.facebook.com/intern/diff/D42774830)

[ghstack-poisoned]
IvanKobzarev added a commit that referenced this pull request Jan 26, 2023
ghstack-source-id: 943c2d5aab69d43bef525bfbf872f9f6773a84e4
Pull Request resolved: #93012
@IvanKobzarev
Copy link
Contributor Author

@IvanKobzarev has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

…ons"

`torch.jit.unused` is supposed to mark methods to be excluded from jit scripting.

jit creates stub with the same number of arguments to throw an exception if client call `jit.unused` method.

That stub will go through scripting and if it contains annotations/type decorations with non-scriptable classes - jit-scripting will fail compiling those classes.

The return type is preserved to keep calls to unused method.

The same happens for `jit.ignore`, but that does not allow to keep non-scriptable types as return types for jit.ignore methods.

e.g.
```
torch.jit.ignore
def __fx_create_arg__(self, tracer: torch.fx.Tracer) -> torch.fx.node.Argument:
    # torch.fx classes are not scriptable
    return tracer.create_node(
        "call_function",
        CFX,
        args=(tracer.create_arg(self.features),),
        kwargs={},
    )
```

```
FunctionDef(
name='__fx_create_arg__', 
args=arguments(
	posonlyargs=[], 
	args=[
		arg(arg='self', annotation=Name(id='Any', ctx=Load()), type_comment=None),
		arg(arg='tracer', annotation=Name(id='Any', ctx=Load()), type_comment=None)
	],
	vararg=None,
	kwonlyargs=[],
	kw_defaults=[],
	kwarg=None,
	defaults=[]
),
body=[Raise(exc=Call(func=Name(id='RuntimeError', ctx=Load()), args=[Constant(value='Cannot call unused methods', kind=None)], keywords=[]), cause     None)], 
decorator_list=[
	Attribute(value=Attribute(value=Name(id='torch', ctx=Load()), attr='jit', ctx=Load()), attr='unused', ctx=Load())
],
returns=Attribute(value=Attribute(value=Attribute(value=Name(id='torch', ctx=Load()), attr='fx', ctx=Load()), attr='node', ctx=Load()), attr='Argument', ctx=Load()), 
type_comment=None
)
```

Fix:
- Completely  jit.ignore methods for class definitions
- Drop return types for jit.ignore where it is kept for jit.unused and jit.ignore




Testing:
Added test case in `test/jit/test_types.py` with non-scriptable type annotations (fx.* classes) that fails before fix and passes after.

```
python test/test_jit.py
```



Differential Revision: [D42774830](https://our.internmc.facebook.com/intern/diff/D42774830)

[ghstack-poisoned]
IvanKobzarev added a commit that referenced this pull request Jan 26, 2023
ghstack-source-id: 0ad8d16474cb1d726c14269fa8ee7a78a3b5082c
Pull Request resolved: #93012
@IvanKobzarev
Copy link
Contributor Author

@IvanKobzarev has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

Copy link
Contributor

@davidberard98 davidberard98 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@IvanKobzarev had some questions here:

  • don't we need to keep return type annotations for jit.ignored functions? e.g. if I have
@torch.jit.ignore
def inner(x) -> List[int]:
    return make_network_call(x)

def outer(x):
    w = inner(x)
    return torch.tensor(w)

then I think I would need the annotations on inner to know that the return type is a list.

Also - In your test case here, I would say that you should use torch.jit.unused instead of torch.jit.ignore. Does jit.unused work here?

i.e. if we're returning a type that's not supported by jit, then we can never call it from a jit-scripted function; so it's safe to mark it as jit.unused.

Is this accurate or am I missing something?

@IvanKobzarev
Copy link
Contributor Author

IvanKobzarev commented Jan 27, 2023

@davidberard98

In current state of code we do not drop return types for both jit.ignore and jit.unused and my test case fails for both.

First I've tried return type dropping for jit.unused and found that it breaks test/jit/test_class_type.py
As it needs this return type to be jit-compiled:

======================================================================
ERROR: test_unused_method (jit.test_class_type.TestClassType)
Test unused methods on scripted classes.
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/home/ivankobzarev/github/pytorch/test/jit/test_class_type.py", line 980, in test_unused_method
    class Unused(object):
  File "/home/ivankobzarev/github/pytorch/torch/jit/_script.py", line 1323, in script
    _compile_and_register_class(obj, _rcb, qualified_name)
  File "/home/ivankobzarev/github/pytorch/torch/jit/_recursive.py", line 51, in _compile_and_register_class
    script_class = torch._C._jit_script_class_compile(qualified_name, ast, defaults, rcb)
RuntimeError:
Return value was annotated as having type int but is actually of type NoneType:
  File "/home/ivankobzarev/github/pytorch/test/jit/test_class_type.py", line 995
            def uses_unused(self) -> int:
                return self.unused(y="hi", x=3)
                ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE


----------------------------------------------------------------------
  1. If to drop type for jit.ignore - return types all the tests are passing.

To my understanding:

jit.unused allows jit-compile with those calls.

model with calls to jit.ignore methods can not be jit-compiled:
https://github.com/pytorch/pytorch/blob/master/torch/_jit_internal.py#L648

@davidberard98
Copy link
Contributor

hmm, interesting - I'm not sure why this is the case, but I still think that if we're going to make this change we should make it for jit.unused, not jit.ignore.

We do need return types from jit.ignore, because jit.ignore-d can be called from jit-ed functions. And they still exist in this PR, e.g. if you try to call an ignored function that returns an Iterable, you'll see a failure.

Meanwhile jit.unused calls made from jit-ed functions will always error, so I think it's reasonable to ignore return types from them. Do you think you could see if it's possible to do this instead?

cc @qihqi for opinions or if you have any additional context here

@davidberard98
Copy link
Contributor

also an additional question - what's the use case that requires using jit.unused / jit.ignore on a function returning a non-scriptable type? typically if you have a non-forward method that you don't plan on calling, you can just leave it unannotated and jit will not compile it unless it's marked as @torch.jit.export, or if it's called by a function/method that is scripted.

…ons"

`torch.jit.unused` is supposed to mark methods to be excluded from jit scripting.

jit creates stub with the same number of arguments to throw an exception if client call `jit.unused` method.

That stub will go through scripting and if it contains annotations/type decorations with non-scriptable classes - jit-scripting will fail compiling those classes.

The return type is preserved to keep calls to unused method.

The same happens for `jit.ignore`, but that does not allow to keep non-scriptable types as return types for jit.ignore methods.

e.g.
```
torch.jit.ignore
def __fx_create_arg__(self, tracer: torch.fx.Tracer) -> torch.fx.node.Argument:
    # torch.fx classes are not scriptable
    return tracer.create_node(
        "call_function",
        CFX,
        args=(tracer.create_arg(self.features),),
        kwargs={},
    )
```

```
FunctionDef(
name='__fx_create_arg__', 
args=arguments(
	posonlyargs=[], 
	args=[
		arg(arg='self', annotation=Name(id='Any', ctx=Load()), type_comment=None),
		arg(arg='tracer', annotation=Name(id='Any', ctx=Load()), type_comment=None)
	],
	vararg=None,
	kwonlyargs=[],
	kw_defaults=[],
	kwarg=None,
	defaults=[]
),
body=[Raise(exc=Call(func=Name(id='RuntimeError', ctx=Load()), args=[Constant(value='Cannot call unused methods', kind=None)], keywords=[]), cause     None)], 
decorator_list=[
	Attribute(value=Attribute(value=Name(id='torch', ctx=Load()), attr='jit', ctx=Load()), attr='unused', ctx=Load())
],
returns=Attribute(value=Attribute(value=Attribute(value=Name(id='torch', ctx=Load()), attr='fx', ctx=Load()), attr='node', ctx=Load()), attr='Argument', ctx=Load()), 
type_comment=None
)
```

Fix:
- Completely  jit.ignore methods for class definitions
- Drop return types for jit.ignore where it is kept for jit.unused and jit.ignore




Testing:
Added test case in `test/jit/test_types.py` with non-scriptable type annotations (fx.* classes) that fails before fix and passes after.

```
python test/test_jit.py
```



Differential Revision: [D42774830](https://our.internmc.facebook.com/intern/diff/D42774830)

[ghstack-poisoned]
IvanKobzarev added a commit that referenced this pull request Jan 30, 2023
ghstack-source-id: a07cf48dbb80f890f8de16cfbb060666da68e24f
Pull Request resolved: #93012
Copy link
Contributor

@davidberard98 davidberard98 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok so... I think the problem is that jit.ignore actually isn't implemented for non-module classes. e.g. this fails, because it's trying to compile the jit.ignored function.

@torch.jit.script
class MyClass:
    def __init__(self):
        self.val = 5.0

    def forward(self, x):
        return x * self.val

    @torch.jit.ignore
    def unused(self, x) -> int:
        return next(g)

    def use_unused(self, x) -> int:
        return 4 * self.unused(x)

Your original solution (i.e. the original one I reviewed) would work, but I think behavior would be somewhat unexpected because @torch.jit.ignore would make the function invisible, instead of leaving it as an uncompiled python fn (and this behavior would be different between objects & modules). I think it's preferable (although still not optimal) to leave jit.ignore unimplemented in this case, instead of providing differing implementations for objects & modules.

I went down a rabbit hole of trying to figure out how to implement ignore properly for non-module objects (seems not too hard?) but I realize now that it's not really necessary for your use case :)

so... instead I think maybe the best option is to add some annotation that completely skips a class method. This is the default in modules, but for non-module classes torchscript will attempt to script all the methods.

... let me ask around first to make sure this is a reasonable option.

…ons"

`torch.jit.unused` and `torch.jit.ignore` do not allow to keep in class that needs to be torch scripted a member function, that has non scriptable declaration (e.g. return type)

Adding FunctionModifier _DROP to allow fully skip those functions from scripting and keep them in the code of the scripted class.


E.g. it can be used for:

```
torch.jit._drop
def __fx_create_arg__(self, tracer: torch.fx.Tracer) -> torch.fx.node.Argument:
    # torch.fx classes are not scriptable
    return tracer.create_node(
        "call_function",
        CFX,
        args=(tracer.create_arg(self.features),),
        kwargs={},
    )

def __iter__(self) -> Iterator[torch.Tensor]:
    return iter(self.a)
```

Testing:
Added test case in `test/jit/test_types.py` with non-scriptable type annotations (fx.* classes) that fails before fix and passes after.

```
python test/test_jit.py
```



Differential Revision: [D42774830](https://our.internmc.facebook.com/intern/diff/D42774830)

[ghstack-poisoned]
IvanKobzarev added a commit that referenced this pull request Jan 31, 2023
ghstack-source-id: 0c705bd43a6a09f2c823a2288dd00238f8b46e11
Pull Request resolved: #93012
@IvanKobzarev
Copy link
Contributor Author

@IvanKobzarev has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@IvanKobzarev IvanKobzarev changed the title [jit] jit.ignore funcs ignore non-jit return type annotations [jit] jit._drop fun modifier to allow in jit class non-jit decl funs Jan 31, 2023
Copy link
Contributor

@davidberard98 davidberard98 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mostly looks good! left some comments - I think it would be good to specify that it's for non-module classes and then only attempt to implement for that use case.

(approve to unblock)

@@ -762,12 +768,21 @@ def should_drop(fn) -> bool:
attr = get_torchscript_modifier(fn)
if attr is None:
return False
return attr is FunctionModifiers.UNUSED
return attr is FunctionModifiers.UNUSED or attr is FunctionModifiers._DROP
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also don't think this needs to be changed, but could be wrong

Comment on lines +285 to +288
if _is_drop_fn(fn):
# Dropping potentially unsupported return type annotation for jit._drop
fn_def.returns = None
fn_def.type_comment = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this is needed, because we won't call get_jit_def on a method if it's marked with _drop.

return (
mod is FunctionModifiers.UNUSED
or mod is FunctionModifiers.IGNORE
or mod is FunctionModifiers._DROP
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also here, I think it's safe to leave as it was before (although less confident of this one, check tests to be sure)

AFAICT, is_ignored_fn is mostly used in module scripting. _DROP is probably only going to be used for non-module classes (because modules already drop methods by default, except for the forward method. And supporting this on the forward method might be a bit extra work, because a lot of things assume that modules have a forward method). So, I suspect we don't need to add this.

@@ -740,6 +741,11 @@ def decorator(fn):
return decorator


def _drop(fn):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: could you add some extra docs here? e.g. why we need it (classes try to compile all methods by default, and unused/ignore don't work because they have return types), use case (non-method classes only)

kwargs={},
)

torch.jit.script(CFX)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add an export + import test here? example.

I did notice that you can still call __fx_create_arg__ on the scripted class. Might be good to document this (or disable it if you know how, but I'm not exactly sure right now what the best way to do this is)

@facebook-github-bot
Copy link
Contributor

@pytorchbot merge

(Initiating merge automatically since Phabricator Diff has merged)

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Feb 1, 2023
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@facebook-github-bot facebook-github-bot deleted the gh/ivankobzarev/119/head branch June 8, 2023 17:20
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/trunk Trigger trunk jobs on your pull request Merged release notes: jit release notes category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants