From 07f038aa9d75bb0befb7fdf4e22b2875cdb0bb49 Mon Sep 17 00:00:00 2001 From: Taylor Robie Date: Tue, 1 Dec 2020 19:56:13 -0800 Subject: [PATCH] Add option for cpp_extensions to compile standalone executable (#47862) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/47862 Test Plan: Imported from OSS Reviewed By: ngimel Differential Revision: D25199265 Pulled By: robieta fbshipit-source-id: eceb04dea60b82eb10434099639fa3afa61000ca --- test/test_determination.py | 1 + test/test_utils.py | 58 +++++++++++- torch/utils/_cpp_extension_versioner.py | 6 +- torch/utils/cpp_extension.py | 120 ++++++++++++++++-------- 4 files changed, 143 insertions(+), 42 deletions(-) diff --git a/test/test_determination.py b/test/test_determination.py index 0f860cab5101..7e9420285e5a 100644 --- a/test/test_determination.py +++ b/test/test_determination.py @@ -112,6 +112,7 @@ def test_torch_file(self): "distributed/test_distributed_fork", "test_cpp_extensions_aot_ninja", "test_cpp_extensions_aot_no_ninja", + "test_utils", "test_determination", ], ) diff --git a/test/test_utils.py b/test/test_utils.py index 1e6449d3764c..5f1e693ab12f 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -3,17 +3,20 @@ import re import shutil import random +import subprocess import tempfile +import textwrap import unittest import torch import torch.nn as nn import torch.utils.data import torch.cuda from torch.utils.checkpoint import checkpoint, checkpoint_sequential +import torch.utils.cpp_extension import torch.hub as hub from torch.autograd._functions.utils import check_onnx_broadcast from torch.onnx.symbolic_opset9 import _prepare_onnx_paddings -from torch.testing._internal.common_utils import load_tests, retry, IS_SANDCASTLE +from torch.testing._internal.common_utils import load_tests, retry, IS_SANDCASTLE, IS_WINDOWS from urllib.error import URLError # load_tests from torch.testing._internal.common_utils is used to automatically filter tests for @@ -662,5 +665,58 @@ def forward(self, x): ms(torch.tensor([False], dtype=torch.bool)) +@unittest.skipIf(IS_SANDCASTLE, "cpp_extension is OSS only.") +class TestStandaloneCPPJIT(TestCase): + def test_load_standalone(self): + build_dir = tempfile.mkdtemp() + try: + src_path = os.path.join(build_dir, "main.cpp") + src = textwrap.dedent("""\ + #include + #include + int main() { + auto x = torch::eye(3); + std::cout << x << std::endl; + } + """) + with open(src_path, "wt") as f: + f.write(src) + + exec_path = torch.utils.cpp_extension.load( + "standalone_load_test", + src_path, + build_directory=build_dir, + is_python_module=False, + is_standalone=True, + ) + + ext = ".exe" if IS_WINDOWS else "" + self.assertEqual( + exec_path, + os.path.join(build_dir, f"standalone_load_test{ext}") + ) + + for shell in [True, False]: + r = subprocess.run( + [exec_path], + shell=shell, + stdout=subprocess.PIPE, + ) + self.assertEqual(r.returncode, 0) + self.assertEqual( + # Windows prints "\r\n" for newlines. + textwrap.dedent(r.stdout.decode("utf-8")).replace("\r\n", "\n"), + textwrap.dedent("""\ + 1 0 0 + 0 1 0 + 0 0 1 + [ CPUFloatType{3,3} ] + """) + ) + + finally: + shutil.rmtree(build_dir) + + if __name__ == '__main__': run_tests() diff --git a/torch/utils/_cpp_extension_versioner.py b/torch/utils/_cpp_extension_versioner.py index cb778ab8923d..958d34ecc71a 100644 --- a/torch/utils/_cpp_extension_versioner.py +++ b/torch/utils/_cpp_extension_versioner.py @@ -38,12 +38,16 @@ def bump_version_if_changed(self, source_files, build_arguments, build_directory, - with_cuda): + with_cuda, + is_python_module, + is_standalone): hash_value = 0 hash_value = hash_source_files(hash_value, source_files) hash_value = hash_build_arguments(hash_value, build_arguments) hash_value = update_hash(hash_value, build_directory) hash_value = update_hash(hash_value, with_cuda) + hash_value = update_hash(hash_value, is_python_module) + hash_value = update_hash(hash_value, is_standalone) entry = self.entries.get(name) if entry is None: diff --git a/torch/utils/cpp_extension.py b/torch/utils/cpp_extension.py index a2f47744c5f3..993b04ca23d8 100644 --- a/torch/utils/cpp_extension.py +++ b/torch/utils/cpp_extension.py @@ -23,6 +23,14 @@ IS_WINDOWS = sys.platform == 'win32' +LIB_EXT = '.pyd' if IS_WINDOWS else '.so' +EXEC_EXT = '.exe' if IS_WINDOWS else '' +SHARED_FLAG = '/DLL' if IS_WINDOWS else '-shared' + +_HERE = os.path.abspath(__file__) +_TORCH_PATH = os.path.dirname(os.path.dirname(_HERE)) +TORCH_LIB_PATH = os.path.join(_TORCH_PATH, 'lib') + def _find_cuda_home() -> Optional[str]: r'''Finds the CUDA install path.''' @@ -400,7 +408,7 @@ def unix_cuda_flags(cflags): # overriding the option if the user explicitly passed it. _ccbin = os.getenv("CC") if ( - _ccbin is not None + _ccbin is not None and not any([flag.startswith('-ccbin') or flag.startswith('--compiler-bindir') for flag in cflags]) ): cflags.extend(['-ccbin', _ccbin]) @@ -848,9 +856,7 @@ def include_paths(cuda: bool = False) -> List[str]: Returns: A list of include path strings. ''' - here = os.path.abspath(__file__) - torch_path = os.path.dirname(os.path.dirname(here)) - lib_include = os.path.join(torch_path, 'include') + lib_include = os.path.join(_TORCH_PATH, 'include') paths = [ lib_include, # Remove this once torch/torch.h is officially no longer supported for C++ extensions. @@ -886,13 +892,8 @@ def library_paths(cuda: bool = False) -> List[str]: Returns: A list of library path strings. ''' - paths = [] - # We need to link against libtorch.so - here = os.path.abspath(__file__) - torch_path = os.path.dirname(os.path.dirname(here)) - lib_path = os.path.join(torch_path, 'lib') - paths.append(lib_path) + paths = [TORCH_LIB_PATH] if cuda and IS_HIP_EXTENSION: lib_dir = 'lib' @@ -925,6 +926,7 @@ def load(name, verbose=False, with_cuda: Optional[bool] = None, is_python_module=True, + is_standalone=False, keep_intermediates=True): r''' Loads a PyTorch C++ extension just-in-time (JIT). @@ -979,14 +981,23 @@ def load(name, ``.cuh`` in ``sources``. Set it to `True`` to force CUDA headers and libraries to be included. is_python_module: If ``True`` (default), imports the produced shared - library as a Python module. If ``False``, loads it into the process - as a plain dynamic library. + library as a Python module. If ``False``, behavior depends on + ``is_standalone``. + is_standalone: If ``False`` (default) loads the constructed extension + into the process as a plain dynamic library. If ``True``, build a + standalone executable. Returns: - If ``is_python_module`` is ``True``, returns the loaded PyTorch - extension as a Python module. If ``is_python_module`` is ``False`` - returns nothing (the shared library is loaded into the process as a side - effect). + If ``is_python_module`` is ``True``: + Returns the loaded PyTorch extension as a Python module. + + If ``is_python_module`` is ``False`` and ``is_standalone`` is ``False``: + Returns nothing. (The shared library is loaded into the process as + a side effect.) + + If ``is_standalone`` is ``True``. + Return the path to the executable. (On Windows, TORCH_LIB_PATH is + added to the PATH environment variable as a side effect.) Example: >>> from torch.utils.cpp_extension import load @@ -1007,6 +1018,7 @@ def load(name, verbose, with_cuda, is_python_module, + is_standalone, keep_intermediates=keep_intermediates) @@ -1155,6 +1167,7 @@ def load_inline(name, verbose, with_cuda, is_python_module, + is_standalone=False, keep_intermediates=keep_intermediates) @@ -1168,7 +1181,11 @@ def _jit_compile(name, verbose: bool, with_cuda: Optional[bool], is_python_module, + is_standalone, keep_intermediates=True) -> None: + if is_python_module and is_standalone: + raise ValueError("`is_python_module` and `is_standalone` are mutually exclusive.") + if with_cuda is None: with_cuda = any(map(_is_cuda_file, sources)) with_cudnn = any(['cudnn' in f for f in extra_ldflags or []]) @@ -1178,7 +1195,9 @@ def _jit_compile(name, sources, build_arguments=[extra_cflags, extra_cuda_cflags, extra_ldflags, extra_include_paths], build_directory=build_directory, - with_cuda=with_cuda + with_cuda=with_cuda, + is_python_module=is_python_module, + is_standalone=is_standalone, ) if version > 0: if version != old_version and verbose: @@ -1210,7 +1229,8 @@ def _jit_compile(name, extra_include_paths=extra_include_paths or [], build_directory=build_directory, verbose=verbose, - with_cuda=with_cuda) + with_cuda=with_cuda, + is_standalone=is_standalone) finally: baton.release() else: @@ -1221,6 +1241,10 @@ def _jit_compile(name, if verbose: print(f'Loading extension module {name}...') + + if is_standalone: + return _get_exec_path(name, build_directory) + return _import_module_from_library(name, build_directory, is_python_module) @@ -1275,7 +1299,8 @@ def _write_ninja_file_and_build_library( extra_include_paths, build_directory: str, verbose: bool, - with_cuda: Optional[bool]) -> None: + with_cuda: Optional[bool], + is_standalone: bool = False) -> None: verify_ninja_availability() if IS_WINDOWS: compiler = os.environ.get('CXX', 'cl') @@ -1287,7 +1312,8 @@ def _write_ninja_file_and_build_library( extra_ldflags = _prepare_ldflags( extra_ldflags or [], with_cuda, - verbose) + verbose, + is_standalone) build_file_path = os.path.join(build_directory, 'build.ninja') if verbose: print(f'Emitting ninja build file {build_file_path}...') @@ -1301,7 +1327,8 @@ def _write_ninja_file_and_build_library( extra_cuda_cflags=extra_cuda_cflags or [], extra_ldflags=extra_ldflags or [], extra_include_paths=extra_include_paths or [], - with_cuda=with_cuda) + with_cuda=with_cuda, + is_standalone=is_standalone) if verbose: print(f'Building extension module {name}...') @@ -1334,11 +1361,7 @@ def verify_ninja_availability(): raise RuntimeError("Ninja is required to load C++ extensions") -def _prepare_ldflags(extra_ldflags, with_cuda, verbose): - here = os.path.abspath(__file__) - torch_path = os.path.dirname(os.path.dirname(here)) - lib_path = os.path.join(torch_path, 'lib') - +def _prepare_ldflags(extra_ldflags, with_cuda, verbose, is_standalone): if IS_WINDOWS: python_path = os.path.dirname(sys.executable) python_lib_path = os.path.join(python_path, 'libs') @@ -1353,11 +1376,13 @@ def _prepare_ldflags(extra_ldflags, with_cuda, verbose): # Related issue: https://github.com/pytorch/pytorch/issues/31611 extra_ldflags.append('-INCLUDE:?warp_size@cuda@at@@YAHXZ') extra_ldflags.append('torch.lib') - extra_ldflags.append('torch_python.lib') - extra_ldflags.append(f'/LIBPATH:{python_lib_path}') - extra_ldflags.append(f'/LIBPATH:{lib_path}') + extra_ldflags.append(f'/LIBPATH:{TORCH_LIB_PATH}') + if not is_standalone: + extra_ldflags.append('torch_python.lib') + extra_ldflags.append(f'/LIBPATH:{python_lib_path}') + else: - extra_ldflags.append(f'-L{lib_path}') + extra_ldflags.append(f'-L{TORCH_LIB_PATH}') extra_ldflags.append('-lc10') if with_cuda: extra_ldflags.append('-lc10_hip' if IS_HIP_EXTENSION else '-lc10_cuda') @@ -1365,7 +1390,11 @@ def _prepare_ldflags(extra_ldflags, with_cuda, verbose): if with_cuda: extra_ldflags.append('-ltorch_hip' if IS_HIP_EXTENSION else '-ltorch_cuda') extra_ldflags.append('-ltorch') - extra_ldflags.append('-ltorch_python') + if not is_standalone: + extra_ldflags.append('-ltorch_python') + + if is_standalone: + extra_ldflags.append(f"-Wl,-rpath,{TORCH_LIB_PATH}") if with_cuda: if verbose: @@ -1565,6 +1594,17 @@ def _run_ninja_build(build_directory: str, verbose: bool, error_prefix: str) -> raise RuntimeError(message) from e +def _get_exec_path(module_name, path): + if IS_WINDOWS and TORCH_LIB_PATH not in os.getenv('PATH', '').split(';'): + torch_lib_in_path = any( + os.path.exists(p) and os.path.samefile(p, TORCH_LIB_PATH) + for p in os.getenv('PATH', '').split(';') + ) + if not torch_lib_in_path: + os.environ['PATH'] = f"{TORCH_LIB_PATH};{os.getenv('PATH', '')}" + return os.path.join(path, f'{module_name}{EXEC_EXT}') + + def _import_module_from_library(module_name, path, is_python_module): # https://stackoverflow.com/questions/67631/how-to-import-a-module-given-the-full-path file, path, description = imp.find_module(module_name, [path]) @@ -1583,7 +1623,8 @@ def _write_ninja_file_to_build_library(path, extra_cuda_cflags, extra_ldflags, extra_include_paths, - with_cuda) -> None: + with_cuda, + is_standalone) -> None: extra_cflags = [flag.strip() for flag in extra_cflags] extra_cuda_cflags = [flag.strip() for flag in extra_cuda_cflags] extra_ldflags = [flag.strip() for flag in extra_ldflags] @@ -1603,8 +1644,10 @@ def _write_ninja_file_to_build_library(path, user_includes += system_includes system_includes.clear() - common_cflags = [f'-DTORCH_EXTENSION_NAME={name}'] - common_cflags.append('-DTORCH_API_INCLUDE_EXTENSION_H') + common_cflags = [] + if not is_standalone: + common_cflags.append(f'-DTORCH_EXTENSION_NAME={name}') + common_cflags.append('-DTORCH_API_INCLUDE_EXTENSION_H') # Note [Pybind11 ABI constants] # @@ -1674,19 +1717,16 @@ def object_file_path(source_file: str) -> str: return target objects = [object_file_path(src) for src in sources] + ldflags = ([] if is_standalone else [SHARED_FLAG]) + extra_ldflags - if IS_WINDOWS: - ldflags = ['/DLL'] + extra_ldflags - else: - ldflags = ['-shared'] + extra_ldflags # The darwin linker needs explicit consent to ignore unresolved symbols. if sys.platform.startswith('darwin'): ldflags.append('-undefined dynamic_lookup') elif IS_WINDOWS: ldflags = _nt_quote_args(ldflags) - ext = 'pyd' if IS_WINDOWS else 'so' - library_target = f'{name}.{ext}' + ext = EXEC_EXT if is_standalone else LIB_EXT + library_target = f'{name}{ext}' _write_ninja_file( path=path,