Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[einsum] Call view instead of sum to remediate MPS regression #87135

Closed
wants to merge 3 commits into from

Conversation

janeyx99
Copy link
Contributor

@janeyx99 janeyx99 commented Oct 17, 2022

Fixes #87010.

It turns out that squeeze is much faster than sum, and view is faster than squeeze, so we should default to that whenever possible.

Benchmarking results show that, on MPS, we would be going from the following code taking 29.89ms instead of the current 1466ms, almost a 50x speedup.

q = torch.rand(16, 4096, 40, device='mps', dtype=torch.float)
k = torch.rand(16, 4096, 40, device='mps', dtype=torch.float)
torch.einsum('b i d, b j d -> b i j', q, k).max().item()

And a regular einsum will now take .506ms instead of 2.76ms.

q = torch.rand(16, 4096, 40, device='mps', dtype=torch.float)
k = torch.rand(16, 4096, 40, device='mps', dtype=torch.float)
torch.einsum('b i d, b j d -> b i j', q, k)

Special thanks to @soulitzer for helping me experiment + figure out how to squash the remaining 5x regression due to squeeze being slower than view!!

@pytorch-bot
Copy link

pytorch-bot bot commented Oct 17, 2022

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/87135

Note: Links to docs will display an error until the docs builds have been completed.

❗ 2 Active SEVs

There are 2 currently active SEVs. If your PR is affected, please view them below:

❌ 2 Failures

As of commit 5292c2c:

The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@linux-foundation-easycla
Copy link

linux-foundation-easycla bot commented Oct 17, 2022

CLA Signed

The committers listed above are authorized under a signed CLA.

  • ✅ login: janeyx99 / name: Jane (Yuan) Xu (a5f333d)

@janeyx99
Copy link
Contributor Author

/easycla

2 similar comments
@janeyx99
Copy link
Contributor Author

/easycla

@janeyx99
Copy link
Contributor Author

/easycla

@janeyx99
Copy link
Contributor Author

FYI @Birch-san

@Birch-san
Copy link

brilliant!! thanks @janeyx99 for pursuing this.
if this is further patched with a mitigation for the change to path kwargs: would that bring us back to 1.12.1 perf, or would it be faster? I assume 1.12.1 was made without knowledge of this squeeze optimization?

@janeyx99
Copy link
Contributor Author

@Birch-san Unfortunately this does not get us back to 1.12.1 perf just yet (see disclaimer)--it is still ~4-5x regression since this code was introduced to allow for flexible ordering of contractions.

@Birch-san
Copy link

yes, understood 🙂 and good progress nonetheless.
okay, so whatever happens from here has to work within the constraints of keeping support for flexible ordering? so a code path to use the old algorithm isn't possible, but something else might be possible?

@janeyx99
Copy link
Contributor Author

I am attempting to implement such a code path, but doing it elegantly is the challenge 😛

@Birch-san
Copy link

🙇

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Oct 18, 2022
@janeyx99
Copy link
Contributor Author

@pytorchbot merge -r

@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a rebase job. Check the current status here

@pytorchmergebot
Copy link
Collaborator

Successfully rebased replace-sum-with-squeeze onto refs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, via git checkout replace-sum-with-squeeze && git pull --rebase)

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 9 additional jobs have failed, first few of them are: trunk ,trunk / android-emulator-build-test / build-and-test ,trunk / ios-12-5-1-x86-64 / build ,trunk / macos-12-py3-x86-64 / build ,trunk / macos-12-py3-x86-64-lite-interpreter / build

Details for Dev Infra team Raised by workflow job

@janeyx99
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

The merge job was canceled. If you believe this is a mistake,then you can re trigger it through pytorch-bot.

@janeyx99 janeyx99 changed the title [einsum] Call squeeze instead of sum to remediate MPS regression [einsum] Call view instead of sum to remediate MPS regression Oct 18, 2022
@janeyx99
Copy link
Contributor Author

@Birch-san this PR should now get us back 1.12.1 perf! Feel free to verify! Once this lands, I will be trying to get this in our release candidate.

@malfet malfet added this to the 1.13.0 milestone Oct 18, 2022
@malfet
Copy link
Contributor

