Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ set(CMAKE_CXX_STANDARD 14)
set(TORCHSCATTER_VERSION 2.0.9)

option(WITH_CUDA "Enable CUDA support" OFF)
option(WITH_PYTHON "Link to Python when building" ON)

if(WITH_CUDA)
enable_language(CUDA)
Expand All @@ -12,7 +13,10 @@ if(WITH_CUDA)
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-relaxed-constexpr")
endif()

find_package(Python3 COMPONENTS Development)
if (WITH_PYTHON)
add_definitions(-DWITH_PYTHON)
find_package(Python3 COMPONENTS Development)
endif()
find_package(Torch REQUIRED)

file(GLOB HEADERS csrc/*.h)
Expand All @@ -22,7 +26,10 @@ if(WITH_CUDA)
endif()

add_library(${PROJECT_NAME} SHARED ${OPERATOR_SOURCES})
target_link_libraries(${PROJECT_NAME} PRIVATE ${TORCH_LIBRARIES} Python3::Python)
target_link_libraries(${PROJECT_NAME} PRIVATE ${TORCH_LIBRARIES})
if (WITH_PYTHON)
target_link_libraries(${PROJECT_NAME} PRIVATE Python3::Python)
endif()
set_target_properties(${PROJECT_NAME} PROPERTIES EXPORT_NAME TorchScatter)

target_include_directories(${PROJECT_NAME} INTERFACE
Expand Down
2 changes: 1 addition & 1 deletion csrc/cpu/index_info.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#pragma once

#include <torch/extension.h>
#include "../extensions.h"

#define MAX_TENSORINFO_DIMS 25

Expand Down
2 changes: 1 addition & 1 deletion csrc/cpu/scatter_cpu.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#pragma once

#include <torch/extension.h>
#include "../extensions.h"

std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
scatter_cpu(torch::Tensor src, torch::Tensor index, int64_t dim,
Expand Down
2 changes: 1 addition & 1 deletion csrc/cpu/segment_coo_cpu.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#pragma once

#include <torch/extension.h>
#include "../extensions.h"

std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
segment_coo_cpu(torch::Tensor src, torch::Tensor index,
Expand Down
2 changes: 1 addition & 1 deletion csrc/cpu/segment_csr_cpu.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#pragma once

#include <torch/extension.h>
#include "../extensions.h"

std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
segment_csr_cpu(torch::Tensor src, torch::Tensor indptr,
Expand Down
2 changes: 1 addition & 1 deletion csrc/cpu/utils.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#pragma once

#include <torch/extension.h>
#include "../extensions.h"

#define CHECK_CPU(x) AT_ASSERTM(x.device().is_cpu(), #x " must be CPU tensor")
#define CHECK_INPUT(x) AT_ASSERTM(x, "Input mismatch")
2 changes: 1 addition & 1 deletion csrc/cuda/scatter_cuda.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#pragma once

#include <torch/extension.h>
#include "../extensions.h"

std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
scatter_cuda(torch::Tensor src, torch::Tensor index, int64_t dim,
Expand Down
2 changes: 1 addition & 1 deletion csrc/cuda/segment_coo_cuda.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#pragma once

#include <torch/extension.h>
#include "../extensions.h"

std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
segment_coo_cuda(torch::Tensor src, torch::Tensor index,
Expand Down
2 changes: 1 addition & 1 deletion csrc/cuda/segment_csr_cuda.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#pragma once

#include <torch/extension.h>
#include "../extensions.h"

std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
segment_csr_cuda(torch::Tensor src, torch::Tensor indptr,
Expand Down
2 changes: 1 addition & 1 deletion csrc/cuda/utils.cuh
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#pragma once

#include <torch/extension.h>
#include "../extensions.h"

#define CHECK_CUDA(x) \
AT_ASSERTM(x.device().is_cuda(), #x " must be CUDA tensor")
Expand Down
2 changes: 2 additions & 0 deletions csrc/extensions.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
#include "macros.h"
#include <torch/torch.h>
5 changes: 5 additions & 0 deletions csrc/scatter.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
#ifdef WITH_PYTHON
#include <Python.h>
#endif

#include <torch/script.h>

#include "cpu/scatter_cpu.h"
Expand All @@ -10,12 +13,14 @@
#endif

#ifdef _WIN32
#ifdef WITH_PYTHON
#ifdef WITH_CUDA
PyMODINIT_FUNC PyInit__scatter_cuda(void) { return NULL; }
#else
PyMODINIT_FUNC PyInit__scatter_cpu(void) { return NULL; }
#endif
#endif
#endif

torch::Tensor broadcast(torch::Tensor src, torch::Tensor other, int64_t dim) {
if (src.dim() == 1)
Expand Down
4 changes: 1 addition & 3 deletions csrc/scatter.h
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
#pragma once

#include <torch/extension.h>

#include "macros.h"
#include "extensions.h"

namespace scatter {
SCATTER_API int64_t cuda_version() noexcept;
Expand Down
5 changes: 5 additions & 0 deletions csrc/segment_coo.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
#ifdef WITH_PYTHON
#include <Python.h>
#endif

#include <torch/script.h>

#include "cpu/segment_coo_cpu.h"
Expand All @@ -10,12 +13,14 @@
#endif

#ifdef _WIN32
#ifdef WITH_PYTHON
#ifdef WITH_CUDA
PyMODINIT_FUNC PyInit__segment_coo_cuda(void) { return NULL; }
#else
PyMODINIT_FUNC PyInit__segment_coo_cpu(void) { return NULL; }
#endif
#endif
#endif

std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
segment_coo_fw(torch::Tensor src, torch::Tensor index,
Expand Down
5 changes: 5 additions & 0 deletions csrc/segment_csr.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
#ifdef WITH_PYTHON
#include <Python.h>
#endif

#include <torch/script.h>

#include "cpu/segment_csr_cpu.h"
Expand All @@ -10,12 +13,14 @@
#endif

#ifdef _WIN32
#ifdef WITH_PYTHON
#ifdef WITH_CUDA
PyMODINIT_FUNC PyInit__segment_csr_cuda(void) { return NULL; }
#else
PyMODINIT_FUNC PyInit__segment_csr_cpu(void) { return NULL; }
#endif
#endif
#endif

std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
segment_csr_fw(torch::Tensor src, torch::Tensor indptr,
Expand Down
5 changes: 5 additions & 0 deletions csrc/version.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
#ifdef WITH_PYTHON
#include <Python.h>
#endif

#include <torch/script.h>
#include "scatter.h"
#include "macros.h"
Expand All @@ -8,12 +11,14 @@
#endif

#ifdef _WIN32
#ifdef WITH_PYTHON
#ifdef WITH_CUDA
PyMODINIT_FUNC PyInit__version_cuda(void) { return NULL; }
#else
PyMODINIT_FUNC PyInit__version_cpu(void) { return NULL; }
#endif
#endif
#endif

namespace scatter {
SCATTER_API int64_t cuda_version() noexcept {
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def get_extensions():
main_files = glob.glob(osp.join(extensions_dir, '*.cpp'))

for main, suffix in product(main_files, suffices):
define_macros = []
define_macros = [('WITH_PYTHON', None)]

if sys.platform == 'win32':
define_macros += [('torchscatter_EXPORTS', None)]
Expand Down