Skip to content

Conversation

wonjoo-wj
Copy link
Collaborator

@wonjoo-wj wonjoo-wj commented Feb 9, 2023

Fixes for PyTorch/XLA functionalization integration


Some notable changes include:

  • More asserts in FunctionalTensorWrapper, so bugs show up more cleanly in cases where we e.g. forget to wrap an output
  • Make the *_scatter ops CompositeExplicitAutogradNonFunctional, so we get a better error message and XLA doesn't accidentally try to us them
  • Fix LTC/XLA codegen in core to handle multi-tensor out= ops with no returns
  • Better erroring: Allow XLA to use the CPU fallback from core in a way so that it always errors on view ops, which XLA should no longer see.
  • Update MetaConverter to exclude XLA tensors in raising NotImplemented…
  • Add _propagate_xla_data op
  • Add meta tensor support for some ops

@pytorch-bot
Copy link

pytorch-bot bot commented Feb 9, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/94537

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 18b6dca:
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

or isinstance(t, FakeTensor)
):
if any(
if t.device.type != "xla" and any(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you know why this was needed? (Or what codepath in the functionalization <> XLA integration hit this code?)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, AFAIK -- in dynamo codepath, pytorch converts the tensor into a fake tensor, and then run the ops on the fake tensor. And that process involves making a Meta tensor, and then turning that into the fake tensor subclass.

And now since XLA tensors are functional tensors, the if statement (torch._is_functional_tensor(t)) here returns true, which results in hitting the return NotImplemented code. So we added this check to bypass this if check.

More details at pytorch/xla#4414 (comment).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, a while ago we talked about the functionalization code paths for dynamo vs. lazy tensor. I think my take here is that it would be preferably for pytorch/XLA to go through all of this functionalization infra only when using the normal, non-dynamo code paths. And in the dynamo integration, for XLA to do what it used to do before this PR (not bother wrapping tensors into FunctionalWrappers, etc).

The main reason I'm leaning this way is because when you're using the torch.compile() / dynamo API, functionalization for XLA is completely redundant - our infra will send a graph to the backend to compile, and you're guaranteed that the graph is already functionalized.

If you end up keeping all of this functionalization logic in both code paths, one thing you might (?) run into is that there will be two levels of functionalization happening. This isn't really something that's tested / supported today. So my take here is that:

(1) If this one check here is enough to get everything working smoothly for dynamo, then this carve-out seems fine to me.

(2) If it's not, and you start hitting other weird functionalization failures, we might want to think about turning XLA functionalization off in the dynamo codepath.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the pointers. So far, this one check was enough to get everything working smoothly dynamo until recently we got a new dynamo regression for one of our unit tests (pytorch/xla#4680). I didn't get chance to debug this too much in-depth, but outside from this it seems like we're not getting any other weird functionalization failures.

@JackCaoG, what do you think about this? And if we were to modify PyTorch/XLA's dynamo to go off the functionalization codepath, would that be at our dynamo bridge at (https://github.com/pytorch/xla/blob/master/torch_xla/core/dynamo_bridge.py#L39 to manually update each tensor to a non-functional tensor?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My take is: it looks like that we can keep this approach for now, and then investigate to skip functioanlization for dynamo later on.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can try to disable the functionization in https://github.com/pytorch/xla/blob/master/torch_xla/core/dynamo_bridge.py#L224. That might be the right thing to do (If there is an easy flag to turn it on and off).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably not, otherwise we can just land functionalization with the flag!!!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

well then there is no way for us to do that. The dynamo path's tracing will go through the same lazy tracing logic.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, we can add a runtime flag... There are just a few places where we wrap the XLATensor into the FunctionalWrapper. I have been thinking about this for a while if we have to land the feature to unblock people...

@alanwaketan alanwaketan force-pushed the functionalization branch 2 times, most recently from bfc0a1b to 36d4eb6 Compare February 21, 2023 21:13
@wonjoo-wj
Copy link
Collaborator Author

I'll mark this one ready for review now since all CIs are green.

@wonjoo-wj wonjoo-wj marked this pull request as ready for review February 24, 2023 10:03
@wonjoo-wj wonjoo-wj force-pushed the functionalization branch 2 times, most recently from e152beb to 54a76e3 Compare February 25, 2023 00:57
@ezyang ezyang added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Feb 27, 2023
@alanwaketan alanwaketan requested a review from a team as a code owner February 27, 2023 20:31
alanwaketan and others added 14 commits March 1, 2023 23:39
Summary:
This pull request adds an op called _propagate_xla_data to help propagating
information between the updated_tensor created during in-place ops transform
and the original input such that pytorch/xla can keep their in-place ops
optimization after adopting functionalization.

Test Plan:
In XLA: PJRT_DEVICE=CPU python test/test_input_output_aliases.py -v -k test_non_view
Copy link
Contributor

@bdhirsh bdhirsh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! :)

@bdhirsh
Copy link
Contributor

bdhirsh commented Mar 2, 2023

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Mar 2, 2023
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@alanwaketan
Copy link
Collaborator

Thanks, Brian.

cyyever pushed a commit to cyyever/pytorch_private that referenced this pull request Mar 5, 2023
Fixes for PyTorch/XLA functionalization integration

---
Some notable changes include:
- More asserts in `FunctionalTensorWrapper`, so bugs show up more cleanly in cases where we e.g. forget to wrap an output
- Make the *_scatter ops `CompositeExplicitAutogradNonFunctional`, so we get a better error message and XLA doesn't accidentally try to us them
- Fix LTC/XLA codegen in core to handle multi-tensor out= ops with no returns
- Better erroring: Allow XLA to use the CPU fallback from core in a way so that it always errors on view ops, which XLA should no longer see.
- Update MetaConverter to exclude XLA tensors in raising NotImplemented…
- Add `_propagate_xla_data` op
- Add meta tensor support for some ops
Pull Request resolved: pytorch/pytorch#94537
Approved by: https://github.com/bdhirsh
cyyever pushed a commit to cyyever/pytorch_private that referenced this pull request Mar 5, 2023
Fixes for PyTorch/XLA functionalization integration

---
Some notable changes include:
- More asserts in `FunctionalTensorWrapper`, so bugs show up more cleanly in cases where we e.g. forget to wrap an output
- Make the *_scatter ops `CompositeExplicitAutogradNonFunctional`, so we get a better error message and XLA doesn't accidentally try to us them
- Fix LTC/XLA codegen in core to handle multi-tensor out= ops with no returns
- Better erroring: Allow XLA to use the CPU fallback from core in a way so that it always errors on view ops, which XLA should no longer see.
- Update MetaConverter to exclude XLA tensors in raising NotImplemented…
- Add `_propagate_xla_data` op
- Add meta tensor support for some ops
Pull Request resolved: pytorch/pytorch#94537
Approved by: https://github.com/bdhirsh
ydwu4 added a commit that referenced this pull request Mar 11, 2023
Fixes for PyTorch/XLA functionalization integration

---
Some notable changes include:
- More asserts in `FunctionalTensorWrapper`, so bugs show up more cleanly in cases where we e.g. forget to wrap an output
- Make the *_scatter ops `CompositeExplicitAutogradNonFunctional`, so we get a better error message and XLA doesn't accidentally try to us them
- Fix LTC/XLA codegen in core to handle multi-tensor out= ops with no returns
- Better erroring: Allow XLA to use the CPU fallback from core in a way so that it always errors on view ops, which XLA should no longer see.
- Update MetaConverter to exclude XLA tensors in raising NotImplemented…
- Add `_propagate_xla_data` op
- Add meta tensor support for some ops
Approved by: https://github.com/bdhirsh

[ghstack-poisoned]
ydwu4 added a commit to ydwu4/pytorch that referenced this pull request Mar 13, 2023
Fixes for PyTorch/XLA functionalization integration

---
Some notable changes include:
- More asserts in `FunctionalTensorWrapper`, so bugs show up more cleanly in cases where we e.g. forget to wrap an output
- Make the *_scatter ops `CompositeExplicitAutogradNonFunctional`, so we get a better error message and XLA doesn't accidentally try to us them
- Fix LTC/XLA codegen in core to handle multi-tensor out= ops with no returns
- Better erroring: Allow XLA to use the CPU fallback from core in a way so that it always errors on view ops, which XLA should no longer see.
- Update MetaConverter to exclude XLA tensors in raising NotImplemented…
- Add `_propagate_xla_data` op
- Add meta tensor support for some ops
Approved by: https://github.com/bdhirsh
@github-actions github-actions bot deleted the functionalization branch August 31, 2024 02:00
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/inductor ciflow/trunk Trigger trunk jobs on your pull request Merged open source topic: not user facing topic category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants