Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Workaround for CuDNN-8.7+ load bug #98644

Closed
wants to merge 1 commit into from

Conversation

malfet
Copy link
Contributor

@malfet malfet commented Apr 7, 2023

Preload cudnn_cnn_infer and consume dlerror to prevent spurious call to abort() from libcudnn.so.8, if libnvrtc.so is missing on the system.

Fixes #97041

@pytorch-bot
Copy link

pytorch-bot bot commented Apr 7, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/98644

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit e75d472:
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@malfet malfet added this to the 2.0.1 milestone Apr 7, 2023
@malfet malfet added topic: bug fixes topic category release notes: cudnn ciflow/binaries_wheel Trigger binary build and upload jobs for wheel on the PR labels Apr 7, 2023
@malfet malfet requested a review from ngimel April 9, 2023 17:59
@ptrblck
Copy link
Collaborator

ptrblck commented Apr 10, 2023

I've tried to verify the fix using the built pip wheels from https://gha-artifacts.s3.amazonaws.com/pytorch/pytorch/4642010532/linux-bionic-cuda11.8-py3.10-gcc7-sm86/artifacts.zip on nvidia/cuda:12.0.1-cudnn8-runtime-ubuntu22.04 which reproduces the originally reported error using:

python3 -c "import torch; print(torch.__config__.show());conv=torch.nn.Conv2d(3,3,3).cuda(); out=conv(torch.rand(1, 3, 24, 24, device='cuda'))"

which now fails with:

Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/usr/local/lib/python3.10/dist-packages/torch/__init__.py", line 228, in <module>
    _load_global_deps()
  File "/usr/local/lib/python3.10/dist-packages/torch/__init__.py", line 187, in _load_global_deps
    raise err
  File "/usr/local/lib/python3.10/dist-packages/torch/__init__.py", line 168, in _load_global_deps
    ctypes.CDLL(lib_path, mode=ctypes.RTLD_GLOBAL)
  File "/usr/lib/python3.10/ctypes/__init__.py", line 374, in __init__
    self._handle = _dlopen(self._name, mode)
OSError: libmpi_cxx.so.20: cannot open shared object file: No such file or directory

Indeed mpi is linked as a dependency:

ldd /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_global_deps.so 
	linux-vdso.so.1 (0x00007ffe9bfa5000)
	libmpi_cxx.so.20 => not found
	libmpi.so.20 => not found
	libmkl_intel_lp64.so.1 => not found
	libmkl_gnu_thread.so.1 => not found
	libmkl_core.so.1 => not found
	libpthread.so.0 => /lib/x86_64-linux-gnu/libpthread.so.0 (0x00007f58330df000)
	libm.so.6 => /lib/x86_64-linux-gnu/libm.so.6 (0x00007f5832d19000)
	libdl.so.2 => /lib/x86_64-linux-gnu/libdl.so.2 (0x00007f58330da000)
	libcufft.so.10 => not found
	libcurand.so.10 => /usr/local/cuda/lib64/libcurand.so.10 (0x00007f582c800000)
	libcublas.so.11 => not found
	libcublasLt.so.11 => not found
	libcudart.so.11.0 => not found
	libnvToolsExt.so.1 => /usr/local/cuda/lib64/libnvToolsExt.so.1 (0x00007f582c400000)
	libgomp.so.1 => /lib/x86_64-linux-gnu/libgomp.so.1 (0x00007f583308e000)
	libc.so.6 => /lib/x86_64-linux-gnu/libc.so.6 (0x00007f582c1d8000)
	/lib64/ld-linux-x86-64.so.2 (0x00007f58330ec000)
	librt.so.1 => /lib/x86_64-linux-gnu/librt.so.1 (0x00007f5833087000)
	libgcc_s.so.1 => /lib/x86_64-linux-gnu/libgcc_s.so.1 (0x00007f5833067000)
	libstdc++.so.6 => /lib/x86_64-linux-gnu/libstdc++.so.6 (0x00007f582bfae000)

but given the binaries are only ~267MB and other libs are missing, I think my workflow to verify the fix might be wrong or I'm using the wrong artifact binaries.

@malfet
Copy link
Contributor Author

malfet commented Apr 10, 2023

Downloaded build artifact as gh run download 4642013900 --name "manywheel-py3_10-cuda11_8", copied inside ubuntu:22.04 container, installed it as follows:

# pip install torch-2.1.0.dev20230407+cu118-cp310-cp310-linux_x86_64.whl --index-url https://download.pytorch.org/whl/nightly/cu118
Looking in indexes: https://download.pytorch.org/whl/nightly/cu118
Processing ./torch-2.1.0.dev20230407+cu118-cp310-cp310-linux_x86_64.whl
Collecting typing-extensions
  Downloading https://download.pytorch.org/whl/nightly/typing_extensions-4.4.0-py3-none-any.whl (26 kB)
Collecting filelock
  Downloading https://download.pytorch.org/whl/nightly/filelock-3.9.0-py3-none-any.whl (9.7 kB)
Collecting pytorch-triton==2.1.0+46672772b4
  Downloading https://download.pytorch.org/whl/nightly/pytorch_triton-2.1.0%2B46672772b4-cp310-cp310-linux_x86_64.whl (87.2 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 87.2/87.2 MB 24.8 MB/s eta 0:00:00
Collecting jinja2
  Downloading https://download.pytorch.org/whl/nightly/Jinja2-3.1.2-py3-none-any.whl (133 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 133.1/133.1 KB 25.8 MB/s eta 0:00:00
Collecting sympy
  Downloading https://download.pytorch.org/whl/nightly/sympy-1.11.1-py3-none-any.whl (6.5 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 6.5/6.5 MB 100.7 MB/s eta 0:00:00
Collecting networkx
  Downloading https://download.pytorch.org/whl/nightly/networkx-2.6.3-py3-none-any.whl (1.9 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.9/1.9 MB 91.3 MB/s eta 0:00:00
Collecting MarkupSafe>=2.0
  Downloading https://download.pytorch.org/whl/nightly/MarkupSafe-2.1.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (25 kB)
Collecting mpmath>=0.19
  Downloading https://download.pytorch.org/whl/nightly/mpmath-1.2.1-py3-none-any.whl (532 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 532.6/532.6 KB 65.6 MB/s eta 0:00:00
Installing collected packages: mpmath, typing-extensions, sympy, networkx, MarkupSafe, filelock, pytorch-triton, jinja2, torch
Successfully installed MarkupSafe-2.1.2 filelock-3.9.0 jinja2-3.1.2 mpmath-1.2.1 networkx-2.6.3 pytorch-triton-2.1.0+46672772b4 sympy-1.11.1 torch-2.1.0.dev20230407+cu118 typing-extensions-4.4.0
WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv

And tested using the standard script:

root@e259dbb7ce8a:~# python3 -c "import torch; print(torch.__config__.show());conv=torch.nn.Conv2d(3,3,3).cuda(); out=conv(torch.rand(1, 3, 24, 24, device='cuda'))"
PyTorch built with:
  - GCC 9.3
  - C++ Version: 201703
  - Intel(R) oneAPI Math Kernel Library Version 2022.2-Product Build 20220804 for Intel(R) 64 architecture applications
  - Intel(R) MKL-DNN v2.7.3 (Git Hash 6dbeffbae1f23cbbeae17adb7b5b13f1f37c080e)
  - OpenMP 201511 (a.k.a. OpenMP 4.5)
  - LAPACK is enabled (usually provided by MKL)
  - NNPACK is enabled
  - CPU capability usage: AVX2
  - CUDA Runtime 11.8
  - NVCC architecture flags: -gencode;arch=compute_50,code=sm_50;-gencode;arch=compute_60,code=sm_60;-gencode;arch=compute_70,code=sm_70;-gencode;arch=compute_75,code=sm_75;-gencode;arch=compute_80,code=sm_80;-gencode;arch=compute_86,code=sm_86;-gencode;arch=compute_37,code=sm_37;-gencode;arch=compute_90,code=sm_90
  - CuDNN 8.7
  - Magma 2.6.1
  - Build settings: BLAS_INFO=mkl, BUILD_TYPE=Release, CUDA_VERSION=11.8, CUDNN_VERSION=8.7.0, CXX_COMPILER=/opt/rh/devtoolset-9/root/usr/bin/c++, CXX_FLAGS= -D_GLIBCXX_USE_CXX11_ABI=0 -fabi-version=11 -fvisibility-inlines-hidden -DUSE_PTHREADPOOL -DNDEBUG -DUSE_KINETO -DLIBKINETO_NOROCTRACER -DUSE_FBGEMM -DUSE_QNNPACK -DUSE_PYTORCH_QNNPACK -DUSE_XNNPACK -DSYMBOLICATE_MOBILE_DEBUG_HANDLE -O2 -fPIC -Wall -Wextra -Werror=return-type -Werror=non-virtual-dtor -Werror=bool-operation -Wnarrowing -Wno-missing-field-initializers -Wno-type-limits -Wno-array-bounds -Wno-unknown-pragmas -Wno-unused-parameter -Wno-unused-function -Wno-unused-result -Wno-strict-overflow -Wno-strict-aliasing -Wno-stringop-overflow -Wno-psabi -Wno-error=pedantic -Wno-error=old-style-cast -Wno-invalid-partial-specialization -Wno-unused-private-field -Wno-aligned-allocation-unavailable -Wno-missing-braces -fdiagnostics-color=always -faligned-new -Wno-unused-but-set-variable -Wno-maybe-uninitialized -fno-math-errno -fno-trapping-math -Werror=format -Werror=cast-function-type -Wno-stringop-overflow, LAPACK_INFO=mkl, PERF_WITH_AVX=1, PERF_WITH_AVX2=1, PERF_WITH_AVX512=1, TORCH_DISABLE_GPU_ASSERTS=ON, TORCH_VERSION=2.1.0, USE_CUDA=ON, USE_CUDNN=ON, USE_EXCEPTION_PTR=1, USE_GFLAGS=OFF, USE_GLOG=OFF, USE_MKL=ON, USE_MKLDNN=ON, USE_MPI=OFF, USE_NCCL=1, USE_NNPACK=ON, USE_OPENMP=ON, USE_ROCM=OFF, 

