-
Notifications
You must be signed in to change notification settings - Fork 21.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
single-batch torch.bmm
is significantly slower with cuBLAS>12.1.0
#114911
Comments
I see similar number switching from 2.1.0+cu121 to 2.0.1+cu117:
Also, I see perf drop even if I use nvidia-cublas-cu12==12.3.4.1 |
Reduced it to a simple bmm perf regression when used instead of dot: import torch
import time
def benchmark_fn(name, fn, args, warmup=5, cycles=100):
for _ in range(warmup):
fn(*args)
torch.cuda.synchronize()
begin = time.time()
for _ in range(cycles):
fn(*args)
torch.cuda.synchronize()
dt = (time.time() - begin)
dt_us = int(dt * 1000000) / cycles
print(f"{name}:", dt_us, "us")
if __name__ == "__main__":
print("torch: ", torch.__version__, " device: ", torch.cuda.get_device_name(0))
m, n, k=1, 1, 65535
a=torch.rand((m, k), device='cuda')
b=torch.rand((k, n), device='cuda')
benchmark_fn("bmm", torch.bmm, (a.unsqueeze(0), b.unsqueeze(0)))
benchmark_fn("mm", torch.mm, (a, b)) |
@malfet I cannot reproduce the regression using |
@ptrblck A100. But how did you install 2.1.1+cu121? Using big wheels or ones from pypi? It will not repro with bigwheels afaik, as cuda-12.1 toolkit is shipped with this older version of BLAS.
|
torch.bmm
is significantly slower with cuBLAS>12.1.0
We deprecated big wheels with cu121, and made small wheels be the default on Nov 22. This is probably the reason why this issue was found now. |
Just a follow up to close the loop on hardware type. The initial report from Daniel was also on A100. |
@gottbrath I think HW type is irrelevant, it has the same perf regression everywhere(A100 and H100) |
Ok, following patch brings performance really close for both cases: iff --git a/aten/src/ATen/native/cuda/Blas.cpp b/aten/src/ATen/native/cuda/Blas.cpp
index 38cce45ab6e..872454e9e16 100644
--- a/aten/src/ATen/native/cuda/Blas.cpp
+++ b/aten/src/ATen/native/cuda/Blas.cpp
@@ -382,6 +382,13 @@ const Tensor& baddbmm_out_cuda_impl(const Tensor& result, const Tensor& self, co
}
}
+ // If batch is 1 always call addmm_out_cuda_imlp
+ if (result.size(0) == 1) {
+ auto result_s = result.squeeze(0);
+ addmm_out_cuda_impl(result_s, self.squeeze(0), batch1.squeeze(0), batch2.squeeze(0), beta, alpha);
+ return result;
+ }
+
bool transpose_result = false;
c10::MaybeOwned<Tensor> result_;
IntArrayRef result_strides = result.strides(); But I'm a bit concearned there are still perf difference:
|
I wonder if more such tests could be added to track perf (even if they are not falling now). E.g. many people have found that PyTorch gemm routines perform sub-optimally for inner-dim-size of 2/3/4 (used for 3d geometry). Even if these are not optimized, might be nice to add measurements for these too |
@vadimkantorov do you have list of matrix sizes you would be interested to benchmark? |
I'll try to search issues for 2/3/4-inner-dim complaints :) For the others, I wonder if some einsum formulas could be grep'd from GitHub Search across all repos |
Based on internal post: https://fb.workplace.com/groups/1405155842844877/permalink/7730383983655333/ |
Wrote a bit more comprehensive benchmark script , which shows strange perf deviation between cuda-11.8 and cuda-12.1. Numbers below are run on A100:
|
I.e. it feels reasonable to always call `at::cuda::gemm` rather than `at::cuda::bgemm` when num_batches == 1 After the change, benchmarking torch built with CUDA-12 using [following perf script](https://gist.github.com/malfet/6a17156d7f5663b8b12054a1beff3fe1) on A100 are as follows: | Shape | bmm_time | mm_time | slow down (%) | | -------------- | --------- | --------- | ------------- | | 1x1x4096 | 14.18 | 14.31 | -0.89 | | 1x1x8192 | 14.37 | 14.37 | -0.05 | | 1x1x16384 | 14.03 | 14.12 | -0.68 | | 1x1x32768 | 14.19 | 14.24 | -0.35 | | 1x1x65536 | 14.85 | 14.52 | 2.30 | | 1x1x131072 | 14.03 | 14.07 | -0.33 | | 128x128x128 | 11.34 | 11.06 | 2.56 | | 256x256x256 | 14.85 | 14.40 | 3.15 | | 512x512x512 | 27.22 | 27.22 | -0.01 | | 1024x1024x1024 | 129.66 | 129.50 | 0.12 | | 2048x2048x2048 | 972.18 | 973.24 | -0.11 | | 129x127x129 | 11.21 | 11.25 | -0.39 | | 257x255x257 | 14.50 | 14.43 | 0.44 | | 513x511x513 | 29.01 | 29.01 | 0.01 | | 1025x1023x1025 | 137.65 | 137.64 | 0.01 | | 2049x2047x2049 | 982.58 | 982.65 | -0.01 | | 4097x3x4097 | 86.65 | 86.64 | 0.01 | | 8193x3x8193 | 384.02 | 383.96 | 0.02 | | 16385x3x16385 | 1106.73 | 1107.32 | -0.05 | | 32769x3x32769 | 4739.49 | 4739.48 | 0.00 | | 65537x3x65537 | 17377.78 | 17378.74 | -0.01 | | 4097x5x4097 | 87.09 | 87.12 | -0.03 | | 8193x5x8193 | 301.38 | 301.36 | 0.01 | | 16385x5x16385 | 1107.38 | 1108.04 | -0.06 | | 32769x5x32769 | 4743.73 | 4744.07 | -0.01 | | 65537x5x65537 | 17392.32 | 17395.42 | -0.02 | | 4097x7x4097 | 87.17 | 87.19 | -0.02 | | 8193x7x8193 | 301.94 | 302.00 | -0.02 | | 16385x7x16385 | 1107.17 | 1106.79 | 0.03 | | 32769x7x32769 | 4747.15 | 4747.13 | 0.00 | | 65537x7x65537 | 17403.85 | 17405.02 | -0.01 | Fixes perf problem reported in #114911 Pull Request resolved: #114992 Approved by: https://github.com/Skylion007, https://github.com/eqy
Removing high-pri as I've landed the change to mitigate the issue |
I.e. it feels reasonable to always call `at::cuda::gemm` rather than `at::cuda::bgemm` when num_batches == 1 After the change, benchmarking torch built with CUDA-12 using [following perf script](https://gist.github.com/malfet/6a17156d7f5663b8b12054a1beff3fe1) on A100 are as follows: | Shape | bmm_time | mm_time | slow down (%) | | -------------- | --------- | --------- | ------------- | | 1x1x4096 | 14.18 | 14.31 | -0.89 | | 1x1x8192 | 14.37 | 14.37 | -0.05 | | 1x1x16384 | 14.03 | 14.12 | -0.68 | | 1x1x32768 | 14.19 | 14.24 | -0.35 | | 1x1x65536 | 14.85 | 14.52 | 2.30 | | 1x1x131072 | 14.03 | 14.07 | -0.33 | | 128x128x128 | 11.34 | 11.06 | 2.56 | | 256x256x256 | 14.85 | 14.40 | 3.15 | | 512x512x512 | 27.22 | 27.22 | -0.01 | | 1024x1024x1024 | 129.66 | 129.50 | 0.12 | | 2048x2048x2048 | 972.18 | 973.24 | -0.11 | | 129x127x129 | 11.21 | 11.25 | -0.39 | | 257x255x257 | 14.50 | 14.43 | 0.44 | | 513x511x513 | 29.01 | 29.01 | 0.01 | | 1025x1023x1025 | 137.65 | 137.64 | 0.01 | | 2049x2047x2049 | 982.58 | 982.65 | -0.01 | | 4097x3x4097 | 86.65 | 86.64 | 0.01 | | 8193x3x8193 | 384.02 | 383.96 | 0.02 | | 16385x3x16385 | 1106.73 | 1107.32 | -0.05 | | 32769x3x32769 | 4739.49 | 4739.48 | 0.00 | | 65537x3x65537 | 17377.78 | 17378.74 | -0.01 | | 4097x5x4097 | 87.09 | 87.12 | -0.03 | | 8193x5x8193 | 301.38 | 301.36 | 0.01 | | 16385x5x16385 | 1107.38 | 1108.04 | -0.06 | | 32769x5x32769 | 4743.73 | 4744.07 | -0.01 | | 65537x5x65537 | 17392.32 | 17395.42 | -0.02 | | 4097x7x4097 | 87.17 | 87.19 | -0.02 | | 8193x7x8193 | 301.94 | 302.00 | -0.02 | | 16385x7x16385 | 1107.17 | 1106.79 | 0.03 | | 32769x7x32769 | 4747.15 | 4747.13 | 0.00 | | 65537x7x65537 | 17403.85 | 17405.02 | -0.01 | Fixes perf problem reported in pytorch#114911 Pull Request resolved: pytorch#114992 Approved by: https://github.com/Skylion007, https://github.com/eqy
I.e. it feels reasonable to always call `at::cuda::gemm` rather than `at::cuda::bgemm` when num_batches == 1 After the change, benchmarking torch built with CUDA-12 using [following perf script](https://gist.github.com/malfet/6a17156d7f5663b8b12054a1beff3fe1) on A100 are as follows: | Shape | bmm_time | mm_time | slow down (%) | | -------------- | --------- | --------- | ------------- | | 1x1x4096 | 14.18 | 14.31 | -0.89 | | 1x1x8192 | 14.37 | 14.37 | -0.05 | | 1x1x16384 | 14.03 | 14.12 | -0.68 | | 1x1x32768 | 14.19 | 14.24 | -0.35 | | 1x1x65536 | 14.85 | 14.52 | 2.30 | | 1x1x131072 | 14.03 | 14.07 | -0.33 | | 128x128x128 | 11.34 | 11.06 | 2.56 | | 256x256x256 | 14.85 | 14.40 | 3.15 | | 512x512x512 | 27.22 | 27.22 | -0.01 | | 1024x1024x1024 | 129.66 | 129.50 | 0.12 | | 2048x2048x2048 | 972.18 | 973.24 | -0.11 | | 129x127x129 | 11.21 | 11.25 | -0.39 | | 257x255x257 | 14.50 | 14.43 | 0.44 | | 513x511x513 | 29.01 | 29.01 | 0.01 | | 1025x1023x1025 | 137.65 | 137.64 | 0.01 | | 2049x2047x2049 | 982.58 | 982.65 | -0.01 | | 4097x3x4097 | 86.65 | 86.64 | 0.01 | | 8193x3x8193 | 384.02 | 383.96 | 0.02 | | 16385x3x16385 | 1106.73 | 1107.32 | -0.05 | | 32769x3x32769 | 4739.49 | 4739.48 | 0.00 | | 65537x3x65537 | 17377.78 | 17378.74 | -0.01 | | 4097x5x4097 | 87.09 | 87.12 | -0.03 | | 8193x5x8193 | 301.38 | 301.36 | 0.01 | | 16385x5x16385 | 1107.38 | 1108.04 | -0.06 | | 32769x5x32769 | 4743.73 | 4744.07 | -0.01 | | 65537x5x65537 | 17392.32 | 17395.42 | -0.02 | | 4097x7x4097 | 87.17 | 87.19 | -0.02 | | 8193x7x8193 | 301.94 | 302.00 | -0.02 | | 16385x7x16385 | 1107.17 | 1106.79 | 0.03 | | 32769x7x32769 | 4747.15 | 4747.13 | 0.00 | | 65537x7x65537 | 17403.85 | 17405.02 | -0.01 | Fixes perf problem reported in pytorch#114911 Pull Request resolved: pytorch#114992 Approved by: https://github.com/Skylion007, https://github.com/eqy
I.e. it feels reasonable to always call `at::cuda::gemm` rather than `at::cuda::bgemm` when num_batches == 1 After the change, benchmarking torch built with CUDA-12 using [following perf script](https://gist.github.com/malfet/6a17156d7f5663b8b12054a1beff3fe1) on A100 are as follows: | Shape | bmm_time | mm_time | slow down (%) | | -------------- | --------- | --------- | ------------- | | 1x1x4096 | 14.18 | 14.31 | -0.89 | | 1x1x8192 | 14.37 | 14.37 | -0.05 | | 1x1x16384 | 14.03 | 14.12 | -0.68 | | 1x1x32768 | 14.19 | 14.24 | -0.35 | | 1x1x65536 | 14.85 | 14.52 | 2.30 | | 1x1x131072 | 14.03 | 14.07 | -0.33 | | 128x128x128 | 11.34 | 11.06 | 2.56 | | 256x256x256 | 14.85 | 14.40 | 3.15 | | 512x512x512 | 27.22 | 27.22 | -0.01 | | 1024x1024x1024 | 129.66 | 129.50 | 0.12 | | 2048x2048x2048 | 972.18 | 973.24 | -0.11 | | 129x127x129 | 11.21 | 11.25 | -0.39 | | 257x255x257 | 14.50 | 14.43 | 0.44 | | 513x511x513 | 29.01 | 29.01 | 0.01 | | 1025x1023x1025 | 137.65 | 137.64 | 0.01 | | 2049x2047x2049 | 982.58 | 982.65 | -0.01 | | 4097x3x4097 | 86.65 | 86.64 | 0.01 | | 8193x3x8193 | 384.02 | 383.96 | 0.02 | | 16385x3x16385 | 1106.73 | 1107.32 | -0.05 | | 32769x3x32769 | 4739.49 | 4739.48 | 0.00 | | 65537x3x65537 | 17377.78 | 17378.74 | -0.01 | | 4097x5x4097 | 87.09 | 87.12 | -0.03 | | 8193x5x8193 | 301.38 | 301.36 | 0.01 | | 16385x5x16385 | 1107.38 | 1108.04 | -0.06 | | 32769x5x32769 | 4743.73 | 4744.07 | -0.01 | | 65537x5x65537 | 17392.32 | 17395.42 | -0.02 | | 4097x7x4097 | 87.17 | 87.19 | -0.02 | | 8193x7x8193 | 301.94 | 302.00 | -0.02 | | 16385x7x16385 | 1107.17 | 1106.79 | 0.03 | | 32769x7x32769 | 4747.15 | 4747.13 | 0.00 | | 65537x7x65537 | 17403.85 | 17405.02 | -0.01 | Fixes perf problem reported in #114911 Pull Request resolved: #114992 Approved by: https://github.com/Skylion007, https://github.com/eqy Co-authored-by: Nikita Shulga <nshulga@meta.com>
Confirm that this is fixed in the upcoming 2.2.0 release: On 2.2.0:
The regression is there on 2.1.2:
|
🐛 Describe the bug
Running a simple torch.einsum is 250x slower. It executes a void gemv2T_kernel_val kernel from cuBLAS, which underutilises the GPU (uses a single SM).
This looks like a regression in cuBLAS. torch ships with nvidia-cublas-cu12==12.1.3.1, and downgrading to nvidia-cublas-cu12==12.1.0.26 fixes the issue.
NOTE: The conda builds are not affected, as they install the older version.
Versions
2.2.x nightly
cc @ezyang @gchanan @zou3519 @kadeng @ptrblck @malfet
The text was updated successfully, but these errors were encountered: