Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

libtorch_cuda.so is missing fast kernels from libcudnn_static.a, therefore statically linked cuDNN could be much slower than dynamically linked #50153

Closed
zasdfgbnm opened this issue Jan 6, 2021 · 11 comments
Labels
high priority module: build Build system issues module: cuda Related to torch.cuda, and CUDA support in general module: cudnn Related to torch.backends.cudnn, and CuDNN support module: performance Issues related to performance, either of kernel code or framework glue triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@zasdfgbnm
Copy link
Collaborator

zasdfgbnm commented Jan 6, 2021

🐛 Bug

libtorch_cuda.so is missing fast kernels from libcudnn_static.a, therefore statically linked cuDNN could be much slower than dynamically linked.

People at NVIDIA found that the following code is much slower on backward when running with statically linked cuDNN compared to dynamically linked one:

import torch
from torch import nn
import time
import pandas as pd

n_trials = 100
warmup_iters = 100
torch.backends.cudnn.benchmark = False

convs = [
    # ===      Input dimensions      ===   === Stride === ===           Kernel          ===
    {'n': 8, 'C': 64,  'H': 64,  'W': 64,  'u': 4, 'v': 4, 'K': 64,  'R': 4, 'S': 4, 'G': 1, 'pad': 1, 'bias': False,},
]

for conv in convs:
    model = nn.Conv2d(
        in_channels=conv['C'],
        out_channels=conv['K'],
        kernel_size=(conv['R'], conv['S']),
        stride=(conv['u'], conv['v']),
        padding=conv['pad'],
        groups=conv['G'],
        bias=conv['bias'],
    ).cuda().half()
    x = torch.randn((conv['n'], conv['C'], conv['H'], conv['W']),
                    device='cuda', dtype=torch.float16, requires_grad=True)
    # Select kernels, get y, dy
    for _ in range(warmup_iters):
        y = model.forward(x)
        dy = torch.randn_like(y)
        y.backward(dy)
    # Time forward pass
    torch.cuda.synchronize()
    t_start = time.perf_counter()
    for _ in range(n_trials):
        y = model.forward(x)
    torch.cuda.synchronize()
    t_end = time.perf_counter()
    dt_fwd = (t_end - t_start) / n_trials
    # Time backward pass
    torch.cuda.synchronize()
    dy = torch.randn_like(y)
    torch.cuda.synchronize()
    t_start = time.perf_counter()
    for _ in range(n_trials):
        y.backward(dy, retain_graph=True)
    torch.cuda.synchronize()
    t_end = time.perf_counter()
    dt_bwd = (t_end - t_start) / n_trials
    conv[f"fwd_{'fp16'}"] = int(dt_fwd*1e6)
    conv[f"bwd_{'fp16'}"] = int(dt_bwd*1e6)
df = pd.DataFrame(convs)
print(repr(df))

The backward for statically linked cuDNN is about 4x slower than the dynamic ones.

Profiling shows that static cuDNN and dynamic cuDNN are calling different kernels:

static:

   n   C   H   W  u  v   K  R  S  G  pad   bias  fwd_fp16  bwd_fp16
0  8  64  64  64  4  4  64  4  4  1    1  False       182       854

