New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix issue of baddbmm when out has nan value for beta=0 #96086
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/96086
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit df0aa8e: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
r *= beta; | ||
r = beta == scalar_t(0) ? scalar_t(0) : beta * r; | ||
for (const auto k : c10::irange(ks)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, wow, I did not know that baddbmm
has an in-house implementation...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was sure matmuls are delegated to MKL or something of the sort.
This bug seems to have been there for the past 5 years, I'm amazed there haven't been reported random failures of deployed CPU inference systems.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
haha, maybe that's why :) everyone who cares about cpu inference, made sure to use MKL and bypass this buggy impl :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, unless this path bypasses MKL when it is not available, I would assume dispatching to aten::mm
could be a better strategy unless circular. If that is the case indeed, this could be done as a follow-up?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@nikitaved Never tried, torch.mm
was never powerful enough for my purposes (no broadcast, single matrix). Not for real world heavy lifting.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
According to the docs, that's not torch.mm
. Probably torch.matmul
.
I never figured why there are so many entrypoints.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, it is matmul
indeed. mm
appears to be a method (Tensor.mm
) with the aforementioned limitations. Totally agree, seems like this redundancy is more confusing than it is helpful...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
baddbmm is usually dispatched to mkl, but in some cases (small matrices, small batches) it goes to special implementation).
@@ -1557,7 +1557,7 @@ inline void baddbmm_cpu_kernel(const Tensor& result, const Tensor& self, const T | |||
r += s2[k] * m1[k][j]; | |||
} | |||
} else { | |||
r *= beta; | |||
r = beta == scalar_t(0) ? scalar_t(0) : beta * r; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please leave a comment here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And not that it matters that much given that this implementation will not be particularly efficient, but you may want to do this up front calling result.zero_()
in this case and putting an if here, so that this part is vectorized. Similar to how it's done for this path:
pytorch/aten/src/ATen/native/LinearAlgebra.cpp
Lines 1639 to 1641 in 5b2ab0d
if (is_bmm_out || (beta.to<c10::complex<double>>() == 0.0)) { | |
self_or_result.zero_(); | |
return; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think it will get a better performance calling result.zero_()
, for this path, the m*n*k
is a smaller number, and if we use result.zero_()
, there has two steps(no fusion, and also have dispatch overhead).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I do some performance test for m=1, n=1, k=300
:
- threads = 20
- this PR
batch_size = 1, avg time is 0.006 (ms)
batch_size = 20, avg time is 0.011 (ms)
batch_size = 40, avg time is 0.011 (ms)
batch_size = 80, avg time is 0.012 (ms)
batch_size = 160, avg time is 0.015 (ms)
- your proposal
batch_size = 1, avg time is 0.013 (ms)
batch_size = 20, avg time is 0.019 (ms)
batch_size = 40, avg time is 0.018 (ms)
batch_size = 80, avg time is 0.023 (ms)
batch_size = 160, avg time is 0.029 (ms)
- threads =1
- this PR
batch_size = 1, avg time is 0.013 (ms)
batch_size = 20, avg time is 0.024 (ms)
batch_size = 40, avg time is 0.036 (ms)
batch_size = 80, avg time is 0.069 (ms)
batch_size = 160, avg time is 0.137 (ms)
- your proposal
batch_size = 1, avg time is 0.010 (ms)
batch_size = 20, avg time is 0.024 (ms)
batch_size = 40, avg time is 0.038 (ms)
batch_size = 80, avg time is 0.072 (ms)
batch_size = 160, avg time is 0.142 (ms)
test code:
import torch
num_iter = 600
def fn(a, b, c):
return torch.baddbmm(a, b, c, beta=0.00)
for batch_size in [1, 20, 40, 80, 160]:
m = 1
n = 1
p = 300
a = torch.randn((batch_size, n, p))
b = torch.randn((batch_size, n, m))
c = torch.randn((batch_size, m, p))
fwd = 0
with torch.no_grad():
for i in range(300):
y = fn(a, b, c)
with torch.no_grad():
t1 = time.time()
for i in range(num_iter):
y = fn(a, b, c)
t2 = time.time()
fwd = fwd + (t2 - t1)
avg_time = fwd / num_iter * 1000
print("batch_size = %d, avg time is %0.3f (ms) fps:%f"%(batch_size, avg_time, batch_size * num_iter / fwd))
Fix #96037. cc jgong5 mingfeima sanchitintel ashokei jingxu10 [ghstack-poisoned]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fair enough!
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Fix pytorch#96037. Pull Request resolved: pytorch#96086 Approved by: https://github.com/ngimel, https://github.com/lezcano
Fix pytorch/pytorch#96037. Pull Request resolved: pytorch/pytorch#96086 Approved by: https://github.com/ngimel, https://github.com/lezcano
Fix pytorch/pytorch#96037. Pull Request resolved: pytorch/pytorch#96086 Approved by: https://github.com/ngimel, https://github.com/lezcano
Fix pytorch#96037. Pull Request resolved: pytorch#96086 Approved by: https://github.com/ngimel, https://github.com/lezcano
Stack from ghstack (oldest at bottom):
Fix #96037.
cc @jgong5 @mingfeima @sanchitintel @ashokei @jingxu10