/usr/local/lib/python3.10/dist-packages/torch/nn/modules/conv.py:137: UserWarning: Failed to initialize NumPy: No module named 'numpy' (Triggered internally at ../torch/csrc/utils/tensor_numpy.cpp:84.)
  self.weight = Parameter(torch.empty(
/usr/local/lib/python3.10/dist-packages/torch/nn/modules/conv.py:459: UserWarning: Applied workaround for CuDNN issue, install nvrtc.so (Triggered internally at ../aten/src/ATen/native/cudnn/Conv_v8.cpp:80.)
  return F.conv2d(input, weight, bias, self.stride,

@ptrblck
Copy link
Collaborator

ptrblck commented Apr 10, 2023

Thanks for explaining the gh cli workflow.
I have verified this change also using the resnet code snippet from: #97041 (comment) and see:

/usr/local/lib/python3.10/dist-packages/torch/nn/modules/conv.py:459: UserWarning: Applied workaround for CuDNN issue, install nvrtc.so (Triggered internally at ../aten/src/ATen/native/cudnn/Conv_v8.cpp:80.)
  return F.conv2d(input, weight, bias, self.stride,

@malfet
Copy link
Contributor Author

malfet commented Apr 10, 2023

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Apr 10, 2023
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@malfet malfet deleted the malfet/workaround-for-cudnn-load-bug branch April 11, 2023 03:16
malfet added a commit that referenced this pull request Apr 18, 2023
Preload `cudnn_cnn_infer` and consume `dlerror` to prevent spurious call to `abort()` from `libcudnn.so.8`, if `libnvrtc.so` is missing on the system.

Fixes #97041

Pull Request resolved: #98644
Approved by: https://github.com/ngimel

(cherry picked from commit c00fd71)
atalman pushed a commit that referenced this pull request Apr 19, 2023
Preload `cudnn_cnn_infer` and consume `dlerror` to prevent spurious call to `abort()` from `libcudnn.so.8`, if `libnvrtc.so` is missing on the system.

Fixes #97041

Pull Request resolved: #98644
Approved by: https://github.com/ngimel

(cherry picked from commit c00fd71)
ZainRizvi pushed a commit that referenced this pull request Apr 19, 2023
Preload `cudnn_cnn_infer` and consume `dlerror` to prevent spurious call to `abort()` from `libcudnn.so.8`, if `libnvrtc.so` is missing on the system.

Fixes #97041

Pull Request resolved: #98644
Approved by: https://github.com/ngimel
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/binaries_wheel Trigger binary build and upload jobs for wheel on the PR ciflow/trunk Trigger trunk jobs on your pull request Merged merging release notes: cudnn topic: bug fixes topic category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Convolutions are broken for PyTorch-2.0 CUDA-11.8 wheel builds
4 participants