-
Notifications
You must be signed in to change notification settings - Fork 21.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.
Already on GitHub? Sign in to your account
PyTorch / libtorch executables fail when built against libcuda stub library #35418
Comments
Thank you for reporting! Building on a machine without a GPU is a typical way how CI and binaries are built, so it is strange that it is not working. |
It might be an easy problem to miss in CI / test. I can only find 4 test binaries ( The test binaries tolerate the CUDA failure caused by the stub library, and run to successful completion with just an info message that the CUDA tests are skipped.
The
Will work correctly despite the So I think the most visible symptom of this will be for binaries built against In general, though, the |
Yeah, this seems like a major problem with the way we're doing builds right now. |
When I build on a system with no cuda drivers but with the cuda toolkit installed, I see a difference between
where the binaries change the order (only showing one, all the files in
I wonder why the order changes. If I understand correctly, these executable should find the real drivers before the fake stub one because of the search order. |
I don't have an explanation for the ordering difference there, but one observation... When installed from So I think the 2nd case, with the main |
) Summary: Closes pytorchgh-35418, PR pytorchgh-16414 added [the `CMAKE_INSTALL_RPATH_USE_LINK_PATH`directive](https://github.com/pytorch/pytorch/pull/16414/files#diff-dcf5891602b4162c36c2125c806639c5R16) which is non-standard and will cause CMake to write an `RPATH` entry for libraries outside the current build. Removing it leaves an RPATH entry for `$ORIGIN` but removes the entries for things like `/usr/local/cuda-10.2/lib64/stubs:/usr/local/cuda-10.2/lib64` for `libcaffe2_nvrtc.so` on linux. The added test fails before this PR, passes after. It is equivalent to checking `objdump -p torch/lib/libcaffe2_nvrtc.so | grep RPATH` for an external path to the directory where cuda "lives" I am not sure if it solve the `rpath/libc++.1.dylib` problem for `_C.cpython-37m-darwin.so` on macOS in issue pytorchgh-36941 Pull Request resolved: pytorch#37737 Differential Revision: D22068657 Pulled By: ezyang fbshipit-source-id: b04c529572a94363855f1e4dd3e93c9db3c85657
Reopening because #37737 was reverted |
In #37737 I tried to remove RPATH linking outside of the pytorch directory. This broke conda, since it depends on RPATH linking to conda-env/lib where all the mkl, cuda, and other DLLs live. Perhaps there is a way to be more selective when telling CMake what directories to insert into RPATH: LINK_PATH is too broad. |
Actually, this RPATH linking also affects me in #40829 since it means that the shm manager is not relocatable. I think conda is supposed to be able to identify situations like this an automatically update the rpath in executables when it installs conda packages; c.f., https://docs.conda.io/projects/conda-build/en/latest/resources/define-metadata.html#detect-binary-files-with-prefix So maybe we just need to figure out how to make this play ball |
馃悰 Bug
PyTorch's cmake setup for CUDA includes a couple
stubs/
directories in the search path forlibcuda.so
:https://github.com/pytorch/pytorch/blob/v1.4.0/cmake/public/cuda.cmake#L179
Searching for libcuda in the stubs/ directories is needed in various cases, for example when:
However, PyTorch's build setup tells cmake to add any non-PyTorch-local libraries to built objects'
RPATH
:https://github.com/pytorch/pytorch/blob/v1.4.0/cmake/Dependencies.cmake#L14
So the few objects that link directly to libcuda (e.g.
lib/libcaffe2_nvrtc.so
andbin/test_dist_autograd
) may end up with the libcuda stubs/ directory included in RPATH.During loading, RPATH has priority over normal system search paths and even
LD_LIBRARY_PATH
. So these objects will prefer the libcuda stub, even on hosts where libcuda is present in the normal location.The same cmake setup is re-used for libtorch, so the same will occur for libtorch apps as well:
https://discuss.pytorch.org/t/torch-is-available-returns-false/73753
The libcuda stub is not a functional CUDA runtime, so PyTorch and libtorch objects affected by this won't be able to use CUDA even if libcuda is installed in the normal system search path.
This change would prevent a libcuda stubs directory from being added to objects' RPATH, but doesn't seem like a great solution. It abuses a cmake internal mechanism to trick it into avoiding adding the stubs directory.
I hope someone more familiar with cmake and the general pytorch build setup might have a better idea.
To Reproduce
Steps to reproduce the behavior:
objdump -p <install_dir>/lib/libcaffe2_nvrtc.so | grep RPATH
Expected behavior
PyTorch and libtorch objects should build successfully against libcuda stub, and then run successfully against "real" libcuda (at least when libcuda is present in a default system search path).
Environment
conda
,pip
, source): sourcepython setup.py install
Python 3.6.9 :: Anaconda, Inc.
cc @ezyang @gchanan @zou3519 @seemethere
The text was updated successfully, but these errors were encountered: