-
Notifications
You must be signed in to change notification settings - Fork 25.6k
Fixes for PyTorch/XLA functionalization integration #94537
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/94537
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 18b6dca: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
9f6d339
to
719bb6e
Compare
or isinstance(t, FakeTensor) | ||
): | ||
if any( | ||
if t.device.type != "xla" and any( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you know why this was needed? (Or what codepath in the functionalization <> XLA integration hit this code?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep, AFAIK -- in dynamo codepath, pytorch converts the tensor into a fake tensor, and then run the ops on the fake tensor. And that process involves making a Meta tensor, and then turning that into the fake tensor subclass.
And now since XLA tensors are functional tensors, the if
statement (torch._is_functional_tensor(t)
) here returns true, which results in hitting the return NotImplemented
code. So we added this check to bypass this if check.
More details at pytorch/xla#4414 (comment).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So, a while ago we talked about the functionalization code paths for dynamo vs. lazy tensor. I think my take here is that it would be preferably for pytorch/XLA to go through all of this functionalization infra only when using the normal, non-dynamo code paths. And in the dynamo integration, for XLA to do what it used to do before this PR (not bother wrapping tensors into FunctionalWrappers, etc).
The main reason I'm leaning this way is because when you're using the torch.compile() / dynamo API, functionalization for XLA is completely redundant - our infra will send a graph to the backend to compile, and you're guaranteed that the graph is already functionalized.
If you end up keeping all of this functionalization logic in both code paths, one thing you might (?) run into is that there will be two levels of functionalization happening. This isn't really something that's tested / supported today. So my take here is that:
(1) If this one check here is enough to get everything working smoothly for dynamo, then this carve-out seems fine to me.
(2) If it's not, and you start hitting other weird functionalization failures, we might want to think about turning XLA functionalization off in the dynamo codepath.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the pointers. So far, this one check was enough to get everything working smoothly dynamo until recently we got a new dynamo regression for one of our unit tests (pytorch/xla#4680). I didn't get chance to debug this too much in-depth, but outside from this it seems like we're not getting any other weird functionalization failures.
@JackCaoG, what do you think about this? And if we were to modify PyTorch/XLA's dynamo to go off the functionalization codepath, would that be at our dynamo bridge at (https://github.com/pytorch/xla/blob/master/torch_xla/core/dynamo_bridge.py#L39 to manually update each tensor to a non-functional tensor?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My take is: it looks like that we can keep this approach for now, and then investigate to skip functioanlization for dynamo later on.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can try to disable the functionization in https://github.com/pytorch/xla/blob/master/torch_xla/core/dynamo_bridge.py#L224. That might be the right thing to do (If there is an easy flag to turn it on and off).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably not, otherwise we can just land functionalization with the flag!!!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
well then there is no way for us to do that. The dynamo path's tracing will go through the same lazy tracing logic.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well, we can add a runtime flag... There are just a few places where we wrap the XLATensor into the FunctionalWrapper. I have been thinking about this for a while if we have to land the feature to unblock people...
bfc0a1b
to
36d4eb6
Compare
73d8e2f
to
bc57d58
Compare
I'll mark this one ready for review now since all CIs are green. |
e152beb
to
54a76e3
Compare
Summary: This pull request adds an op called _propagate_xla_data to help propagating information between the updated_tensor created during in-place ops transform and the original input such that pytorch/xla can keep their in-place ops optimization after adopting functionalization. Test Plan: In XLA: PJRT_DEVICE=CPU python test/test_input_output_aliases.py -v -k test_non_view
This reverts commit a470744.
82c3002
to
0f9d958
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! :)
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Thanks, Brian. |
Fixes for PyTorch/XLA functionalization integration --- Some notable changes include: - More asserts in `FunctionalTensorWrapper`, so bugs show up more cleanly in cases where we e.g. forget to wrap an output - Make the *_scatter ops `CompositeExplicitAutogradNonFunctional`, so we get a better error message and XLA doesn't accidentally try to us them - Fix LTC/XLA codegen in core to handle multi-tensor out= ops with no returns - Better erroring: Allow XLA to use the CPU fallback from core in a way so that it always errors on view ops, which XLA should no longer see. - Update MetaConverter to exclude XLA tensors in raising NotImplemented… - Add `_propagate_xla_data` op - Add meta tensor support for some ops Pull Request resolved: pytorch/pytorch#94537 Approved by: https://github.com/bdhirsh
Fixes for PyTorch/XLA functionalization integration --- Some notable changes include: - More asserts in `FunctionalTensorWrapper`, so bugs show up more cleanly in cases where we e.g. forget to wrap an output - Make the *_scatter ops `CompositeExplicitAutogradNonFunctional`, so we get a better error message and XLA doesn't accidentally try to us them - Fix LTC/XLA codegen in core to handle multi-tensor out= ops with no returns - Better erroring: Allow XLA to use the CPU fallback from core in a way so that it always errors on view ops, which XLA should no longer see. - Update MetaConverter to exclude XLA tensors in raising NotImplemented… - Add `_propagate_xla_data` op - Add meta tensor support for some ops Pull Request resolved: pytorch/pytorch#94537 Approved by: https://github.com/bdhirsh
Fixes for PyTorch/XLA functionalization integration --- Some notable changes include: - More asserts in `FunctionalTensorWrapper`, so bugs show up more cleanly in cases where we e.g. forget to wrap an output - Make the *_scatter ops `CompositeExplicitAutogradNonFunctional`, so we get a better error message and XLA doesn't accidentally try to us them - Fix LTC/XLA codegen in core to handle multi-tensor out= ops with no returns - Better erroring: Allow XLA to use the CPU fallback from core in a way so that it always errors on view ops, which XLA should no longer see. - Update MetaConverter to exclude XLA tensors in raising NotImplemented… - Add `_propagate_xla_data` op - Add meta tensor support for some ops Approved by: https://github.com/bdhirsh [ghstack-poisoned]
Fixes for PyTorch/XLA functionalization integration --- Some notable changes include: - More asserts in `FunctionalTensorWrapper`, so bugs show up more cleanly in cases where we e.g. forget to wrap an output - Make the *_scatter ops `CompositeExplicitAutogradNonFunctional`, so we get a better error message and XLA doesn't accidentally try to us them - Fix LTC/XLA codegen in core to handle multi-tensor out= ops with no returns - Better erroring: Allow XLA to use the CPU fallback from core in a way so that it always errors on view ops, which XLA should no longer see. - Update MetaConverter to exclude XLA tensors in raising NotImplemented… - Add `_propagate_xla_data` op - Add meta tensor support for some ops Approved by: https://github.com/bdhirsh
Fixes for PyTorch/XLA functionalization integration
Some notable changes include:
FunctionalTensorWrapper
, so bugs show up more cleanly in cases where we e.g. forget to wrap an outputCompositeExplicitAutogradNonFunctional
, so we get a better error message and XLA doesn't accidentally try to us them_propagate_xla_data
op