Skip to content

Commit

Permalink
Add option for cpp_extensions to compile standalone executable (#47862)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #47862

Test Plan: Imported from OSS

Reviewed By: ngimel

Differential Revision: D25199265

Pulled By: robieta

fbshipit-source-id: eceb04dea60b82eb10434099639fa3afa61000ca
  • Loading branch information
Taylor Robie authored and facebook-github-bot committed Dec 2, 2020
1 parent 27905df commit 07f038a
Show file tree
Hide file tree
Showing 4 changed files with 143 additions and 42 deletions.
1 change: 1 addition & 0 deletions test/test_determination.py
Expand Up @@ -112,6 +112,7 @@ def test_torch_file(self):
"distributed/test_distributed_fork",
"test_cpp_extensions_aot_ninja",
"test_cpp_extensions_aot_no_ninja",
"test_utils",
"test_determination",
],
)
Expand Down
58 changes: 57 additions & 1 deletion test/test_utils.py
Expand Up @@ -3,17 +3,20 @@
import re
import shutil
import random
import subprocess
import tempfile
import textwrap
import unittest
import torch
import torch.nn as nn
import torch.utils.data
import torch.cuda
from torch.utils.checkpoint import checkpoint, checkpoint_sequential
import torch.utils.cpp_extension
import torch.hub as hub
from torch.autograd._functions.utils import check_onnx_broadcast
from torch.onnx.symbolic_opset9 import _prepare_onnx_paddings
from torch.testing._internal.common_utils import load_tests, retry, IS_SANDCASTLE
from torch.testing._internal.common_utils import load_tests, retry, IS_SANDCASTLE, IS_WINDOWS
from urllib.error import URLError

# load_tests from torch.testing._internal.common_utils is used to automatically filter tests for
Expand Down Expand Up @@ -662,5 +665,58 @@ def forward(self, x):
ms(torch.tensor([False], dtype=torch.bool))


@unittest.skipIf(IS_SANDCASTLE, "cpp_extension is OSS only.")
class TestStandaloneCPPJIT(TestCase):
def test_load_standalone(self):
build_dir = tempfile.mkdtemp()
try:
src_path = os.path.join(build_dir, "main.cpp")
src = textwrap.dedent("""\
#include <iostream>
#include <torch/torch.h>
int main() {
auto x = torch::eye(3);
std::cout << x << std::endl;
}
""")
with open(src_path, "wt") as f:
f.write(src)

exec_path = torch.utils.cpp_extension.load(
"standalone_load_test",
src_path,
build_directory=build_dir,
is_python_module=False,
is_standalone=True,
)

ext = ".exe" if IS_WINDOWS else ""
self.assertEqual(
exec_path,
os.path.join(build_dir, f"standalone_load_test{ext}")
)

for shell in [True, False]:
r = subprocess.run(
[exec_path],
shell=shell,
stdout=subprocess.PIPE,
)
self.assertEqual(r.returncode, 0)
self.assertEqual(
# Windows prints "\r\n" for newlines.
textwrap.dedent(r.stdout.decode("utf-8")).replace("\r\n", "\n"),
textwrap.dedent("""\
1 0 0
0 1 0
0 0 1
[ CPUFloatType{3,3} ]
""")
)

finally:
shutil.rmtree(build_dir)


if __name__ == '__main__':
run_tests()
6 changes: 5 additions & 1 deletion torch/utils/_cpp_extension_versioner.py
Expand Up @@ -38,12 +38,16 @@ def bump_version_if_changed(self,
source_files,
build_arguments,
build_directory,
with_cuda):
with_cuda,
is_python_module,
is_standalone):
hash_value = 0
hash_value = hash_source_files(hash_value, source_files)
hash_value = hash_build_arguments(hash_value, build_arguments)
hash_value = update_hash(hash_value, build_directory)
hash_value = update_hash(hash_value, with_cuda)
hash_value = update_hash(hash_value, is_python_module)
hash_value = update_hash(hash_value, is_standalone)