CUDA Kernel Statistics:

 Time(%)  Total Time (ns)  Instances   Average   Minimum  Maximum                                                  Name                                                
 -------  ---------------  ---------  ---------  -------  -------  ----------------------------------------------------------------------------------------------------
    83.3      144,834,124        200  724,170.6  721,456  727,633  void cudnn::detail::dgrad_alg1_engine<__half, 512, 6, 5, 3, 3, 3, false, true>(int, int, int, __hal…
     7.0       12,106,391        200   60,532.0   59,970   61,154  ampere_scudnn_128x32_stridedB_splitK_xregs_large_nn_v1                                              
     2.7        4,714,436        400   11,786.1    6,432   19,360  void cudnn::ops::convertTensor_kernel<__half, float, float, 0>(float, __half const*, float, float*,…
     1.7        2,984,743        398    7,499.4    1,728   13,248  void at::native::vectorized_elementwise_kernel<4, at::native::AddFunctor<c10::Half>, at::detail::Ar…
     1.6        2,752,791        400    6,882.0    2,016   12,097  void nchwToNhwcKernel<__half, __half, float, true, false, (cudnnKernelDataType_t)0>(int, int, int, …
     1.4        2,509,566        200   12,547.8   12,448   12,704  ampere_fp16_s16816cudnn_fp16_64x64_ldg8_relu_f2f_exp_stages_64x5_small_nhwc_tn_v1                   
     0.7        1,290,721        200    6,453.6    6,336    6,688  void cudnn::ops::scalePackedTensor_kernel<__half, float>(cudnnTensor4dStruct, __half*, float)       
     0.7        1,210,042        200    6,050.2    5,728    6,177  void cudnn::ops::convertTensor_kernel<float, __half, float, 0>(float, float const*, float, __half*,…
     0.2          385,991        102    3,784.2    3,680   11,328  _ZN2at6native88_GLOBAL__N__64_tmpxft_00029b37_00000000_8_DistributionNormal_compute_75_cpp1_ii_7d80…
     0.2          358,094        200    1,790.5    1,632    1,952  cask_cudnn::computeOffsetsKernel(cask_cudnn::ComputeOffsetsParams)                                  
     0.2          338,760        200    1,693.8    1,632    1,761  void nhwcToNchwKernel<__half, __half, float, true, false, (cudnnKernelDataType_t)0>(int, int, int, …
     0.1          241,859        200    1,209.3    1,151    1,600  cask_cudnn::computeWgradSplitKOffsetsKernel(cask_cudnn::ComputeSplitKOffsetsParams)                 
     0.1          220,261        200    1,101.3    1,056    1,537  cask_cudnn::computeWgradBOffsetsKernel(cask_cudnn::ComputeWgradBOffsetsParams)

dynamic:

   n   C   H   W  u  v   K  R  S  G  pad   bias  fwd_fp16  bwd_fp16
0  8  64  64  64  4  4  64  4  4  1    1  False       161       213

CUDA Kernel Statistics:

 Time(%)  Total Time (ns)  Instances  Average   Minimum  Maximum                                                  Name                                                
 -------  ---------------  ---------  --------  -------  -------  ----------------------------------------------------------------------------------------------------
    24.8        7,749,473        200  38,747.4   37,953   40,033  void xmma_new::gemm::kernel<xmma_new::implicit_gemm::dgrad_indexed::Kernel_traits<xmma_new::Ampere_…
    18.0        5,621,544        800   7,026.9    2,272   12,609  void nchwToNhwcKernel<__half, __half, float, true, false, (cudnnKernelDataType_t)0>(int, int, int, …
    16.1        5,023,863        200  25,119.3   24,832   25,633  void cutlass::Kernel<cutlass_tensorop_f16_s16816fprop_precomputed_f16_128x64_32x6>(cutlass_tensorop…
    12.7        3,963,823        200  19,819.1   19,297   20,224  void xmma_new::gemm::kernel<xmma_new::implicit_gemm::wgrad_indexed::Kernel_traits<xmma_new::Ampere_…
    11.1        3,473,959        398   8,728.5    1,663   17,024  void at::native::vectorized_elementwise_kernel<4, at::native::AddFunctor<c10::Half>, at::detail::Ar…
     9.8        3,068,712        200  15,343.6   15,104   15,585  void foldedNhwcToNchwKernel<__half, __half, float, true, (cudnnKernelDataType_t)0>(int, int, int, i…
     3.5        1,091,794        400   2,729.5    2,400    3,328  void nchwToFoldedNhwcKernel<__half, __half, float, true, (cudnnKernelDataType_t)0>(int, int, int, i…
     2.5          773,168        400   1,932.9    1,824    2,240  void nhwcToNchwKernel<__half, __half, float, true, false, (cudnnKernelDataType_t)0>(int, int, int, …
     1.5          455,047        102   4,461.2    4,320   13,248  _ZN2at6native88_GLOBAL__N__64_tmpxft_00033092_00000000_8_DistributionNormal_compute_75_cpp1_ii_7d80…

I am not 100% sure about the reason, but it seems to me that this is because PyTorch does not copy fast kernels from libcudnn_static.a to libtorch_cuda.so, so cuDNN has to use a slow kernel:

$ cuobjdump -symbols /usr/lib/libcudnn_cnn_infer.so | grep cutlass_tensorop_f16_s16816fprop_precomputed_f16_128x64_32x6
STT_FUNC         STB_GLOBAL STO_ENTRY      _ZN7cutlass6KernelI60cutlass_tensorop_f16_s16816fprop_precomputed_f16_128x64_32x6EEvNT_6ParamsE
STT_FUNC         STB_GLOBAL STO_ENTRY      _ZN7cutlass6KernelI60cutlass_tensorop_f16_s16816fprop_precomputed_f16_128x64_32x6EEvNT_6ParamsE
$ cuobjdump -symbols /usr/lib/libcudnn_static.a | grep cutlass_tensorop_f16_s16816fprop_precomputed_f16_128x64_32x6
STT_FUNC         STB_GLOBAL STO_ENTRY      _ZN7cutlass6KernelI60cutlass_tensorop_f16_s16816fprop_precomputed_f16_128x64_32x6EEvNT_6ParamsE
STT_FUNC         STB_GLOBAL STO_ENTRY      _ZN7cutlass6KernelI60cutlass_tensorop_f16_s16816fprop_precomputed_f16_128x64_32x6EEvNT_6ParamsE
$ cuobjdump -symbols ~/.local/lib/python3.9/site-packages/torch/lib/libtorch_cuda.so | grep cutlass
<no output>
$ cuobjdump -symbols ~/.local/lib/python3.9/site-packages/torch/lib/libtorch_cuda.so | grep dgrad_alg1_engine | head
STT_FUNC         STB_LOCAL  STO_ENTRY      _ZN5cudnn6detail17dgrad_alg1_engineI6__halfLi512ELi6ELi5ELi3ELi3ELi3ELb1ELb1EEEviiiPKT_iS5_iPS3_18kernel_grad_paramsyifi
STT_FUNC         STB_LOCAL  STO_ENTRY      _ZN5cudnn6detail17dgrad_alg1_engineI6__halfLi512ELi6ELi5ELi3ELi3ELi3ELb1ELb0EEEviiiPKT_iS5_iPS3_18kernel_grad_paramsyifi
STT_FUNC         STB_LOCAL  STO_ENTRY      _ZN5cudnn6detail17dgrad_alg1_engineI6__halfLi512ELi6ELi5ELi3ELi3ELi3ELb0ELb1EEEviiiPKT_iS5_iPS3_18kernel_grad_paramsyifi
STT_FUNC         STB_LOCAL  STO_ENTRY      _ZN5cudnn6detail17dgrad_alg1_engineI6__halfLi512ELi6ELi5ELi3ELi3ELi3ELb0ELb0EEEviiiPKT_iS5_iPS3_18kernel_grad_paramsyifi
STT_FUNC         STB_LOCAL  STO_ENTRY      _ZN5cudnn6detail17dgrad_alg1_engineI6__halfLi128ELi5ELi5ELi3ELi3ELi3ELb1ELb1EEEviiiPKT_iS5_iPS3_18kernel_grad_paramsyifi
STT_FUNC         STB_LOCAL  STO_ENTRY      _ZN5cudnn6detail17dgrad_alg1_engineI6__halfLi128ELi5ELi5ELi3ELi3ELi3ELb1ELb0EEEviiiPKT_iS5_iPS3_18kernel_grad_paramsyifi
STT_FUNC         STB_LOCAL  STO_ENTRY      _ZN5cudnn6detail17dgrad_alg1_engineI6__halfLi128ELi5ELi5ELi3ELi3ELi3ELb0ELb1EEEviiiPKT_iS5_iPS3_18kernel_grad_paramsyifi
STT_FUNC         STB_LOCAL  STO_ENTRY      _ZN5cudnn6detail17dgrad_alg1_engineI6__halfLi128ELi5ELi5ELi3ELi3ELi3ELb0ELb0EEEviiiPKT_iS5_iPS3_18kernel_grad_paramsyifi
STT_FUNC         STB_LOCAL  STO_ENTRY      _ZN5cudnn6detail17dgrad_alg1_engineI6__halfLi128ELi6ELi8ELi3ELi3ELi5ELb1ELb1EEEviiiPKT_iS5_iPS3_18kernel_grad_paramsyifi
STT_FUNC         STB_LOCAL  STO_ENTRY      _ZN5cudnn6detail17dgrad_alg1_engineI6__halfLi128ELi6ELi8ELi3ELi3ELi5ELb1ELb0EEEviiiPKT_iS5_iPS3_18kernel_grad_paramsyifi

Environment

Collecting environment information...
PyTorch version: 1.8.0a0+4a6c178
Is debug build: False
CUDA used to build PyTorch: 11.2
ROCM used to build PyTorch: N/A

OS: Arch Linux (x86_64)
GCC version: (GCC) 10.2.0
Clang version: 11.0.0
CMake version: version 3.19.2

Python version: 3.9 (64-bit runtime)
Is CUDA available: True
CUDA runtime version: Could not collect
GPU models and configuration: 
GPU 0: GeForce RTX 3090
GPU 1: GeForce RTX 2080 Ti

Nvidia driver version: 455.45.01
cuDNN version: Probably one of the following:
/usr/lib/libcudnn.so.8.0.5
/usr/lib/libcudnn_adv_infer.so.8.0.5
/usr/lib/libcudnn_adv_train.so.8.0.5
/usr/lib/libcudnn_cnn_infer.so.8.0.5
/usr/lib/libcudnn_cnn_train.so.8.0.5
/usr/lib/libcudnn_ops_infer.so.8.0.5
/usr/lib/libcudnn_ops_train.so.8.0.5
HIP runtime version: N/A
MIOpen runtime version: N/A

Versions of relevant libraries:
[pip3] numpy==1.19.4
[pip3] pytorch-sphinx-theme==0.0.24
[pip3] torch==1.8.0a0
[pip3] torchvision==0.8.2
[conda] Could not collect

cc @ezyang @gchanan @zou3519 @bdhirsh @jbschlosser @anjali411 @Varal7 @malfet @seemethere @walterddr @ngimel @csarofeen @ptrblck @xwang233 @VitalyFedyunin

@zasdfgbnm zasdfgbnm added module: cuda Related to torch.cuda, and CUDA support in general module: cudnn Related to torch.backends.cudnn, and CuDNN support module: performance Issues related to performance, either of kernel code or framework glue module: build Build system issues labels Jan 6, 2021
@malfet malfet added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Jan 6, 2021
@malfet
Copy link
Contributor

malfet commented Jan 6, 2021

@zasdfgbnm Is it only true for fp16 on Ampere or for tensor types/architectures as well?

@malfet
Copy link
Contributor

malfet commented Jan 6, 2021

I do not see cutlass kernels invoked when running on 2080 with shared libcudnn.so.7.5.0:

$ ldd -r /usr/local/lib64/python3.7/site-packages/torch/lib/libtorch_cuda.so|grep cudnn; sudo nvprof python /tmp/cudnn.py 
	libcudnn.so.7 => /lib64/libcudnn.so.7 (0x00007f64ba80c000)
==30579== NVPROF is profiling process 30579, command: python /tmp/cudnn.py
==30579== Profiling application: python /tmp/cudnn.py
==30579== Profiling result:
            Type  Time(%)      Time     Calls       Avg       Min       Max  Name
 GPU activities:   22.76%  23.449ms       200  117.25us  116.00us  119.46us  volta_s884cudnn_fp16_256x128_ldg8_dgrad_exp_interior_nhwc_tt_v1
                   16.05%  16.532ms       200  82.658us  81.632us  84.191us  turing_fp16_s1688cudnn_fp16_256x64_ldg8_relu_f2f_exp_small_nhwc_tn_v1
                   15.01%  15.467ms       800  19.334us  2.4960us  29.536us  void nchwToNhwcKernel<__half, __half, float, bool=1, bool=1>(int, int, int, int, __half const *, __half*, float, float)
                   13.71%  14.122ms       200  70.610us  69.504us  73.376us  turing_s1688cudnn_fp16_128x128_ldg8_wgrad_idx_exp_interior_nhwc_nt_v1
                    7.35%  7.5691ms       400  18.922us  3.8080us  34.272us  void nhwcToNchwKernel<float, __half, float, bool=1, bool=1>(int, int, int, int, float const *, __half*, float, float)
                    6.96%  7.1695ms       200  35.847us  34.848us  36.799us  void foldedNchwToNchwKernel<__half, __half, float, bool=1>(int, int, int, int, int, int, int, __half const *, __half*, int, int, int, int, int, int, int, int, int, int, int, float, float, cudnn::reduced_divisor, cudnn::reduced_divisor, cudnn::reduced_divisor, cudnn::reduced_divisor, cudnn::reduced_divisor)
                    6.57%  6.7711ms       597  11.341us  1.3440us  29.792us  void at::native::vectorized_elementwise_kernel<int=4, at::native::AddFunctor<c10::Half>, at::detail::Array<char*, int=3>>(int, c10::Half, at::native::AddFunctor<c10::Half>)
                    2.97%  3.0633ms       200  15.316us  14.240us  16.832us  _ZN2at6native13reduce_kernelILi512ELi1ENS0_8ReduceOpIN3c104HalfENS0_14func_wrapper_tIS4_ZNS0_11sum_functorIS4_fS4_EclERNS_14TensorIteratorEEUlffE_EEjS4_Li4EEEEEvT1_
                    1.61%  1.6631ms       400  4.1570us  3.7440us  4.9600us  void nchwToNhwcKernel<__half, __half, float, bool=1, bool=0>(int, int, int, int, __half const *, __half*, float, float)
                    1.31%  1.3519ms       200  6.7590us  6.6550us  6.9440us  void at::native::unrolled_elementwise_kernel<at::native::AddFunctor<c10::Half>, at::detail::Array<char*, int=3>, OffsetCalculator<int=2, unsigned int>, OffsetCalculator<int=1, unsigned int>, at::native::memory::LoadWithoutCast, at::native::memory::StoreWithoutCast>(int, c10::Half, at::native::AddFunctor<c10::Half>, char*, int=3, at::detail::Array<char*, int=3>, int=2)
                    1.27%  1.3046ms      1608     811ns     672ns  1.6640us  [CUDA memset]
                    1.26%  1.2989ms       200  6.4940us  6.3360us  6.8160us  void nchwToFoldedNchwKernel<__half, __half, float, bool=1>(int, int, int, int, __half const *, __half*, int, int, int, int, int, int, int, int, int, int, int, float, float, cudnn::reduced_divisor, cudnn::reduced_divisor, cudnn::reduced_divisor, cudnn::reduced_divisor, cudnn::reduced_divisor)
                    0.65%  674.08us       200  3.3700us  3.2000us  10.592us  void nhwcToNchwKernel<__half, __half, float, bool=1, bool=0>(int, int, int, int, __half const *, __half*, float, float)
                    0.60%  616.22us       400  1.5400us  1.5030us  1.6960us  cudnn::gemm::computeOffsetsKernel(cudnn::gemm::ComputeOffsetsParams)
                    0.54%  552.80us       200  2.7640us  2.7200us  2.9120us  void nchwAddPaddingKernel<__half, __half, float, bool=1>(int, int, int, int, int, int, int, int, __half const *, __half*, int, int, int, int, int, float, float, cudnn::reduced_divisor, cudnn::reduced_divisor, cudnn::reduced_divisor)
                    0.39%  403.20us       102  3.9520us  3.7120us  22.368us  _ZN2at6native77_GLOBAL__N__53_tmpxft_0000540f_00000000_6_DistributionNormal_cpp1_ii_7d80e40543distribution_elementwise_grid_stride_kernelIfLi4EZNS0_9templates4cuda20normal_and_transformIN3c104HalfEfLm4EPNS_17CUDAGeneratorImplEZZZNS4_13normal_kernelIS9_EEvRNS_6TensorEddT_ENKUlvE_clEvENKUlvE4_clEvEUlfE_EEvRNS_14TensorIteratorET2_T3_EUlP24curandStatePhilox4_32_10E0_ZNS1_27distribution_nullary_kernelIS7_fLi4ES9_SN_SG_EEvSI_SJ_RKSK_T4_EUlifE_EEviNS_15PhiloxCudaStateET1_SJ_
                    0.39%  401.98us       200  2.0090us  1.8880us  2.1440us  cudnn::gemm::computeWgradOffsetsKernel(cudnn::gemm::ComputeOffsetsParams)
                    0.33%  341.86us       200  1.7090us  1.6320us  1.8560us  void scalePackedTensor_kernel<float, float>(cudnnTensor4dStruct, float*, float)
                    0.24%  243.64us       200  1.2180us  1.1840us  1.3120us  cudnn::gemm::computeBOffsetsKernel(cudnn::gemm::ComputeBOffsetsParams)
                    0.03%  29.152us         4  7.2880us     736ns  26.688us  [CUDA memcpy HtoD]
                    0.01%  8.4160us         2  4.2080us  1.5360us  6.8800us  _ZN2at6native27unrolled_elementwise_kernelIZZZNS0_21copy_device_to_deviceERNS_14TensorIteratorEbENKUlvE0_clEvENKUlvE18_clEvEUlN3c104HalfEE_NS_6detail5ArrayIPcLi2EEE23TrivialOffsetCalculatorILi1EjESE_NS0_6memory12LoadWithCastILi1EEENSF_13StoreWithCastEEEviT_T0_T1_T2_T3_T4_
      API calls:   59.82%  2.31092s        17  135.94ms  6.2100us  2.31004s  cudaMalloc
                   37.21%  1.43739s        16  89.837ms  1.3280us  1.43732s  cudaStreamCreateWithFlags
                    1.61%  62.140ms      5101  12.181us  3.9240us  423.50us  cudaLaunchKernel
                    0.64%  24.742ms      1608  15.386us  3.2920us  68.570us  cudaMemsetAsync
                    0.34%  13.132ms     16583     791ns     197ns  894.64us  cudaGetDevice
                    0.08%  3.2837ms       600  5.4720us  1.8020us  23.735us  cudaFuncGetAttributes
                    0.07%  2.7987ms         5  559.73us  2.5070us  2.7492ms  cudaDeviceSynchronize
                    0.05%  2.0204ms      5703     354ns      77ns  17.001us  cudaGetLastError
                    0.05%  1.8670ms       600  3.1110us     809ns  18.436us  cudaEventRecord
                    0.05%  1.7425ms      2458     708ns     198ns  18.551us  cudaDeviceGetAttribute
                    0.02%  776.88us       600  1.2940us     335ns  2.8390us  cudaOccupancyMaxActiveBlocksPerMultiprocessorWithFlags
                    0.02%  582.60us         2  291.30us  53.778us  528.83us  cudaHostAlloc
                    0.01%  476.64us         3  158.88us  125.31us  177.84us  cuDeviceTotalMem
                    0.01%  397.86us       285  1.3960us      94ns  62.333us  cuDeviceGetAttribute
                    0.01%  293.13us       340     862ns     254ns  3.9730us  cudaFuncSetAttribute
                    0.01%  251.09us         2  125.54us  124.85us  126.24us  cudaGetDeviceProperties
                    0.01%  215.19us         8  26.898us  1.2450us  192.15us  cudaStreamCreateWithPriority
                    0.00%  63.562us         2  31.781us  9.6760us  53.886us  cudaMemcpyAsync
                    0.00%  60.175us         3  20.058us  13.462us  23.445us  cuDeviceGetName
                    0.00%  55.437us        56     989ns     298ns  6.9090us  cudaEventCreateWithFlags
                    0.00%  50.027us         2  25.013us  17.503us  32.524us  cudaMemcpy
                    0.00%  34.373us         2  17.186us  3.8210us  30.552us  cudaStreamSynchronize
                    0.00%  7.3260us         8     915ns     261ns  1.7410us  cudaFree
                    0.00%  3.8780us         2  1.9390us  1.5740us  2.3040us  cuInit
                    0.00%  3.8350us         1  3.8350us  3.8350us  3.8350us  cuDeviceGetPCIBusId
                    0.00%  3.7790us         2  1.8890us  1.2210us  2.5580us  cudaHostGetDevicePointer
                    0.00%  3.0750us         2  1.5370us  1.1440us  1.9310us  cudaDeviceGetStreamPriorityRange
                    0.00%  2.2640us         2  1.1320us     561ns  1.7030us  cudaSetDevice
                    0.00%  1.8610us         6     310ns     106ns     652ns  cudaGetDeviceCount
                    0.00%  1.6420us         5     328ns     116ns     935ns  cuDeviceGetCount
                    0.00%     970ns         4     242ns     125ns     460ns  cuDeviceGet
                    0.00%     554ns         3     184ns     151ns     219ns  cuDeviceGetUuid
                    0.00%     465ns         2     232ns     219ns     246ns  cuDriverGetVersion

@ptrblck
Copy link
Collaborator

ptrblck commented Jan 7, 2021

@malfet Thanks for the quick test!
We are currently trying to scope this issue and are investigating the drop of these kernels from the statically built lib with the cudnn team.

@zasdfgbnm
Copy link
Collaborator Author

Update: cudnn static linking should use --whole-archive, otherwise there will be kernels missing.

@malfet
Copy link
Contributor

malfet commented Jan 13, 2021

@zasdfgbnm wouldn't that lead to a 50% increase in the binary size? Can we simply reference one symbol that instantiate cutlass kernels from torch_cuda?

@zasdfgbnm
Copy link
Collaborator Author

@malfet Do you know how to tell cmake to replace its /usr/lib/libcudnn_static.a with --whole-archive /usr/lib/libcudnn_static.a --no-whole-archive only on the ld command for libtorch_cuda.so? I tried a few things, and never succeed. With cudnn 8.1, the size increase is 400MB. We may or may not want to ship pip wheels with whole-archive, but at least it worth adding an environmental variable TORCH_CUDNN_WHOLE_ARCHIVE that enables this behavior.

@kehuanfeng
Copy link

kehuanfeng commented Mar 11, 2021

@zasdfgbnm @malfet We are facing this issue too.

Is there any progress upon how to address this?

Additionally, missing CUDA kernel when linking statically with cudnn is a behavior of linker or some kind of fault of NVIDIA cudnn?

@rwightman
Copy link

So, does this issue explain why NGC containers have been consistently faster than the official conda builds for a number of PyTorch versions now? NGC == dynamic link, conda/pip = static w/ this issue? This has pretty significant impact if that is the case.

I ran some benchmarks trying to figure out what was happening as I've kept bumping into it with new releases... https://gist.github.com/rwightman/bb59f9e245162cee0e38bd66bd8cd77f

@ptrblck
Copy link
Collaborator

ptrblck commented Mar 28, 2021

@rwightman (Mostly) for Turing and Ampere these kernels would be missing in the binaries due to this issue, which would explain the performance difference between the binaries and a source build. Besides that the cudnn (CUDA, cublas, etc.) versions could also differ, which might give more performance gains or potential regressions.

@malfet
Copy link
Contributor

malfet commented Jun 9, 2021

Following gist reproduces the problem with static linking..... Whole-library linking or extracting cudnn_static.a and specifying all objects inside of it explicitly fixes the problem

facebook-github-bot pushed a commit that referenced this issue Jun 9, 2021
Summary:
This is only important for builds where cuDNN is linked statically into libtorch_cpu.
Before this PR PyTorch wheels often accidentally contained several partial copies of cudnn_static library.
Splitting the interface into header only (cudnn-public) and library+headers(cudnn-private) prevents those from happening.
Preliminary step towards enabling optional linking whole cudnn_library to workaround issue reported in #50153

Pull Request resolved: #59721

Reviewed By: ngimel

Differential Revision: D29000967

Pulled By: malfet

fbshipit-source-id: f054df92b265e9494076ab16c247427b39da9336
deniskokarev pushed a commit to deniskokarev/pytorch that referenced this issue Jun 9, 2021
Summary:
This is only important for builds where cuDNN is linked statically into libtorch_cpu.
Before this PR PyTorch wheels often accidentally contained several partial copies of cudnn_static library.
Splitting the interface into header only (cudnn-public) and library+headers(cudnn-private) prevents those from happening.
Preliminary step towards enabling optional linking whole cudnn_library to workaround issue reported in pytorch#50153

Pull Request resolved: pytorch#59721

Reviewed By: ngimel

Differential Revision: D29000967

Pulled By: malfet

fbshipit-source-id: f054df92b265e9494076ab16c247427b39da9336
facebook-github-bot pushed a commit that referenced this issue Jun 10, 2021
Summary:
It is only enabled if USE_STATIC_CUDNN is enabled

Next step after #59721 towards resolving fast kernels stripping reported in #50153

Pull Request resolved: #59744

Reviewed By: seemethere, ngimel

Differential Revision: D29007314

Pulled By: malfet

fbshipit-source-id: 7091e299c0c6cc2a8aa82fbf49312cecf3bb861a
malfet added a commit to malfet/pytorch that referenced this issue Jun 11, 2021
Summary:
This is only important for builds where cuDNN is linked statically into libtorch_cpu.
Before this PR PyTorch wheels often accidentally contained several partial copies of cudnn_static library.
Splitting the interface into header only (cudnn-public) and library+headers(cudnn-private) prevents those from happening.
Preliminary step towards enabling optional linking whole cudnn_library to workaround issue reported in pytorch#50153

Pull Request resolved: pytorch#59721

Reviewed By: ngimel

Differential Revision: D29000967

Pulled By: malfet

fbshipit-source-id: f054df92b265e9494076ab16c247427b39da9336
malfet added a commit to malfet/pytorch that referenced this issue Jun 11, 2021
Summary:
It is only enabled if USE_STATIC_CUDNN is enabled

Next step after pytorch#59721 towards resolving fast kernels stripping reported in pytorch#50153

Pull Request resolved: pytorch#59744

Reviewed By: seemethere, ngimel

Differential Revision: D29007314

Pulled By: malfet

fbshipit-source-id: 7091e299c0c6cc2a8aa82fbf49312cecf3bb861a
malfet added a commit to malfet/pytorch that referenced this issue Jun 11, 2021
Summary:
Fixes pytorch#50153

Pull Request resolved: pytorch#59802

Reviewed By: driazati, seemethere

Differential Revision: D29033537

Pulled By: malfet

fbshipit-source-id: e816fc71f273ae0b4ba8a0621d5368a2078561a1
malfet added a commit that referenced this issue Jun 11, 2021
* Move cublas dependency after CuDNN (#58287)

Summary:
Library linking order matters during static linking
Not sure whether its a bug or a feature, but if cublas is reference
before CuDNN, it will be partially statically linked into the library,
even if it is not used

Pull Request resolved: #58287

Reviewed By: janeyx99

Differential Revision: D28433165

Pulled By: malfet

fbshipit-source-id: 8dffa0533075126dc383428f838f7d048074205c

* [CMake] Split caffe2::cudnn into public and private (#59721)

Summary:
This is only important for builds where cuDNN is linked statically into libtorch_cpu.
Before this PR PyTorch wheels often accidentally contained several partial copies of cudnn_static library.
Splitting the interface into header only (cudnn-public) and library+headers(cudnn-private) prevents those from happening.
Preliminary step towards enabling optional linking whole cudnn_library to workaround issue reported in #50153

Pull Request resolved: #59721

Reviewed By: ngimel

Differential Revision: D29000967

Pulled By: malfet

fbshipit-source-id: f054df92b265e9494076ab16c247427b39da9336

* Add USE_WHOLE_CUDNN option (#59744)

Summary:
It is only enabled if USE_STATIC_CUDNN is enabled

Next step after #59721 towards resolving fast kernels stripping reported in #50153

Pull Request resolved: #59744

Reviewed By: seemethere, ngimel

Differential Revision: D29007314

Pulled By: malfet

fbshipit-source-id: 7091e299c0c6cc2a8aa82fbf49312cecf3bb861a

* [Binary] Link whole CuDNN for CUDA-11.1 (#59802)

Summary:
Fixes #50153

Pull Request resolved: #59802

Reviewed By: driazati, seemethere

Differential Revision: D29033537

Pulled By: malfet

fbshipit-source-id: e816fc71f273ae0b4ba8a0621d5368a2078561a1
@jramapuram
Copy link

Is statically linking still necessary on Pytorch 1.12.1 w/ CUDA 11.6+ for accurate A100 performance @malfet? The option to statically link is currently broken as described in #81692.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
high priority module: build Build system issues module: cuda Related to torch.cuda, and CUDA support in general module: cudnn Related to torch.backends.cudnn, and CuDNN support module: performance Issues related to performance, either of kernel code or framework glue triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging a pull request may close this issue.

8 participants