From 7db04f7b24ef88fc99c3e9e23c7c89cfe5e7e866 Mon Sep 17 00:00:00 2001 From: Gerico Vidanes Date: Wed, 4 May 2022 19:37:52 +0100 Subject: [PATCH 1/6] Mark exported symbols with SCATTER_API so they are available in the DLL on Windows --- csrc/scatter.cpp | 9 +++++---- csrc/scatter.h | 40 +++++++++++++++++++++++++--------------- csrc/segment_coo.cpp | 11 ++++++----- csrc/segment_csr.cpp | 11 ++++++----- csrc/version.cpp | 3 ++- 5 files changed, 44 insertions(+), 30 deletions(-) diff --git a/csrc/scatter.cpp b/csrc/scatter.cpp index 3a418ab3..f78e015e 100644 --- a/csrc/scatter.cpp +++ b/csrc/scatter.cpp @@ -1,6 +1,7 @@ #include #include +#include "scatter.h" #include "cpu/scatter_cpu.h" #include "utils.h" @@ -226,7 +227,7 @@ class ScatterMax : public torch::autograd::Function { } }; -torch::Tensor scatter_sum(torch::Tensor src, torch::Tensor index, int64_t dim, +SCATTER_API torch::Tensor scatter_sum(torch::Tensor src, torch::Tensor index, int64_t dim, torch::optional optional_out, torch::optional dim_size) { return ScatterSum::apply(src, index, dim, optional_out, dim_size)[0]; @@ -238,13 +239,13 @@ torch::Tensor scatter_mul(torch::Tensor src, torch::Tensor index, int64_t dim, return ScatterMul::apply(src, index, dim, optional_out, dim_size)[0]; } -torch::Tensor scatter_mean(torch::Tensor src, torch::Tensor index, int64_t dim, +SCATTER_API torch::Tensor scatter_mean(torch::Tensor src, torch::Tensor index, int64_t dim, torch::optional optional_out, torch::optional dim_size) { return ScatterMean::apply(src, index, dim, optional_out, dim_size)[0]; } -std::tuple +SCATTER_API std::tuple scatter_min(torch::Tensor src, torch::Tensor index, int64_t dim, torch::optional optional_out, torch::optional dim_size) { @@ -252,7 +253,7 @@ scatter_min(torch::Tensor src, torch::Tensor index, int64_t dim, return std::make_tuple(result[0], result[1]); } -std::tuple +SCATTER_API std::tuple scatter_max(torch::Tensor src, torch::Tensor index, int64_t dim, torch::optional optional_out, torch::optional dim_size) { diff --git a/csrc/scatter.h b/csrc/scatter.h index 2d19fedb..40066f75 100644 --- a/csrc/scatter.h +++ b/csrc/scatter.h @@ -2,6 +2,16 @@ #include +#ifdef _WIN32 +#if defined(torchscatter_EXPORTS) +#define SCATTER_API __declspec(dllexport) +#else +#define SCATTER_API __declspec(dllimport) +#endif +#else +#define SCATTER_API +#endif + #if (defined __cpp_inline_variables) || __cplusplus >= 201703L #define SCATTER_INLINE_VARIABLE inline #else @@ -13,65 +23,65 @@ #endif namespace scatter { -int64_t cuda_version() noexcept; +SCATTER_API int64_t cuda_version() noexcept; namespace detail { SCATTER_INLINE_VARIABLE int64_t _cuda_version = cuda_version(); } // namespace detail } // namespace scatter -torch::Tensor scatter_sum(torch::Tensor src, torch::Tensor index, int64_t dim, +SCATTER_API torch::Tensor scatter_sum(torch::Tensor src, torch::Tensor index, int64_t dim, torch::optional optional_out, torch::optional dim_size); -torch::Tensor scatter_mean(torch::Tensor src, torch::Tensor index, int64_t dim, +SCATTER_API torch::Tensor scatter_mean(torch::Tensor src, torch::Tensor index, int64_t dim, torch::optional optional_out, torch::optional dim_size); -std::tuple +SCATTER_API std::tuple scatter_min(torch::Tensor src, torch::Tensor index, int64_t dim, torch::optional optional_out, torch::optional dim_size); -std::tuple +SCATTER_API std::tuple scatter_max(torch::Tensor src, torch::Tensor index, int64_t dim, torch::optional optional_out, torch::optional dim_size); -torch::Tensor segment_sum_coo(torch::Tensor src, torch::Tensor index, +SCATTER_API torch::Tensor segment_sum_coo(torch::Tensor src, torch::Tensor index, torch::optional optional_out, torch::optional dim_size); -torch::Tensor segment_mean_coo(torch::Tensor src, torch::Tensor index, +SCATTER_API torch::Tensor segment_mean_coo(torch::Tensor src, torch::Tensor index, torch::optional optional_out, torch::optional dim_size); -std::tuple +SCATTER_API std::tuple segment_min_coo(torch::Tensor src, torch::Tensor index, torch::optional optional_out, torch::optional dim_size); -std::tuple +SCATTER_API std::tuple segment_max_coo(torch::Tensor src, torch::Tensor index, torch::optional optional_out, torch::optional dim_size); -torch::Tensor gather_coo(torch::Tensor src, torch::Tensor index, +SCATTER_API torch::Tensor gather_coo(torch::Tensor src, torch::Tensor index, torch::optional optional_out); -torch::Tensor segment_sum_csr(torch::Tensor src, torch::Tensor indptr, +SCATTER_API torch::Tensor segment_sum_csr(torch::Tensor src, torch::Tensor indptr, torch::optional optional_out); -torch::Tensor segment_mean_csr(torch::Tensor src, torch::Tensor indptr, +SCATTER_API torch::Tensor segment_mean_csr(torch::Tensor src, torch::Tensor indptr, torch::optional optional_out); -std::tuple +SCATTER_API std::tuple segment_min_csr(torch::Tensor src, torch::Tensor indptr, torch::optional optional_out); -std::tuple +SCATTER_API std::tuple segment_max_csr(torch::Tensor src, torch::Tensor indptr, torch::optional optional_out); -torch::Tensor gather_csr(torch::Tensor src, torch::Tensor indptr, +SCATTER_API torch::Tensor gather_csr(torch::Tensor src, torch::Tensor indptr, torch::optional optional_out); diff --git a/csrc/segment_coo.cpp b/csrc/segment_coo.cpp index 234f3ee4..64ce0500 100644 --- a/csrc/segment_coo.cpp +++ b/csrc/segment_coo.cpp @@ -3,6 +3,7 @@ #include "cpu/segment_coo_cpu.h" #include "utils.h" +#include "scatter.h" #ifdef WITH_CUDA #include "cuda/segment_coo_cuda.h" @@ -195,19 +196,19 @@ class GatherCOO : public torch::autograd::Function { } }; -torch::Tensor segment_sum_coo(torch::Tensor src, torch::Tensor index, +SCATTER_API torch::Tensor segment_sum_coo(torch::Tensor src, torch::Tensor index, torch::optional optional_out, torch::optional dim_size) { return SegmentSumCOO::apply(src, index, optional_out, dim_size)[0]; } -torch::Tensor segment_mean_coo(torch::Tensor src, torch::Tensor index, +SCATTER_API torch::Tensor segment_mean_coo(torch::Tensor src, torch::Tensor index, torch::optional optional_out, torch::optional dim_size) { return SegmentMeanCOO::apply(src, index, optional_out, dim_size)[0]; } -std::tuple +SCATTER_API std::tuple segment_min_coo(torch::Tensor src, torch::Tensor index, torch::optional optional_out, torch::optional dim_size) { @@ -215,7 +216,7 @@ segment_min_coo(torch::Tensor src, torch::Tensor index, return std::make_tuple(result[0], result[1]); } -std::tuple +SCATTER_API std::tuple segment_max_coo(torch::Tensor src, torch::Tensor index, torch::optional optional_out, torch::optional dim_size) { @@ -223,7 +224,7 @@ segment_max_coo(torch::Tensor src, torch::Tensor index, return std::make_tuple(result[0], result[1]); } -torch::Tensor gather_coo(torch::Tensor src, torch::Tensor index, +SCATTER_API torch::Tensor gather_coo(torch::Tensor src, torch::Tensor index, torch::optional optional_out) { return GatherCOO::apply(src, index, optional_out)[0]; } diff --git a/csrc/segment_csr.cpp b/csrc/segment_csr.cpp index 4b2ad08c..de15611b 100644 --- a/csrc/segment_csr.cpp +++ b/csrc/segment_csr.cpp @@ -3,6 +3,7 @@ #include "cpu/segment_csr_cpu.h" #include "utils.h" +#include "scatter.h" #ifdef WITH_CUDA #include "cuda/segment_csr_cuda.h" @@ -192,31 +193,31 @@ class GatherCSR : public torch::autograd::Function { } }; -torch::Tensor segment_sum_csr(torch::Tensor src, torch::Tensor indptr, +SCATTER_API torch::Tensor segment_sum_csr(torch::Tensor src, torch::Tensor indptr, torch::optional optional_out) { return SegmentSumCSR::apply(src, indptr, optional_out)[0]; } -torch::Tensor segment_mean_csr(torch::Tensor src, torch::Tensor indptr, +SCATTER_API torch::Tensor segment_mean_csr(torch::Tensor src, torch::Tensor indptr, torch::optional optional_out) { return SegmentMeanCSR::apply(src, indptr, optional_out)[0]; } -std::tuple +SCATTER_API std::tuple segment_min_csr(torch::Tensor src, torch::Tensor indptr, torch::optional optional_out) { auto result = SegmentMinCSR::apply(src, indptr, optional_out); return std::make_tuple(result[0], result[1]); } -std::tuple +SCATTER_API std::tuple segment_max_csr(torch::Tensor src, torch::Tensor indptr, torch::optional optional_out) { auto result = SegmentMaxCSR::apply(src, indptr, optional_out); return std::make_tuple(result[0], result[1]); } -torch::Tensor gather_csr(torch::Tensor src, torch::Tensor indptr, +SCATTER_API torch::Tensor gather_csr(torch::Tensor src, torch::Tensor indptr, torch::optional optional_out) { return GatherCSR::apply(src, indptr, optional_out)[0]; } diff --git a/csrc/version.cpp b/csrc/version.cpp index 64d2b267..8a5bca75 100644 --- a/csrc/version.cpp +++ b/csrc/version.cpp @@ -1,5 +1,6 @@ #include #include +#include "scatter.h" #ifdef WITH_CUDA #include @@ -14,7 +15,7 @@ PyMODINIT_FUNC PyInit__version_cpu(void) { return NULL; } #endif namespace scatter { -int64_t cuda_version() noexcept { +SCATTER_API int64_t cuda_version() noexcept { #ifdef WITH_CUDA return CUDA_VERSION; #else From c4ec81782b59b498e6acc3d4fa97513b551122a1 Mon Sep 17 00:00:00 2001 From: Gerico Vidanes Date: Wed, 4 May 2022 22:51:41 +0100 Subject: [PATCH 2/6] macro definition for python package installation --- setup.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/setup.py b/setup.py index b6ad06fe..7d80cc1b 100644 --- a/setup.py +++ b/setup.py @@ -34,6 +34,10 @@ def get_extensions(): for main, suffix in product(main_files, suffices): define_macros = [] + + if sys.platform == 'win32': + define_macros += [('torchscatter_EXPORTS', None)] + extra_compile_args = {'cxx': ['-O2']} if not os.name == 'nt': # Not on Windows: extra_compile_args['cxx'] += ['-Wno-sign-compare'] From b7a9da3554ccd4825a4b84bb3a8e97eadb817a9c Mon Sep 17 00:00:00 2001 From: Gerico Vidanes Date: Sat, 9 Jul 2022 01:49:22 +0100 Subject: [PATCH 3/6] move macros definitions to separate file --- csrc/macros.h | 21 +++++++++++++++++++++ csrc/scatter.cpp | 2 ++ csrc/scatter.h | 20 +------------------- csrc/segment_coo.cpp | 2 +- csrc/segment_csr.cpp | 2 +- csrc/version.cpp | 1 + 6 files changed, 27 insertions(+), 21 deletions(-) create mode 100644 csrc/macros.h diff --git a/csrc/macros.h b/csrc/macros.h new file mode 100644 index 00000000..95baf8ed --- /dev/null +++ b/csrc/macros.h @@ -0,0 +1,21 @@ +#pragma once + +#ifdef _WIN32 +#if defined(torchscatter_EXPORTS) +#define SCATTER_API __declspec(dllexport) +#else +#define SCATTER_API __declspec(dllimport) +#endif +#else +#define SCATTER_API +#endif + +#if (defined __cpp_inline_variables) || __cplusplus >= 201703L +#define SCATTER_INLINE_VARIABLE inline +#else +#ifdef _MSC_VER +#define SCATTER_INLINE_VARIABLE __declspec(selectany) +#else +#define SCATTER_INLINE_VARIABLE __attribute__((weak)) +#endif +#endif \ No newline at end of file diff --git a/csrc/scatter.cpp b/csrc/scatter.cpp index f78e015e..56289e40 100644 --- a/csrc/scatter.cpp +++ b/csrc/scatter.cpp @@ -1,6 +1,8 @@ #include #include +#include "macros.h" + #include "scatter.h" #include "cpu/scatter_cpu.h" #include "utils.h" diff --git a/csrc/scatter.h b/csrc/scatter.h index 40066f75..30dfb018 100644 --- a/csrc/scatter.h +++ b/csrc/scatter.h @@ -2,25 +2,7 @@ #include -#ifdef _WIN32 -#if defined(torchscatter_EXPORTS) -#define SCATTER_API __declspec(dllexport) -#else -#define SCATTER_API __declspec(dllimport) -#endif -#else -#define SCATTER_API -#endif - -#if (defined __cpp_inline_variables) || __cplusplus >= 201703L -#define SCATTER_INLINE_VARIABLE inline -#else -#ifdef _MSC_VER -#define SCATTER_INLINE_VARIABLE __declspec(selectany) -#else -#define SCATTER_INLINE_VARIABLE __attribute__((weak)) -#endif -#endif +#include "macros.h" namespace scatter { SCATTER_API int64_t cuda_version() noexcept; diff --git a/csrc/segment_coo.cpp b/csrc/segment_coo.cpp index 64ce0500..f17ae48d 100644 --- a/csrc/segment_coo.cpp +++ b/csrc/segment_coo.cpp @@ -3,7 +3,7 @@ #include "cpu/segment_coo_cpu.h" #include "utils.h" -#include "scatter.h" +#include "macros.h" #ifdef WITH_CUDA #include "cuda/segment_coo_cuda.h" diff --git a/csrc/segment_csr.cpp b/csrc/segment_csr.cpp index de15611b..efe8016a 100644 --- a/csrc/segment_csr.cpp +++ b/csrc/segment_csr.cpp @@ -3,7 +3,7 @@ #include "cpu/segment_csr_cpu.h" #include "utils.h" -#include "scatter.h" +#include "macros.h" #ifdef WITH_CUDA #include "cuda/segment_csr_cuda.h" diff --git a/csrc/version.cpp b/csrc/version.cpp index 8a5bca75..b7d21510 100644 --- a/csrc/version.cpp +++ b/csrc/version.cpp @@ -1,6 +1,7 @@ #include #include #include "scatter.h" +#include "macros.h" #ifdef WITH_CUDA #include From b1815bda604d779be6ccbc43b9b98420d6246309 Mon Sep 17 00:00:00 2001 From: Gerico Vidanes Date: Sat, 9 Jul 2022 02:23:28 +0100 Subject: [PATCH 4/6] unlink? --- csrc/scatter.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/scatter.cpp b/csrc/scatter.cpp index 56289e40..7195e92a 100644 --- a/csrc/scatter.cpp +++ b/csrc/scatter.cpp @@ -3,7 +3,7 @@ #include "macros.h" -#include "scatter.h" +// #include "scatter.h" #include "cpu/scatter_cpu.h" #include "utils.h" From b3da90d537aeb0e9421f53b15cd3ee5a310bc937 Mon Sep 17 00:00:00 2001 From: rusty1s Date: Sat, 9 Jul 2022 10:49:44 +0200 Subject: [PATCH 5/6] update --- csrc/macros.h | 2 +- csrc/scatter.cpp | 18 +++++++++--------- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/csrc/macros.h b/csrc/macros.h index 95baf8ed..d55e6236 100644 --- a/csrc/macros.h +++ b/csrc/macros.h @@ -18,4 +18,4 @@ #else #define SCATTER_INLINE_VARIABLE __attribute__((weak)) #endif -#endif \ No newline at end of file +#endif diff --git a/csrc/scatter.cpp b/csrc/scatter.cpp index 7195e92a..a71552d0 100644 --- a/csrc/scatter.cpp +++ b/csrc/scatter.cpp @@ -1,10 +1,8 @@ #include #include -#include "macros.h" - -// #include "scatter.h" #include "cpu/scatter_cpu.h" +#include "macros.h" #include "utils.h" #ifdef WITH_CUDA @@ -229,9 +227,10 @@ class ScatterMax : public torch::autograd::Function { } }; -SCATTER_API torch::Tensor scatter_sum(torch::Tensor src, torch::Tensor index, int64_t dim, - torch::optional optional_out, - torch::optional dim_size) { +SCATTER_API torch::Tensor +scatter_sum(torch::Tensor src, torch::Tensor index, int64_t dim, + torch::optional optional_out, + torch::optional dim_size) { return ScatterSum::apply(src, index, dim, optional_out, dim_size)[0]; } @@ -241,9 +240,10 @@ torch::Tensor scatter_mul(torch::Tensor src, torch::Tensor index, int64_t dim, return ScatterMul::apply(src, index, dim, optional_out, dim_size)[0]; } -SCATTER_API torch::Tensor scatter_mean(torch::Tensor src, torch::Tensor index, int64_t dim, - torch::optional optional_out, - torch::optional dim_size) { +SCATTER_API torch::Tensor +scatter_mean(torch::Tensor src, torch::Tensor index, int64_t dim, + torch::optional optional_out, + torch::optional dim_size) { return ScatterMean::apply(src, index, dim, optional_out, dim_size)[0]; } From 093ddcbdc9eb035154bd53a37cf10ee0f49d12f7 Mon Sep 17 00:00:00 2001 From: rusty1s Date: Sat, 9 Jul 2022 10:51:20 +0200 Subject: [PATCH 6/6] linting --- csrc/scatter.h | 48 ++++++++++++++++++++++++++------------------ csrc/segment_coo.cpp | 21 ++++++++++--------- csrc/segment_csr.cpp | 17 +++++++++------- 3 files changed, 50 insertions(+), 36 deletions(-) diff --git a/csrc/scatter.h b/csrc/scatter.h index 30dfb018..a477038c 100644 --- a/csrc/scatter.h +++ b/csrc/scatter.h @@ -12,13 +12,15 @@ SCATTER_INLINE_VARIABLE int64_t _cuda_version = cuda_version(); } // namespace detail } // namespace scatter -SCATTER_API torch::Tensor scatter_sum(torch::Tensor src, torch::Tensor index, int64_t dim, - torch::optional optional_out, - torch::optional dim_size); +SCATTER_API torch::Tensor +scatter_sum(torch::Tensor src, torch::Tensor index, int64_t dim, + torch::optional optional_out, + torch::optional dim_size); -SCATTER_API torch::Tensor scatter_mean(torch::Tensor src, torch::Tensor index, int64_t dim, - torch::optional optional_out, - torch::optional dim_size); +SCATTER_API torch::Tensor +scatter_mean(torch::Tensor src, torch::Tensor index, int64_t dim, + torch::optional optional_out, + torch::optional dim_size); SCATTER_API std::tuple scatter_min(torch::Tensor src, torch::Tensor index, int64_t dim, @@ -30,13 +32,15 @@ scatter_max(torch::Tensor src, torch::Tensor index, int64_t dim, torch::optional optional_out, torch::optional dim_size); -SCATTER_API torch::Tensor segment_sum_coo(torch::Tensor src, torch::Tensor index, - torch::optional optional_out, - torch::optional dim_size); +SCATTER_API torch::Tensor +segment_sum_coo(torch::Tensor src, torch::Tensor index, + torch::optional optional_out, + torch::optional dim_size); -SCATTER_API torch::Tensor segment_mean_coo(torch::Tensor src, torch::Tensor index, - torch::optional optional_out, - torch::optional dim_size); +SCATTER_API torch::Tensor +segment_mean_coo(torch::Tensor src, torch::Tensor index, + torch::optional optional_out, + torch::optional dim_size); SCATTER_API std::tuple segment_min_coo(torch::Tensor src, torch::Tensor index, @@ -48,14 +52,17 @@ segment_max_coo(torch::Tensor src, torch::Tensor index, torch::optional optional_out, torch::optional dim_size); -SCATTER_API torch::Tensor gather_coo(torch::Tensor src, torch::Tensor index, - torch::optional optional_out); +SCATTER_API torch::Tensor +gather_coo(torch::Tensor src, torch::Tensor index, + torch::optional optional_out); -SCATTER_API torch::Tensor segment_sum_csr(torch::Tensor src, torch::Tensor indptr, - torch::optional optional_out); +SCATTER_API torch::Tensor +segment_sum_csr(torch::Tensor src, torch::Tensor indptr, + torch::optional optional_out); -SCATTER_API torch::Tensor segment_mean_csr(torch::Tensor src, torch::Tensor indptr, - torch::optional optional_out); +SCATTER_API torch::Tensor +segment_mean_csr(torch::Tensor src, torch::Tensor indptr, + torch::optional optional_out); SCATTER_API std::tuple segment_min_csr(torch::Tensor src, torch::Tensor indptr, @@ -65,5 +72,6 @@ SCATTER_API std::tuple segment_max_csr(torch::Tensor src, torch::Tensor indptr, torch::optional optional_out); -SCATTER_API torch::Tensor gather_csr(torch::Tensor src, torch::Tensor indptr, - torch::optional optional_out); +SCATTER_API torch::Tensor +gather_csr(torch::Tensor src, torch::Tensor indptr, + torch::optional optional_out); diff --git a/csrc/segment_coo.cpp b/csrc/segment_coo.cpp index f17ae48d..6599ab0c 100644 --- a/csrc/segment_coo.cpp +++ b/csrc/segment_coo.cpp @@ -2,8 +2,8 @@ #include #include "cpu/segment_coo_cpu.h" -#include "utils.h" #include "macros.h" +#include "utils.h" #ifdef WITH_CUDA #include "cuda/segment_coo_cuda.h" @@ -196,15 +196,17 @@ class GatherCOO : public torch::autograd::Function { } }; -SCATTER_API torch::Tensor segment_sum_coo(torch::Tensor src, torch::Tensor index, - torch::optional optional_out, - torch::optional dim_size) { +SCATTER_API torch::Tensor +segment_sum_coo(torch::Tensor src, torch::Tensor index, + torch::optional optional_out, + torch::optional dim_size) { return SegmentSumCOO::apply(src, index, optional_out, dim_size)[0]; } -SCATTER_API torch::Tensor segment_mean_coo(torch::Tensor src, torch::Tensor index, - torch::optional optional_out, - torch::optional dim_size) { +SCATTER_API torch::Tensor +segment_mean_coo(torch::Tensor src, torch::Tensor index, + torch::optional optional_out, + torch::optional dim_size) { return SegmentMeanCOO::apply(src, index, optional_out, dim_size)[0]; } @@ -224,8 +226,9 @@ segment_max_coo(torch::Tensor src, torch::Tensor index, return std::make_tuple(result[0], result[1]); } -SCATTER_API torch::Tensor gather_coo(torch::Tensor src, torch::Tensor index, - torch::optional optional_out) { +SCATTER_API torch::Tensor +gather_coo(torch::Tensor src, torch::Tensor index, + torch::optional optional_out) { return GatherCOO::apply(src, index, optional_out)[0]; } diff --git a/csrc/segment_csr.cpp b/csrc/segment_csr.cpp index efe8016a..969dad7e 100644 --- a/csrc/segment_csr.cpp +++ b/csrc/segment_csr.cpp @@ -2,8 +2,8 @@ #include #include "cpu/segment_csr_cpu.h" -#include "utils.h" #include "macros.h" +#include "utils.h" #ifdef WITH_CUDA #include "cuda/segment_csr_cuda.h" @@ -193,13 +193,15 @@ class GatherCSR : public torch::autograd::Function { } }; -SCATTER_API torch::Tensor segment_sum_csr(torch::Tensor src, torch::Tensor indptr, - torch::optional optional_out) { +SCATTER_API torch::Tensor +segment_sum_csr(torch::Tensor src, torch::Tensor indptr, + torch::optional optional_out) { return SegmentSumCSR::apply(src, indptr, optional_out)[0]; } -SCATTER_API torch::Tensor segment_mean_csr(torch::Tensor src, torch::Tensor indptr, - torch::optional optional_out) { +SCATTER_API torch::Tensor +segment_mean_csr(torch::Tensor src, torch::Tensor indptr, + torch::optional optional_out) { return SegmentMeanCSR::apply(src, indptr, optional_out)[0]; } @@ -217,8 +219,9 @@ segment_max_csr(torch::Tensor src, torch::Tensor indptr, return std::make_tuple(result[0], result[1]); } -SCATTER_API torch::Tensor gather_csr(torch::Tensor src, torch::Tensor indptr, - torch::optional optional_out) { +SCATTER_API torch::Tensor +gather_csr(torch::Tensor src, torch::Tensor indptr, + torch::optional optional_out) { return GatherCSR::apply(src, indptr, optional_out)[0]; }