From 19df6430798cccaf4eeb2dda0bf55f35c8f6dc1f Mon Sep 17 00:00:00 2001 From: Ken Leidal Date: Fri, 11 Dec 2020 15:37:03 -0500 Subject: [PATCH 1/2] build both cpu and gpu binaries so same package can run on both CPU and GPU machines --- setup.py | 68 +++++++++++++++++++++------------------ torch_scatter/__init__.py | 6 ++++ 2 files changed, 43 insertions(+), 31 deletions(-) diff --git a/setup.py b/setup.py index 42749208..19a9785f 100644 --- a/setup.py +++ b/setup.py @@ -7,7 +7,7 @@ from torch.utils.cpp_extension import BuildExtension from torch.utils.cpp_extension import CppExtension, CUDAExtension, CUDA_HOME -WITH_CUDA = torch.cuda.is_available() and CUDA_HOME is not None +WITH_CUDA = CUDA_HOME is not None if os.getenv('FORCE_CUDA', '0') == '1': WITH_CUDA = True if os.getenv('FORCE_CPU', '0') == '1': @@ -17,42 +17,48 @@ def get_extensions(): - Extension = CppExtension - define_macros = [] - extra_compile_args = {'cxx': []} + extensions = [] + for with_cuda, supername in [ + (False, "cpu"), + (True, "gpu"), + ]: + if with_cuda and not WITH_CUDA: + continue + Extension = CppExtension + define_macros = [] + extra_compile_args = {'cxx': []} - if WITH_CUDA: - Extension = CUDAExtension - define_macros += [('WITH_CUDA', None)] - nvcc_flags = os.getenv('NVCC_FLAGS', '') - nvcc_flags = [] if nvcc_flags == '' else nvcc_flags.split(' ') - nvcc_flags += ['-arch=sm_35', '--expt-relaxed-constexpr'] - extra_compile_args['nvcc'] = nvcc_flags + if with_cuda: + Extension = CUDAExtension + define_macros += [('WITH_CUDA', None)] + nvcc_flags = os.getenv('NVCC_FLAGS', '') + nvcc_flags = [] if nvcc_flags == '' else nvcc_flags.split(' ') + nvcc_flags += ['-arch=sm_35', '--expt-relaxed-constexpr'] + extra_compile_args['nvcc'] = nvcc_flags - extensions_dir = osp.join(osp.dirname(osp.abspath(__file__)), 'csrc') - main_files = glob.glob(osp.join(extensions_dir, '*.cpp')) - extensions = [] - for main in main_files: - name = main.split(os.sep)[-1][:-4] + extensions_dir = osp.join(osp.dirname(osp.abspath(__file__)), 'csrc') + main_files = glob.glob(osp.join(extensions_dir, '*.cpp')) + for main in main_files: + name = main.split(os.sep)[-1][:-4] - sources = [main] + sources = [main] - path = osp.join(extensions_dir, 'cpu', f'{name}_cpu.cpp') - if osp.exists(path): - sources += [path] + path = osp.join(extensions_dir, 'cpu', f'{name}_cpu.cpp') + if osp.exists(path): + sources += [path] - path = osp.join(extensions_dir, 'cuda', f'{name}_cuda.cu') - if WITH_CUDA and osp.exists(path): - sources += [path] + path = osp.join(extensions_dir, 'cuda', f'{name}_cuda.cu') + if with_cuda and osp.exists(path): + sources += [path] - extension = Extension( - 'torch_scatter._' + name, - sources, - include_dirs=[extensions_dir], - define_macros=define_macros, - extra_compile_args=extra_compile_args, - ) - extensions += [extension] + extension = Extension( + 'torch_scatter._%s_%s' % (name, supername), + sources, + include_dirs=[extensions_dir], + define_macros=define_macros, + extra_compile_args=extra_compile_args, + ) + extensions += [extension] return extensions diff --git a/torch_scatter/__init__.py b/torch_scatter/__init__.py index 21f1fe9e..b61a3b57 100644 --- a/torch_scatter/__init__.py +++ b/torch_scatter/__init__.py @@ -6,8 +6,14 @@ __version__ = '2.0.5' +if torch.cuda.is_available(): + sublib = "gpu" +else: + sublib = "cpu" + try: for library in ['_version', '_scatter', '_segment_csr', '_segment_coo']: + library = "%s_%s" % (library, sublib) torch.ops.load_library(importlib.machinery.PathFinder().find_spec( library, [osp.dirname(__file__)]).origin) except AttributeError as e: From 8a07c869e0f4df7d2e9abc5c7746662c09fc28c8 Mon Sep 17 00:00:00 2001 From: rusty1s Date: Tue, 2 Mar 2021 09:12:36 +0100 Subject: [PATCH 2/2] update --- .travis.yml | 10 ++--- CMakeLists.txt | 2 +- csrc/scatter.cpp | 6 ++- csrc/segment_coo.cpp | 6 ++- csrc/segment_csr.cpp | 6 ++- csrc/version.cpp | 6 ++- script/cuda.sh | 2 +- script/rename_wheel.py | 24 ----------- setup.py | 84 ++++++++++++++++++++++----------------- torch_scatter/__init__.py | 12 ++---- 10 files changed, 79 insertions(+), 79 deletions(-) delete mode 100644 script/rename_wheel.py diff --git a/.travis.yml b/.travis.yml index 6e17dd9b..e9e4e99e 100644 --- a/.travis.yml +++ b/.travis.yml @@ -112,13 +112,13 @@ install: - source script/torch.sh - pip install flake8 - pip install codecov - - pip install .[test] + - travis_wait 30 pip install -e . script: - flake8 . - python setup.py test after_success: - - python setup.py bdist_wheel --dist-dir=dist/torch-${TORCH_VERSION} - - python script/rename_wheel.py ${IDX} + - python setup.py bdist_wheel --dist-dir=dist + - ls -lah dist/ - codecov deploy: provider: s3 @@ -127,8 +127,8 @@ deploy: access_key_id: ${S3_ACCESS_KEY} secret_access_key: ${S3_SECRET_ACCESS_KEY} bucket: pytorch-geometric.com - local_dir: dist/torch-${TORCH_VERSION} - upload_dir: whl/torch-${TORCH_VERSION} + local_dir: dist + upload_dir: whl/torch-${TORCH_VERSION}+${IDX} acl: public_read on: all_branches: true diff --git a/CMakeLists.txt b/CMakeLists.txt index be861817..b50670ae 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,7 +1,7 @@ cmake_minimum_required(VERSION 3.0) project(torchscatter) set(CMAKE_CXX_STANDARD 14) -set(TORCHSCATTER_VERSION 2.0.5) +set(TORCHSCATTER_VERSION 2.0.6) option(WITH_CUDA "Enable CUDA support" OFF) diff --git a/csrc/scatter.cpp b/csrc/scatter.cpp index a5c0c7e8..4a7ba1ff 100644 --- a/csrc/scatter.cpp +++ b/csrc/scatter.cpp @@ -9,7 +9,11 @@ #endif #ifdef _WIN32 -PyMODINIT_FUNC PyInit__scatter(void) { return NULL; } +#ifdef WITH_CUDA +PyMODINIT_FUNC PyInit__scatter_cuda(void) { return NULL; } +#else +PyMODINIT_FUNC PyInit__scatter_cpu(void) { return NULL; } +#endif #endif torch::Tensor broadcast(torch::Tensor src, torch::Tensor other, int64_t dim) { diff --git a/csrc/segment_coo.cpp b/csrc/segment_coo.cpp index 2a06e84c..234f3ee4 100644 --- a/csrc/segment_coo.cpp +++ b/csrc/segment_coo.cpp @@ -9,7 +9,11 @@ #endif #ifdef _WIN32 -PyMODINIT_FUNC PyInit__segment_coo(void) { return NULL; } +#ifdef WITH_CUDA +PyMODINIT_FUNC PyInit__segment_coo_cuda(void) { return NULL; } +#else +PyMODINIT_FUNC PyInit__segment_coo_cpu(void) { return NULL; } +#endif #endif std::tuple> diff --git a/csrc/segment_csr.cpp b/csrc/segment_csr.cpp index 44046ff2..4b2ad08c 100644 --- a/csrc/segment_csr.cpp +++ b/csrc/segment_csr.cpp @@ -9,7 +9,11 @@ #endif #ifdef _WIN32 -PyMODINIT_FUNC PyInit__segment_csr(void) { return NULL; } +#ifdef WITH_CUDA +PyMODINIT_FUNC PyInit__segment_csr_cuda(void) { return NULL; } +#else +PyMODINIT_FUNC PyInit__segment_csr_cpu(void) { return NULL; } +#endif #endif std::tuple> diff --git a/csrc/version.cpp b/csrc/version.cpp index 0bc44861..a003ea81 100644 --- a/csrc/version.cpp +++ b/csrc/version.cpp @@ -6,7 +6,11 @@ #endif #ifdef _WIN32 -PyMODINIT_FUNC PyInit__version(void) { return NULL; } +#ifdef WITH_CUDA +PyMODINIT_FUNC PyInit__version_cuda(void) { return NULL; } +#else +PyMODINIT_FUNC PyInit__version_cpu(void) { return NULL; } +#endif #endif int64_t cuda_version() { diff --git a/script/cuda.sh b/script/cuda.sh index 55cdb9f5..54c84127 100755 --- a/script/cuda.sh +++ b/script/cuda.sh @@ -69,7 +69,7 @@ if [ "${TRAVIS_OS_NAME}" = "osx" ] && [ "$IDX" = "cpu" ]; then fi if [ "${IDX}" = "cpu" ]; then - export FORCE_CPU=1 + export FORCE_ONLY_CPU=1 else export FORCE_CUDA=1 fi diff --git a/script/rename_wheel.py b/script/rename_wheel.py deleted file mode 100644 index 73411591..00000000 --- a/script/rename_wheel.py +++ /dev/null @@ -1,24 +0,0 @@ -import sys -import os -import os.path as osp -import glob -import shutil - -idx = sys.argv[1] -assert idx in ['cpu', 'cu92', 'cu101', 'cu102', 'cu110'] - -dist_dir = osp.join(osp.dirname(osp.abspath(__file__)), '..', 'dist') -wheels = glob.glob(osp.join('dist', '**', '*.whl'), recursive=True) - -for wheel in wheels: - if idx in wheel: - continue - - paths = wheel.split(osp.sep) - names = paths[-1].split('-') - - name = '-'.join(names[:-4] + ['latest+' + idx] + names[-3:]) - shutil.copyfile(wheel, osp.join(*paths[:-1], name)) - - name = '-'.join(names[:-4] + [names[-4] + '+' + idx] + names[-3:]) - os.rename(wheel, osp.join(*paths[:-1], name)) diff --git a/setup.py b/setup.py index 19a9785f..236d12d5 100644 --- a/setup.py +++ b/setup.py @@ -1,64 +1,76 @@ import os -import os.path as osp +import sys import glob +import os.path as osp +from itertools import product from setuptools import setup, find_packages import torch +from torch.__config__ import parallel_info from torch.utils.cpp_extension import BuildExtension from torch.utils.cpp_extension import CppExtension, CUDAExtension, CUDA_HOME -WITH_CUDA = CUDA_HOME is not None +WITH_CUDA = torch.cuda.is_available() and CUDA_HOME is not None +suffices = ['cpu', 'cuda'] if WITH_CUDA else ['cpu'] if os.getenv('FORCE_CUDA', '0') == '1': - WITH_CUDA = True -if os.getenv('FORCE_CPU', '0') == '1': - WITH_CUDA = False + suffices = ['cuda', 'cpu'] +if os.getenv('FORCE_ONLY_CUDA', '0') == '1': + suffices = ['cuda'] +if os.getenv('FORCE_ONLY_CPU', '0') == '1': + suffices = ['cpu'] BUILD_DOCS = os.getenv('BUILD_DOCS', '0') == '1' def get_extensions(): extensions = [] - for with_cuda, supername in [ - (False, "cpu"), - (True, "gpu"), - ]: - if with_cuda and not WITH_CUDA: - continue - Extension = CppExtension + + extensions_dir = osp.join(osp.dirname(osp.abspath(__file__)), 'csrc') + main_files = glob.glob(osp.join(extensions_dir, '*.cpp')) + + for main, suffix in product(main_files, suffices): define_macros = [] - extra_compile_args = {'cxx': []} + extra_compile_args = {'cxx': ['-O2']} + extra_link_args = ['-s'] + + info = parallel_info() + if 'backend: OpenMP' in info and 'OpenMP not found' not in info: + extra_compile_args['cxx'] += ['-DAT_PARALLEL_OPENMP'] + if sys.platform == 'win32': + extra_compile_args['cxx'] += ['/openmp'] + else: + extra_compile_args['cxx'] += ['-fopenmp'] + else: + print('Compiling without OpenMP...') - if with_cuda: - Extension = CUDAExtension + if suffix == 'cuda': define_macros += [('WITH_CUDA', None)] nvcc_flags = os.getenv('NVCC_FLAGS', '') nvcc_flags = [] if nvcc_flags == '' else nvcc_flags.split(' ') nvcc_flags += ['-arch=sm_35', '--expt-relaxed-constexpr'] extra_compile_args['nvcc'] = nvcc_flags - extensions_dir = osp.join(osp.dirname(osp.abspath(__file__)), 'csrc') - main_files = glob.glob(osp.join(extensions_dir, '*.cpp')) - for main in main_files: - name = main.split(os.sep)[-1][:-4] - - sources = [main] + name = main.split(os.sep)[-1][:-4] + sources = [main] - path = osp.join(extensions_dir, 'cpu', f'{name}_cpu.cpp') - if osp.exists(path): - sources += [path] + path = osp.join(extensions_dir, 'cpu', f'{name}_cpu.cpp') + if osp.exists(path): + sources += [path] - path = osp.join(extensions_dir, 'cuda', f'{name}_cuda.cu') - if with_cuda and osp.exists(path): - sources += [path] + path = osp.join(extensions_dir, 'cuda', f'{name}_cuda.cu') + if suffix == 'cuda' and osp.exists(path): + sources += [path] - extension = Extension( - 'torch_scatter._%s_%s' % (name, supername), - sources, - include_dirs=[extensions_dir], - define_macros=define_macros, - extra_compile_args=extra_compile_args, - ) - extensions += [extension] + Extension = CppExtension if suffix == 'cpu' else CUDAExtension + extension = Extension( + f'torch_scatter._{name}_{suffix}', + sources, + include_dirs=[extensions_dir], + define_macros=define_macros, + extra_compile_args=extra_compile_args, + extra_link_args=extra_link_args, + ) + extensions += [extension] return extensions @@ -69,7 +81,7 @@ def get_extensions(): setup( name='torch_scatter', - version='2.0.5', + version='2.0.6', author='Matthias Fey', author_email='matthias.fey@tu-dortmund.de', url='https://github.com/rusty1s/pytorch_scatter', diff --git a/torch_scatter/__init__.py b/torch_scatter/__init__.py index b61a3b57..149e99f1 100644 --- a/torch_scatter/__init__.py +++ b/torch_scatter/__init__.py @@ -4,18 +4,14 @@ import torch -__version__ = '2.0.5' +__version__ = '2.0.6' -if torch.cuda.is_available(): - sublib = "gpu" -else: - sublib = "cpu" +suffix = 'cuda' if torch.cuda.is_available() else 'cpu' try: for library in ['_version', '_scatter', '_segment_csr', '_segment_coo']: - library = "%s_%s" % (library, sublib) torch.ops.load_library(importlib.machinery.PathFinder().find_spec( - library, [osp.dirname(__file__)]).origin) + f'{library}_{suffix}', [osp.dirname(__file__)]).origin) except AttributeError as e: if os.getenv('BUILD_DOCS', '0') != '1': raise AttributeError(e) @@ -45,7 +41,7 @@ torch.ops.torch_scatter.segment_max_coo = segment_coo_arg_placeholder torch.ops.torch_scatter.gather_coo = gather_coo_placeholder -if torch.cuda.is_available() and torch.version.cuda: # pragma: no cover +if torch.cuda.is_available(): # pragma: no cover cuda_version = torch.ops.torch_scatter.cuda_version() if cuda_version == -1: