Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

PyTorch / libtorch executables fail when built against libcuda stub library #35418

Open
hartb opened this issue Mar 25, 2020 · 8 comments 路 May be fixed by #122318
Open

PyTorch / libtorch executables fail when built against libcuda stub library #35418

hartb opened this issue Mar 25, 2020 · 8 comments 路 May be fixed by #122318
Labels
high priority module: binaries Anything related to official binaries that we release to users module: build Build system issues triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@hartb
Copy link
Contributor

hartb commented Mar 25, 2020

馃悰 Bug

PyTorch's cmake setup for CUDA includes a couple stubs/ directories in the search path for libcuda.so:

https://github.com/pytorch/pytorch/blob/v1.4.0/cmake/public/cuda.cmake#L179

Searching for libcuda in the stubs/ directories is needed in various cases, for example when:

  • the build host is GPU-less and has CUDA Toolkit but not the GPU driver installed, or
  • the build toolchain doesn't include normal system search paths that would normally house libcuda. This is likely if the Anaconda toolchain is being used to build PyTorch.

However, PyTorch's build setup tells cmake to add any non-PyTorch-local libraries to built objects' RPATH:

https://github.com/pytorch/pytorch/blob/v1.4.0/cmake/Dependencies.cmake#L14

So the few objects that link directly to libcuda (e.g. lib/libcaffe2_nvrtc.so and bin/test_dist_autograd) may end up with the libcuda stubs/ directory included in RPATH.

During loading, RPATH has priority over normal system search paths and even LD_LIBRARY_PATH. So these objects will prefer the libcuda stub, even on hosts where libcuda is present in the normal location.

The same cmake setup is re-used for libtorch, so the same will occur for libtorch apps as well:

https://discuss.pytorch.org/t/torch-is-available-returns-false/73753

The libcuda stub is not a functional CUDA runtime, so PyTorch and libtorch objects affected by this won't be able to use CUDA even if libcuda is installed in the normal system search path.

This change would prevent a libcuda stubs directory from being added to objects' RPATH, but doesn't seem like a great solution. It abuses a cmake internal mechanism to trick it into avoiding adding the stubs directory.

diff --git a/cmake/public/cuda.cmake b/cmake/public/cuda.cmake
index a5c50b90df..3a89c9e8ad 100644
--- a/cmake/public/cuda.cmake
+++ b/cmake/public/cuda.cmake
@@ -187,6 +187,20 @@ find_library(CUDA_NVRTC_LIB nvrtc
     PATHS ${CUDA_TOOLKIT_ROOT_DIR}
     PATH_SUFFIXES lib lib64 lib/x64)

+# Configuration elsewhere (CMAKE_INSTALL_RPATH_USE_LINK_PATH) arranges that any
+# library used during linking will be added to the objects' RPATHs. That's
+# never correct for a libcuda stub library (the stub library is suitable for
+# linking, but not for runtime, and if present in RPATH will be prioritized
+# ahead of system search path). If libcuda was found in the "stubs" directory,
+# abuse CMAKE_PLATFORM_IMPLICIT_LINK_DIRECTORIES to instruct cmake NOT to
+# include it in RPATH This leaves us without a known path to libcuda in the
+# objects, but there's no help for that if the only libcuda visible at build
+# time is the stub.
+if(CUDA_CUDA_LIB MATCHES "stubs")
+  get_filename_component(LIBCUDA_STUB_DIR ${CUDA_CUDA_LIB} DIRECTORY)
+  list(APPEND CMAKE_PLATFORM_IMPLICIT_LINK_DIRECTORIES ${LIBCUDA_STUB_DIR})
+endif()
+
 # Create new style imported libraries.
 # Several of these libraries have a hardcoded path if CAFFE2_STATIC_LINK_CUDA
 # is set. This path is where sane CUDA installations have their static

I hope someone more familiar with cmake and the general pytorch build setup might have a better idea.

To Reproduce

Steps to reproduce the behavior:

  1. Built pytorch on a system with CUDA Toolkit but not GPU driver installed
  2. objdump -p <install_dir>/lib/libcaffe2_nvrtc.so | grep RPATH

Expected behavior

PyTorch and libtorch objects should build successfully against libcuda stub, and then run successfully against "real" libcuda (at least when libcuda is present in a default system search path).

Environment

  • PyTorch Version (e.g., 1.0): 1.4.0
  • OS (e.g., Linux): Linux
  • How you installed PyTorch (conda, pip, source): source
  • Build command you used (if compiling from source): python setup.py install
  • Python version: Python 3.6.9 :: Anaconda, Inc.
  • CUDA/cuDNN version: 10.2 / 7.6.5
  • GPU models and configuration: V100
  • Any other relevant information: Using Anaconda toolchain to build

cc @ezyang @gchanan @zou3519 @seemethere

@ngimel ngimel added the module: build Build system issues label Mar 25, 2020
@ngimel
Copy link
Collaborator

ngimel commented Mar 25, 2020

Thank you for reporting! Building on a machine without a GPU is a typical way how CI and binaries are built, so it is strange that it is not working.
cc @seemethere

@ngimel ngimel added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Mar 25, 2020
@hartb
Copy link
Contributor Author

hartb commented Mar 26, 2020

It might be an easy problem to miss in CI / test. I can only find 4 test binaries (test_api, test_dist_autograd, test_jit, and torch_shm_manager) and one library (libcaffe2_nvrtc.so) that end up with the stubs/ path hard-coded.

