New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Don't use RTLD_GLOBAL to load _C. #31162
Changes from all commits
beb3b40
a0484e6
3231d91
a08fa45
0d0f466
1060202
22a565c
4fc627b
b796cb3
417faae
28c0ef5
25f8d0c
95ebee5
53461ff
edfc4ea
a77da74
4170672
9812a35
70dd543
993c810
20872e7
2081c67
8948282
593b661
d79fe49
2771d72
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,8 +13,10 @@ | |
import os | ||
import sys | ||
import platform | ||
import ctypes | ||
from ._utils import _import_dotted_name | ||
from ._utils_internal import get_file_path, prepare_multiprocessing_environment | ||
from ._utils_internal import get_file_path, prepare_multiprocessing_environment, \ | ||
USE_RTLD_GLOBAL_WITH_LIBTORCH | ||
from .version import __version__ | ||
from ._six import string_classes as _string_classes | ||
|
||
|
@@ -33,61 +35,81 @@ | |
# Load the extension module | ||
################################################################################ | ||
|
||
# Loading the extension with RTLD_GLOBAL option allows to not link extension | ||
# modules against the _C shared object. Their missing THP symbols will be | ||
# automatically filled by the dynamic loader. | ||
import os as _dl_flags | ||
|
||
# if we have numpy, it *must* be imported before the call to setdlopenflags() | ||
# or there is risk that later c modules will segfault when importing numpy | ||
try: | ||
import numpy as _np | ||
except ImportError: | ||
pass | ||
|
||
if platform.system() == 'Windows': | ||
# first get nvToolsExt PATH | ||
def get_nvToolsExt_path(): | ||
NVTOOLEXT_HOME = _dl_flags.getenv('NVTOOLSEXT_PATH', 'C:\\Program Files\\NVIDIA Corporation\\NvToolsExt') | ||
NVTOOLSEXT_PATH = os.getenv('NVTOOLSEXT_PATH', 'C:\\Program Files\\NVIDIA Corporation\\NvToolsExt') | ||
|
||
if _dl_flags.path.exists(NVTOOLEXT_HOME): | ||
return _dl_flags.path.join(NVTOOLEXT_HOME, 'bin', 'x64') | ||
else: | ||
return '' | ||
if os.path.exists(NVTOOLSEXT_PATH): | ||
nvtoolsext_lib_path = os.path.join(NVTOOLSEXT_PATH, 'bin', 'x64') | ||
else: | ||
nvtoolsext_lib_path = '' | ||
|
||
py_dll_path = _dl_flags.path.join(sys.exec_prefix, 'Library', 'bin') | ||
th_dll_path = _dl_flags.path.join(_dl_flags.path.dirname(__file__), 'lib') | ||
py_dll_path = os.path.join(sys.exec_prefix, 'Library', 'bin') | ||
th_dll_path = os.path.join(os.path.dirname(__file__), 'lib') | ||
|
||
dll_paths = [th_dll_path, py_dll_path, get_nvToolsExt_path(), _dl_flags.environ['PATH']] | ||
dll_paths = [th_dll_path, py_dll_path, nvtoolsext_lib_path, os.environ['PATH']] | ||
|
||
# then add the path to env | ||
_dl_flags.environ['PATH'] = ';'.join(dll_paths) | ||
os.environ['PATH'] = ';'.join(dll_paths) | ||
|
||
else: | ||
# first check if the os package has the required flags | ||
|
||
# See Note [Global dependencies] | ||
def _load_global_deps(): | ||
if platform.system() == 'Windows': | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What about defining a global variable There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'll do this in a follow up |
||
return | ||
|
||
lib_name = 'libtorch_global_deps' + ('.dylib' if platform.system() == 'Darwin' else '.so') | ||
here = os.path.abspath(__file__) | ||
lib_path = os.path.join(os.path.dirname(here), 'lib', lib_name) | ||
|
||
ctypes.CDLL(lib_path, mode=ctypes.RTLD_GLOBAL) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I've been going through this out of curiosity and it got me wondering if this doesn't lead to an eventual There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. https://stackoverflow.com/questions/359498/how-can-i-unload-a-dll-using-ctypes-in-python suggests that it doesn't |
||
|
||
|
||
if (USE_RTLD_GLOBAL_WITH_LIBTORCH or os.getenv('TORCH_USE_RTLD_GLOBAL')) and \ | ||
platform.system() != 'Windows': | ||
# Do it the hard way. You might want to load libtorch with RTLD_GLOBAL in a | ||
# few circumstances: | ||
# | ||
# 1. You're in a build environment (e.g., fbcode) where | ||
# libtorch_global_deps is not available, but you still need | ||
# to get mkl to link in with RTLD_GLOBAL or it will just | ||
# not work. | ||
# | ||
# 2. You're trying to run PyTorch under UBSAN and you need | ||
# to ensure that only one copy of libtorch is loaded, so | ||
# vptr checks work properly | ||
# | ||
# If you're using this setting, you must verify that all the libraries | ||
# you load consistently use the same libstdc++, or you may have | ||
# mysterious segfaults. | ||
# | ||
import os as _dl_flags | ||
if not hasattr(_dl_flags, 'RTLD_GLOBAL') or not hasattr(_dl_flags, 'RTLD_LAZY'): | ||
try: | ||
# next try if DLFCN exists | ||
import DLFCN as _dl_flags | ||
except ImportError: | ||
# as a last attempt, use compile-time constants | ||
import torch._dl as _dl_flags | ||
|
||
old_flags = sys.getdlopenflags() | ||
sys.setdlopenflags(_dl_flags.RTLD_GLOBAL | _dl_flags.RTLD_LAZY) | ||
from torch._C import * | ||
sys.setdlopenflags(old_flags) | ||
del old_flags | ||
del _dl_flags | ||
|
||
del _dl_flags | ||
|
||
from torch._C import * | ||
else: | ||
# Easy way. You want this most of the time, because it will prevent | ||
# C++ symbols from libtorch clobbering C++ symbols from other | ||
# libraries, leading to mysterious segfaults. | ||
# | ||
# See Note [Global dependencies] | ||
_load_global_deps() | ||
from torch._C import * | ||
|
||
__all__ += [name for name in dir(_C) | ||
if name[0] != '_' and | ||
not name.endswith('Base')] | ||
|
||
if platform.system() != 'Windows': | ||
sys.setdlopenflags(old_flags) | ||
del old_flags | ||
|
||
################################################################################ | ||
# Define basic utilities | ||
################################################################################ | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -54,3 +54,4 @@ def get_source_lines_and_file(obj): | |
|
||
TEST_MASTER_ADDR = '127.0.0.1' | ||
TEST_MASTER_PORT = 29500 | ||
USE_RTLD_GLOBAL_WITH_LIBTORCH = False | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why do we have this constant if it's always false? Is this so that you can patch it in fbcode? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yep, fbcode shenanigans |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1 @@ | ||
vptr:libtorch.so | ||
vptr:libtorch_python.so | ||
vptr:libcaffe2.so |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A tip that maybe not directly related to this PR: We missed the path of CUDA here. To make it compatible with Python 3.8, we have to do that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm happy to fix this, but I'm not exactly sure what the suggestion is here. What's the other CUDA path we missed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we have a variable that records the version of the cuda used during build? If yes, then the answer is quite simple. For example, for CUDA 9.2, just add the following path.