Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

single-batch torch.bmm is significantly slower with cuBLAS>12.1.0 #114911

Closed
atalman opened this issue Nov 30, 2023 · 15 comments
Closed

single-batch torch.bmm is significantly slower with cuBLAS>12.1.0 #114911

atalman opened this issue Nov 30, 2023 · 15 comments
Labels
module: cuda Related to torch.cuda, and CUDA support in general module: third_party topic: performance topic category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Milestone

Comments

@atalman
Copy link
Contributor

atalman commented Nov 30, 2023

🐛 Describe the bug

Running a simple torch.einsum is 250x slower. It executes a void gemv2T_kernel_val kernel from cuBLAS, which underutilises the GPU (uses a single SM).
This looks like a regression in cuBLAS. torch ships with nvidia-cublas-cu12==12.1.3.1, and downgrading to nvidia-cublas-cu12==12.1.0.26 fixes the issue.
NOTE: The conda builds are not affected, as they install the older version.

import torch
import time
import torch.nn.functional as F
from torch.profiler import profile, record_function, ProfilerActivity
x = torch.randn(2, 128, 65536).cuda()
def forward(x):
    return torch.einsum("s b k, t b k -> ", x, x)
def benchmark_fn(name, fn, *args, **kwargs):
    for _ in range(5):
        fn(*args, **kwargs)
    torch.cuda.synchronize()
    begin = time.time()
    for _ in range(100):
        fn(*args, **kwargs)
    torch.cuda.synchronize()
    dt = (time.time() - begin)
    dt_us = int(dt * 1000000) / 100
    print(f"{name}:", dt_us, "us")
print("torch: ", torch.__version__)
benchmark_fn("fn", forward, x)
# with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True) as prof:
#     forward(x)
# prof.export_chrome_trace("einsum_nightly_conda_cu121.pt.trace.json.gz")
Output with nightly from `pip`:
torch:  2.2.0.dev20231128+cu121
fn: 51291.36 us
PyTorch built with:
  - GCC 9.3
  - C++ Version: 201703
  - Intel(R) oneAPI Math Kernel Library Version 2022.2-Product Build 20220804 for Intel(R) 64 architecture applications
  - Intel(R) MKL-DNN v3.1.1 (Git Hash 64f6bcbcbab628e96f33a62c3e975f8535a7bde4)
  - OpenMP 201511 (a.k.a. OpenMP 4.5)
  - LAPACK is enabled (usually provided by MKL)
  - NNPACK is enabled
  - CPU capability usage: AVX512
  - CUDA Runtime 12.1
  - NVCC architecture flags: -gencode;arch=compute_50,code=sm_50;-gencode;arch=compute_60,code=sm_60;-gencode;arch=compute_70,code=sm_70;-gencode;arch=compute_75,code=sm_75;-gencode;arch=compute_80,code=sm_80;-gencode;arch=compute_86,code=sm_86;-gencode;arch=compute_90,code=sm_90
  - CuDNN 8.9.3
    - Built with CuDNN 8.9.2
  - Magma 2.6.1
  - Build settings: BLAS_INFO=mkl, BUILD_TYPE=Release, CUDA_VERSION=12.1, CUDNN_VERSION=8.9.2, CXX_COMPILER=/opt/rh/devtoolset-9/root/usr/bin/c++, CXX_FLAGS= -D_GLIBCXX_USE_CXX11_ABI=0 -fabi-version=11 -fvisibility-inlines-hidden -DUSE_PTHREADPOOL -DNDEBUG -DUSE_KINETO -DLIBKINETO_NOROCTRACER -DUSE_FBGEMM -DUSE_QNNPACK -DUSE_PYTORCH_QNNPACK -DUSE_XNNPACK -DSYMBOLICATE_MOBILE_DEBUG_HANDLE -O2 -fPIC -Wall -Wextra -Werror=return-type -Werror=non-virtual-dtor -Werror=bool-operation -Wnarrowing -Wno-missing-field-initializers -Wno-type-limits -Wno-array-bounds -Wno-unknown-pragmas -Wno-unused-parameter -Wno-unused-function -Wno-unused-result -Wno-strict-overflow -Wno-strict-aliasing -Wno-stringop-overflow -Wsuggest-override -Wno-psabi -Wno-error=pedantic -Wno-error=old-style-cast -Wno-missing-braces -fdiagnostics-color=always -faligned-new -Wno-unused-but-set-variable -Wno-maybe-uninitialized -fno-math-errno -fno-trapping-math -Werror=format -Wno-stringop-overflow, LAPACK_INFO=mkl, PERF_WITH_AVX=1, PERF_WITH_AVX2=1, PERF_WITH_AVX512=1, TORCH_DISABLE_GPU_ASSERTS=ON, TORCH_VERSION=2.2.0, USE_CUDA=ON, USE_CUDNN=ON, USE_EXCEPTION_PTR=1, USE_GFLAGS=OFF, USE_GLOG=OFF, USE_MKL=ON, USE_MKLDNN=ON, USE_MPI=OFF, USE_NCCL=1, USE_NNPACK=ON, USE_OPENMP=ON, USE_ROCM=OFF,
> To fix the pip version (which comes with `nvidia-cublas-cu12==12.1.3.1`), just run pip install nvidia-cublas-cu12==12.1.0.26
Output with nightly from `conda`:
torch:  2.2.0.dev20231129
fn: 199.1 us
PyTorch built with:
  - GCC 9.3
  - C++ Version: 201703
  - Intel(R) oneAPI Math Kernel Library Version 2021.4-Product Build 20210904 for Intel(R) 64 architecture applications
  - Intel(R) MKL-DNN v3.1.1 (Git Hash 64f6bcbcbab628e96f33a62c3e975f8535a7bde4)
  - OpenMP 201511 (a.k.a. OpenMP 4.5)
  - LAPACK is enabled (usually provided by MKL)
  - NNPACK is enabled
  - CPU capability usage: AVX512
  - CUDA Runtime 12.1
  - NVCC architecture flags: -gencode;arch=compute_50,code=sm_50;-gencode;arch=compute_60,code=sm_60;-gencode;arch=compute_61,code=sm_61;-gencode;arch=compute_70,code=sm_70;-gencode;arch=compute_75,code=sm_75;-gencode;arch=compute_80,code=sm_80;-gencode;arch=compute_86,code=sm_86;-gencode;arch=compute_90,code=sm_90
  - CuDNN 8.9.2
  - Magma 2.6.1
  - Build settings: BLAS_INFO=mkl, BUILD_TYPE=Release, CUDA_VERSION=12.1, CUDNN_VERSION=8.9.2, CXX_COMPILER=/opt/rh/devtoolset-9/root/usr/bin/c++, CXX_FLAGS= -D_GLIBCXX_USE_CXX11_ABI=0 -fabi-version=11 -fvisibility-inlines-hidden -DUSE_PTHREADPOOL -DNDEBUG -DUSE_KINETO -DLIBKINETO_NOROCTRACER -DUSE_FBGEMM -DUSE_QNNPACK -DUSE_PYTORCH_QNNPACK -DUSE_XNNPACK -DSYMBOLICATE_MOBILE_DEBUG_HANDLE -O2 -fPIC -Wall -Wextra -Werror=return-type -Werror=non-virtual-dtor -Werror=bool-operation -Wnarrowing -Wno-missing-field-initializers -Wno-type-limits -Wno-array-bounds -Wno-unknown-pragmas -Wno-unused-parameter -Wno-unused-function -Wno-unused-result -Wno-strict-overflow -Wno-strict-aliasing -Wno-stringop-overflow -Wsuggest-override -Wno-psabi -Wno-error=pedantic -Wno-error=old-style-cast -Wno-missing-braces -fdiagnostics-color=always -faligned-new -Wno-unused-but-set-variable -Wno-maybe-uninitialized -fno-math-errno -fno-trapping-math -Werror=format -Wno-stringop-overflow, LAPACK_INFO=mkl, PERF_WITH_AVX=1, PERF_WITH_AVX2=1, PERF_WITH_AVX512=1, TORCH_DISABLE_GPU_ASSERTS=ON, TORCH_VERSION=2.2.0, USE_CUDA=ON, USE_CUDNN=ON, USE_EXCEPTION_PTR=1, USE_GFLAGS=OFF, USE_GLOG=OFF, USE_MKL=ON, USE_MKLDNN=ON, USE_MPI=OFF, USE_NCCL=ON, USE_NNPACK=ON, USE_OPENMP=ON, USE_ROCM=OFF,

Versions

2.2.x nightly

cc @ezyang @gchanan @zou3519 @kadeng @ptrblck @malfet

@atalman atalman added module: cuda Related to torch.cuda, and CUDA support in general topic: performance topic category labels Nov 30, 2023
@atalman atalman added this to the 2.1.2 milestone Nov 30, 2023
@malfet
Copy link
Contributor

malfet commented Dec 1, 2023

I see similar number switching from 2.1.0+cu121 to 2.0.1+cu117:

$ conda run -n py310-torch200 python einsum-perf.py 
torch:  2.0.1+cu117
fn: 232.83 us

$ conda run -n py310-torch210 python einsum-perf.py 
torch:  2.1.0+cu121
fn: 53856.39 us

Also, I see perf drop even if I use nvidia-cublas-cu12==12.3.4.1

@malfet
Copy link
Contributor

malfet commented Dec 1, 2023

Reduced it to a simple bmm perf regression when used instead of dot:

import torch
import time


def benchmark_fn(name, fn, args, warmup=5, cycles=100):
    for _ in range(warmup):
        fn(*args)
    torch.cuda.synchronize()
    begin = time.time()
    for _ in range(cycles):
        fn(*args)
    torch.cuda.synchronize()
    dt = (time.time() - begin)
    dt_us = int(dt * 1000000) / cycles
    print(f"{name}:", dt_us, "us")


if __name__ == "__main__":
    print("torch: ", torch.__version__, " device: ", torch.cuda.get_device_name(0))
    m, n, k=1, 1, 65535
    a=torch.rand((m, k), device='cuda')
    b=torch.rand((k, n), device='cuda')

    benchmark_fn("bmm", torch.bmm, (a.unsqueeze(0), b.unsqueeze(0)))
    benchmark_fn("mm", torch.mm, (a, b))

@ptrblck
Copy link
Collaborator

ptrblck commented Dec 1, 2023

@malfet I cannot reproduce the regression using 2.1.1+cu121 while the original code shows a slow execution time. Which version and device are you using for the bmm use case?

@malfet
Copy link
Contributor

malfet commented Dec 1, 2023

@ptrblck A100. But how did you install 2.1.1+cu121? Using big wheels or ones from pypi? It will not repro with bigwheels afaik, as cuda-12.1 toolkit is shipped with this older version of BLAS.

$ conda run -n py310-torch200 python mm-perf.py 
torch:  2.0.1+cu117  device:  NVIDIA A100-SXM4-40GB
bmm: 14.89 us
mm: 14.92 us

$ conda run -n py310-torch210 python mm-perf.py 
torch:  2.1.0+cu121  device:  NVIDIA A100-SXM4-40GB
bmm: 209.37 us
mm: 13.5 us

@malfet malfet changed the title [CUDA 12.1 PIP only] pytorch ships with a version of cublas with regression single-batch torch.bmm is significantly slower with cuBLAS>12.1.0 Dec 1, 2023
@atalman
Copy link
Contributor Author

atalman commented Dec 1, 2023

We deprecated big wheels with cu121, and made small wheels be the default on Nov 22. This is probably the reason why this issue was found now.

@gottbrath
Copy link
Contributor

Just a follow up to close the loop on hardware type. The initial report from Daniel was also on A100.

@malfet
Copy link
Contributor

malfet commented Dec 1, 2023

@gottbrath I think HW type is irrelevant, it has the same perf regression everywhere(A100 and H100)

@malfet
Copy link
Contributor

malfet commented Dec 1, 2023

Ok, following patch brings performance really close for both cases:

iff --git a/aten/src/ATen/native/cuda/Blas.cpp b/aten/src/ATen/native/cuda/Blas.cpp
index 38cce45ab6e..872454e9e16 100644
--- a/aten/src/ATen/native/cuda/Blas.cpp
+++ b/aten/src/ATen/native/cuda/Blas.cpp
@@ -382,6 +382,13 @@ const Tensor& baddbmm_out_cuda_impl(const Tensor& result, const Tensor& self, co
     }
   }
 
+  // If batch  is 1 always call addmm_out_cuda_imlp
+  if (result.size(0) == 1) {
+    auto result_s = result.squeeze(0);
+    addmm_out_cuda_impl(result_s, self.squeeze(0), batch1.squeeze(0), batch2.squeeze(0), beta, alpha);
+    return result;
+  }
+
   bool transpose_result = false;
   c10::MaybeOwned<Tensor> result_;
   IntArrayRef result_strides = result.strides();

