Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[einsum] keep the promise that we contract left to right #87199

Closed
wants to merge 1 commit into from

Conversation

janeyx99
Copy link
Contributor

@janeyx99 janeyx99 commented Oct 18, 2022

We promise that if path is not defined, we would go left to right. The previous code did not keep that promise as we push'd combined ops to the back of the list. For most use cases this is fine (einsum with 3 or fewer inputs), but we should do what we say.

Test plan:
Added a print statement to print the sizes of ops we're contracting to see if the order is fixed. Code run:

import torch
a = torch.rand(1)
b = torch.rand(2)
c = torch.rand(3)
d = torch.rand(4)
torch.einsum('a,b,c,d->abcd', a,b,c,d)

BEFORE--it does a+b, then c+d, then a+b+c+d, which...is right, but it's not the order specified by the user.

/Users/janeyx/pytorch/torch/functional.py:378: UserWarning: Contracting a: [1, 1, 1, 1]and b: [1, 2, 1, 1] (Triggered internally at /Users/janeyx/pytorch/aten/src/ATen/native/Linear.cpp:507.)
  return _VF.einsum(equation, operands)  # type: ignore[attr-defined]
/Users/janeyx/pytorch/torch/functional.py:378: UserWarning: Contracting a: [1, 1, 3, 1]and b: [1, 1, 1, 4] (Triggered internally at /Users/janeyx/pytorch/aten/src/ATen/native/Linear.cpp:507.)
  return _VF.einsum(equation, operands)  # type: ignore[attr-defined]
/Users/janeyx/pytorch/torch/functional.py:378: UserWarning: Contracting a: [1, 2, 1, 1]and b: [1, 1, 3, 4] (Triggered internally at /Users/janeyx/pytorch/aten/src/ATen/native/Linear.cpp:507.)
  return _VF.einsum(equation, operands)  # type: ignore[attr-defined]

WITH THIS CHANGE--it actually goes left to right: a+b, a+b+c, a+b+c+d

/Users/janeyx/pytorch/torch/functional.py:378: UserWarning: Contracting a: [1, 1, 1, 1]and b: [1, 2, 1, 1] (Triggered internally at /Users/janeyx/pytorch/aten/src/ATen/native/Linear.cpp:507.)
  return _VF.einsum(equation, operands)  # type: ignore[attr-defined]
/Users/janeyx/pytorch/torch/functional.py:378: UserWarning: Contracting a: [1, 2, 1, 1]and b: [1, 1, 3, 1] (Triggered internally at /Users/janeyx/pytorch/aten/src/ATen/native/Linear.cpp:507.)
  return _VF.einsum(equation, operands)  # type: ignore[attr-defined]
/Users/janeyx/pytorch/torch/functional.py:378: UserWarning: Contracting a: [1, 2, 3, 1]and b: [1, 1, 1, 4] (Triggered internally at /Users/janeyx/pytorch/aten/src/ATen/native/Linear.cpp:507.)
  return _VF.einsum(equation, operands)  # type: ignore[attr-defined]

@pytorch-bot
Copy link

pytorch-bot bot commented Oct 18, 2022

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/87199

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 Failures, 6 Pending

As of commit 62069ec:

The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@janeyx99 janeyx99 marked this pull request as ready for review October 18, 2022 20:27
@janeyx99 janeyx99 added the ciflow/trunk Trigger trunk jobs on your pull request label Oct 18, 2022
Copy link
Contributor

@soulitzer soulitzer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@janeyx99
Copy link
Contributor Author

@pytorchbot merge -f "existing master failures"

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@github-actions
Copy link

Hey @janeyx99.
You've committed this PR, but it does not have both a 'release notes: ...' and 'topics: ...' label. Please add one of each to the PR. The 'release notes: ...' label should represent the part of PyTorch that this PR changes (fx, autograd, distributed, etc) and the 'topics: ...' label should represent the kind of PR it is (not user facing, new feature, bug fix, perf improvement, etc). The list of valid labels can be found here for the 'release notes: ...' and here for the 'topics: ...'.
For changes that are 'topic: not user facing' there is no need for a release notes label.

janeyx99 added a commit that referenced this pull request Oct 19, 2022
We promise that if path is not defined, we would go left to right. The previous code did not keep that promise as we push'd combined ops to the back of the list. For most use cases this is fine (einsum with 3 or fewer inputs), but we should do what we say.

Test plan:
Added a print statement to print the sizes of ops we're contracting to see if the order is fixed. Code run:
```
import torch
a = torch.rand(1)
b = torch.rand(2)
c = torch.rand(3)
d = torch.rand(4)
torch.einsum('a,b,c,d->abcd', a,b,c,d)
```

BEFORE--it does a+b, then c+d, then a+b+c+d, which...is right, but it's not the order specified by the user.
```
/Users/janeyx/pytorch/torch/functional.py:378: UserWarning: Contracting a: [1, 1, 1, 1]and b: [1, 2, 1, 1] (Triggered internally at /Users/janeyx/pytorch/aten/src/ATen/native/Linear.cpp:507.)
  return _VF.einsum(equation, operands)  # type: ignore[attr-defined]
/Users/janeyx/pytorch/torch/functional.py:378: UserWarning: Contracting a: [1, 1, 3, 1]and b: [1, 1, 1, 4] (Triggered internally at /Users/janeyx/pytorch/aten/src/ATen/native/Linear.cpp:507.)
  return _VF.einsum(equation, operands)  # type: ignore[attr-defined]
/Users/janeyx/pytorch/torch/functional.py:378: UserWarning: Contracting a: [1, 2, 1, 1]and b: [1, 1, 3, 4] (Triggered internally at /Users/janeyx/pytorch/aten/src/ATen/native/Linear.cpp:507.)
  return _VF.einsum(equation, operands)  # type: ignore[attr-defined]
```

WITH THIS CHANGE--it actually goes left to right: a+b, a+b+c, a+b+c+d
```
/Users/janeyx/pytorch/torch/functional.py:378: UserWarning: Contracting a: [1, 1, 1, 1]and b: [1, 2, 1, 1] (Triggered internally at /Users/janeyx/pytorch/aten/src/ATen/native/Linear.cpp:507.)
  return _VF.einsum(equation, operands)  # type: ignore[attr-defined]
/Users/janeyx/pytorch/torch/functional.py:378: UserWarning: Contracting a: [1, 2, 1, 1]and b: [1, 1, 3, 1] (Triggered internally at /Users/janeyx/pytorch/aten/src/ATen/native/Linear.cpp:507.)
  return _VF.einsum(equation, operands)  # type: ignore[attr-defined]
/Users/janeyx/pytorch/torch/functional.py:378: UserWarning: Contracting a: [1, 2, 3, 1]and b: [1, 1, 1, 4] (Triggered internally at /Users/janeyx/pytorch/aten/src/ATen/native/Linear.cpp:507.)
  return _VF.einsum(equation, operands)  # type: ignore[attr-defined]
```
Pull Request resolved: #87199
Approved by: https://github.com/soulitzer
atalman pushed a commit that referenced this pull request Oct 19, 2022
…path is None (#87261)

* [einsum] keep the promise that we contract left to right (#87199)

We promise that if path is not defined, we would go left to right. The previous code did not keep that promise as we push'd combined ops to the back of the list. For most use cases this is fine (einsum with 3 or fewer inputs), but we should do what we say.

Test plan:
Added a print statement to print the sizes of ops we're contracting to see if the order is fixed. Code run:
```
import torch
a = torch.rand(1)
b = torch.rand(2)
c = torch.rand(3)
d = torch.rand(4)
torch.einsum('a,b,c,d->abcd', a,b,c,d)
```

BEFORE--it does a+b, then c+d, then a+b+c+d, which...is right, but it's not the order specified by the user.
```
/Users/janeyx/pytorch/torch/functional.py:378: UserWarning: Contracting a: [1, 1, 1, 1]and b: [1, 2, 1, 1] (Triggered internally at /Users/janeyx/pytorch/aten/src/ATen/native/Linear.cpp:507.)
  return _VF.einsum(equation, operands)  # type: ignore[attr-defined]
/Users/janeyx/pytorch/torch/functional.py:378: UserWarning: Contracting a: [1, 1, 3, 1]and b: [1, 1, 1, 4] (Triggered internally at /Users/janeyx/pytorch/aten/src/ATen/native/Linear.cpp:507.)
  return _VF.einsum(equation, operands)  # type: ignore[attr-defined]
/Users/janeyx/pytorch/torch/functional.py:378: UserWarning: Contracting a: [1, 2, 1, 1]and b: [1, 1, 3, 4] (Triggered internally at /Users/janeyx/pytorch/aten/src/ATen/native/Linear.cpp:507.)
  return _VF.einsum(equation, operands)  # type: ignore[attr-defined]
```

WITH THIS CHANGE--it actually goes left to right: a+b, a+b+c, a+b+c+d
```
/Users/janeyx/pytorch/torch/functional.py:378: UserWarning: Contracting a: [1, 1, 1, 1]and b: [1, 2, 1, 1] (Triggered internally at /Users/janeyx/pytorch/aten/src/ATen/native/Linear.cpp:507.)
  return _VF.einsum(equation, operands)  # type: ignore[attr-defined]
/Users/janeyx/pytorch/torch/functional.py:378: UserWarning: Contracting a: [1, 2, 1, 1]and b: [1, 1, 3, 1] (Triggered internally at /Users/janeyx/pytorch/aten/src/ATen/native/Linear.cpp:507.)
  return _VF.einsum(equation, operands)  # type: ignore[attr-defined]
/Users/janeyx/pytorch/torch/functional.py:378: UserWarning: Contracting a: [1, 2, 3, 1]and b: [1, 1, 1, 4] (Triggered internally at /Users/janeyx/pytorch/aten/src/ATen/native/Linear.cpp:507.)
  return _VF.einsum(equation, operands)  # type: ignore[attr-defined]
```
Pull Request resolved: #87199
Approved by: https://github.com/soulitzer

* [einsum] Call view instead of sum to remediate MPS regression (#87135)

Fixes #87010.

It turns out that squeeze is much faster than sum, and view is faster than squeeze, so we should default to that whenever possible.

Benchmarking results show that, on MPS, we would be going from the following code taking **29.89ms instead of the current 1466ms, almost a 50x speedup**.
```
q = torch.rand(16, 4096, 40, device='mps', dtype=torch.float)
k = torch.rand(16, 4096, 40, device='mps', dtype=torch.float)
torch.einsum('b i d, b j d -> b i j', q, k).max().item()
```
And a regular einsum will now take **.506ms instead of 2.76ms.**
```
q = torch.rand(16, 4096, 40, device='mps', dtype=torch.float)
k = torch.rand(16, 4096, 40, device='mps', dtype=torch.float)
torch.einsum('b i d, b j d -> b i j', q, k)
```

Special thanks to @soulitzer for helping me experiment + figure out how to squash the remaining 5x regression due to squeeze being slower than view!!
Pull Request resolved: #87135
Approved by: https://github.com/soulitzer, https://github.com/malfet, https://github.com/albanD
@github-actions github-actions bot deleted the einsum-lr2 branch April 19, 2024 01:51
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/trunk Trigger trunk jobs on your pull request Merged
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants