Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions csrc/macros.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
#pragma once

#ifdef _WIN32
#if defined(torchscatter_EXPORTS)
#define SCATTER_API __declspec(dllexport)
#else
#define SCATTER_API __declspec(dllimport)
#endif
#else
#define SCATTER_API
#endif

#if (defined __cpp_inline_variables) || __cplusplus >= 201703L
#define SCATTER_INLINE_VARIABLE inline
#else
#ifdef _MSC_VER
#define SCATTER_INLINE_VARIABLE __declspec(selectany)
#else
#define SCATTER_INLINE_VARIABLE __attribute__((weak))
#endif
#endif
19 changes: 11 additions & 8 deletions csrc/scatter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#include <torch/script.h>

#include "cpu/scatter_cpu.h"
#include "macros.h"
#include "utils.h"

#ifdef WITH_CUDA
Expand Down Expand Up @@ -226,9 +227,10 @@ class ScatterMax : public torch::autograd::Function<ScatterMax> {
}
};

torch::Tensor scatter_sum(torch::Tensor src, torch::Tensor index, int64_t dim,
torch::optional<torch::Tensor> optional_out,
torch::optional<int64_t> dim_size) {
SCATTER_API torch::Tensor
scatter_sum(torch::Tensor src, torch::Tensor index, int64_t dim,
torch::optional<torch::Tensor> optional_out,
torch::optional<int64_t> dim_size) {
return ScatterSum::apply(src, index, dim, optional_out, dim_size)[0];
}

Expand All @@ -238,21 +240,22 @@ torch::Tensor scatter_mul(torch::Tensor src, torch::Tensor index, int64_t dim,
return ScatterMul::apply(src, index, dim, optional_out, dim_size)[0];
}

torch::Tensor scatter_mean(torch::Tensor src, torch::Tensor index, int64_t dim,
torch::optional<torch::Tensor> optional_out,
torch::optional<int64_t> dim_size) {
SCATTER_API torch::Tensor
scatter_mean(torch::Tensor src, torch::Tensor index, int64_t dim,
torch::optional<torch::Tensor> optional_out,
torch::optional<int64_t> dim_size) {
return ScatterMean::apply(src, index, dim, optional_out, dim_size)[0];
}

std::tuple<torch::Tensor, torch::Tensor>
SCATTER_API std::tuple<torch::Tensor, torch::Tensor>
scatter_min(torch::Tensor src, torch::Tensor index, int64_t dim,
torch::optional<torch::Tensor> optional_out,
torch::optional<int64_t> dim_size) {
auto result = ScatterMin::apply(src, index, dim, optional_out, dim_size);
return std::make_tuple(result[0], result[1]);
}

std::tuple<torch::Tensor, torch::Tensor>
SCATTER_API std::tuple<torch::Tensor, torch::Tensor>
scatter_max(torch::Tensor src, torch::Tensor index, int64_t dim,
torch::optional<torch::Tensor> optional_out,
torch::optional<int64_t> dim_size) {
Expand Down
72 changes: 36 additions & 36 deletions csrc/scatter.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,76 +2,76 @@

#include <torch/extension.h>

#if (defined __cpp_inline_variables) || __cplusplus >= 201703L
#define SCATTER_INLINE_VARIABLE inline
#else
#ifdef _MSC_VER
#define SCATTER_INLINE_VARIABLE __declspec(selectany)
#else
#define SCATTER_INLINE_VARIABLE __attribute__((weak))
#endif
#endif
#include "macros.h"

namespace scatter {
int64_t cuda_version() noexcept;
SCATTER_API int64_t cuda_version() noexcept;

namespace detail {
SCATTER_INLINE_VARIABLE int64_t _cuda_version = cuda_version();
} // namespace detail
} // namespace scatter

torch::Tensor scatter_sum(torch::Tensor src, torch::Tensor index, int64_t dim,
torch::optional<torch::Tensor> optional_out,
torch::optional<int64_t> dim_size);
SCATTER_API torch::Tensor
scatter_sum(torch::Tensor src, torch::Tensor index, int64_t dim,
torch::optional<torch::Tensor> optional_out,
torch::optional<int64_t> dim_size);

torch::Tensor scatter_mean(torch::Tensor src, torch::Tensor index, int64_t dim,
torch::optional<torch::Tensor> optional_out,
torch::optional<int64_t> dim_size);
SCATTER_API torch::Tensor
scatter_mean(torch::Tensor src, torch::Tensor index, int64_t dim,
torch::optional<torch::Tensor> optional_out,
torch::optional<int64_t> dim_size);

std::tuple<torch::Tensor, torch::Tensor>
SCATTER_API std::tuple<torch::Tensor, torch::Tensor>
scatter_min(torch::Tensor src, torch::Tensor index, int64_t dim,
torch::optional<torch::Tensor> optional_out,
torch::optional<int64_t> dim_size);

std::tuple<torch::Tensor, torch::Tensor>
SCATTER_API std::tuple<torch::Tensor, torch::Tensor>
scatter_max(torch::Tensor src, torch::Tensor index, int64_t dim,
torch::optional<torch::Tensor> optional_out,
torch::optional<int64_t> dim_size);

torch::Tensor segment_sum_coo(torch::Tensor src, torch::Tensor index,
torch::optional<torch::Tensor> optional_out,
torch::optional<int64_t> dim_size);
SCATTER_API torch::Tensor
segment_sum_coo(torch::Tensor src, torch::Tensor index,
torch::optional<torch::Tensor> optional_out,
torch::optional<int64_t> dim_size);

torch::Tensor segment_mean_coo(torch::Tensor src, torch::Tensor index,
torch::optional<torch::Tensor> optional_out,
torch::optional<int64_t> dim_size);
SCATTER_API torch::Tensor
segment_mean_coo(torch::Tensor src, torch::Tensor index,
torch::optional<torch::Tensor> optional_out,
torch::optional<int64_t> dim_size);

std::tuple<torch::Tensor, torch::Tensor>
SCATTER_API std::tuple<torch::Tensor, torch::Tensor>
segment_min_coo(torch::Tensor src, torch::Tensor index,
torch::optional<torch::Tensor> optional_out,
torch::optional<int64_t> dim_size);

std::tuple<torch::Tensor, torch::Tensor>
SCATTER_API std::tuple<torch::Tensor, torch::Tensor>
segment_max_coo(torch::Tensor src, torch::Tensor index,
torch::optional<torch::Tensor> optional_out,
torch::optional<int64_t> dim_size);

torch::Tensor gather_coo(torch::Tensor src, torch::Tensor index,
torch::optional<torch::Tensor> optional_out);
SCATTER_API torch::Tensor
gather_coo(torch::Tensor src, torch::Tensor index,
torch::optional<torch::Tensor> optional_out);

torch::Tensor segment_sum_csr(torch::Tensor src, torch::Tensor indptr,
torch::optional<torch::Tensor> optional_out);
SCATTER_API torch::Tensor
segment_sum_csr(torch::Tensor src, torch::Tensor indptr,
torch::optional<torch::Tensor> optional_out);

torch::Tensor segment_mean_csr(torch::Tensor src, torch::Tensor indptr,
torch::optional<torch::Tensor> optional_out);
SCATTER_API torch::Tensor
segment_mean_csr(torch::Tensor src, torch::Tensor indptr,
torch::optional<torch::Tensor> optional_out);

std::tuple<torch::Tensor, torch::Tensor>
SCATTER_API std::tuple<torch::Tensor, torch::Tensor>
segment_min_csr(torch::Tensor src, torch::Tensor indptr,
torch::optional<torch::Tensor> optional_out);

std::tuple<torch::Tensor, torch::Tensor>
SCATTER_API std::tuple<torch::Tensor, torch::Tensor>
segment_max_csr(torch::Tensor src, torch::Tensor indptr,
torch::optional<torch::Tensor> optional_out);

torch::Tensor gather_csr(torch::Tensor src, torch::Tensor indptr,
torch::optional<torch::Tensor> optional_out);
SCATTER_API torch::Tensor
gather_csr(torch::Tensor src, torch::Tensor indptr,
torch::optional<torch::Tensor> optional_out);
24 changes: 14 additions & 10 deletions csrc/segment_coo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#include <torch/script.h>

#include "cpu/segment_coo_cpu.h"
#include "macros.h"
#include "utils.h"

#ifdef WITH_CUDA
Expand Down Expand Up @@ -195,36 +196,39 @@ class GatherCOO : public torch::autograd::Function<GatherCOO> {
}
};

torch::Tensor segment_sum_coo(torch::Tensor src, torch::Tensor index,
torch::optional<torch::Tensor> optional_out,
torch::optional<int64_t> dim_size) {
SCATTER_API torch::Tensor
segment_sum_coo(torch::Tensor src, torch::Tensor index,
torch::optional<torch::Tensor> optional_out,
torch::optional<int64_t> dim_size) {
return SegmentSumCOO::apply(src, index, optional_out, dim_size)[0];
}

torch::Tensor segment_mean_coo(torch::Tensor src, torch::Tensor index,
torch::optional<torch::Tensor> optional_out,
torch::optional<int64_t> dim_size) {
SCATTER_API torch::Tensor
segment_mean_coo(torch::Tensor src, torch::Tensor index,
torch::optional<torch::Tensor> optional_out,
torch::optional<int64_t> dim_size) {
return SegmentMeanCOO::apply(src, index, optional_out, dim_size)[0];
}

std::tuple<torch::Tensor, torch::Tensor>
SCATTER_API std::tuple<torch::Tensor, torch::Tensor>
segment_min_coo(torch::Tensor src, torch::Tensor index,
torch::optional<torch::Tensor> optional_out,
torch::optional<int64_t> dim_size) {
auto result = SegmentMinCOO::apply(src, index, optional_out, dim_size);
return std::make_tuple(result[0], result[1]);
}

std::tuple<torch::Tensor, torch::Tensor>
SCATTER_API std::tuple<torch::Tensor, torch::Tensor>
segment_max_coo(torch::Tensor src, torch::Tensor index,
torch::optional<torch::Tensor> optional_out,
torch::optional<int64_t> dim_size) {
auto result = SegmentMaxCOO::apply(src, index, optional_out, dim_size);
return std::make_tuple(result[0], result[1]);
}

torch::Tensor gather_coo(torch::Tensor src, torch::Tensor index,
torch::optional<torch::Tensor> optional_out) {
SCATTER_API torch::Tensor
gather_coo(torch::Tensor src, torch::Tensor index,
torch::optional<torch::Tensor> optional_out) {
return GatherCOO::apply(src, index, optional_out)[0];
}

Expand Down
20 changes: 12 additions & 8 deletions csrc/segment_csr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#include <torch/script.h>

#include "cpu/segment_csr_cpu.h"
#include "macros.h"
#include "utils.h"

#ifdef WITH_CUDA
Expand Down Expand Up @@ -192,32 +193,35 @@ class GatherCSR : public torch::autograd::Function<GatherCSR> {
}
};

torch::Tensor segment_sum_csr(torch::Tensor src, torch::Tensor indptr,
torch::optional<torch::Tensor> optional_out) {
SCATTER_API torch::Tensor
segment_sum_csr(torch::Tensor src, torch::Tensor indptr,
torch::optional<torch::Tensor> optional_out) {
return SegmentSumCSR::apply(src, indptr, optional_out)[0];
}

torch::Tensor segment_mean_csr(torch::Tensor src, torch::Tensor indptr,
torch::optional<torch::Tensor> optional_out) {
SCATTER_API torch::Tensor
segment_mean_csr(torch::Tensor src, torch::Tensor indptr,
torch::optional<torch::Tensor> optional_out) {
return SegmentMeanCSR::apply(src, indptr, optional_out)[0];
}

std::tuple<torch::Tensor, torch::Tensor>
SCATTER_API std::tuple<torch::Tensor, torch::Tensor>
segment_min_csr(torch::Tensor src, torch::Tensor indptr,
torch::optional<torch::Tensor> optional_out) {
auto result = SegmentMinCSR::apply(src, indptr, optional_out);
return std::make_tuple(result[0], result[1]);
}

std::tuple<torch::Tensor, torch::Tensor>
SCATTER_API std::tuple<torch::Tensor, torch::Tensor>
segment_max_csr(torch::Tensor src, torch::Tensor indptr,
torch::optional<torch::Tensor> optional_out) {
auto result = SegmentMaxCSR::apply(src, indptr, optional_out);
return std::make_tuple(result[0], result[1]);
}

torch::Tensor gather_csr(torch::Tensor src, torch::Tensor indptr,
torch::optional<torch::Tensor> optional_out) {
SCATTER_API torch::Tensor
gather_csr(torch::Tensor src, torch::Tensor indptr,
torch::optional<torch::Tensor> optional_out) {
return GatherCSR::apply(src, indptr, optional_out)[0];
}

Expand Down
4 changes: 3 additions & 1 deletion csrc/version.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#include <Python.h>
#include <torch/script.h>
#include "scatter.h"
#include "macros.h"

#ifdef WITH_CUDA
#include <cuda.h>
Expand All @@ -14,7 +16,7 @@ PyMODINIT_FUNC PyInit__version_cpu(void) { return NULL; }
#endif

namespace scatter {
int64_t cuda_version() noexcept {
SCATTER_API int64_t cuda_version() noexcept {
#ifdef WITH_CUDA
return CUDA_VERSION;
#else
Expand Down
4 changes: 4 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ def get_extensions():

for main, suffix in product(main_files, suffices):
define_macros = []

if sys.platform == 'win32':
define_macros += [('torchscatter_EXPORTS', None)]

extra_compile_args = {'cxx': ['-O2']}
if not os.name == 'nt': # Not on Windows:
extra_compile_args['cxx'] += ['-Wno-sign-compare']
Expand Down