malfet commented Oct 18, 2022

squeeze is indeed must faster than sum as it is essentially a no-op (just creates view of the tensor), but it is not equivalent for any non 1-dim tensor

@Birch-san
Copy link

outstanding! thanks so much.

by 1.12.1 perf, which side of #85297 (comment) are we talking? a .clone() was introduced post-1.12.1 for correctness in some situations, which regressed einsum perf by 54% (overall 5~6% slowdown in image generation). presumably that's still there?

either way: this certainly sounds like this fixes the path kwargs regression.

and regarding the release candidate: any idea whether the einsum correctness fix for #85224 will make it in? from @pcuenca's testing it sounds like it's not included? huggingface/diffusers#372 (comment)

std::vector<int64_t> sum_dims(perm_index - out_num_dim);
std::iota(sum_dims.begin(), sum_dims.end(), out_num_dim);
ops[0] = ops[0].sum(sum_dims);
if (num_ops > 1) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This if is for the special case where there is a single input and thus all the code above was a noop and so we need to actually do the reduction here.
For all the other cases (>1), we are guaranteed from the code above that all reduced dimensions are now of size 1 and thus can just be viewed the way we want?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Precisely

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A comment would be great for future readers.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh shoot i just saw this, will add in another pr

@janeyx99
Copy link
Contributor Author

@Birch-san I wasn't aware of the other perf regression, but it would be interesting to do that benchmark.

With regards to the correctness--I believe 85689 already in https://hud.pytorch.org/hud/pytorch/pytorch/release%2F1.13/8?per_page=50

Copy link
Collaborator

@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SGTM

@janeyx99
Copy link
Contributor Author

Ah, the test_proxy_tensor is a real failure. @albanD might you have a workaround idea?

======================================================================
ERROR: test_make_fx_symbolic_exhaustive_einsum_cpu_float32 (__main__.TestProxyTensorOpInfoCPU)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/Users/janeyx/pytorch/torch/testing/_internal/common_device_type.py", line 391, in instantiated_test
    raise rte
  File "/Users/janeyx/pytorch/torch/testing/_internal/common_device_type.py", line 378, in instantiated_test
    result = test(self, **param_kwargs)
  File "/Users/janeyx/pytorch/torch/testing/_internal/common_device_type.py", line 824, in test_wrapper
    return test(*args, **kwargs)
  File "/Users/janeyx/pytorch/test/test_proxy_tensor.py", line 1361, in test_make_fx_symbolic_exhaustive
    _test_make_fx_helper(self, device, dtype, op, "symbolic")
  File "/Users/janeyx/pytorch/test/test_proxy_tensor.py", line 1332, in _test_make_fx_helper
    new_f = make_fx(f, tracing_mode=tracing_mode)(args, kwargs)
  File "/Users/janeyx/pytorch/torch/fx/experimental/proxy_tensor.py", line 663, in wrapped
    t = dispatch_trace(wrap_key(func, args, fx_tracer), tracer=fx_tracer, concrete_args=tuple(phs))
  File "/Users/janeyx/pytorch/torch/fx/experimental/proxy_tensor.py", line 413, in dispatch_trace
    graph = tracer.trace(root, concrete_args)
  File "/Users/janeyx/pytorch/torch/fx/_symbolic_trace.py", line 739, in trace
    (self.create_arg(fn(*args)),),
  File "/Users/janeyx/pytorch/torch/fx/_symbolic_trace.py", line 614, in flatten_fn
    tree_out = root_fn(*tree_args)
  File "/Users/janeyx/pytorch/torch/fx/experimental/proxy_tensor.py", line 427, in wrapped
    out = f(*tensors)
  File "/Users/janeyx/pytorch/test/test_proxy_tensor.py", line 1322, in f
    return op.op(*args, **kwargs)
  File "/Users/janeyx/pytorch/torch/testing/_internal/common_methods_invocations.py", line 13219, in <lambda>
    op=lambda tensors, equation: torch.einsum(equation, tensors),
  File "/Users/janeyx/pytorch/torch/functional.py", line 373, in einsum
    return einsum(equation, *_operands)
  File "/Users/janeyx/pytorch/torch/functional.py", line 378, in einsum
    return _VF.einsum(equation, operands)  # type: ignore[attr-defined]
RuntimeError: Cannot call sizes() on tensor with symbolic sizes/strides

Copy link
Collaborator

@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comments for the symint support.

aten/src/ATen/native/Linear.cpp Outdated Show resolved Hide resolved
aten/src/ATen/native/Linear.cpp Outdated Show resolved Hide resolved
aten/src/ATen/native/Linear.cpp Show resolved Hide resolved
aten/src/ATen/native/Linear.cpp Show resolved Hide resolved
@pcuenca
Copy link

pcuenca commented Oct 18, 2022

This looks absolutely amazing @janeyx99, thanks a lot for the effort! Regarding non-determinism, I'll run some tests to try to understand the scope better.

@janeyx99
Copy link
Contributor Author

@pytorchbot merge -f "preexisting failures"

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@github-actions
Copy link

Hey @janeyx99.
You've committed this PR, but it does not have both a 'release notes: ...' and 'topics: ...' label. Please add one of each to the PR. The 'release notes: ...' label should represent the part of PyTorch that this PR changes (fx, autograd, distributed, etc) and the 'topics: ...' label should represent the kind of PR it is (not user facing, new feature, bug fix, perf improvement, etc). The list of valid labels can be found here for the 'release notes: ...' and here for the 'topics: ...'.
For changes that are 'topic: not user facing' there is no need for a release notes label.

janeyx99 added a commit that referenced this pull request Oct 19, 2022
Fixes #87010.

It turns out that squeeze is much faster than sum, and view is faster than squeeze, so we should default to that whenever possible.

Benchmarking results show that, on MPS, we would be going from the following code taking **29.89ms instead of the current 1466ms, almost a 50x speedup**.
```
q = torch.rand(16, 4096, 40, device='mps', dtype=torch.float)
k = torch.rand(16, 4096, 40, device='mps', dtype=torch.float)
torch.einsum('b i d, b j d -> b i j', q, k).max().item()
```
And a regular einsum will now take **.506ms instead of 2.76ms.**
```
q = torch.rand(16, 4096, 40, device='mps', dtype=torch.float)
k = torch.rand(16, 4096, 40, device='mps', dtype=torch.float)
torch.einsum('b i d, b j d -> b i j', q, k)
```

Special thanks to @soulitzer for helping me experiment + figure out how to squash the remaining 5x regression due to squeeze being slower than view!!
Pull Request resolved: #87135
Approved by: https://github.com/soulitzer, https://github.com/malfet, https://github.com/albanD
atalman pushed a commit that referenced this pull request Oct 19, 2022
…path is None (#87261)

* [einsum] keep the promise that we contract left to right (#87199)

We promise that if path is not defined, we would go left to right. The previous code did not keep that promise as we push'd combined ops to the back of the list. For most use cases this is fine (einsum with 3 or fewer inputs), but we should do what we say.

Test plan:
Added a print statement to print the sizes of ops we're contracting to see if the order is fixed. Code run:
```
import torch
a = torch.rand(1)
b = torch.rand(2)
c = torch.rand(3)
d = torch.rand(4)
torch.einsum('a,b,c,d->abcd', a,b,c,d)
```

BEFORE--it does a+b, then c+d, then a+b+c+d, which...is right, but it's not the order specified by the user.
```
/Users/janeyx/pytorch/torch/functional.py:378: UserWarning: Contracting a: [1, 1, 1, 1]and b: [1, 2, 1, 1] (Triggered internally at /Users/janeyx/pytorch/aten/src/ATen/native/Linear.cpp:507.)
  return _VF.einsum(equation, operands)  # type: ignore[attr-defined]
/Users/janeyx/pytorch/torch/functional.py:378: UserWarning: Contracting a: [1, 1, 3, 1]and b: [1, 1, 1, 4] (Triggered internally at /Users/janeyx/pytorch/aten/src/ATen/native/Linear.cpp:507.)
  return _VF.einsum(equation, operands)  # type: ignore[attr-defined]
/Users/janeyx/pytorch/torch/functional.py:378: UserWarning: Contracting a: [1, 2, 1, 1]and b: [1, 1, 3, 4] (Triggered internally at /Users/janeyx/pytorch/aten/src/ATen/native/Linear.cpp:507.)
  return _VF.einsum(equation, operands)  # type: ignore[attr-defined]
```

WITH THIS CHANGE--it actually goes left to right: a+b, a+b+c, a+b+c+d
```
/Users/janeyx/pytorch/torch/functional.py:378: UserWarning: Contracting a: [1, 1, 1, 1]and b: [1, 2, 1, 1] (Triggered internally at /Users/janeyx/pytorch/aten/src/ATen/native/Linear.cpp:507.)
  return _VF.einsum(equation, operands)  # type: ignore[attr-defined]
/Users/janeyx/pytorch/torch/functional.py:378: UserWarning: Contracting a: [1, 2, 1, 1]and b: [1, 1, 3, 1] (Triggered internally at /Users/janeyx/pytorch/aten/src/ATen/native/Linear.cpp:507.)
  return _VF.einsum(equation, operands)  # type: ignore[attr-defined]
/Users/janeyx/pytorch/torch/functional.py:378: UserWarning: Contracting a: [1, 2, 3, 1]and b: [1, 1, 1, 4] (Triggered internally at /Users/janeyx/pytorch/aten/src/ATen/native/Linear.cpp:507.)
  return _VF.einsum(equation, operands)  # type: ignore[attr-defined]
```
Pull Request resolved: #87199
Approved by: https://github.com/soulitzer

* [einsum] Call view instead of sum to remediate MPS regression (#87135)

Fixes #87010.

It turns out that squeeze is much faster than sum, and view is faster than squeeze, so we should default to that whenever possible.

Benchmarking results show that, on MPS, we would be going from the following code taking **29.89ms instead of the current 1466ms, almost a 50x speedup**.
```
q = torch.rand(16, 4096, 40, device='mps', dtype=torch.float)
k = torch.rand(16, 4096, 40, device='mps', dtype=torch.float)
torch.einsum('b i d, b j d -> b i j', q, k).max().item()
```
And a regular einsum will now take **.506ms instead of 2.76ms.**
```
q = torch.rand(16, 4096, 40, device='mps', dtype=torch.float)
k = torch.rand(16, 4096, 40, device='mps', dtype=torch.float)
torch.einsum('b i d, b j d -> b i j', q, k)
```

Special thanks to @soulitzer for helping me experiment + figure out how to squash the remaining 5x regression due to squeeze being slower than view!!
Pull Request resolved: #87135
Approved by: https://github.com/soulitzer, https://github.com/malfet, https://github.com/albanD
pytorchmergebot pushed a commit that referenced this pull request Oct 24, 2022
Tiny followup from #87135 (comment)

and another typo i noticed while doing the autograd lab
Pull Request resolved: #87264
Approved by: https://github.com/soulitzer
sgrigory pushed a commit to sgrigory/pytorch that referenced this pull request Oct 28, 2022
Tiny followup from pytorch#87135 (comment)

and another typo i noticed while doing the autograd lab
Pull Request resolved: pytorch#87264
Approved by: https://github.com/soulitzer
kulinseth pushed a commit to kulinseth/pytorch that referenced this pull request Nov 5, 2022
Tiny followup from pytorch#87135 (comment)

and another typo i noticed while doing the autograd lab
Pull Request resolved: pytorch#87264
Approved by: https://github.com/soulitzer
kulinseth pushed a commit to kulinseth/pytorch that referenced this pull request Dec 10, 2022
Tiny followup from pytorch#87135 (comment)

and another typo i noticed while doing the autograd lab
Pull Request resolved: pytorch#87264
Approved by: https://github.com/soulitzer
@github-actions github-actions bot deleted the replace-sum-with-squeeze branch April 19, 2024 01:51
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/trunk Trigger trunk jobs on your pull request Merged
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[MPS] einsum 42x slower since 1.13.0.dev20220925 — on (16, 4096, 40)*(16, 40, 4096) matmul
7 participants