But I'm a bit concearned there are still perf difference:

torch:  2.2.0a0+git033d7b6  device:  NVIDIA A100-SXM4-40GB
bmm: 18.08 us
mm: 13.75 us
True

@vadimkantorov
Copy link
Contributor

I wonder if more such tests could be added to track perf (even if they are not falling now). E.g. many people have found that PyTorch gemm routines perform sub-optimally for inner-dim-size of 2/3/4 (used for 3d geometry). Even if these are not optimized, might be nice to add measurements for these too

@malfet
Copy link
Contributor

malfet commented Dec 2, 2023

@vadimkantorov do you have list of matrix sizes you would be interested to benchmark?

@vadimkantorov
Copy link
Contributor

vadimkantorov commented Dec 2, 2023

I'll try to search issues for 2/3/4-inner-dim complaints :)

For the others, I wonder if some einsum formulas could be grep'd from GitHub Search across all repos

@atalman atalman modified the milestones: 2.1.2, 2.2.0 Dec 4, 2023
@atalman
Copy link
Contributor Author

atalman commented Dec 4, 2023

@drisspg drisspg added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Dec 4, 2023
malfet added a commit that referenced this issue Dec 5, 2023
@malfet
Copy link
Contributor

malfet commented Dec 7, 2023

Wrote a bit more comprehensive benchmark script , which shows strange perf deviation between cuda-11.8 and cuda-12.1. Numbers below are run on A100:

Shape bmm_time(11.8) mm_time(1.18) slow down(11.8) % bmm_time(12.1) mm_time(12.1) slow down(12.1) %
1x1x4096 12.38 11.96 3.48 15.29 11.98 27.63
1x1x8192 12.26 11.84 3.55 28.29 12.00 135.75
1x1x16384 11.81 11.66 1.29 53.85 12.08 345.66
1x1x32768 12.00 11.81 1.61 105.13 12.23 759.40
1x1x65536 14.82 15.05 -1.48 207.38 12.23 1595.67
1x1x131072 12.02 11.77 2.15 412.48 11.89 3370.08
128x128x128 9.47 9.69 -2.24 9.84 9.35 5.24
256x256x256 12.66 12.60 0.50 10.18 12.51 -18.65
512x512x512 27.34 27.31 0.10 27.24 27.25 -0.04
1024x1024x1024 129.59 129.48 0.08 130.57 129.47 0.85
2048x2048x2048 973.63 973.04 0.06 962.19 972.95 -1.11
129x127x129 9.56 8.97 6.62 10.03 9.53 5.21
257x255x257 12.85 12.78 0.52 10.87 12.32 -11.75
513x511x513 28.99 28.98 0.05 29.04 28.98 0.21
1025x1023x1025 137.92 137.76 0.11 134.73 137.71 -2.16
2049x2047x2049 982.34 982.32 0.00 982.23 982.27 -0.00
4097x3x4097 86.94 86.91 0.03 165.83 86.70 91.28
8193x3x8193 384.38 384.54 -0.04 384.41 384.49 -0.02
16385x3x16385 1106.25 1107.35 -0.10 1345.39 1106.49 21.59
32769x3x32769 4736.79 4737.19 -0.01 4737.26 4737.98 -0.02
65537x3x65537 17368.65 17371.21 -0.01 17372.71 17369.23 0.02
4097x5x4097 87.50 87.49 0.01 165.92 87.28 90.09
8193x5x8193 302.27 302.29 -0.00 384.93 301.81 27.54
16385x5x16385 1107.69 1107.65 0.00 1346.61 1106.80 21.67
32769x5x32769 4743.02 4743.13 -0.00 4742.29 4742.85 -0.01
65537x5x65537 17393.08 17392.32 0.00 17385.43 17385.87 -0.00
4097x7x4097 87.58 87.60 -0.02 166.02 87.27 90.24
8193x7x8193 302.42 302.45 -0.01 385.23 302.38 27.40
16385x7x16385 1106.55 1107.34 -0.07 1347.56 1107.23 21.71
32769x7x32769 4746.99 4746.58 0.01 4745.20 4744.77 0.01
65537x7x65537 17406.08 17424.31 -0.10 17458.41 17463.22 -0.03

pytorchmergebot pushed a commit that referenced this issue Dec 8, 2023
I.e. it feels reasonable to always call `at::cuda::gemm` rather than `at::cuda::bgemm` when num_batches == 1
After the change, benchmarking torch built with CUDA-12 using  [following perf script](https://gist.github.com/malfet/6a17156d7f5663b8b12054a1beff3fe1) on A100  are as follows:
|      Shape     |  bmm_time |  mm_time  | slow down (%) |
| -------------- | --------- | --------- | ------------- |
|    1x1x4096    |   14.18   |   14.31   |     -0.89     |
|    1x1x8192    |   14.37   |   14.37   |     -0.05     |
|   1x1x16384    |   14.03   |   14.12   |     -0.68     |
|   1x1x32768    |   14.19   |   14.24   |     -0.35     |
|   1x1x65536    |   14.85   |   14.52   |     2.30      |
|   1x1x131072   |   14.03   |   14.07   |     -0.33     |
|  128x128x128   |   11.34   |   11.06   |     2.56      |
|  256x256x256   |   14.85   |   14.40   |     3.15      |
|  512x512x512   |   27.22   |   27.22   |     -0.01     |
| 1024x1024x1024 |  129.66   |  129.50   |     0.12      |
| 2048x2048x2048 |  972.18   |  973.24   |     -0.11     |
|  129x127x129   |   11.21   |   11.25   |     -0.39     |
|  257x255x257   |   14.50   |   14.43   |     0.44      |
|  513x511x513   |   29.01   |   29.01   |     0.01      |
| 1025x1023x1025 |  137.65   |  137.64   |     0.01      |
| 2049x2047x2049 |  982.58   |  982.65   |     -0.01     |
|  4097x3x4097   |   86.65   |   86.64   |     0.01      |
|  8193x3x8193   |  384.02   |  383.96   |     0.02      |
| 16385x3x16385  |  1106.73  |  1107.32  |     -0.05     |
| 32769x3x32769  |  4739.49  |  4739.48  |     0.00      |
| 65537x3x65537  | 17377.78  | 17378.74  |     -0.01     |
|  4097x5x4097   |   87.09   |   87.12   |     -0.03     |
|  8193x5x8193   |  301.38   |  301.36   |     0.01      |
| 16385x5x16385  |  1107.38  |  1108.04  |     -0.06     |
| 32769x5x32769  |  4743.73  |  4744.07  |     -0.01     |
| 65537x5x65537  | 17392.32  | 17395.42  |     -0.02     |
|  4097x7x4097   |   87.17   |   87.19   |     -0.02     |
|  8193x7x8193   |  301.94   |  302.00   |     -0.02     |
| 16385x7x16385  |  1107.17  |  1106.79  |     0.03      |
| 32769x7x32769  |  4747.15  |  4747.13  |     0.00      |
| 65537x7x65537  | 17403.85  | 17405.02  |     -0.01     |

Fixes perf problem reported in #114911
Pull Request resolved: #114992
Approved by: https://github.com/Skylion007, https://github.com/eqy
@malfet
Copy link
Contributor

malfet commented Dec 8, 2023

Removing high-pri as I've landed the change to mitigate the issue

dmenig pushed a commit to dmenig/pytorch that referenced this issue Dec 21, 2023
I.e. it feels reasonable to always call `at::cuda::gemm` rather than `at::cuda::bgemm` when num_batches == 1
After the change, benchmarking torch built with CUDA-12 using  [following perf script](https://gist.github.com/malfet/6a17156d7f5663b8b12054a1beff3fe1) on A100  are as follows:
|      Shape     |  bmm_time |  mm_time  | slow down (%) |
| -------------- | --------- | --------- | ------------- |
|    1x1x4096    |   14.18   |   14.31   |     -0.89     |
|    1x1x8192    |   14.37   |   14.37   |     -0.05     |
|   1x1x16384    |   14.03   |   14.12   |     -0.68     |
|   1x1x32768    |   14.19   |   14.24   |     -0.35     |
|   1x1x65536    |   14.85   |   14.52   |     2.30      |
|   1x1x131072   |   14.03   |   14.07   |     -0.33     |
|  128x128x128   |   11.34   |   11.06   |     2.56      |
|  256x256x256   |   14.85   |   14.40   |     3.15      |
|  512x512x512   |   27.22   |   27.22   |     -0.01     |
| 1024x1024x1024 |  129.66   |  129.50   |     0.12      |
| 2048x2048x2048 |  972.18   |  973.24   |     -0.11     |
|  129x127x129   |   11.21   |   11.25   |     -0.39     |
|  257x255x257   |   14.50   |   14.43   |     0.44      |
|  513x511x513   |   29.01   |   29.01   |     0.01      |
| 1025x1023x1025 |  137.65   |  137.64   |     0.01      |
| 2049x2047x2049 |  982.58   |  982.65   |     -0.01     |
|  4097x3x4097   |   86.65   |   86.64   |     0.01      |
|  8193x3x8193   |  384.02   |  383.96   |     0.02      |
| 16385x3x16385  |  1106.73  |  1107.32  |     -0.05     |
| 32769x3x32769  |  4739.49  |  4739.48  |     0.00      |
| 65537x3x65537  | 17377.78  | 17378.74  |     -0.01     |
|  4097x5x4097   |   87.09   |   87.12   |     -0.03     |
|  8193x5x8193   |  301.38   |  301.36   |     0.01      |
| 16385x5x16385  |  1107.38  |  1108.04  |     -0.06     |
| 32769x5x32769  |  4743.73  |  4744.07  |     -0.01     |
| 65537x5x65537  | 17392.32  | 17395.42  |     -0.02     |
|  4097x7x4097   |   87.17   |   87.19   |     -0.02     |
|  8193x7x8193   |  301.94   |  302.00   |     -0.02     |
| 16385x7x16385  |  1107.17  |  1106.79  |     0.03      |
| 32769x7x32769  |  4747.15  |  4747.13  |     0.00      |
| 65537x7x65537  | 17403.85  | 17405.02  |     -0.01     |

Fixes perf problem reported in pytorch#114911
Pull Request resolved: pytorch#114992
Approved by: https://github.com/Skylion007, https://github.com/eqy
atalman pushed a commit to atalman/pytorch that referenced this issue Dec 28, 2023
I.e. it feels reasonable to always call `at::cuda::gemm` rather than `at::cuda::bgemm` when num_batches == 1
After the change, benchmarking torch built with CUDA-12 using  [following perf script](https://gist.github.com/malfet/6a17156d7f5663b8b12054a1beff3fe1) on A100  are as follows:
|      Shape     |  bmm_time |  mm_time  | slow down (%) |
| -------------- | --------- | --------- | ------------- |
|    1x1x4096    |   14.18   |   14.31   |     -0.89     |
|    1x1x8192    |   14.37   |   14.37   |     -0.05     |
|   1x1x16384    |   14.03   |   14.12   |     -0.68     |
|   1x1x32768    |   14.19   |   14.24   |     -0.35     |
|   1x1x65536    |   14.85   |   14.52   |     2.30      |
|   1x1x131072   |   14.03   |   14.07   |     -0.33     |
|  128x128x128   |   11.34   |   11.06   |     2.56      |
|  256x256x256   |   14.85   |   14.40   |     3.15      |
|  512x512x512   |   27.22   |   27.22   |     -0.01     |
| 1024x1024x1024 |  129.66   |  129.50   |     0.12      |
| 2048x2048x2048 |  972.18   |  973.24   |     -0.11     |
|  129x127x129   |   11.21   |   11.25   |     -0.39     |
|  257x255x257   |   14.50   |   14.43   |     0.44      |
|  513x511x513   |   29.01   |   29.01   |     0.01      |
| 1025x1023x1025 |  137.65   |  137.64   |     0.01      |
| 2049x2047x2049 |  982.58   |  982.65   |     -0.01     |
|  4097x3x4097   |   86.65   |   86.64   |     0.01      |
|  8193x3x8193   |  384.02   |  383.96   |     0.02      |
| 16385x3x16385  |  1106.73  |  1107.32  |     -0.05     |
| 32769x3x32769  |  4739.49  |  4739.48  |     0.00      |
| 65537x3x65537  | 17377.78  | 17378.74  |     -0.01     |
|  4097x5x4097   |   87.09   |   87.12   |     -0.03     |
|  8193x5x8193   |  301.38   |  301.36   |     0.01      |
| 16385x5x16385  |  1107.38  |  1108.04  |     -0.06     |
| 32769x5x32769  |  4743.73  |  4744.07  |     -0.01     |
| 65537x5x65537  | 17392.32  | 17395.42  |     -0.02     |
|  4097x7x4097   |   87.17   |   87.19   |     -0.02     |
|  8193x7x8193   |  301.94   |  302.00   |     -0.02     |
| 16385x7x16385  |  1107.17  |  1106.79  |     0.03      |
| 32769x7x32769  |  4747.15  |  4747.13  |     0.00      |
| 65537x7x65537  | 17403.85  | 17405.02  |     -0.01     |

Fixes perf problem reported in pytorch#114911
Pull Request resolved: pytorch#114992
Approved by: https://github.com/Skylion007, https://github.com/eqy
atalman added a commit that referenced this issue Jan 2, 2024
I.e. it feels reasonable to always call `at::cuda::gemm` rather than `at::cuda::bgemm` when num_batches == 1
After the change, benchmarking torch built with CUDA-12 using  [following perf script](https://gist.github.com/malfet/6a17156d7f5663b8b12054a1beff3fe1) on A100  are as follows:
|      Shape     |  bmm_time |  mm_time  | slow down (%) |
| -------------- | --------- | --------- | ------------- |
|    1x1x4096    |   14.18   |   14.31   |     -0.89     |
|    1x1x8192    |   14.37   |   14.37   |     -0.05     |
|   1x1x16384    |   14.03   |   14.12   |     -0.68     |
|   1x1x32768    |   14.19   |   14.24   |     -0.35     |
|   1x1x65536    |   14.85   |   14.52   |     2.30      |
|   1x1x131072   |   14.03   |   14.07   |     -0.33     |
|  128x128x128   |   11.34   |   11.06   |     2.56      |
|  256x256x256   |   14.85   |   14.40   |     3.15      |
|  512x512x512   |   27.22   |   27.22   |     -0.01     |
| 1024x1024x1024 |  129.66   |  129.50   |     0.12      |
| 2048x2048x2048 |  972.18   |  973.24   |     -0.11     |
|  129x127x129   |   11.21   |   11.25   |     -0.39     |
|  257x255x257   |   14.50   |   14.43   |     0.44      |
|  513x511x513   |   29.01   |   29.01   |     0.01      |
| 1025x1023x1025 |  137.65   |  137.64   |     0.01      |
| 2049x2047x2049 |  982.58   |  982.65   |     -0.01     |
|  4097x3x4097   |   86.65   |   86.64   |     0.01      |
|  8193x3x8193   |  384.02   |  383.96   |     0.02      |
| 16385x3x16385  |  1106.73  |  1107.32  |     -0.05     |
| 32769x3x32769  |  4739.49  |  4739.48  |     0.00      |
| 65537x3x65537  | 17377.78  | 17378.74  |     -0.01     |
|  4097x5x4097   |   87.09   |   87.12   |     -0.03     |
|  8193x5x8193   |  301.38   |  301.36   |     0.01      |
| 16385x5x16385  |  1107.38  |  1108.04  |     -0.06     |
| 32769x5x32769  |  4743.73  |  4744.07  |     -0.01     |
| 65537x5x65537  | 17392.32  | 17395.42  |     -0.02     |
|  4097x7x4097   |   87.17   |   87.19   |     -0.02     |
|  8193x7x8193   |  301.94   |  302.00   |     -0.02     |
| 16385x7x16385  |  1107.17  |  1106.79  |     0.03      |
| 32769x7x32769  |  4747.15  |  4747.13  |     0.00      |
| 65537x7x65537  | 17403.85  | 17405.02  |     -0.01     |

Fixes perf problem reported in #114911
Pull Request resolved: #114992
Approved by: https://github.com/Skylion007, https://github.com/eqy

Co-authored-by: Nikita Shulga <nshulga@meta.com>
@huydhn
Copy link
Contributor

huydhn commented Jan 18, 2024

Confirm that this is fixed in the upcoming 2.2.0 release:

On 2.2.0:

torch:  2.2.0+cu121
fn: 208.05 us

The regression is there on 2.1.2:

torch:  2.1.2+cu121
fn: 52402.06 us

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: cuda Related to torch.cuda, and CUDA support in general module: third_party topic: performance topic category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

7 participants