Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Torch cond operator, python dispatch, pyoperator #83154
Torch cond operator, python dispatch, pyoperator #83154
Changes from 71 commits
ab64805
599d91a
2e3e976
d92b6f4
0d5e2bf
e3db51f
ccec9ef
d472fa2
4071b1b
dab9822
f2d885c
56db44a
28db0f8
9dfeaee
a3efe51
2ea2860
af5fe52
255a288
2f13c14
97cdd25
247a68c
946d4fe
1aa70d2
fb4b669
cdee501
ca203af
f38bc6f
670191a
7e903c7
63ebfb3
0807c85
2bf2686
343e9af
208c5b0
a5cb0fc
d67a117
a51fcb5
6f29257
dcba0b4
6fa4bd4
e7521fe
d634ea0
776cf3a
f8cf09c
2415bbe
3974ecd
f10df36
a5b874a
010700b
651f351
e4f5ff8
9cebd47
68018c0
aee98ad
20cf7ab
ea58ef6
604ec98
7ceaf24
d55baf4
aeae58c
f009a58
1d62b06
d593d25
1e98c87
9de2972
1bde5ee
52ffdc6
ee61ade
2059523
2b7b876
366591c
50181ee
663f8ca
10db393
8b56428
4c115b0
f2445ce
c4fe703
a346b3d
fc5d7ff
0840e56
ddd8f0a
4000d77
25e0701
1c152bb
670f3d2
File filter
Filter by extension
Conversations
Jump to
There are no files selected for viewing
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm sorry @Chillee
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
no_dispatch
here shouldn't be necessary, as you have already dispelled the ambient proxy mode upon entry to this functionThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, this shouldn't be necessary. You're probably getting hobbled by
get_isolated_graphmodule
API. That API should have the real tensor output from when it traced through, you just need to get it to disgorge that information so you can use it directlyThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yea, we can try that! Let me play with it a bit, I don't love this as it is.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This won't tell the difference between a tuple of three Tensors and a list of three Tensors, right?
Right way to do this might be to compare the specs after tree_flattening:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, agreed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As a minor improvement, perhaps you would like to support kwargs too?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I kind of didn't want kwargs in this API, but if pressed, have no qualms around adding them.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if you want people to manually flatten before passing into cond, you should accept operands as a list of tensors and not a vararg function. Vararg functions in torch.ops is not a thing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The mode suspension mechanic here is not quite right, ordinarily, we would call into TorchDispatchMode which would take care of reapplying the inner mode before the inside of the mode function gets run.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you elaborate? I don't think I have the full mental model for modes right.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is a mode stack. When you call into the handler for a mode, you pop that mode off the stack before you do it, so that internal calls in the handler go to the next mode in the stack.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Got it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This registration is not quite right. To register to the Python key implies that the implementation works for any Python mode and for any tensor subclass with
__torch_dispatch__
on it. To actually make good on this promise, you must actually call the__torch_dispatch__
method that was provided in this way.What this implementation does, instead, is call into a hard-coded implementation of the mode that is suitable for ProxyTensor. This is fine, but the python fallback mustn't imply that it can be used in other situations, which is what it will blindly crash through here. So you should at least assert that the mode stack has exactly one mode in it and it is proxy tensor. Which, speaking of which, means that fake tensor mode is not going to work properly in that case...
The easiest fix may just to be to write the generalized version of this logic. That just means recreating the python fallback key in C++ in PythonFallbackKernel.cpp; but you get to pass functions in the args instead. Then, ProxyTensorMode would be responsible for seeing that cond is being called and handling it properly. You could put your trace_cond implementation directly in ProxyMode, or you could add a little registration system for custom trace handling and then register your trace cond that way.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like the idea of a registration system. A much earlier version of this PR had this going through
__torch_dispatch__
on the mode, which in turn would call back to this in one version, and had the logic inside Proxy in another.What does moving it to C++ give us? Just sharing more of the stack with how we dispatch today?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, I wasn't clear. I meant porting that current C++ logic into Python, the same way you did with the dispatcher
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, yeah, we absolutely should.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A minor improvement, that I quite like from the Python operator registration API (we also need to be careful about names haha, maybe we'll call them Operator and PyOperator) is to do the registrations via decorator. This makes it immediately clear that the registration is for a particular dispatch key.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, that's reasonable. I kind of prefer this kind of API over decorators, but am open to moving to decorators.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe more of a question for @Chillee @eellison -- do we want to test that this works under the other flavors of make_fx (I'm mostly just thinking fake Tensor tracing; this is probably not going to work with symbolic yet) or do we punt that to the future?