entry = self.entries.get(name)
if entry is None:
Expand Down
120 changes: 80 additions & 40 deletions torch/utils/cpp_extension.py
Expand Up @@ -23,6 +23,14 @@


IS_WINDOWS = sys.platform == 'win32'
LIB_EXT = '.pyd' if IS_WINDOWS else '.so'
EXEC_EXT = '.exe' if IS_WINDOWS else ''
SHARED_FLAG = '/DLL' if IS_WINDOWS else '-shared'

_HERE = os.path.abspath(__file__)
_TORCH_PATH = os.path.dirname(os.path.dirname(_HERE))
TORCH_LIB_PATH = os.path.join(_TORCH_PATH, 'lib')


def _find_cuda_home() -> Optional[str]:
r'''Finds the CUDA install path.'''
Expand Down Expand Up @@ -400,7 +408,7 @@ def unix_cuda_flags(cflags):
# overriding the option if the user explicitly passed it.
_ccbin = os.getenv("CC")
if (
_ccbin is not None
_ccbin is not None
and not any([flag.startswith('-ccbin') or flag.startswith('--compiler-bindir') for flag in cflags])
):
cflags.extend(['-ccbin', _ccbin])
Expand Down Expand Up @@ -848,9 +856,7 @@ def include_paths(cuda: bool = False) -> List[str]:
Returns:
A list of include path strings.
'''
here = os.path.abspath(__file__)
torch_path = os.path.dirname(os.path.dirname(here))
lib_include = os.path.join(torch_path, 'include')
lib_include = os.path.join(_TORCH_PATH, 'include')
paths = [
lib_include,
# Remove this once torch/torch.h is officially no longer supported for C++ extensions.
Expand Down Expand Up @@ -886,13 +892,8 @@ def library_paths(cuda: bool = False) -> List[str]:
Returns:
A list of library path strings.
'''
paths = []

# We need to link against libtorch.so
here = os.path.abspath(__file__)
torch_path = os.path.dirname(os.path.dirname(here))
lib_path = os.path.join(torch_path, 'lib')
paths.append(lib_path)
paths = [TORCH_LIB_PATH]