The test binaries tolerate the CUDA failure caused by the stub library, and run to successful completion with just an info message that the CUDA tests are skipped. test_api, for example:

CUDA not available. Disabling CUDA and MultiCUDA tests
Note: Google Test filter = *-*_CUDA:*_MultiCUDA
[==========] Running 744 tests from 31 test cases.
...
[==========] 744 tests from 31 test cases ran. (56442 ms total)
[  PASSED  ] 744 tests.

The libcaffe2_nvrtc.so library case is more interesting... That library is only ever dynamically loaded (i.e. dlopen(), not "normal" dynamic linking) by libtorch.so. But libtorch depends on libcudart which dynamically loads libcuda itself. So I think anthing that would use libcaffe2_nvrtc will, in practice, end up with libcuda already being loaded before libcaffe2_nvrtc's RPATH has any force. So stuff like:

import torch

x = torch.randn(1, device='cuda:0')                                                                                                                                         
print(torch._C._cuda_hasPrimaryContext(0))

Will work correctly despite the stubs/ RPATH in libcaffe2_nvrtc.

So I think the most visible symptom of this will be for binaries built against libtorch, where the bad RPATH entry is incorporated into the new binary itself and so will be in force before libtorch is loaded.

In general, though, the libcuda in stubs/ isn't functional for runtime, so I think the risk of ever having stubs/ in RPATH outweighs the benefits.

@ngimel ngimel added the module: binaries Anything related to official binaries that we release to users label Mar 26, 2020
@ezyang
Copy link
Contributor

ezyang commented Mar 26, 2020

Yeah, this seems like a major problem with the way we're doing builds right now.

@mattip
Copy link
Collaborator

mattip commented Apr 7, 2020

When I build on a system with no cuda drivers but with the cuda toolkit installed, I see a difference between libcaffe2_nvrtc.so and the various binaries. The so has (my cmake uses RUNPATH, and the install cmake command adds $ORIGIN)

./torch/lib/libcaffe2_nvrtc.so:   RUNPATH  /
        $ORIGIN:/usr/local/cuda-10.2/lib64/stubs:/usr/local/cuda-10.2/lib64

where the binaries change the order (only showing one, all the files in ./torch/bin are the same):

./torch/bin/test_api:   RUNPATH /
        $ORIGIN:/usr/local/cuda-10.2/lib64:/usr/local/cuda-10.2/lib64/stubs

I wonder why the order changes. If I understand correctly, these executable should find the real drivers before the fake stub one because of the search order.

@hartb
Copy link
Contributor Author

hartb commented Apr 7, 2020

I don't have an explanation for the ordering difference there, but one observation...

When installed from rpm (at least), the CUDA Toolkit libraries will live in /usr/local/cuda-xxx, but the CUDA driver interface library (libcuda.so) will only be found in in the stubs/ directory and in /usr/lib64/

So I think the 2nd case, with the main cuda-xxx/lib64 directory ahead of stubs/ doesn't actually help with libcuda.so

xwang233 pushed a commit to xwang233/pytorch that referenced this issue Jun 20, 2020
)

Summary:
Closes pytorchgh-35418,

PR pytorchgh-16414 added [the `CMAKE_INSTALL_RPATH_USE_LINK_PATH`directive](https://github.com/pytorch/pytorch/pull/16414/files#diff-dcf5891602b4162c36c2125c806639c5R16) which is non-standard and will cause CMake to write an `RPATH` entry for libraries outside the current build. Removing it leaves an RPATH entry for `$ORIGIN` but removes the entries for things like `/usr/local/cuda-10.2/lib64/stubs:/usr/local/cuda-10.2/lib64` for `libcaffe2_nvrtc.so` on linux.

The added test fails before this PR, passes after. It is equivalent to checking `objdump -p torch/lib/libcaffe2_nvrtc.so | grep RPATH` for an external path to the directory where cuda "lives"

I am not sure if it solve the `rpath/libc++.1.dylib` problem for `_C.cpython-37m-darwin.so` on macOS in issue pytorchgh-36941
Pull Request resolved: pytorch#37737

Differential Revision: D22068657

Pulled By: ezyang

fbshipit-source-id: b04c529572a94363855f1e4dd3e93c9db3c85657
@zou3519
Copy link
Contributor

zou3519 commented Jun 30, 2020

Reopening because #37737 was reverted

@zou3519 zou3519 reopened this Jun 30, 2020
@mattip
Copy link
Collaborator

mattip commented Jun 30, 2020

In #37737 I tried to remove RPATH linking outside of the pytorch directory. This broke conda, since it depends on RPATH linking to conda-env/lib where all the mkl, cuda, and other DLLs live. Perhaps there is a way to be more selective when telling CMake what directories to insert into RPATH: LINK_PATH is too broad.

@ezyang
Copy link
Contributor

ezyang commented Jul 1, 2020

Actually, this RPATH linking also affects me in #40829 since it means that the shm manager is not relocatable.

I think conda is supposed to be able to identify situations like this an automatically update the rpath in executables when it installs conda packages; c.f., https://docs.conda.io/projects/conda-build/en/latest/resources/define-metadata.html#detect-binary-files-with-prefix So maybe we just need to figure out how to make this play ball

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
high priority module: binaries Anything related to official binaries that we release to users module: build Build system issues triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
5 participants