-
Notifications
You must be signed in to change notification settings - Fork 21.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[einsum] keep the promise that we contract left to right #87199
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/87199
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 Failures, 6 PendingAs of commit 62069ec: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
@pytorchbot merge -f "existing master failures" |
Merge startedYour change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Hey @janeyx99. |
We promise that if path is not defined, we would go left to right. The previous code did not keep that promise as we push'd combined ops to the back of the list. For most use cases this is fine (einsum with 3 or fewer inputs), but we should do what we say. Test plan: Added a print statement to print the sizes of ops we're contracting to see if the order is fixed. Code run: ``` import torch a = torch.rand(1) b = torch.rand(2) c = torch.rand(3) d = torch.rand(4) torch.einsum('a,b,c,d->abcd', a,b,c,d) ``` BEFORE--it does a+b, then c+d, then a+b+c+d, which...is right, but it's not the order specified by the user. ``` /Users/janeyx/pytorch/torch/functional.py:378: UserWarning: Contracting a: [1, 1, 1, 1]and b: [1, 2, 1, 1] (Triggered internally at /Users/janeyx/pytorch/aten/src/ATen/native/Linear.cpp:507.) return _VF.einsum(equation, operands) # type: ignore[attr-defined] /Users/janeyx/pytorch/torch/functional.py:378: UserWarning: Contracting a: [1, 1, 3, 1]and b: [1, 1, 1, 4] (Triggered internally at /Users/janeyx/pytorch/aten/src/ATen/native/Linear.cpp:507.) return _VF.einsum(equation, operands) # type: ignore[attr-defined] /Users/janeyx/pytorch/torch/functional.py:378: UserWarning: Contracting a: [1, 2, 1, 1]and b: [1, 1, 3, 4] (Triggered internally at /Users/janeyx/pytorch/aten/src/ATen/native/Linear.cpp:507.) return _VF.einsum(equation, operands) # type: ignore[attr-defined] ``` WITH THIS CHANGE--it actually goes left to right: a+b, a+b+c, a+b+c+d ``` /Users/janeyx/pytorch/torch/functional.py:378: UserWarning: Contracting a: [1, 1, 1, 1]and b: [1, 2, 1, 1] (Triggered internally at /Users/janeyx/pytorch/aten/src/ATen/native/Linear.cpp:507.) return _VF.einsum(equation, operands) # type: ignore[attr-defined] /Users/janeyx/pytorch/torch/functional.py:378: UserWarning: Contracting a: [1, 2, 1, 1]and b: [1, 1, 3, 1] (Triggered internally at /Users/janeyx/pytorch/aten/src/ATen/native/Linear.cpp:507.) return _VF.einsum(equation, operands) # type: ignore[attr-defined] /Users/janeyx/pytorch/torch/functional.py:378: UserWarning: Contracting a: [1, 2, 3, 1]and b: [1, 1, 1, 4] (Triggered internally at /Users/janeyx/pytorch/aten/src/ATen/native/Linear.cpp:507.) return _VF.einsum(equation, operands) # type: ignore[attr-defined] ``` Pull Request resolved: #87199 Approved by: https://github.com/soulitzer
…path is None (#87261) * [einsum] keep the promise that we contract left to right (#87199) We promise that if path is not defined, we would go left to right. The previous code did not keep that promise as we push'd combined ops to the back of the list. For most use cases this is fine (einsum with 3 or fewer inputs), but we should do what we say. Test plan: Added a print statement to print the sizes of ops we're contracting to see if the order is fixed. Code run: ``` import torch a = torch.rand(1) b = torch.rand(2) c = torch.rand(3) d = torch.rand(4) torch.einsum('a,b,c,d->abcd', a,b,c,d) ``` BEFORE--it does a+b, then c+d, then a+b+c+d, which...is right, but it's not the order specified by the user. ``` /Users/janeyx/pytorch/torch/functional.py:378: UserWarning: Contracting a: [1, 1, 1, 1]and b: [1, 2, 1, 1] (Triggered internally at /Users/janeyx/pytorch/aten/src/ATen/native/Linear.cpp:507.) return _VF.einsum(equation, operands) # type: ignore[attr-defined] /Users/janeyx/pytorch/torch/functional.py:378: UserWarning: Contracting a: [1, 1, 3, 1]and b: [1, 1, 1, 4] (Triggered internally at /Users/janeyx/pytorch/aten/src/ATen/native/Linear.cpp:507.) return _VF.einsum(equation, operands) # type: ignore[attr-defined] /Users/janeyx/pytorch/torch/functional.py:378: UserWarning: Contracting a: [1, 2, 1, 1]and b: [1, 1, 3, 4] (Triggered internally at /Users/janeyx/pytorch/aten/src/ATen/native/Linear.cpp:507.) return _VF.einsum(equation, operands) # type: ignore[attr-defined] ``` WITH THIS CHANGE--it actually goes left to right: a+b, a+b+c, a+b+c+d ``` /Users/janeyx/pytorch/torch/functional.py:378: UserWarning: Contracting a: [1, 1, 1, 1]and b: [1, 2, 1, 1] (Triggered internally at /Users/janeyx/pytorch/aten/src/ATen/native/Linear.cpp:507.) return _VF.einsum(equation, operands) # type: ignore[attr-defined] /Users/janeyx/pytorch/torch/functional.py:378: UserWarning: Contracting a: [1, 2, 1, 1]and b: [1, 1, 3, 1] (Triggered internally at /Users/janeyx/pytorch/aten/src/ATen/native/Linear.cpp:507.) return _VF.einsum(equation, operands) # type: ignore[attr-defined] /Users/janeyx/pytorch/torch/functional.py:378: UserWarning: Contracting a: [1, 2, 3, 1]and b: [1, 1, 1, 4] (Triggered internally at /Users/janeyx/pytorch/aten/src/ATen/native/Linear.cpp:507.) return _VF.einsum(equation, operands) # type: ignore[attr-defined] ``` Pull Request resolved: #87199 Approved by: https://github.com/soulitzer * [einsum] Call view instead of sum to remediate MPS regression (#87135) Fixes #87010. It turns out that squeeze is much faster than sum, and view is faster than squeeze, so we should default to that whenever possible. Benchmarking results show that, on MPS, we would be going from the following code taking **29.89ms instead of the current 1466ms, almost a 50x speedup**. ``` q = torch.rand(16, 4096, 40, device='mps', dtype=torch.float) k = torch.rand(16, 4096, 40, device='mps', dtype=torch.float) torch.einsum('b i d, b j d -> b i j', q, k).max().item() ``` And a regular einsum will now take **.506ms instead of 2.76ms.** ``` q = torch.rand(16, 4096, 40, device='mps', dtype=torch.float) k = torch.rand(16, 4096, 40, device='mps', dtype=torch.float) torch.einsum('b i d, b j d -> b i j', q, k) ``` Special thanks to @soulitzer for helping me experiment + figure out how to squash the remaining 5x regression due to squeeze being slower than view!! Pull Request resolved: #87135 Approved by: https://github.com/soulitzer, https://github.com/malfet, https://github.com/albanD
We promise that if path is not defined, we would go left to right. The previous code did not keep that promise as we push'd combined ops to the back of the list. For most use cases this is fine (einsum with 3 or fewer inputs), but we should do what we say.
Test plan:
Added a print statement to print the sizes of ops we're contracting to see if the order is fixed. Code run:
BEFORE--it does a+b, then c+d, then a+b+c+d, which...is right, but it's not the order specified by the user.
WITH THIS CHANGE--it actually goes left to right: a+b, a+b+c, a+b+c+d