if cuda and IS_HIP_EXTENSION:
lib_dir = 'lib'
Expand Down Expand Up @@ -925,6 +926,7 @@ def load(name,
verbose=False,
with_cuda: Optional[bool] = None,
is_python_module=True,
is_standalone=False,
keep_intermediates=True):
r'''
Loads a PyTorch C++ extension just-in-time (JIT).
Expand Down Expand Up @@ -979,14 +981,23 @@ def load(name,
``.cuh`` in ``sources``. Set it to `True`` to force CUDA headers
and libraries to be included.
is_python_module: If ``True`` (default), imports the produced shared
library as a Python module. If ``False``, loads it into the process
as a plain dynamic library.
library as a Python module. If ``False``, behavior depends on
``is_standalone``.
is_standalone: If ``False`` (default) loads the constructed extension
into the process as a plain dynamic library. If ``True``, build a
standalone executable.
Returns:
If ``is_python_module`` is ``True``, returns the loaded PyTorch
extension as a Python module. If ``is_python_module`` is ``False``
returns nothing (the shared library is loaded into the process as a side
effect).
If ``is_python_module`` is ``True``:
Returns the loaded PyTorch extension as a Python module.
If ``is_python_module`` is ``False`` and ``is_standalone`` is ``False``:
Returns nothing. (The shared library is loaded into the process as
a side effect.)
If ``is_standalone`` is ``True``.
Return the path to the executable. (On Windows, TORCH_LIB_PATH is
added to the PATH environment variable as a side effect.)
Example:
>>> from torch.utils.cpp_extension import load
Expand All @@ -1007,6 +1018,7 @@ def load(name,
verbose,
with_cuda,
is_python_module,
is_standalone,
keep_intermediates=keep_intermediates)


Expand Down Expand Up @@ -1155,6 +1167,7 @@ def load_inline(name,
verbose,
with_cuda,
is_python_module,
is_standalone=False,
keep_intermediates=keep_intermediates)


Expand All @@ -1168,7 +1181,11 @@ def _jit_compile(name,
verbose: bool,
with_cuda: Optional[bool],
is_python_module,
is_standalone,
keep_intermediates=True) -> None:
if is_python_module and is_standalone:
raise ValueError("`is_python_module` and `is_standalone` are mutually exclusive.")

if with_cuda is None:
with_cuda = any(map(_is_cuda_file, sources))
with_cudnn = any(['cudnn' in f for f in extra_ldflags or []])
Expand All @@ -1178,7 +1195,9 @@ def _jit_compile(name,
sources,
build_arguments=[extra_cflags, extra_cuda_cflags, extra_ldflags, extra_include_paths],
build_directory=build_directory,
with_cuda=with_cuda
with_cuda=with_cuda,
is_python_module=is_python_module,
is_standalone=is_standalone,
)
if version > 0:
if version != old_version and verbose:
Expand Down Expand Up @@ -1210,7 +1229,8 @@ def _jit_compile(name,
extra_include_paths=extra_include_paths or [],
build_directory=build_directory,
verbose=verbose,
with_cuda=with_cuda)
with_cuda=with_cuda,
is_standalone=is_standalone)
finally:
baton.release()
else:
Expand All @@ -1221,6 +1241,10 @@ def _jit_compile(name,

if verbose:
print(f'Loading extension module {name}...')

if is_standalone:
return _get_exec_path(name, build_directory)

return _import_module_from_library(name, build_directory, is_python_module)


Expand Down Expand Up @@ -1275,7 +1299,8 @@ def _write_ninja_file_and_build_library(
extra_include_paths,
build_directory: str,
verbose: bool,
with_cuda: Optional[bool]) -> None:
with_cuda: Optional[bool],
is_standalone: bool = False) -> None:
verify_ninja_availability()
if IS_WINDOWS:
compiler = os.environ.get('CXX', 'cl')
Expand All @@ -1287,7 +1312,8 @@ def _write_ninja_file_and_build_library(
extra_ldflags = _prepare_ldflags(
extra_ldflags or [],
with_cuda,
verbose)
verbose,
is_standalone)
build_file_path = os.path.join(build_directory, 'build.ninja')
if verbose:
print(f'Emitting ninja build file {build_file_path}...')
Expand All @@ -1301,7 +1327,8 @@ def _write_ninja_file_and_build_library(
extra_cuda_cflags=extra_cuda_cflags or [],
extra_ldflags=extra_ldflags or [],
extra_include_paths=extra_include_paths or [],
with_cuda=with_cuda)
with_cuda=with_cuda,
is_standalone=is_standalone)

if verbose:
print(f'Building extension module {name}...')
Expand Down Expand Up @@ -1334,11 +1361,7 @@ def verify_ninja_availability():
raise RuntimeError("Ninja is required to load C++ extensions")


def _prepare_ldflags(extra_ldflags, with_cuda, verbose):
here = os.path.abspath(__file__)
torch_path = os.path.dirname(os.path.dirname(here))
lib_path = os.path.join(torch_path, 'lib')

def _prepare_ldflags(extra_ldflags, with_cuda, verbose, is_standalone):
if IS_WINDOWS:
python_path = os.path.dirname(sys.executable)
python_lib_path = os.path.join(python_path, 'libs')
Expand All @@ -1353,19 +1376,25 @@ def _prepare_ldflags(extra_ldflags, with_cuda, verbose):
# Related issue: https://github.com/pytorch/pytorch/issues/31611
extra_ldflags.append('-INCLUDE:?warp_size@cuda@at@@YAHXZ')
extra_ldflags.append('torch.lib')
extra_ldflags.append('torch_python.lib')
extra_ldflags.append(f'/LIBPATH:{python_lib_path}')
extra_ldflags.append(f'/LIBPATH:{lib_path}')
extra_ldflags.append(f'/LIBPATH:{TORCH_LIB_PATH}')
if not is_standalone:
extra_ldflags.append('torch_python.lib')
extra_ldflags.append(f'/LIBPATH:{python_lib_path}')

else:
extra_ldflags.append(f'-L{lib_path}')
extra_ldflags.append(f'-L{TORCH_LIB_PATH}')
extra_ldflags.append('-lc10')
if with_cuda:
extra_ldflags.append('-lc10_hip' if IS_HIP_EXTENSION else '-lc10_cuda')
extra_ldflags.append('-ltorch_cpu')
if with_cuda:
extra_ldflags.append('-ltorch_hip' if IS_HIP_EXTENSION else '-ltorch_cuda')
extra_ldflags.append('-ltorch')
extra_ldflags.append('-ltorch_python')
if not is_standalone:
extra_ldflags.append('-ltorch_python')

if is_standalone:
extra_ldflags.append(f"-Wl,-rpath,{TORCH_LIB_PATH}")

if with_cuda:
if verbose:
Expand Down Expand Up @@ -1565,6 +1594,17 @@ def _run_ninja_build(build_directory: str, verbose: bool, error_prefix: str) ->
raise RuntimeError(message) from e


def _get_exec_path(module_name, path):
if IS_WINDOWS and TORCH_LIB_PATH not in os.getenv('PATH', '').split(';'):
torch_lib_in_path = any(
os.path.exists(p) and os.path.samefile(p, TORCH_LIB_PATH)
for p in os.getenv('PATH', '').split(';')
)
if not torch_lib_in_path:
os.environ['PATH'] = f"{TORCH_LIB_PATH};{os.getenv('PATH', '')}"
return os.path.join(path, f'{module_name}{EXEC_EXT}')


def _import_module_from_library(module_name, path, is_python_module):
# https://stackoverflow.com/questions/67631/how-to-import-a-module-given-the-full-path
file, path, description = imp.find_module(module_name, [path])
Expand All @@ -1583,7 +1623,8 @@ def _write_ninja_file_to_build_library(path,
extra_cuda_cflags,
extra_ldflags,
extra_include_paths,
with_cuda) -> None:
with_cuda,
is_standalone) -> None:
extra_cflags = [flag.strip() for flag in extra_cflags]
extra_cuda_cflags = [flag.strip() for flag in extra_cuda_cflags]
extra_ldflags = [flag.strip() for flag in extra_ldflags]
Expand All @@ -1603,8 +1644,10 @@ def _write_ninja_file_to_build_library(path,
user_includes += system_includes
system_includes.clear()

common_cflags = [f'-DTORCH_EXTENSION_NAME={name}']
common_cflags.append('-DTORCH_API_INCLUDE_EXTENSION_H')
common_cflags = []
if not is_standalone:
common_cflags.append(f'-DTORCH_EXTENSION_NAME={name}')
common_cflags.append('-DTORCH_API_INCLUDE_EXTENSION_H')

# Note [Pybind11 ABI constants]
#
Expand Down Expand Up @@ -1674,19 +1717,16 @@ def object_file_path(source_file: str) -> str:
return target

objects = [object_file_path(src) for src in sources]
ldflags = ([] if is_standalone else [SHARED_FLAG]) + extra_ldflags

if IS_WINDOWS:
ldflags = ['/DLL'] + extra_ldflags
else:
ldflags = ['-shared'] + extra_ldflags
# The darwin linker needs explicit consent to ignore unresolved symbols.
if sys.platform.startswith('darwin'):
ldflags.append('-undefined dynamic_lookup')
elif IS_WINDOWS:
ldflags = _nt_quote_args(ldflags)

ext = 'pyd' if IS_WINDOWS else 'so'
library_target = f'{name}.{ext}'
ext = EXEC_EXT if is_standalone else LIB_EXT
library_target = f'{name}{ext}'

_write_ninja_file(
path=path,
Expand Down

0 comments on commit 07f038a

Please sign in to comment.