Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix issue of baddbmm when out has nan value for beta=0 #96086

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 2 additions & 1 deletion aten/src/ATen/native/LinearAlgebra.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1557,7 +1557,8 @@ inline void baddbmm_cpu_kernel(const Tensor& result, const Tensor& self, const T
r += s2[k] * m1[k][j];
}
} else {
r *= beta;
// For beta == 0, the r's value will be ignored, especially for nan value.
r = beta == scalar_t(0) ? scalar_t(0) : beta * r;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please leave a comment here.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And not that it matters that much given that this implementation will not be particularly efficient, but you may want to do this up front calling result.zero_() in this case and putting an if here, so that this part is vectorized. Similar to how it's done for this path:

if (is_bmm_out || (beta.to<c10::complex<double>>() == 0.0)) {
self_or_result.zero_();
return;

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it will get a better performance calling result.zero_(), for this path, the m*n*k is a smaller number, and if we use result.zero_(), there has two steps(no fusion, and also have dispatch overhead).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do some performance test for m=1, n=1, k=300:

  1. threads = 20
  • this PR
batch_size = 1, avg time is 0.006 (ms)
batch_size = 20, avg time is 0.011 (ms) 
batch_size = 40, avg time is 0.011 (ms)
batch_size = 80, avg time is 0.012 (ms)
batch_size = 160, avg time is 0.015 (ms)
  • your proposal
batch_size = 1, avg time is 0.013 (ms)
batch_size = 20, avg time is 0.019 (ms)
batch_size = 40, avg time is 0.018 (ms)
batch_size = 80, avg time is 0.023 (ms)
batch_size = 160, avg time is 0.029 (ms)
  1. threads =1
  • this PR
batch_size = 1, avg time is 0.013 (ms)
batch_size = 20, avg time is 0.024 (ms)
batch_size = 40, avg time is 0.036 (ms)
batch_size = 80, avg time is 0.069 (ms)
batch_size = 160, avg time is 0.137 (ms)
  • your proposal
batch_size = 1, avg time is 0.010 (ms)
batch_size = 20, avg time is 0.024 (ms)
batch_size = 40, avg time is 0.038 (ms)
batch_size = 80, avg time is 0.072 (ms)
batch_size = 160, avg time is 0.142 (ms)

test code:

import torch
num_iter = 600

def fn(a, b, c):
    return torch.baddbmm(a, b, c, beta=0.00)

for batch_size  in [1, 20, 40, 80, 160]:
    m = 1
    n = 1
    p = 300
    a = torch.randn((batch_size, n, p))
    b = torch.randn((batch_size, n, m))
    c = torch.randn((batch_size, m, p))

    fwd = 0
    with torch.no_grad():
        for i in range(300):
            y = fn(a, b, c)

    with torch.no_grad():
        t1  = time.time()
        for i in range(num_iter):
            y = fn(a, b, c)
        t2 = time.time()
        fwd = fwd + (t2 - t1)

    avg_time = fwd / num_iter * 1000
    print("batch_size = %d, avg time is %0.3f (ms) fps:%f"%(batch_size, avg_time, batch_size  * num_iter / fwd))

for (const auto k : c10::irange(ks)) {
r += alpha * s2[k] * m1[k][j];
}
Expand Down
15 changes: 15 additions & 0 deletions test/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -5565,6 +5565,21 @@ def test_addmm_baddbmm_overflow(self, device, dtype):
self.assertTrue((out == 10000.).all())
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = orig

@dtypes(torch.float)
def test_baddbmm_nan_input_with_zero_beta(self, device, dtype):
for shape in [[3, 2, 2], [2, 20, 20]]:
mat1, mat2 = [torch.randn(shape, dtype=dtype, device=device) for _ in range(2)]
inputs = [torch.randn(shape, dtype=dtype, device=device),
torch.randn(shape, dtype=dtype, device=device).fill_(torch.nan)]
outs = [None, torch.randn(shape, dtype=dtype, device=device),
torch.randn(shape, dtype=dtype, device=device).fill_(torch.nan)]
options = itertools.product(inputs, outs)
for input, out in options:
y_ref = torch.bmm(mat1, mat2)
y = torch.baddbmm(input, mat1, mat2, beta=0.0, out=out)
self.assertEqual(y_ref, y)


@unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "cublas runtime error")
@onlyCUDA
def test_matmul_45724(self, device):
Expand Down