From be95eea962e1fb80d433d4694f9e3ea37c2ed4ac Mon Sep 17 00:00:00 2001 From: Gerico Vidanes Date: Fri, 22 Jul 2022 18:46:39 +0100 Subject: [PATCH] `WITH_PYTHON` conditionals --- CMakeLists.txt | 11 +++++++++-- csrc/cpu/index_info.h | 2 +- csrc/cpu/scatter_cpu.h | 2 +- csrc/cpu/segment_coo_cpu.h | 2 +- csrc/cpu/segment_csr_cpu.h | 2 +- csrc/cpu/utils.h | 2 +- csrc/cuda/scatter_cuda.h | 2 +- csrc/cuda/segment_coo_cuda.h | 2 +- csrc/cuda/segment_csr_cuda.h | 2 +- csrc/cuda/utils.cuh | 2 +- csrc/extensions.h | 2 ++ csrc/scatter.cpp | 5 +++++ csrc/scatter.h | 4 +--- csrc/segment_coo.cpp | 5 +++++ csrc/segment_csr.cpp | 5 +++++ csrc/version.cpp | 5 +++++ setup.py | 2 +- 17 files changed, 42 insertions(+), 15 deletions(-) create mode 100644 csrc/extensions.h diff --git a/CMakeLists.txt b/CMakeLists.txt index 4c7f9a03..e307f829 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -4,6 +4,7 @@ set(CMAKE_CXX_STANDARD 14) set(TORCHSCATTER_VERSION 2.0.9) option(WITH_CUDA "Enable CUDA support" OFF) +option(WITH_PYTHON "Link to Python when building" ON) if(WITH_CUDA) enable_language(CUDA) @@ -12,7 +13,10 @@ if(WITH_CUDA) set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-relaxed-constexpr") endif() -find_package(Python3 COMPONENTS Development) +if (WITH_PYTHON) + add_definitions(-DWITH_PYTHON) + find_package(Python3 COMPONENTS Development) +endif() find_package(Torch REQUIRED) file(GLOB HEADERS csrc/*.h) @@ -22,7 +26,10 @@ if(WITH_CUDA) endif() add_library(${PROJECT_NAME} SHARED ${OPERATOR_SOURCES}) -target_link_libraries(${PROJECT_NAME} PRIVATE ${TORCH_LIBRARIES} Python3::Python) +target_link_libraries(${PROJECT_NAME} PRIVATE ${TORCH_LIBRARIES}) +if (WITH_PYTHON) + target_link_libraries(${PROJECT_NAME} PRIVATE Python3::Python) +endif() set_target_properties(${PROJECT_NAME} PROPERTIES EXPORT_NAME TorchScatter) target_include_directories(${PROJECT_NAME} INTERFACE diff --git a/csrc/cpu/index_info.h b/csrc/cpu/index_info.h index 9709a1de..5e9ed0b4 100644 --- a/csrc/cpu/index_info.h +++ b/csrc/cpu/index_info.h @@ -1,6 +1,6 @@ #pragma once -#include +#include "../extensions.h" #define MAX_TENSORINFO_DIMS 25 diff --git a/csrc/cpu/scatter_cpu.h b/csrc/cpu/scatter_cpu.h index 25122e70..286394c1 100644 --- a/csrc/cpu/scatter_cpu.h +++ b/csrc/cpu/scatter_cpu.h @@ -1,6 +1,6 @@ #pragma once -#include +#include "../extensions.h" std::tuple> scatter_cpu(torch::Tensor src, torch::Tensor index, int64_t dim, diff --git a/csrc/cpu/segment_coo_cpu.h b/csrc/cpu/segment_coo_cpu.h index feb7a827..425cfb8b 100644 --- a/csrc/cpu/segment_coo_cpu.h +++ b/csrc/cpu/segment_coo_cpu.h @@ -1,6 +1,6 @@ #pragma once -#include +#include "../extensions.h" std::tuple> segment_coo_cpu(torch::Tensor src, torch::Tensor index, diff --git a/csrc/cpu/segment_csr_cpu.h b/csrc/cpu/segment_csr_cpu.h index b93d450b..716f6890 100644 --- a/csrc/cpu/segment_csr_cpu.h +++ b/csrc/cpu/segment_csr_cpu.h @@ -1,6 +1,6 @@ #pragma once -#include +#include "../extensions.h" std::tuple> segment_csr_cpu(torch::Tensor src, torch::Tensor indptr, diff --git a/csrc/cpu/utils.h b/csrc/cpu/utils.h index 40dfb344..66ae38bf 100644 --- a/csrc/cpu/utils.h +++ b/csrc/cpu/utils.h @@ -1,6 +1,6 @@ #pragma once -#include +#include "../extensions.h" #define CHECK_CPU(x) AT_ASSERTM(x.device().is_cpu(), #x " must be CPU tensor") #define CHECK_INPUT(x) AT_ASSERTM(x, "Input mismatch") diff --git a/csrc/cuda/scatter_cuda.h b/csrc/cuda/scatter_cuda.h index 95c80642..92bdfa8d 100644 --- a/csrc/cuda/scatter_cuda.h +++ b/csrc/cuda/scatter_cuda.h @@ -1,6 +1,6 @@ #pragma once -#include +#include "../extensions.h" std::tuple> scatter_cuda(torch::Tensor src, torch::Tensor index, int64_t dim, diff --git a/csrc/cuda/segment_coo_cuda.h b/csrc/cuda/segment_coo_cuda.h index 68154775..f401faea 100644 --- a/csrc/cuda/segment_coo_cuda.h +++ b/csrc/cuda/segment_coo_cuda.h @@ -1,6 +1,6 @@ #pragma once -#include +#include "../extensions.h" std::tuple> segment_coo_cuda(torch::Tensor src, torch::Tensor index, diff --git a/csrc/cuda/segment_csr_cuda.h b/csrc/cuda/segment_csr_cuda.h index 5f8bd40e..c59f5401 100644 --- a/csrc/cuda/segment_csr_cuda.h +++ b/csrc/cuda/segment_csr_cuda.h @@ -1,6 +1,6 @@ #pragma once -#include +#include "../extensions.h" std::tuple> segment_csr_cuda(torch::Tensor src, torch::Tensor indptr, diff --git a/csrc/cuda/utils.cuh b/csrc/cuda/utils.cuh index ee2e9108..c1aa222d 100644 --- a/csrc/cuda/utils.cuh +++ b/csrc/cuda/utils.cuh @@ -1,6 +1,6 @@ #pragma once -#include +#include "../extensions.h" #define CHECK_CUDA(x) \ AT_ASSERTM(x.device().is_cuda(), #x " must be CUDA tensor") diff --git a/csrc/extensions.h b/csrc/extensions.h new file mode 100644 index 00000000..91c4df1a --- /dev/null +++ b/csrc/extensions.h @@ -0,0 +1,2 @@ +#include "macros.h" +#include diff --git a/csrc/scatter.cpp b/csrc/scatter.cpp index a71552d0..e2733d64 100644 --- a/csrc/scatter.cpp +++ b/csrc/scatter.cpp @@ -1,4 +1,7 @@ +#ifdef WITH_PYTHON #include +#endif + #include #include "cpu/scatter_cpu.h" @@ -10,12 +13,14 @@ #endif #ifdef _WIN32 +#ifdef WITH_PYTHON #ifdef WITH_CUDA PyMODINIT_FUNC PyInit__scatter_cuda(void) { return NULL; } #else PyMODINIT_FUNC PyInit__scatter_cpu(void) { return NULL; } #endif #endif +#endif torch::Tensor broadcast(torch::Tensor src, torch::Tensor other, int64_t dim) { if (src.dim() == 1) diff --git a/csrc/scatter.h b/csrc/scatter.h index a477038c..64004c1f 100644 --- a/csrc/scatter.h +++ b/csrc/scatter.h @@ -1,8 +1,6 @@ #pragma once -#include - -#include "macros.h" +#include "extensions.h" namespace scatter { SCATTER_API int64_t cuda_version() noexcept; diff --git a/csrc/segment_coo.cpp b/csrc/segment_coo.cpp index 6599ab0c..055dab87 100644 --- a/csrc/segment_coo.cpp +++ b/csrc/segment_coo.cpp @@ -1,4 +1,7 @@ +#ifdef WITH_PYTHON #include +#endif + #include #include "cpu/segment_coo_cpu.h" @@ -10,12 +13,14 @@ #endif #ifdef _WIN32 +#ifdef WITH_PYTHON #ifdef WITH_CUDA PyMODINIT_FUNC PyInit__segment_coo_cuda(void) { return NULL; } #else PyMODINIT_FUNC PyInit__segment_coo_cpu(void) { return NULL; } #endif #endif +#endif std::tuple> segment_coo_fw(torch::Tensor src, torch::Tensor index, diff --git a/csrc/segment_csr.cpp b/csrc/segment_csr.cpp index 969dad7e..b8f366eb 100644 --- a/csrc/segment_csr.cpp +++ b/csrc/segment_csr.cpp @@ -1,4 +1,7 @@ +#ifdef WITH_PYTHON #include +#endif + #include #include "cpu/segment_csr_cpu.h" @@ -10,12 +13,14 @@ #endif #ifdef _WIN32 +#ifdef WITH_PYTHON #ifdef WITH_CUDA PyMODINIT_FUNC PyInit__segment_csr_cuda(void) { return NULL; } #else PyMODINIT_FUNC PyInit__segment_csr_cpu(void) { return NULL; } #endif #endif +#endif std::tuple> segment_csr_fw(torch::Tensor src, torch::Tensor indptr, diff --git a/csrc/version.cpp b/csrc/version.cpp index b7d21510..2388947e 100644 --- a/csrc/version.cpp +++ b/csrc/version.cpp @@ -1,4 +1,7 @@ +#ifdef WITH_PYTHON #include +#endif + #include #include "scatter.h" #include "macros.h" @@ -8,12 +11,14 @@ #endif #ifdef _WIN32 +#ifdef WITH_PYTHON #ifdef WITH_CUDA PyMODINIT_FUNC PyInit__version_cuda(void) { return NULL; } #else PyMODINIT_FUNC PyInit__version_cpu(void) { return NULL; } #endif #endif +#endif namespace scatter { SCATTER_API int64_t cuda_version() noexcept { diff --git a/setup.py b/setup.py index e728cf18..53ebf86e 100644 --- a/setup.py +++ b/setup.py @@ -34,7 +34,7 @@ def get_extensions(): main_files = glob.glob(osp.join(extensions_dir, '*.cpp')) for main, suffix in product(main_files, suffices): - define_macros = [] + define_macros = [('WITH_PYTHON', None)] if sys.platform == 'win32': define_macros += [('torchscatter_EXPORTS', None)]