New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Don't use RTLD_GLOBAL to load _C. #31162
Changes from 14 commits
beb3b40
a0484e6
3231d91
a08fa45
0d0f466
1060202
22a565c
4fc627b
b796cb3
417faae
28c0ef5
25f8d0c
95ebee5
53461ff
edfc4ea
a77da74
4170672
9812a35
70dd543
993c810
20872e7
2081c67
8948282
593b661
d79fe49
2771d72
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,6 +13,7 @@ | |
import os | ||
import sys | ||
import platform | ||
import ctypes | ||
from ._utils import _import_dotted_name | ||
from ._utils_internal import get_file_path, prepare_multiprocessing_environment | ||
from .version import __version__ | ||
|
@@ -33,61 +34,44 @@ | |
# Load the extension module | ||
################################################################################ | ||
|
||
# Loading the extension with RTLD_GLOBAL option allows to not link extension | ||
# modules against the _C shared object. Their missing THP symbols will be | ||
# automatically filled by the dynamic loader. | ||
import os as _dl_flags | ||
|
||
# if we have numpy, it *must* be imported before the call to setdlopenflags() | ||
# or there is risk that later c modules will segfault when importing numpy | ||
try: | ||
import numpy as _np | ||
except ImportError: | ||
pass | ||
|
||
if platform.system() == 'Windows': | ||
# first get nvToolsExt PATH | ||
def get_nvToolsExt_path(): | ||
NVTOOLEXT_HOME = _dl_flags.getenv('NVTOOLSEXT_PATH', 'C:\\Program Files\\NVIDIA Corporation\\NvToolsExt') | ||
NVTOOLSEXT_PATH = os.getenv('NVTOOLSEXT_PATH', 'C:\\Program Files\\NVIDIA Corporation\\NvToolsExt') | ||
|
||
if _dl_flags.path.exists(NVTOOLEXT_HOME): | ||
return _dl_flags.path.join(NVTOOLEXT_HOME, 'bin', 'x64') | ||
else: | ||
return '' | ||
if os.path.exists(NVTOOLSEXT_PATH): | ||
nvtoolsext_lib_path = os.path.join(NVTOOLSEXT_PATH, 'bin', 'x64') | ||
else: | ||
nvtoolsext_lib_path = '' | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. A tip that maybe not directly related to this PR: We missed the path of CUDA here. To make it compatible with Python 3.8, we have to do that. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm happy to fix this, but I'm not exactly sure what the suggestion is here. What's the other CUDA path we missed? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we have a variable that records the version of the cuda used during build? If yes, then the answer is quite simple. For example, for CUDA 9.2, just add the following path. cuda_path = os.path.join(os.environ.get('CUDA_PATH_V9_2', r'C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v9.2'), 'bin') |
||
|
||
py_dll_path = _dl_flags.path.join(sys.exec_prefix, 'Library', 'bin') | ||
th_dll_path = _dl_flags.path.join(_dl_flags.path.dirname(__file__), 'lib') | ||
py_dll_path = os.path.join(sys.exec_prefix, 'Library', 'bin') | ||
th_dll_path = os.path.join(os.path.dirname(__file__), 'lib') | ||
|
||
dll_paths = [th_dll_path, py_dll_path, get_nvToolsExt_path(), _dl_flags.environ['PATH']] | ||
dll_paths = [th_dll_path, py_dll_path, nvtoolsext_lib_path, os.environ['PATH']] | ||
|
||
# then add the path to env | ||
_dl_flags.environ['PATH'] = ';'.join(dll_paths) | ||
os.environ['PATH'] = ';'.join(dll_paths) | ||
|
||
else: | ||
# first check if the os package has the required flags | ||
if not hasattr(_dl_flags, 'RTLD_GLOBAL') or not hasattr(_dl_flags, 'RTLD_LAZY'): | ||
try: | ||
# next try if DLFCN exists | ||
import DLFCN as _dl_flags | ||
except ImportError: | ||
# as a last attempt, use compile-time constants | ||
import torch._dl as _dl_flags | ||
|
||
old_flags = sys.getdlopenflags() | ||
sys.setdlopenflags(_dl_flags.RTLD_GLOBAL | _dl_flags.RTLD_LAZY) | ||
# See Note [Global dependencies] | ||
def _load_global_deps(): | ||
if platform.system() == 'Windows': | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What about defining a global variable There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'll do this in a follow up |
||
return | ||
|
||
del _dl_flags | ||
lib_name = 'libtorch_global_deps' + ('.dylib' if platform.system() == 'Darwin' else '.so') | ||
here = os.path.abspath(__file__) | ||
lib_path = os.path.join(os.path.dirname(here), 'lib', lib_name) | ||
|
||
ctypes.CDLL(lib_path, mode=ctypes.RTLD_GLOBAL) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I've been going through this out of curiosity and it got me wondering if this doesn't lead to an eventual There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. https://stackoverflow.com/questions/359498/how-can-i-unload-a-dll-using-ctypes-in-python suggests that it doesn't |
||
|
||
|
||
# See Note [Global dependencies] | ||
_load_global_deps() | ||
|
||
from torch._C import * | ||
|
||
__all__ += [name for name in dir(_C) | ||
if name[0] != '_' and | ||
not name.endswith('Base')] | ||
|
||
if platform.system() != 'Windows': | ||
sys.setdlopenflags(old_flags) | ||
del old_flags | ||
|
||
################################################################################ | ||
# Define basic utilities | ||
################################################################################ | ||
|
This file was deleted.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: UBSAN relies on UBSAN ?