-
Notifications
You must be signed in to change notification settings - Fork 21.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[jit] jit._drop fun modifier to allow in jit class non-jit decl funs #93012
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/93012
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 62038df: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
ghstack-source-id: 562c225193c41dd49519c41d2824d2bef03b01d7 Pull Request resolved: #93012
[ghstack-poisoned]
ghstack-source-id: 46e9c01f89d9aacc590255e7cab9c94547033da7 Pull Request resolved: #93012
@IvanKobzarev has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
`torch.jit.unused` is supposed to mark methods to be excluded from jit scripting. jit creates stub with the same number of arguments to throw an exception if client call `jit.unused` method. That stub will go through scripting and if it contains annotations/type decorations with non-scriptable classes - jit-scripting will fail compiling those classes. The return type is preserved to keep calls to unused method. The same happens for `jit.ignore`, but that does not allow to keep non-scriptable types as return types for jit.ignore methods. e.g. ``` torch.jit.ignore def __fx_create_arg__(self, tracer: torch.fx.Tracer) -> torch.fx.node.Argument: # torch.fx classes are not scriptable return tracer.create_node( "call_function", CFX, args=(tracer.create_arg(self.features),), kwargs={}, ) ``` ``` FunctionDef( name='__fx_create_arg__', args=arguments( posonlyargs=[], args=[ arg(arg='self', annotation=Name(id='Any', ctx=Load()), type_comment=None), arg(arg='tracer', annotation=Name(id='Any', ctx=Load()), type_comment=None) ], vararg=None, kwonlyargs=[], kw_defaults=[], kwarg=None, defaults=[] ), body=[Raise(exc=Call(func=Name(id='RuntimeError', ctx=Load()), args=[Constant(value='Cannot call unused methods', kind=None)], keywords=[]), cause None)], decorator_list=[ Attribute(value=Attribute(value=Name(id='torch', ctx=Load()), attr='jit', ctx=Load()), attr='unused', ctx=Load()) ], returns=Attribute(value=Attribute(value=Attribute(value=Name(id='torch', ctx=Load()), attr='fx', ctx=Load()), attr='node', ctx=Load()), attr='Argument', ctx=Load()), type_comment=None ) ``` Fix: - Completely jit.ignore methods for class definitions - Drop return types for jit.ignore where it is kept for jit.unused and jit.ignore Testing: Added test case in `test/jit/test_types.py` with non-scriptable type annotations (fx.* classes) that fails before fix and passes after. ``` python test/test_jit.py ``` Differential Revision: [D42774830](https://our.internmc.facebook.com/intern/diff/D42774830) [ghstack-poisoned]
ghstack-source-id: b71b6127043480ec3fdb35fe5569580a4d284da5 Pull Request resolved: #93012
@IvanKobzarev has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
…ons" `torch.jit.unused` is supposed to mark methods to be excluded from jit scripting. jit creates stub with the same number of arguments to throw an exception if client call `jit.unused` method. That stub will go through scripting and if it contains annotations/type decorations with non-scriptable classes - jit-scripting will fail compiling those classes. The return type is preserved to keep calls to unused method. The same happens for `jit.ignore`, but that does not allow to keep non-scriptable types as return types for jit.ignore methods. e.g. ``` torch.jit.ignore def __fx_create_arg__(self, tracer: torch.fx.Tracer) -> torch.fx.node.Argument: # torch.fx classes are not scriptable return tracer.create_node( "call_function", CFX, args=(tracer.create_arg(self.features),), kwargs={}, ) ``` ``` FunctionDef( name='__fx_create_arg__', args=arguments( posonlyargs=[], args=[ arg(arg='self', annotation=Name(id='Any', ctx=Load()), type_comment=None), arg(arg='tracer', annotation=Name(id='Any', ctx=Load()), type_comment=None) ], vararg=None, kwonlyargs=[], kw_defaults=[], kwarg=None, defaults=[] ), body=[Raise(exc=Call(func=Name(id='RuntimeError', ctx=Load()), args=[Constant(value='Cannot call unused methods', kind=None)], keywords=[]), cause None)], decorator_list=[ Attribute(value=Attribute(value=Name(id='torch', ctx=Load()), attr='jit', ctx=Load()), attr='unused', ctx=Load()) ], returns=Attribute(value=Attribute(value=Attribute(value=Name(id='torch', ctx=Load()), attr='fx', ctx=Load()), attr='node', ctx=Load()), attr='Argument', ctx=Load()), type_comment=None ) ``` Fix: - Completely jit.ignore methods for class definitions - Drop return types for jit.ignore where it is kept for jit.unused and jit.ignore Testing: Added test case in `test/jit/test_types.py` with non-scriptable type annotations (fx.* classes) that fails before fix and passes after. ``` python test/test_jit.py ``` Differential Revision: [D42774830](https://our.internmc.facebook.com/intern/diff/D42774830) [ghstack-poisoned]
ghstack-source-id: 47bf3f40fdb96f92f8603d1db90d8919f6b74349 Pull Request resolved: #93012
@IvanKobzarev has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
…ons" `torch.jit.unused` is supposed to mark methods to be excluded from jit scripting. jit creates stub with the same number of arguments to throw an exception if client call `jit.unused` method. That stub will go through scripting and if it contains annotations/type decorations with non-scriptable classes - jit-scripting will fail compiling those classes. The return type is preserved to keep calls to unused method. The same happens for `jit.ignore`, but that does not allow to keep non-scriptable types as return types for jit.ignore methods. e.g. ``` torch.jit.ignore def __fx_create_arg__(self, tracer: torch.fx.Tracer) -> torch.fx.node.Argument: # torch.fx classes are not scriptable return tracer.create_node( "call_function", CFX, args=(tracer.create_arg(self.features),), kwargs={}, ) ``` ``` FunctionDef( name='__fx_create_arg__', args=arguments( posonlyargs=[], args=[ arg(arg='self', annotation=Name(id='Any', ctx=Load()), type_comment=None), arg(arg='tracer', annotation=Name(id='Any', ctx=Load()), type_comment=None) ], vararg=None, kwonlyargs=[], kw_defaults=[], kwarg=None, defaults=[] ), body=[Raise(exc=Call(func=Name(id='RuntimeError', ctx=Load()), args=[Constant(value='Cannot call unused methods', kind=None)], keywords=[]), cause None)], decorator_list=[ Attribute(value=Attribute(value=Name(id='torch', ctx=Load()), attr='jit', ctx=Load()), attr='unused', ctx=Load()) ], returns=Attribute(value=Attribute(value=Attribute(value=Name(id='torch', ctx=Load()), attr='fx', ctx=Load()), attr='node', ctx=Load()), attr='Argument', ctx=Load()), type_comment=None ) ``` Fix: - Completely jit.ignore methods for class definitions - Drop return types for jit.ignore where it is kept for jit.unused and jit.ignore Testing: Added test case in `test/jit/test_types.py` with non-scriptable type annotations (fx.* classes) that fails before fix and passes after. ``` python test/test_jit.py ``` Differential Revision: [D42774830](https://our.internmc.facebook.com/intern/diff/D42774830) [ghstack-poisoned]
ghstack-source-id: 943c2d5aab69d43bef525bfbf872f9f6773a84e4 Pull Request resolved: #93012
@IvanKobzarev has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
…ons" `torch.jit.unused` is supposed to mark methods to be excluded from jit scripting. jit creates stub with the same number of arguments to throw an exception if client call `jit.unused` method. That stub will go through scripting and if it contains annotations/type decorations with non-scriptable classes - jit-scripting will fail compiling those classes. The return type is preserved to keep calls to unused method. The same happens for `jit.ignore`, but that does not allow to keep non-scriptable types as return types for jit.ignore methods. e.g. ``` torch.jit.ignore def __fx_create_arg__(self, tracer: torch.fx.Tracer) -> torch.fx.node.Argument: # torch.fx classes are not scriptable return tracer.create_node( "call_function", CFX, args=(tracer.create_arg(self.features),), kwargs={}, ) ``` ``` FunctionDef( name='__fx_create_arg__', args=arguments( posonlyargs=[], args=[ arg(arg='self', annotation=Name(id='Any', ctx=Load()), type_comment=None), arg(arg='tracer', annotation=Name(id='Any', ctx=Load()), type_comment=None) ], vararg=None, kwonlyargs=[], kw_defaults=[], kwarg=None, defaults=[] ), body=[Raise(exc=Call(func=Name(id='RuntimeError', ctx=Load()), args=[Constant(value='Cannot call unused methods', kind=None)], keywords=[]), cause None)], decorator_list=[ Attribute(value=Attribute(value=Name(id='torch', ctx=Load()), attr='jit', ctx=Load()), attr='unused', ctx=Load()) ], returns=Attribute(value=Attribute(value=Attribute(value=Name(id='torch', ctx=Load()), attr='fx', ctx=Load()), attr='node', ctx=Load()), attr='Argument', ctx=Load()), type_comment=None ) ``` Fix: - Completely jit.ignore methods for class definitions - Drop return types for jit.ignore where it is kept for jit.unused and jit.ignore Testing: Added test case in `test/jit/test_types.py` with non-scriptable type annotations (fx.* classes) that fails before fix and passes after. ``` python test/test_jit.py ``` Differential Revision: [D42774830](https://our.internmc.facebook.com/intern/diff/D42774830) [ghstack-poisoned]
ghstack-source-id: 0ad8d16474cb1d726c14269fa8ee7a78a3b5082c Pull Request resolved: #93012
@IvanKobzarev has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@IvanKobzarev had some questions here:
- don't we need to keep return type annotations for jit.ignored functions? e.g. if I have
@torch.jit.ignore
def inner(x) -> List[int]:
return make_network_call(x)
def outer(x):
w = inner(x)
return torch.tensor(w)
then I think I would need the annotations on inner
to know that the return type is a list.
Also - In your test case here, I would say that you should use torch.jit.unused instead of torch.jit.ignore. Does jit.unused work here?
i.e. if we're returning a type that's not supported by jit, then we can never call it from a jit-scripted function; so it's safe to mark it as jit.unused.
Is this accurate or am I missing something?
In current state of code we do not drop return types for both First I've tried return type dropping for
To my understanding:
model with calls to |
hmm, interesting - I'm not sure why this is the case, but I still think that if we're going to make this change we should make it for jit.unused, not jit.ignore. We do need return types from jit.ignore, because jit.ignore-d can be called from jit-ed functions. And they still exist in this PR, e.g. if you try to call an ignored function that returns an Iterable, you'll see a failure. Meanwhile jit.unused calls made from jit-ed functions will always error, so I think it's reasonable to ignore return types from them. Do you think you could see if it's possible to do this instead? cc @qihqi for opinions or if you have any additional context here |
also an additional question - what's the use case that requires using jit.unused / jit.ignore on a function returning a non-scriptable type? typically if you have a non-forward method that you don't plan on calling, you can just leave it unannotated and jit will not compile it unless it's marked as @torch.jit.export, or if it's called by a function/method that is scripted. |
…ons" `torch.jit.unused` is supposed to mark methods to be excluded from jit scripting. jit creates stub with the same number of arguments to throw an exception if client call `jit.unused` method. That stub will go through scripting and if it contains annotations/type decorations with non-scriptable classes - jit-scripting will fail compiling those classes. The return type is preserved to keep calls to unused method. The same happens for `jit.ignore`, but that does not allow to keep non-scriptable types as return types for jit.ignore methods. e.g. ``` torch.jit.ignore def __fx_create_arg__(self, tracer: torch.fx.Tracer) -> torch.fx.node.Argument: # torch.fx classes are not scriptable return tracer.create_node( "call_function", CFX, args=(tracer.create_arg(self.features),), kwargs={}, ) ``` ``` FunctionDef( name='__fx_create_arg__', args=arguments( posonlyargs=[], args=[ arg(arg='self', annotation=Name(id='Any', ctx=Load()), type_comment=None), arg(arg='tracer', annotation=Name(id='Any', ctx=Load()), type_comment=None) ], vararg=None, kwonlyargs=[], kw_defaults=[], kwarg=None, defaults=[] ), body=[Raise(exc=Call(func=Name(id='RuntimeError', ctx=Load()), args=[Constant(value='Cannot call unused methods', kind=None)], keywords=[]), cause None)], decorator_list=[ Attribute(value=Attribute(value=Name(id='torch', ctx=Load()), attr='jit', ctx=Load()), attr='unused', ctx=Load()) ], returns=Attribute(value=Attribute(value=Attribute(value=Name(id='torch', ctx=Load()), attr='fx', ctx=Load()), attr='node', ctx=Load()), attr='Argument', ctx=Load()), type_comment=None ) ``` Fix: - Completely jit.ignore methods for class definitions - Drop return types for jit.ignore where it is kept for jit.unused and jit.ignore Testing: Added test case in `test/jit/test_types.py` with non-scriptable type annotations (fx.* classes) that fails before fix and passes after. ``` python test/test_jit.py ``` Differential Revision: [D42774830](https://our.internmc.facebook.com/intern/diff/D42774830) [ghstack-poisoned]
ghstack-source-id: a07cf48dbb80f890f8de16cfbb060666da68e24f Pull Request resolved: #93012
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok so... I think the problem is that jit.ignore actually isn't implemented for non-module classes. e.g. this fails, because it's trying to compile the jit.ignored function.
@torch.jit.script
class MyClass:
def __init__(self):
self.val = 5.0
def forward(self, x):
return x * self.val
@torch.jit.ignore
def unused(self, x) -> int:
return next(g)
def use_unused(self, x) -> int:
return 4 * self.unused(x)
Your original solution (i.e. the original one I reviewed) would work, but I think behavior would be somewhat unexpected because @torch.jit.ignore
would make the function invisible, instead of leaving it as an uncompiled python fn (and this behavior would be different between objects & modules). I think it's preferable (although still not optimal) to leave jit.ignore unimplemented in this case, instead of providing differing implementations for objects & modules.
I went down a rabbit hole of trying to figure out how to implement ignore properly for non-module objects (seems not too hard?) but I realize now that it's not really necessary for your use case :)
so... instead I think maybe the best option is to add some annotation that completely skips a class method. This is the default in modules, but for non-module classes torchscript will attempt to script all the methods.
... let me ask around first to make sure this is a reasonable option.
…ons" `torch.jit.unused` and `torch.jit.ignore` do not allow to keep in class that needs to be torch scripted a member function, that has non scriptable declaration (e.g. return type) Adding FunctionModifier _DROP to allow fully skip those functions from scripting and keep them in the code of the scripted class. E.g. it can be used for: ``` torch.jit._drop def __fx_create_arg__(self, tracer: torch.fx.Tracer) -> torch.fx.node.Argument: # torch.fx classes are not scriptable return tracer.create_node( "call_function", CFX, args=(tracer.create_arg(self.features),), kwargs={}, ) def __iter__(self) -> Iterator[torch.Tensor]: return iter(self.a) ``` Testing: Added test case in `test/jit/test_types.py` with non-scriptable type annotations (fx.* classes) that fails before fix and passes after. ``` python test/test_jit.py ``` Differential Revision: [D42774830](https://our.internmc.facebook.com/intern/diff/D42774830) [ghstack-poisoned]
ghstack-source-id: 0c705bd43a6a09f2c823a2288dd00238f8b46e11 Pull Request resolved: #93012
@IvanKobzarev has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mostly looks good! left some comments - I think it would be good to specify that it's for non-module classes and then only attempt to implement for that use case.
(approve to unblock)
@@ -762,12 +768,21 @@ def should_drop(fn) -> bool: | |||
attr = get_torchscript_modifier(fn) | |||
if attr is None: | |||
return False | |||
return attr is FunctionModifiers.UNUSED | |||
return attr is FunctionModifiers.UNUSED or attr is FunctionModifiers._DROP |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I also don't think this needs to be changed, but could be wrong
if _is_drop_fn(fn): | ||
# Dropping potentially unsupported return type annotation for jit._drop | ||
fn_def.returns = None | ||
fn_def.type_comment = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think this is needed, because we won't call get_jit_def on a method if it's marked with _drop.
return ( | ||
mod is FunctionModifiers.UNUSED | ||
or mod is FunctionModifiers.IGNORE | ||
or mod is FunctionModifiers._DROP |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also here, I think it's safe to leave as it was before (although less confident of this one, check tests to be sure)
AFAICT, is_ignored_fn is mostly used in module scripting. _DROP is probably only going to be used for non-module classes (because modules already drop methods by default, except for the forward method. And supporting this on the forward method might be a bit extra work, because a lot of things assume that modules have a forward method). So, I suspect we don't need to add this.
@@ -740,6 +741,11 @@ def decorator(fn): | |||
return decorator | |||
|
|||
|
|||
def _drop(fn): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: could you add some extra docs here? e.g. why we need it (classes try to compile all methods by default, and unused/ignore don't work because they have return types), use case (non-method classes only)
kwargs={}, | ||
) | ||
|
||
torch.jit.script(CFX) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you add an export + import test here? example.
I did notice that you can still call __fx_create_arg__
on the scripted class. Might be good to document this (or disable it if you know how, but I'm not exactly sure right now what the best way to do this is)
@pytorchbot merge (Initiating merge automatically since Phabricator Diff has merged) |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
@torch.jit.unused
and@torch.jit.ignore
do not allow to keep in torch scripted class member function, that has non scriptable declaration (e.g. return type)Adding FunctionModifier _DROP to allow fully skip those functions from scripting and keep them in the code of the scripted class.
E.g. it can be used for:
Testing:
Added test case in
test/jit/test_types.py
with non-scriptable type annotations (fx.* classes) that fails before fix and passes after.Stack from ghstack (oldest at bottom):
Differential Revision: D42774830