From 26a9e988ab78d075097fc4645883696d4d41ca9c Mon Sep 17 00:00:00 2001 From: rusty1s Date: Wed, 29 Jan 2020 12:01:44 +0100 Subject: [PATCH 01/12] tracebale segment csr --- cpu/gather.cpp | 144 --------------------- cpu/reducer.h | 68 ++++++++++ cpu/segment.cpp | 256 -------------------------------------- cpu/segment_coo.cpp | 7 ++ cpu/segment_coo_impl.h | 182 +++++++++++++++++++++++++++ cpu/segment_csr.cpp | 41 ++++++ cpu/segment_csr_impl.h | 146 ++++++++++++++++++++++ cpu/utils.h | 6 + test/test_jit.py | 31 +++++ torch_scatter/__init__.py | 8 +- 10 files changed, 485 insertions(+), 404 deletions(-) delete mode 100644 cpu/gather.cpp create mode 100644 cpu/reducer.h delete mode 100644 cpu/segment.cpp create mode 100644 cpu/segment_coo.cpp create mode 100644 cpu/segment_coo_impl.h create mode 100644 cpu/segment_csr.cpp create mode 100644 cpu/segment_csr_impl.h create mode 100644 cpu/utils.h create mode 100644 test/test_jit.py diff --git a/cpu/gather.cpp b/cpu/gather.cpp deleted file mode 100644 index a8abdc62..00000000 --- a/cpu/gather.cpp +++ /dev/null @@ -1,144 +0,0 @@ -#include - -#include "compat.h" -#include "index_info.h" - -#include - -#define CHECK_CPU(x) AT_ASSERTM(x.device().is_cpu(), #x " must be CPU tensor") - -torch::Tensor gather_csr(torch::Tensor src, torch::Tensor indptr, - torch::optional out_opt) { - CHECK_CPU(src); - CHECK_CPU(indptr); - if (out_opt.has_value()) - CHECK_CPU(out_opt.value()); - - AT_ASSERTM(src.dim() >= indptr.dim(), "Input mismatch"); - for (int i = 0; i < indptr.dim() - 1; i++) - AT_ASSERTM(src.size(i) == indptr.size(i), "Input mismatch"); - - src = src.contiguous(); - auto gather_dim = indptr.dim() - 1; - AT_ASSERTM(src.size(gather_dim) == indptr.size(gather_dim) - 1, - "Input mismatch"); - - torch::Tensor out; - if (out_opt.has_value()) { - out = out_opt.value().contiguous(); - for (int i = 0; i < out.dim(); i++) - if (i != gather_dim) - AT_ASSERTM(src.size(i) == out.size(i), "Input mismatch"); - } else { - auto sizes = src.sizes().vec(); - sizes[gather_dim] = *indptr.flatten()[-1].DATA_PTR(); - out = torch::empty(sizes, src.options()); - } - - auto N = src.size(gather_dim) * (indptr.numel() / indptr.size(-1)); - auto K = src.numel() / N; - auto E = out.size(gather_dim); - - auto indptr_info = getTensorInfo(indptr); - auto stride = indptr_info.strides[indptr_info.dims - 1]; - AT_DISPATCH_ALL_TYPES(src.scalar_type(), "gather_csr", [&] { - auto src_data = src.DATA_PTR(); - auto out_data = out.DATA_PTR(); - - std::vector vals(K); - int64_t row_start, row_end; - for (int n = 0; n < N; n++) { - int offset = IndexPtrToOffset::get(n, indptr_info); - row_start = indptr_info.data[offset]; - row_end = indptr_info.data[offset + stride]; - - for (int k = 0; k < K; k++) { - vals[k] = src_data[n * K + k]; - } - - offset = (n / (indptr.size(-1) - 1)) * E * K; - for (int64_t e = row_start; e < row_end; e++) { - for (int k = 0; k < K; k++) { - out_data[offset + e * K + k] = vals[k]; - } - } - } - }); - - return out; -} - -torch::Tensor gather_coo(torch::Tensor src, torch::Tensor index, - torch::optional out_opt) { - CHECK_CPU(src); - CHECK_CPU(index); - if (out_opt.has_value()) - CHECK_CPU(out_opt.value()); - - AT_ASSERTM(src.dim() >= index.dim(), "Input mismatch"); - for (int i = 0; i < index.dim() - 1; i++) - AT_ASSERTM(src.size(i) == index.size(i), "Input mismatch"); - - src = src.contiguous(); - auto gather_dim = index.dim() - 1; - - torch::Tensor out; - if (out_opt.has_value()) { - out = out_opt.value().contiguous(); - for (int i = 0; i < index.dim(); i++) - AT_ASSERTM(out.size(i) == index.size(i), "Input mismatch"); - for (int i = index.dim() + 1; i < src.dim(); i++) - AT_ASSERTM(out.size(i) == src.size(i), "Input mismatch"); - } else { - auto sizes = src.sizes().vec(); - sizes[gather_dim] = index.size(gather_dim); - out = torch::empty(sizes, src.options()); - } - - auto E_1 = index.numel() / out.size(gather_dim); - auto E_2 = index.size(gather_dim); - auto K = out.numel() / index.numel(); - auto N = src.size(gather_dim); - - auto index_info = getTensorInfo(index); - auto stride = index_info.strides[index_info.dims - 1]; - AT_DISPATCH_ALL_TYPES(src.scalar_type(), "gather_coo", [&] { - auto src_data = src.DATA_PTR(); - auto out_data = out.DATA_PTR(); - - std::vector vals(K); - int64_t idx, next_idx; - for (int e_1 = 0; e_1 < E_1; e_1++) { - int offset = IndexToOffset::get(e_1 * E_2, index_info); - idx = index_info.data[offset]; - - for (int k = 0; k < K; k++) { - vals[k] = src_data[e_1 * N * K + idx * K + k]; - } - - for (int e_2 = 0; e_2 < E_2; e_2++) { - for (int k = 0; k < K; k++) { - out_data[e_1 * E_2 * K + e_2 * K + k] = vals[k]; - } - - if (e_2 < E_2 - 1) { - next_idx = index_info.data[offset + (e_2 + 1) * stride]; - assert(idx <= next_idx); - - if (idx != next_idx) { - idx = next_idx; - for (int k = 0; k < K; k++) { - vals[k] = src_data[e_1 * N * K + idx * K + k]; - } - } - } - } - } - }); - - return out; -} - -static auto registry = - torch::RegisterOperators("torch_scatter_cpu::gather_csr", &gather_csr) - .op("torch_scatter_cpu::gather_coo", &gather_coo); diff --git a/cpu/reducer.h b/cpu/reducer.h new file mode 100644 index 00000000..91d730ee --- /dev/null +++ b/cpu/reducer.h @@ -0,0 +1,68 @@ +#pragma once + +#include + +enum ReductionType { SUM, MEAN, MIN, MAX }; + +const std::map reduce2REDUCE = { + {"sum", SUM}, {"add", SUM}, {"mean", MEAN}, {"min", MIN}, {"max", MAX}, +}; + +#define AT_DISPATCH_REDUCTION_TYPES(reduce, ...) \ + [&] { \ + switch (reduce2REDUCE.at(reduce)) { \ + case SUM: { \ + const ReductionType REDUCE = SUM; \ + return __VA_ARGS__(); \ + } \ + case MEAN: { \ + const ReductionType REDUCE = MEAN; \ + return __VA_ARGS__(); \ + } \ + case MIN: { \ + const ReductionType REDUCE = MIN; \ + return __VA_ARGS__(); \ + } \ + case MAX: { \ + const ReductionType REDUCE = MAX; \ + return __VA_ARGS__(); \ + } \ + } \ + }() + +template struct Reducer { + static inline scalar_t init() { + if (REDUCE == MIN) + return std::numeric_limits::max(); + else if (REDUCE == MAX) + return std::numeric_limits::lowest(); + else + return (scalar_t)0; + } + + static inline void update(scalar_t *val, scalar_t new_val, int64_t *arg, + int64_t new_arg) { + if (REDUCE == SUM || REDUCE == MEAN) + *val = *val + new_val; + else if ((REDUCE == MIN && new_val < *val) || + (REDUCE == MAX && new_val > *val)) { + *val = new_val; + *arg = new_arg; + } + } + + static inline void write(scalar_t *address, scalar_t val, + int64_t *arg_address, int64_t arg, int count) { + if (REDUCE == SUM) + *address = val; + else if (REDUCE == MEAN) + *address = val / (count > 0 ? count : (scalar_t)1); + else if (REDUCE == MIN || REDUCE == MAX) { + if (count > 0) { + *address = val; + *arg_address = arg; + } else + *address = (scalar_t)0; + } + } +}; diff --git a/cpu/segment.cpp b/cpu/segment.cpp deleted file mode 100644 index f67bef07..00000000 --- a/cpu/segment.cpp +++ /dev/null @@ -1,256 +0,0 @@ -#include - -#include "compat.h" -#include "index_info.h" - -#include - -#define CHECK_CPU(x) AT_ASSERTM(x.device().is_cpu(), #x " must be CPU tensor") - -enum ReductionType { SUM, MEAN, MIN, MAX }; - -const std::map reduce2REDUCE = { - {"sum", SUM}, {"add", SUM}, {"mean", MEAN}, {"min", MIN}, {"max", MAX}, -}; - -#define AT_DISPATCH_REDUCTION_TYPES(reduce, ...) \ - [&] { \ - switch (reduce2REDUCE.at(reduce)) { \ - case SUM: { \ - const ReductionType REDUCE = SUM; \ - return __VA_ARGS__(); \ - } \ - case MEAN: { \ - const ReductionType REDUCE = MEAN; \ - return __VA_ARGS__(); \ - } \ - case MIN: { \ - const ReductionType REDUCE = MIN; \ - return __VA_ARGS__(); \ - } \ - case MAX: { \ - const ReductionType REDUCE = MAX; \ - return __VA_ARGS__(); \ - } \ - } \ - }() - -template struct Reducer { - static inline scalar_t init() { - if (REDUCE == MIN) { - return std::numeric_limits::max(); - } else if (REDUCE == MAX) { - return std::numeric_limits::lowest(); - } else { - return (scalar_t)0; - } - } - - static inline void update(scalar_t *val, scalar_t new_val, int64_t *arg, - int64_t new_arg) { - if (REDUCE == SUM || REDUCE == MEAN) { - *val = *val + new_val; - } else if ((REDUCE == MIN && new_val < *val) || - (REDUCE == MAX && new_val > *val)) { - *val = new_val; - *arg = new_arg; - } - } - - static inline void write(scalar_t *address, scalar_t val, - int64_t *arg_address, int64_t arg, int count) { - if (REDUCE == SUM) { - *address = val; - } else if (REDUCE == MEAN) { - *address = val / (count > 0 ? count : (scalar_t)1); - } else if (REDUCE == MIN || REDUCE == MAX) { - if (count > 0) { - *address = val; - *arg_address = arg; - } else { - *address = (scalar_t)0; - } - } - } -}; - -std::tuple> -segment_csr(torch::Tensor src, torch::Tensor indptr, - torch::optional out_opt, std::string reduce) { - CHECK_CPU(src); - CHECK_CPU(indptr); - if (out_opt.has_value()) - CHECK_CPU(out_opt.value()); - - AT_ASSERTM(src.dim() >= indptr.dim(), "Input mismatch"); - - // Broadcasting `indptr` via `expand`. - auto sizes = indptr.sizes().vec(); - for (int i = 0; i < indptr.dim() - 1; i++) { - sizes[i] = src.size(i); - } - indptr = indptr.expand(sizes); - - src = src.contiguous(); - auto reduce_dim = indptr.dim() - 1; - - torch::Tensor out; - if (out_opt.has_value()) { - out = out_opt.value().contiguous(); - for (int i = 0; i < out.dim(); i++) - if (i != reduce_dim) - AT_ASSERTM(src.size(i) == out.size(i), "Input mismatch"); - AT_ASSERTM(out.size(reduce_dim) == indptr.size(reduce_dim) - 1, - "Input mismatch"); - } else { - sizes = src.sizes().vec(); - sizes[reduce_dim] = indptr.size(reduce_dim) - 1; - out = torch::empty(sizes, src.options()); - } - - torch::optional arg_out = torch::nullopt; - int64_t *arg_out_data = nullptr; - if (reduce2REDUCE.at(reduce) == MIN || reduce2REDUCE.at(reduce) == MAX) { - arg_out = torch::full_like(out, src.size(reduce_dim), indptr.options()); - arg_out_data = arg_out.value().DATA_PTR(); - } - - auto N = out.size(reduce_dim) * (indptr.numel() / indptr.size(-1)); - auto K = out.numel() / N; - auto E = src.size(reduce_dim); - - auto indptr_info = getTensorInfo(indptr); - auto stride = indptr_info.strides[indptr_info.dims - 1]; - AT_DISPATCH_ALL_TYPES(src.scalar_type(), "segment_csr", [&] { - auto src_data = src.DATA_PTR(); - auto out_data = out.DATA_PTR(); - - std::vector vals(K); - int64_t row_start, row_end; - std::vector args(K); - AT_DISPATCH_REDUCTION_TYPES(reduce, [&] { - for (int n = 0; n < N; n++) { - int offset = IndexPtrToOffset::get(n, indptr_info); - row_start = indptr_info.data[offset]; - row_end = indptr_info.data[offset + stride]; - - offset = (n / (indptr.size(-1) - 1)) * E * K; - for (int k = 0; k < K; k++) { - vals[k] = Reducer::init(); - } - for (int64_t e = row_start; e < row_end; e++) { - for (int k = 0; k < K; k++) { - Reducer::update( - &vals[k], src_data[offset + e * K + k], &args[k], e); - } - } - for (int k = 0; k < K; k++) { - Reducer::write(out_data + n * K + k, vals[k], - arg_out_data + n * K + k, args[k], - row_end - row_start); - } - } - }); - }); - - return std::make_tuple(out, arg_out); -} - -std::tuple> -segment_coo(torch::Tensor src, torch::Tensor index, torch::Tensor out, - std::string reduce) { - CHECK_CPU(src); - CHECK_CPU(index); - CHECK_CPU(out); - - AT_ASSERTM(src.dim() >= index.dim(), "Input mismatch"); - - // Broadcasting `index` via `expand`. - auto sizes = index.sizes().vec(); - for (int i = 0; i < index.dim(); i++) { - sizes[i] = src.size(i); - } - index = index.expand(sizes); - - src = src.contiguous(); - out = out.contiguous(); - auto reduce_dim = index.dim() - 1; - - for (int i = 0; i < out.dim(); i++) - if (i != reduce_dim) - AT_ASSERTM(src.size(i) == out.size(i), "Input mismatch"); - - torch::optional arg_out = torch::nullopt; - int64_t *arg_out_data = nullptr; - if (reduce2REDUCE.at(reduce) == MIN || reduce2REDUCE.at(reduce) == MAX) { - arg_out = torch::full_like(out, src.size(reduce_dim), index.options()); - arg_out_data = arg_out.value().DATA_PTR(); - } - - auto E_1 = index.numel() / src.size(reduce_dim); - auto E_2 = src.size(reduce_dim); - auto K = src.numel() / index.numel(); - auto N = out.size(reduce_dim); - - auto index_info = getTensorInfo(index); - auto stride = index_info.strides[index_info.dims - 1]; - AT_DISPATCH_ALL_TYPES(src.scalar_type(), "segment_coo", [&] { - auto src_data = src.DATA_PTR(); - auto out_data = out.DATA_PTR(); - - std::vector vals(K); - int64_t idx, next_idx, row_start; - std::vector args(K); - AT_DISPATCH_REDUCTION_TYPES(reduce, [&] { - for (int e_1 = 0; e_1 < E_1; e_1++) { - int offset = IndexToOffset::get(e_1 * E_2, index_info); - idx = index_info.data[offset]; - - for (int k = 0; k < K; k++) { - vals[k] = out_data[e_1 * N * K + k]; - } - - row_start = 0; - for (int e_2 = 0; e_2 < E_2; e_2++) { - - for (int k = 0; k < K; k++) { - Reducer::update( - &vals[k], src_data[e_1 * E_2 * K + e_2 * K + k], &args[k], e_2); - } - - if (e_2 == E_2 - 1) { - for (int k = 0; k < K; k++) { - Reducer::write( - out_data + e_1 * N * K + idx * K + k, vals[k], - arg_out_data + e_1 * N * K + idx * K + k, args[k], - e_2 + 1 - row_start); - } - } else { - next_idx = index_info.data[offset + (e_2 + 1) * stride]; - assert(idx <= next_idx); - - if (idx != next_idx) { - for (int k = 0; k < K; k++) { - Reducer::write( - out_data + e_1 * N * K + idx * K + k, vals[k], - arg_out_data + e_1 * N * K + idx * K + k, args[k], - e_2 + 1 - row_start); - - vals[k] = out_data[e_1 * N * K + next_idx * K + k]; - } - row_start = e_2 + 1; - } - - idx = next_idx; - } - } - } - }); - }); - - return std::make_tuple(out, arg_out); -} - -static auto registry = - torch::RegisterOperators("torch_scatter_cpu::segment_csr", &segment_csr) - .op("torch_scatter_cpu::segment_coo", &segment_coo); diff --git a/cpu/segment_coo.cpp b/cpu/segment_coo.cpp new file mode 100644 index 00000000..9d10938f --- /dev/null +++ b/cpu/segment_coo.cpp @@ -0,0 +1,7 @@ +#include + +#include "segment_coo_impl.h" + +static auto registry = + torch::RegisterOperators("torch_scatter_cpu::segment_coo", &segment_coo) + .op("torch_scatter_cpu::gather_coo", &gather_coo); diff --git a/cpu/segment_coo_impl.h b/cpu/segment_coo_impl.h new file mode 100644 index 00000000..147134b3 --- /dev/null +++ b/cpu/segment_coo_impl.h @@ -0,0 +1,182 @@ +#pragma once + +#include + +#include "compat.h" +#include "index_info.h" +#include "reducer.h" +#include "utils.h" + +std::tuple> +segment_coo(torch::Tensor src, torch::Tensor index, + torch::optional optional_out, std::string reduce) { + CHECK_CPU(src); + CHECK_CPU(index); + if (optional_out.has_value()) + CHECK_CPU(optional_out.value()); + + CHECK_INPUT(src.dim() >= index.dim()); + + // Broadcasting `index` via `expand`. + auto sizes = index.sizes().vec(); + for (int i = 0; i < index.dim(); i++) + sizes[i] = src.size(i); + index = index.expand(sizes); + + auto dim = index.dim() - 1; + + src = src.contiguous(); + + torch::Tensor out; + if (optional_out.has_value()) { + out = optional_out.value().contiguous(); + for (int i = 0; i < out.dim(); i++) + if (i != dim) + CHECK_INPUT(src.size(i) == out.size(i)); + } else { + sizes = src.sizes().vec(); + sizes[dim] = *index.max().DATA_PTR(); + out = torch::empty(sizes, src.options()); + } + + torch::optional arg_out = torch::nullopt; + int64_t *arg_out_data = nullptr; + if (reduce2REDUCE.at(reduce) == MIN || reduce2REDUCE.at(reduce) == MAX) { + arg_out = torch::full_like(out, src.size(dim), index.options()); + arg_out_data = arg_out.value().DATA_PTR(); + } + + auto B = index.numel() / src.size(dim); + auto E = src.size(dim); + auto K = src.numel() / index.numel(); + auto N = out.size(dim); + + auto index_info = getTensorInfo(index); + auto stride = index_info.strides[index_info.dims - 1]; + std::vector args(K); + AT_DISPATCH_ALL_TYPES(src.scalar_type(), "segment_coo", [&] { + auto src_data = src.DATA_PTR(); + auto out_data = out.DATA_PTR(); + + std::vector vals(K); + int64_t idx, next_idx, row_start; + AT_DISPATCH_REDUCTION_TYPES(reduce, [&] { + if (!optional_out.has_value()) + out.fill_(Reducer::init()); + + for (auto b = 0; b < B; b++) { + auto offset = IndexToOffset::get(b * E, index_info); + idx = index_info.data[offset]; + + for (auto k = 0; k < K; k++) + vals[k] = out_data[b * N * K + k]; + + row_start = 0; + for (auto e = 0; e < E; e++) { + + for (auto k = 0; k < K; k++) + Reducer::update( + &vals[k], src_data[b * E * K + e * K + k], &args[k], e); + + if (e == E - 1) { + for (auto k = 0; k < K; k++) + Reducer::write( + out_data + b * N * K + idx * K + k, vals[k], + arg_out_data + b * N * K + idx * K + k, args[k], + e + 1 - row_start); + } else { + next_idx = index_info.data[offset + (e + 1) * stride]; + assert(idx <= next_idx); + + if (idx != next_idx) { + for (auto k = 0; k < K; k++) { + Reducer::write( + out_data + b * N * K + idx * K + k, vals[k], + arg_out_data + b * N * K + idx * K + k, args[k], + e + 1 - row_start); + + vals[k] = out_data[b * N * K + next_idx * K + k]; + } + row_start = e + 1; + } + + idx = next_idx; + } + } + } + if (!optional_out.has_value() && (REDUCE == MIN || REDUCE == MAX)) { + out.masked_fill_(out == Reducer::init(), (scalar_t)0); + } + }); + }); + + return std::make_tuple(out, arg_out); +} + +torch::Tensor gather_coo(torch::Tensor src, torch::Tensor index, + torch::optional optional_out) { + CHECK_CPU(src); + CHECK_CPU(index); + if (optional_out.has_value()) + CHECK_CPU(optional_out.value()); + + CHECK_INPUT(src.dim() >= index.dim()); + for (auto i = 0; i < index.dim() - 1; i++) + CHECK_INPUT(src.size(i) == index.size(i)); + + auto dim = index.dim() - 1; + + src = src.contiguous(); + + torch::Tensor out; + if (optional_out.has_value()) { + out = optional_out.value().contiguous(); + for (auto i = 0; i < src.dim(); i++) + if (i != dim) + CHECK_INPUT(src.size(i) == out.size(i)); + } else { + auto sizes = src.sizes().vec(); + sizes[dim] = index.size(dim); + out = torch::empty(sizes, src.options()); + } + + auto B = index.numel() / out.size(dim); + auto E = index.size(dim); + auto K = out.numel() / index.numel(); + auto N = src.size(dim); + + auto index_info = getTensorInfo(index); + auto stride = index_info.strides[index_info.dims - 1]; + AT_DISPATCH_ALL_TYPES(src.scalar_type(), "gather_coo", [&] { + auto src_data = src.DATA_PTR(); + auto out_data = out.DATA_PTR(); + + std::vector vals(K); + int64_t idx, next_idx; + for (auto b = 0; b < B; b++) { + auto offset = IndexToOffset::get(b * E, index_info); + idx = index_info.data[offset]; + + for (auto k = 0; k < K; k++) + vals[k] = src_data[b * N * K + idx * K + k]; + + for (auto e = 0; e < E; e++) { + for (auto k = 0; k < K; k++) + out_data[b * E * K + e * K + k] = vals[k]; + + if (e < E - 1) { + next_idx = index_info.data[offset + (e + 1) * stride]; + CHECK_INPUT(idx < E && idx <= next_idx); + + if (idx != next_idx) { + idx = next_idx; + for (auto k = 0; k < K; k++) + vals[k] = src_data[b * N * K + idx * K + k]; + } + } + } + } + }); + + return out; +} diff --git a/cpu/segment_csr.cpp b/cpu/segment_csr.cpp new file mode 100644 index 00000000..c39dd829 --- /dev/null +++ b/cpu/segment_csr.cpp @@ -0,0 +1,41 @@ +#include + +#include "segment_csr_impl.h" + +using torch::autograd::AutogradContext; +using torch::autograd::Variable; +using torch::autograd::variable_list; + +class SegmentSumCSR : public torch::autograd::Function { +public: + static variable_list forward(AutogradContext *ctx, Variable src, + Variable indptr, + torch::optional optional_out) { + ctx->saved_data["src_shape"] = src.sizes(); + auto result = segment_csr(src, indptr, optional_out, "sum"); + auto out = std::get<0>(result); + ctx->save_for_backward({indptr}); + return {out}; + } + + static variable_list backward(AutogradContext *ctx, variable_list grad_outs) { + auto grad_out = grad_outs[0]; + auto saved = ctx->get_saved_variables(); + auto indptr = saved[0]; + auto src_shape = ctx->saved_data["src_shape"].toIntVector(); + auto grad_in = torch::empty(src_shape, grad_out.options()); + gather_csr(grad_out, indptr, grad_in); + + return {grad_in, Variable(), Variable()}; + } +}; + +torch::Tensor segment_sum_csr(torch::Tensor src, torch::Tensor indptr, + torch::optional optional_out) { + return SegmentSumCSR::apply(src, indptr, optional_out)[0]; +} + +static auto registry = + torch::RegisterOperators("torch_scatter_cpu::segment_csr", &segment_csr) + .op("torch_scatter_cpu::gather_csr", &gather_csr) + .op("torch_scatter_cpu::segment_sum_csr", &segment_sum_csr); diff --git a/cpu/segment_csr_impl.h b/cpu/segment_csr_impl.h new file mode 100644 index 00000000..8823c654 --- /dev/null +++ b/cpu/segment_csr_impl.h @@ -0,0 +1,146 @@ +#pragma once + +#include + +#include "compat.h" +#include "index_info.h" +#include "reducer.h" +#include "utils.h" + +std::tuple> +segment_csr(torch::Tensor src, torch::Tensor indptr, + torch::optional optional_out, std::string reduce) { + CHECK_CPU(src); + CHECK_CPU(indptr); + if (optional_out.has_value()) + CHECK_CPU(optional_out.value()); + + CHECK_INPUT(src.dim() >= indptr.dim()); + + auto sizes = indptr.sizes().vec(); + for (auto i = 0; i < indptr.dim() - 1; i++) + sizes[i] = src.size(i); + indptr = indptr.expand(sizes); + + auto dim = indptr.dim() - 1; + + src = src.contiguous(); + + torch::Tensor out; + if (optional_out.has_value()) { + out = optional_out.value().contiguous(); + for (int i = 0; i < out.dim(); i++) + if (i != dim) + CHECK_INPUT(src.size(i) == out.size(i)); + CHECK_INPUT(out.size(dim) == indptr.size(dim) - 1); + } else { + sizes = src.sizes().vec(); + sizes[dim] = indptr.size(dim) - 1; + out = torch::empty(sizes, src.options()); + } + + torch::optional arg_out = torch::nullopt; + int64_t *arg_out_data = nullptr; + if (reduce2REDUCE.at(reduce) == MIN || reduce2REDUCE.at(reduce) == MAX) { + arg_out = torch::full(out.sizes(), src.size(dim), indptr.options()); + arg_out_data = arg_out.value().DATA_PTR(); + } + + auto N = out.size(dim) * (indptr.numel() / indptr.size(-1)); + auto K = out.numel() / N; + auto E = src.size(dim); + + auto indptr_info = getTensorInfo(indptr); + auto stride = indptr_info.strides[indptr_info.dims - 1]; + std::vector args(K); + AT_DISPATCH_ALL_TYPES(src.scalar_type(), "segment_csr", [&] { + auto src_data = src.DATA_PTR(); + auto out_data = out.DATA_PTR(); + + std::vector vals(K); + int64_t row_start, row_end; + AT_DISPATCH_REDUCTION_TYPES(reduce, [&] { + for (auto n = 0; n < N; n++) { + auto offset = IndexPtrToOffset::get(n, indptr_info); + row_start = indptr_info.data[offset]; + row_end = indptr_info.data[offset + stride]; + + offset = (n / (indptr.size(-1) - 1)) * E * K; + for (auto k = 0; k < K; k++) + vals[k] = Reducer::init(); + + for (auto e = row_start; e < row_end; e++) { + CHECK_INPUT(e < E); + for (auto k = 0; k < K; k++) + Reducer::update( + &vals[k], src_data[offset + e * K + k], &args[k], e); + } + + for (auto k = 0; k < K; k++) + Reducer::write(out_data + n * K + k, vals[k], + arg_out_data + n * K + k, args[k], + row_end - row_start); + } + }); + }); + + return std::make_tuple(out, arg_out); +} + +torch::Tensor gather_csr(torch::Tensor src, torch::Tensor indptr, + torch::optional optional_out) { + CHECK_CPU(src); + CHECK_CPU(indptr); + if (optional_out.has_value()) + CHECK_CPU(optional_out.value()); + + CHECK_INPUT(src.dim() >= indptr.dim()); + for (auto i = 0; i < indptr.dim() - 1; i++) + CHECK_INPUT(src.size(i) == indptr.size(i)); + + auto dim = indptr.dim() - 1; + CHECK_INPUT(src.size(dim) == indptr.size(dim) - 1); + + src = src.contiguous(); + + torch::Tensor out; + if (optional_out.has_value()) { + out = optional_out.value().contiguous(); + for (auto i = 0; i < out.dim(); i++) + if (i != dim) + CHECK_INPUT(src.size(i) == out.size(i)); + } else { + auto sizes = src.sizes().vec(); + sizes[dim] = *indptr.flatten()[-1].DATA_PTR(); + out = torch::empty(sizes, src.options()); + } + + auto N = src.size(dim) * (indptr.numel() / indptr.size(-1)); + auto K = src.numel() / N; + auto E = out.size(dim); + + auto indptr_info = getTensorInfo(indptr); + auto stride = indptr_info.strides[indptr_info.dims - 1]; + AT_DISPATCH_ALL_TYPES(src.scalar_type(), "gather_csr", [&] { + auto src_data = src.DATA_PTR(); + auto out_data = out.DATA_PTR(); + + std::vector vals(K); + int64_t row_start, row_end; + for (int n = 0; n < N; n++) { + auto offset = IndexPtrToOffset::get(n, indptr_info); + row_start = indptr_info.data[offset]; + row_end = indptr_info.data[offset + stride]; + + for (auto k = 0; k < K; k++) + vals[k] = src_data[n * K + k]; + + offset = (n / (indptr.size(-1) - 1)) * E * K; + for (auto e = row_start; e < row_end; e++) + for (auto k = 0; k < K; k++) + out_data[offset + e * K + k] = vals[k]; + } + }); + + return out; +} diff --git a/cpu/utils.h b/cpu/utils.h new file mode 100644 index 00000000..40dfb344 --- /dev/null +++ b/cpu/utils.h @@ -0,0 +1,6 @@ +#pragma once + +#include + +#define CHECK_CPU(x) AT_ASSERTM(x.device().is_cpu(), #x " must be CPU tensor") +#define CHECK_INPUT(x) AT_ASSERTM(x, "Input mismatch") diff --git a/test/test_jit.py b/test/test_jit.py new file mode 100644 index 00000000..74f9d815 --- /dev/null +++ b/test/test_jit.py @@ -0,0 +1,31 @@ +from typing import Optional + +import torch +import torch_scatter + + +@torch.jit.script +def segment_csr(src: torch.Tensor, indptr: torch.Tensor, + out: Optional[torch.Tensor] = None, reduce: str = "sum"): + return torch.ops.torch_scatter_cpu.segment_sum_csr(src, indptr, out) + + +def test_jit(): + # op = torch.ops.torch_scatter_cpu.segment_sum_csr + + src = torch.randn(8, 4) + src.requires_grad_() + indptr = torch.tensor([0, 2, 4, 6, 8]) + + out = segment_csr(src, indptr) + print(out) + + print(src.grad) + out.backward(torch.randn_like(out)) + print(src.grad) + + # op = torch.ops.torch_scatter_cpu.segment_csr + # out = op(src, indptr, None, "sum") + # print(out) + + # traced_cell = torch.jit.script(op) diff --git a/torch_scatter/__init__.py b/torch_scatter/__init__.py index 7a025d63..74f0648d 100644 --- a/torch_scatter/__init__.py +++ b/torch_scatter/__init__.py @@ -16,13 +16,13 @@ import torch_scatter.composite torch.ops.load_library('torch_scatter/scatter_cpu.so') -torch.ops.load_library('torch_scatter/segment_cpu.so') -torch.ops.load_library('torch_scatter/gather_cpu.so') +torch.ops.load_library('torch_scatter/segment_csr_cpu.so') +torch.ops.load_library('torch_scatter/segment_coo_cpu.so') try: torch.ops.load_library('torch_scatter/scatter_cuda.so') - torch.ops.load_library('torch_scatter/segment_cuda.so') - torch.ops.load_library('torch_scatter/gather_cuda.so') + # torch.ops.load_library('torch_scatter/segment_csr_cuda.so') + # torch.ops.load_library('torch_scatter/segment_coo_cuda.so') except OSError as e: if torch.cuda.is_available(): raise e From 0c887ffc590da3c7ff2b170943cc76d69d8fedfe Mon Sep 17 00:00:00 2001 From: rusty1s Date: Wed, 29 Jan 2020 14:22:12 +0100 Subject: [PATCH 02/12] segment/gather csr done --- LICENSE | 2 +- csrc/cpu/index_info.h | 63 ++++++++ csrc/cpu/reducer.h | 84 +++++++++++ csrc/cpu/segment_csr_cpu.cpp | 147 ++++++++++++++++++ csrc/cpu/segment_csr_cpu.h | 11 ++ csrc/cpu/utils.h | 6 + csrc/cuda/atomics.cuh | 230 +++++++++++++++++++++++++++++ csrc/cuda/index_info.cuh | 19 +++ csrc/cuda/reducer.cuh | 114 ++++++++++++++ csrc/cuda/segment_csr_cuda.cu | 270 ++++++++++++++++++++++++++++++++++ csrc/cuda/segment_csr_cuda.h | 11 ++ csrc/cuda/utils.cuh | 7 + csrc/segment_csr.cpp | 218 +++++++++++++++++++++++++++ setup.py | 100 ++++++------- test/test_segment.py | 54 +++---- torch_scatter/__init__.py | 50 ++----- torch_scatter/segment_csr.py | 59 ++++++++ 17 files changed, 1327 insertions(+), 118 deletions(-) create mode 100644 csrc/cpu/index_info.h create mode 100644 csrc/cpu/reducer.h create mode 100644 csrc/cpu/segment_csr_cpu.cpp create mode 100644 csrc/cpu/segment_csr_cpu.h create mode 100644 csrc/cpu/utils.h create mode 100644 csrc/cuda/atomics.cuh create mode 100644 csrc/cuda/index_info.cuh create mode 100644 csrc/cuda/reducer.cuh create mode 100644 csrc/cuda/segment_csr_cuda.cu create mode 100644 csrc/cuda/segment_csr_cuda.h create mode 100644 csrc/cuda/utils.cuh create mode 100644 csrc/segment_csr.cpp create mode 100644 torch_scatter/segment_csr.py diff --git a/LICENSE b/LICENSE index c4318cc1..9ca2096e 100644 --- a/LICENSE +++ b/LICENSE @@ -1,4 +1,4 @@ -Copyright (c) 2019 Matthias Fey +Copyright (c) 2020 Matthias Fey Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal diff --git a/csrc/cpu/index_info.h b/csrc/cpu/index_info.h new file mode 100644 index 00000000..9709a1de --- /dev/null +++ b/csrc/cpu/index_info.h @@ -0,0 +1,63 @@ +#pragma once + +#include + +#define MAX_TENSORINFO_DIMS 25 + +template struct TensorInfo { + TensorInfo(scalar_t *p, int dim, int sz[MAX_TENSORINFO_DIMS], + int st[MAX_TENSORINFO_DIMS]) { + data = p; + dims = dim; + AT_ASSERT(dims < MAX_TENSORINFO_DIMS); + + for (int i = 0; i < dim; ++i) { + sizes[i] = sz[i]; + strides[i] = st[i]; + } + } + + scalar_t *data; + int dims; + int sizes[MAX_TENSORINFO_DIMS]; + int strides[MAX_TENSORINFO_DIMS]; +}; + +template +TensorInfo getTensorInfo(const torch::Tensor &tensor) { + int sizes[MAX_TENSORINFO_DIMS]; + int strides[MAX_TENSORINFO_DIMS]; + + int dims = tensor.dim(); + for (int i = 0; i < dims; ++i) { + sizes[i] = tensor.size(i); + strides[i] = tensor.stride(i); + } + + return TensorInfo(tensor.data_ptr(), dims, sizes, + strides); +} + +template struct IndexToOffset { + static inline int get(int idx, const TensorInfo &info) { + int offset = 0; + for (int i = info.dims - 1; i >= 0; --i) { + offset += (idx % info.sizes[i]) * info.strides[i]; + idx /= info.sizes[i]; + } + return offset; + } +}; + +template struct IndexPtrToOffset { + static inline int get(int idx, const TensorInfo &info) { + int offset = idx % (info.sizes[info.dims - 1] - 1); + offset *= info.strides[info.dims - 1]; + idx /= info.sizes[info.dims - 1] - 1; + for (int i = info.dims - 2; i >= 0; --i) { + offset += (idx % info.sizes[i]) * info.strides[i]; + idx /= info.sizes[i]; + } + return offset; + } +}; diff --git a/csrc/cpu/reducer.h b/csrc/cpu/reducer.h new file mode 100644 index 00000000..edb0ec5f --- /dev/null +++ b/csrc/cpu/reducer.h @@ -0,0 +1,84 @@ +#pragma once + +#include +#include + +enum ReductionType { SUM, MEAN, MUL, DIV, MIN, MAX }; + +const std::map reduce2REDUCE = { + {"sum", SUM}, {"mean", MEAN}, {"mul", MUL}, + {"div", DIV}, {"min", MIN}, {"max", MAX}, +}; + +#define AT_DISPATCH_REDUCTION_TYPES(reduce, ...) \ + [&] { \ + switch (reduce2REDUCE.at(reduce)) { \ + case SUM: { \ + const ReductionType REDUCE = SUM; \ + return __VA_ARGS__(); \ + } \ + case MEAN: { \ + const ReductionType REDUCE = MEAN; \ + return __VA_ARGS__(); \ + } \ + case MUL: { \ + const ReductionType REDUCE = MUL; \ + return __VA_ARGS__(); \ + } \ + case DIV: { \ + const ReductionType REDUCE = DIV; \ + return __VA_ARGS__(); \ + } \ + case MIN: { \ + const ReductionType REDUCE = MIN; \ + return __VA_ARGS__(); \ + } \ + case MAX: { \ + const ReductionType REDUCE = MAX; \ + return __VA_ARGS__(); \ + } \ + } \ + }() + +template struct Reducer { + static inline scalar_t init() { + if (REDUCE == MUL || REDUCE == DIV) + return (scalar_t)1; + else if (REDUCE == MIN) + return std::numeric_limits::max(); + else if (REDUCE == MAX) + return std::numeric_limits::lowest(); + else + return (scalar_t)0; + } + + static inline void update(scalar_t *val, scalar_t new_val, int64_t *arg, + int64_t new_arg) { + if (REDUCE == SUM || REDUCE == MEAN) + *val = *val + new_val; + else if (REDUCE == MUL) + *val = *val * new_val; + else if (REDUCE == DIV) + *val = *val / new_val; + else if ((REDUCE == MIN && new_val < *val) || + (REDUCE == MAX && new_val > *val)) { + *val = new_val; + *arg = new_arg; + } + } + + static inline void write(scalar_t *address, scalar_t val, + int64_t *arg_address, int64_t arg, int count) { + if (REDUCE == SUM || REDUCE == MUL || REDUCE == DIV) + *address = val; + else if (REDUCE == MEAN) + *address = val / (count > 0 ? count : (scalar_t)1); + else if (REDUCE == MIN || REDUCE == MAX) { + if (count > 0) { + *address = val; + *arg_address = arg; + } else + *address = (scalar_t)0; + } + } +}; diff --git a/csrc/cpu/segment_csr_cpu.cpp b/csrc/cpu/segment_csr_cpu.cpp new file mode 100644 index 00000000..00790128 --- /dev/null +++ b/csrc/cpu/segment_csr_cpu.cpp @@ -0,0 +1,147 @@ +#include "segment_csr_cpu.h" + +#include "index_info.h" +#include "reducer.h" +#include "utils.h" + +std::tuple> +segment_csr_cpu(torch::Tensor src, torch::Tensor indptr, + torch::optional optional_out, + std::string reduce) { + CHECK_CPU(src); + CHECK_CPU(indptr); + if (optional_out.has_value()) + CHECK_CPU(optional_out.value()); + + CHECK_INPUT(src.dim() >= indptr.dim()); + + auto sizes = indptr.sizes().vec(); + for (auto i = 0; i < indptr.dim() - 1; i++) + sizes[i] = src.size(i); + indptr = indptr.expand(sizes); + + auto dim = indptr.dim() - 1; + + src = src.contiguous(); + + torch::Tensor out; + if (optional_out.has_value()) { + out = optional_out.value().contiguous(); + for (int i = 0; i < out.dim(); i++) + if (i != dim) + CHECK_INPUT(src.size(i) == out.size(i)); + CHECK_INPUT(out.size(dim) == indptr.size(dim) - 1); + } else { + sizes = src.sizes().vec(); + sizes[dim] = indptr.size(dim) - 1; + out = torch::empty(sizes, src.options()); + } + + torch::optional arg_out = torch::nullopt; + int64_t *arg_out_data = nullptr; + if (reduce2REDUCE.at(reduce) == MIN || reduce2REDUCE.at(reduce) == MAX) { + arg_out = torch::full(out.sizes(), src.size(dim), indptr.options()); + arg_out_data = arg_out.value().data_ptr(); + } + + auto N = out.size(dim) * (indptr.numel() / indptr.size(-1)); + auto K = out.numel() / N; + auto E = src.size(dim); + + auto indptr_info = getTensorInfo(indptr); + auto stride = indptr_info.strides[indptr_info.dims - 1]; + std::vector args(K); + AT_DISPATCH_ALL_TYPES(src.scalar_type(), "segment_csr", [&] { + auto src_data = src.data_ptr(); + auto out_data = out.data_ptr(); + + std::vector vals(K); + int64_t row_start, row_end; + AT_DISPATCH_REDUCTION_TYPES(reduce, [&] { + for (auto n = 0; n < N; n++) { + auto offset = IndexPtrToOffset::get(n, indptr_info); + row_start = indptr_info.data[offset]; + row_end = indptr_info.data[offset + stride]; + + offset = (n / (indptr.size(-1) - 1)) * E * K; + for (auto k = 0; k < K; k++) + vals[k] = Reducer::init(); + + for (auto e = row_start; e < row_end; e++) { + CHECK_INPUT(e < E); + for (auto k = 0; k < K; k++) + Reducer::update( + &vals[k], src_data[offset + e * K + k], &args[k], e); + } + + for (auto k = 0; k < K; k++) + Reducer::write(out_data + n * K + k, vals[k], + arg_out_data + n * K + k, args[k], + row_end - row_start); + } + }); + }); + + return std::make_tuple(out, arg_out); +} + +torch::Tensor gather_csr_cpu(torch::Tensor src, torch::Tensor indptr, + torch::optional optional_out) { + CHECK_CPU(src); + CHECK_CPU(indptr); + if (optional_out.has_value()) + CHECK_CPU(optional_out.value()); + + CHECK_INPUT(src.dim() >= indptr.dim()); + + auto sizes = indptr.sizes().vec(); + for (auto i = 0; i < indptr.dim() - 1; i++) + sizes[i] = src.size(i); + indptr = indptr.expand(sizes); + + auto dim = indptr.dim() - 1; + CHECK_INPUT(src.size(dim) == indptr.size(dim) - 1); + + src = src.contiguous(); + + torch::Tensor out; + if (optional_out.has_value()) { + out = optional_out.value().contiguous(); + for (auto i = 0; i < out.dim(); i++) + if (i != dim) + CHECK_INPUT(src.size(i) == out.size(i)); + } else { + auto sizes = src.sizes().vec(); + sizes[dim] = *indptr.flatten()[-1].data_ptr(); + out = torch::empty(sizes, src.options()); + } + + auto N = src.size(dim) * (indptr.numel() / indptr.size(-1)); + auto K = src.numel() / N; + auto E = out.size(dim); + + auto indptr_info = getTensorInfo(indptr); + auto stride = indptr_info.strides[indptr_info.dims - 1]; + AT_DISPATCH_ALL_TYPES(src.scalar_type(), "gather_csr", [&] { + auto src_data = src.data_ptr(); + auto out_data = out.data_ptr(); + + std::vector vals(K); + int64_t row_start, row_end; + for (int n = 0; n < N; n++) { + auto offset = IndexPtrToOffset::get(n, indptr_info); + row_start = indptr_info.data[offset]; + row_end = indptr_info.data[offset + stride]; + + for (auto k = 0; k < K; k++) + vals[k] = src_data[n * K + k]; + + offset = (n / (indptr.size(-1) - 1)) * E * K; + for (auto e = row_start; e < row_end; e++) + for (auto k = 0; k < K; k++) + out_data[offset + e * K + k] = vals[k]; + } + }); + + return out; +} diff --git a/csrc/cpu/segment_csr_cpu.h b/csrc/cpu/segment_csr_cpu.h new file mode 100644 index 00000000..b93d450b --- /dev/null +++ b/csrc/cpu/segment_csr_cpu.h @@ -0,0 +1,11 @@ +#pragma once + +#include + +std::tuple> +segment_csr_cpu(torch::Tensor src, torch::Tensor indptr, + torch::optional optional_out, + std::string reduce); + +torch::Tensor gather_csr_cpu(torch::Tensor src, torch::Tensor indptr, + torch::optional optional_out); diff --git a/csrc/cpu/utils.h b/csrc/cpu/utils.h new file mode 100644 index 00000000..40dfb344 --- /dev/null +++ b/csrc/cpu/utils.h @@ -0,0 +1,6 @@ +#pragma once + +#include + +#define CHECK_CPU(x) AT_ASSERTM(x.device().is_cpu(), #x " must be CPU tensor") +#define CHECK_INPUT(x) AT_ASSERTM(x, "Input mismatch") diff --git a/csrc/cuda/atomics.cuh b/csrc/cuda/atomics.cuh new file mode 100644 index 00000000..32427eac --- /dev/null +++ b/csrc/cuda/atomics.cuh @@ -0,0 +1,230 @@ +#pragma once + +#define ATOMIC(NAME) \ + template struct Atomic##NAME##IntegerImpl; \ + \ + template struct Atomic##NAME##IntegerImpl { \ + inline __device__ void operator()(scalar *address, scalar val) { \ + uint32_t *address_as_ui = (uint32_t *)(address - ((size_t)address & 3)); \ + uint32_t old = *address_as_ui; \ + uint32_t shift = ((size_t)address & 3) * 8; \ + uint32_t sum; \ + uint32_t assumed; \ + \ + do { \ + assumed = old; \ + sum = OP(val, scalar((old >> shift) & 0xff)); \ + old = (old & ~(0x000000ff << shift)) | (sum << shift); \ + old = atomicCAS(address_as_ui, assumed, old); \ + } while (assumed != old); \ + } \ + }; \ + \ + template struct Atomic##NAME##IntegerImpl { \ + inline __device__ void operator()(scalar *address, scalar val) { \ + uint32_t *address_as_ui = \ + (uint32_t *)((char *)address - ((size_t)address & 2)); \ + uint32_t old = *address_as_ui; \ + uint32_t sum; \ + uint32_t newval; \ + uint32_t assumed; \ + \ + do { \ + assumed = old; \ + sum = OP(val, (size_t)address & 2 ? scalar(old >> 16) \ + : scalar(old & 0xffff)); \ + newval = (size_t)address & 2 ? (old & 0xffff) | (sum << 16) \ + : (old & 0xffff0000) | sum; \ + old = atomicCAS(address_as_ui, assumed, newval); \ + } while (assumed != old); \ + } \ + }; \ + \ + template struct Atomic##NAME##IntegerImpl { \ + inline __device__ void operator()(scalar *address, scalar val) { \ + uint32_t *address_as_ui = (uint32_t *)address; \ + uint32_t old = *address_as_ui; \ + uint32_t assumed; \ + \ + do { \ + assumed = old; \ + old = atomicCAS(address_as_ui, assumed, OP(val, (scalar)old)); \ + } while (assumed != old); \ + } \ + }; \ + \ + template struct Atomic##NAME##IntegerImpl { \ + inline __device__ void operator()(scalar *address, scalar val) { \ + unsigned long long *address_as_ull = (unsigned long long *)address; \ + unsigned long long old = *address_as_ull; \ + unsigned long long assumed; \ + \ + do { \ + assumed = old; \ + old = atomicCAS(address_as_ull, assumed, OP(val, (scalar)old)); \ + } while (assumed != old); \ + } \ + }; \ + \ + template struct Atomic##NAME##DecimalImpl; \ + \ + template struct Atomic##NAME##DecimalImpl { \ + inline __device__ void operator()(scalar *address, scalar val) { \ + int *address_as_i = (int *)address; \ + int old = *address_as_i; \ + int assumed; \ + \ + do { \ + assumed = old; \ + old = atomicCAS(address_as_i, assumed, \ + __float_as_int(OP(val, __int_as_float(assumed)))); \ + } while (assumed != old); \ + } \ + }; \ + \ + template struct Atomic##NAME##DecimalImpl { \ + inline __device__ void operator()(scalar *address, scalar val) { \ + unsigned long long int *address_as_ull = \ + (unsigned long long int *)address; \ + unsigned long long int old = *address_as_ull; \ + unsigned long long int assumed; \ + \ + do { \ + assumed = old; \ + old = atomicCAS( \ + address_as_ull, assumed, \ + __double_as_longlong(OP(val, __longlong_as_double(assumed)))); \ + } while (assumed != old); \ + } \ + }; + +#define OP(X, Y) Y + X +ATOMIC(Add) +#undef OP +static inline __device__ void atomAdd(uint8_t *address, uint8_t val) { + AtomicAddIntegerImpl()(address, val); +} +static inline __device__ void atomAdd(int8_t *address, int8_t val) { + AtomicAddIntegerImpl()(address, val); +} +static inline __device__ void atomAdd(int16_t *address, int16_t val) { + AtomicAddIntegerImpl()(address, val); +} +static inline __device__ void atomAdd(int32_t *address, int32_t val) { + atomicAdd(address, val); +} +static inline __device__ void atomAdd(int64_t *address, int64_t val) { + AtomicAddIntegerImpl()(address, val); +} +static inline __device__ void atomAdd(float *address, float val) { + atomicAdd(address, val); +} +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 600 || CUDA_VERSION < 8000) +static inline __device__ void atomAdd(double *address, double val) { + AtomicAddDecimalImpl()(address, val); +} +#else +static inline __device__ void atomAdd(double *address, double val) { + atomicAdd(address, val); +} +#endif + +#define OP(X, Y) Y *X +ATOMIC(Mul) +#undef OP +static inline __device__ void atomMul(uint8_t *address, uint8_t val) { + AtomicMulIntegerImpl()(address, val); +} +static inline __device__ void atomMul(int8_t *address, int8_t val) { + AtomicMulIntegerImpl()(address, val); +} +static inline __device__ void atomMul(int16_t *address, int16_t val) { + AtomicMulIntegerImpl()(address, val); +} +static inline __device__ void atomMul(int32_t *address, int32_t val) { + AtomicMulIntegerImpl()(address, val); +} +static inline __device__ void atomMul(int64_t *address, int64_t val) { + AtomicMulIntegerImpl()(address, val); +} +static inline __device__ void atomMul(float *address, float val) { + AtomicMulDecimalImpl()(address, val); +} +static inline __device__ void atomMul(double *address, double val) { + AtomicMulDecimalImpl()(address, val); +} + +#define OP(X, Y) Y / X +ATOMIC(Div) +#undef OP +static inline __device__ void atomDiv(uint8_t *address, uint8_t val) { + AtomicDivIntegerImpl()(address, val); +} +static inline __device__ void atomDiv(int8_t *address, int8_t val) { + AtomicDivIntegerImpl()(address, val); +} +static inline __device__ void atomDiv(int16_t *address, int16_t val) { + AtomicDivIntegerImpl()(address, val); +} +static inline __device__ void atomDiv(int32_t *address, int32_t val) { + AtomicDivIntegerImpl()(address, val); +} +static inline __device__ void atomDiv(int64_t *address, int64_t val) { + AtomicDivIntegerImpl()(address, val); +} +static inline __device__ void atomDiv(float *address, float val) { + AtomicDivDecimalImpl()(address, val); +} +static inline __device__ void atomDiv(double *address, double val) { + AtomicDivDecimalImpl()(address, val); +} + +#define OP(X, Y) max(Y, X) +ATOMIC(Max) +#undef OP +static inline __device__ void atomMax(uint8_t *address, uint8_t val) { + AtomicMaxIntegerImpl()(address, val); +} +static inline __device__ void atomMax(int8_t *address, int8_t val) { + AtomicMaxIntegerImpl()(address, val); +} +static inline __device__ void atomMax(int16_t *address, int16_t val) { + AtomicMaxIntegerImpl()(address, val); +} +static inline __device__ void atomMax(int32_t *address, int32_t val) { + atomicMax(address, val); +} +static inline __device__ void atomMax(int64_t *address, int64_t val) { + AtomicMaxIntegerImpl()(address, val); +} +static inline __device__ void atomMax(float *address, float val) { + AtomicMaxDecimalImpl()(address, val); +} +static inline __device__ void atomMax(double *address, double val) { + AtomicMaxDecimalImpl()(address, val); +} + +#define OP(X, Y) min(Y, X) +ATOMIC(Min) +#undef OP +static inline __device__ void atomMin(uint8_t *address, uint8_t val) { + AtomicMinIntegerImpl()(address, val); +} +static inline __device__ void atomMin(int8_t *address, int8_t val) { + AtomicMinIntegerImpl()(address, val); +} +static inline __device__ void atomMin(int16_t *address, int16_t val) { + AtomicMinIntegerImpl()(address, val); +} +static inline __device__ void atomMin(int32_t *address, int32_t val) { + atomicMin(address, val); +} +static inline __device__ void atomMin(int64_t *address, int64_t val) { + AtomicMinIntegerImpl()(address, val); +} +static inline __device__ void atomMin(float *address, float val) { + AtomicMinDecimalImpl()(address, val); +} +static inline __device__ void atomMin(double *address, double val) { + AtomicMinDecimalImpl()(address, val); +} diff --git a/csrc/cuda/index_info.cuh b/csrc/cuda/index_info.cuh new file mode 100644 index 00000000..a5af0f7e --- /dev/null +++ b/csrc/cuda/index_info.cuh @@ -0,0 +1,19 @@ +#pragma once + +#include + +// We need our own `IndexToOffset` implementation since we do not want to +// access the last element of the `indexptr`. +template struct IndexPtrToOffset { + static inline __host__ __device__ int + get(int idx, const at::cuda::detail::TensorInfo &info) { + int offset = idx % (info.sizes[info.dims - 1] - 1); + offset *= info.strides[info.dims - 1]; + idx /= info.sizes[info.dims - 1] - 1; + for (int i = info.dims - 2; i >= 0; --i) { + offset += (idx % info.sizes[i]) * info.strides[i]; + idx /= info.sizes[i]; + } + return offset; + } +}; diff --git a/csrc/cuda/reducer.cuh b/csrc/cuda/reducer.cuh new file mode 100644 index 00000000..1e126e6a --- /dev/null +++ b/csrc/cuda/reducer.cuh @@ -0,0 +1,114 @@ +#pragma once + +#include +#include + +#include "atomics.cuh" + +enum ReductionType { SUM, MEAN, MUL, DIV, MIN, MAX }; + +const std::map reduce2REDUCE = { + {"sum", SUM}, {"mean", MEAN}, {"mul", MUL}, + {"div", DIV}, {"min", MIN}, {"max", MAX}, +}; + +#define AT_DISPATCH_REDUCTION_TYPES(reduce, ...) \ + [&] { \ + switch (reduce2REDUCE.at(reduce)) { \ + case SUM: { \ + const ReductionType REDUCE = SUM; \ + return __VA_ARGS__(); \ + } \ + case MEAN: { \ + const ReductionType REDUCE = MEAN; \ + return __VA_ARGS__(); \ + } \ + case MUL: { \ + const ReductionType REDUCE = MUL; \ + return __VA_ARGS__(); \ + } \ + case DIV: { \ + const ReductionType REDUCE = DIV; \ + return __VA_ARGS__(); \ + } \ + case MIN: { \ + const ReductionType REDUCE = MIN; \ + return __VA_ARGS__(); \ + } \ + case MAX: { \ + const ReductionType REDUCE = MAX; \ + return __VA_ARGS__(); \ + } \ + } \ + }() + +template struct Reducer { + static inline __host__ __device__ scalar_t init() { + if (REDUCE == MUL || REDUCE == DIV) + return (scalar_t)1; + else if (REDUCE == MIN) + return std::numeric_limits::max(); + else if (REDUCE == MAX) + return std::numeric_limits::lowest(); + else + return (scalar_t)0; + } + + static inline __host__ __device__ void update(scalar_t *val, + scalar_t new_val) { + if (REDUCE == SUM || REDUCE == MEAN) + *val = *val + new_val; + else if (REDUCE == MUL) + *val = *val * new_val; + else if (REDUCE == DIV) + *val = *val / new_val; + else if ((REDUCE == MIN && new_val < *val) || + (REDUCE == MAX && new_val > *val)) { + *val = new_val; + } + } + + static inline __host__ __device__ void update(scalar_t *val, scalar_t new_val, + int64_t *arg, int64_t new_arg) { + if (REDUCE == SUM || REDUCE == MEAN) + *val = *val + new_val; + else if (REDUCE == MUL) + *val = *val * new_val; + else if (REDUCE == DIV) + *val = *val / new_val; + else if ((REDUCE == MIN && new_val < *val) || + (REDUCE == MAX && new_val > *val)) { + *val = new_val; + *arg = new_arg; + } + } + + static inline __host__ __device__ void write(scalar_t *address, scalar_t val, + int64_t *arg_address, + int64_t arg, int count) { + if (REDUCE == SUM || REDUCE == MUL || REDUCE == DIV) + *address = val; + else if (REDUCE == MEAN) + *address = val / (count > 0 ? count : (scalar_t)1); + else if (REDUCE == MIN || REDUCE == MAX) { + if (count > 0) { + *address = val; + *arg_address = arg; + } else + *address = (scalar_t)0; + } + } + + static inline __device__ void atomic_write(scalar_t *address, scalar_t val) { + if (REDUCE == SUM || REDUCE == MEAN) + atomAdd(address, val); + else if (REDUCE == MUL) + atomMul(address, val); + else if (REDUCE == DIV) + atomDiv(address, val); + else if (REDUCE == MIN && val < *address) + atomMin(address, val); + else if (REDUCE == MAX && val > *address) + atomMax(address, val); + } +}; diff --git a/csrc/cuda/segment_csr_cuda.cu b/csrc/cuda/segment_csr_cuda.cu new file mode 100644 index 00000000..1bf6cf5f --- /dev/null +++ b/csrc/cuda/segment_csr_cuda.cu @@ -0,0 +1,270 @@ +#include "segment_csr_cuda.h" + +#include +#include +#include + +#include "index_info.cuh" +#include "reducer.cuh" +#include "utils.cuh" + +#define THREADS 256 +#define BLOCKS(TB, N) (TB * N + THREADS - 1) / THREADS +#define FULL_MASK 0xffffffff + +template +__global__ void +segment_csr_kernel(const scalar_t *src_data, + const at::cuda::detail::TensorInfo indptr_info, + scalar_t *out_data, int64_t *arg_out_data, size_t N, + size_t E) { + + // Each warp processes exactly `32/TB` rows and aggregates all row values + // via a parallel reduction. + + int thread_idx = blockIdx.x * blockDim.x + threadIdx.x; + int row_idx = thread_idx / TB; + int lane_idx = thread_idx & (TB - 1); + + if (row_idx < N) { + int offset = IndexPtrToOffset::get(row_idx, indptr_info); + int64_t row_start = __ldg(indptr_info.data + offset); + int64_t row_end = __ldg(indptr_info.data + offset + + indptr_info.strides[indptr_info.dims - 1]); + + scalar_t val = Reducer::init(); + int64_t arg, arg_tmp; + + offset = (row_idx / (indptr_info.sizes[indptr_info.dims - 1] - 1)) * E; + for (int64_t src_idx = row_start + lane_idx; src_idx < row_end; + src_idx += TB) { + Reducer::update(&val, src_data[offset + src_idx], &arg, + src_idx); + } + +#pragma unroll + for (int i = TB / 2; i > 0; i /= 2) { + // Parallel reduction inside a single warp. + if (REDUCE == MIN || REDUCE == MAX) + arg_tmp = __shfl_down_sync(FULL_MASK, arg, i); + Reducer::update( + &val, __shfl_down_sync(FULL_MASK, val, i), &arg, arg_tmp); + } + + if (lane_idx == 0) { + Reducer::write(out_data + row_idx, val, + arg_out_data + row_idx, arg, + row_end - row_start); + } + } +} + +template +__global__ void segment_csr_broadcast_kernel( + const scalar_t *src_data, + const at::cuda::detail::TensorInfo indptr_info, + scalar_t *out_data, int64_t *arg_out_data, size_t N, size_t K, size_t E) { + + // Each thread processes exactly one row. It turned out that is more + // efficient than using shared memory due to avoiding synchronization + // barriers. + + int thread_idx = blockIdx.x * blockDim.x + threadIdx.x; + int row_idx = thread_idx / K; + int lane_idx = thread_idx % K; + + if (thread_idx < N * K) { + int offset = IndexPtrToOffset::get(row_idx, indptr_info); + int64_t row_start = __ldg(indptr_info.data + offset); + int64_t row_end = __ldg(indptr_info.data + offset + + indptr_info.strides[indptr_info.dims - 1]); + + scalar_t val = Reducer::init(); + int64_t arg; + + offset = (row_idx / (indptr_info.sizes[indptr_info.dims - 1] - 1)) * E * K; + for (int64_t src_idx = row_start; src_idx < row_end; src_idx++) { + Reducer::update( + &val, src_data[offset + K * src_idx + lane_idx], &arg, src_idx); + } + + Reducer::write(out_data + thread_idx, val, + arg_out_data + thread_idx, arg, + row_end - row_start); + } +} + +std::tuple> +segment_csr_cuda(torch::Tensor src, torch::Tensor indptr, + torch::optional optional_out, + std::string reduce) { + CHECK_CUDA(src); + CHECK_CUDA(indptr); + if (optional_out.has_value()) + CHECK_CUDA(optional_out.value()); + cudaSetDevice(src.get_device()); + + CHECK_INPUT(src.dim() >= indptr.dim()); + + auto sizes = indptr.sizes().vec(); + for (auto i = 0; i < indptr.dim() - 1; i++) + sizes[i] = src.size(i); + indptr = indptr.expand(sizes); + + auto dim = indptr.dim() - 1; + + src = src.contiguous(); + + torch::Tensor out; + if (optional_out.has_value()) { + out = optional_out.value().contiguous(); + for (int i = 0; i < out.dim(); i++) + if (i != dim) + CHECK_INPUT(src.size(i) == out.size(i)); + CHECK_INPUT(out.size(dim) == indptr.size(dim) - 1); + } else { + sizes = src.sizes().vec(); + sizes[dim] = indptr.size(dim) - 1; + out = torch::empty(sizes, src.options()); + } + + torch::optional arg_out = torch::nullopt; + int64_t *arg_out_data = nullptr; + if (reduce2REDUCE.at(reduce) == MIN || reduce2REDUCE.at(reduce) == MAX) { + arg_out = torch::full(out.sizes(), src.size(dim), indptr.options()); + arg_out_data = arg_out.value().data_ptr(); + } + + auto N = out.size(dim) * (indptr.numel() / indptr.size(-1)); + auto K = out.numel() / N; + auto E = src.size(dim); + + auto indptr_info = at::cuda::detail::getTensorInfo(indptr); + auto stream = at::cuda::getCurrentCUDAStream(); + AT_DISPATCH_ALL_TYPES(src.scalar_type(), "segment_csr_kernel", [&] { + auto src_data = src.data_ptr(); + auto out_data = out.data_ptr(); + + AT_DISPATCH_REDUCTION_TYPES(reduce, [&] { + if (K == 1) { + segment_csr_kernel + <<>>( + src_data, indptr_info, out_data, arg_out_data, N, E); + } else { + segment_csr_broadcast_kernel + <<>>( + src_data, indptr_info, out_data, arg_out_data, N, K, E); + } + }); + }); + + return std::make_tuple(out, arg_out); +} + +template +__global__ void +gather_csr_kernel(const scalar_t *src_data, + const at::cuda::detail::TensorInfo indptr_info, + scalar_t *out_data, size_t N, size_t E) { + + int thread_idx = blockIdx.x * blockDim.x + threadIdx.x; + int row_idx = thread_idx / TB; + int lane_idx = thread_idx % TB; + + if (row_idx < N) { + int offset = IndexPtrToOffset::get(row_idx, indptr_info); + int row_start = __ldg(indptr_info.data + offset); + int row_end = __ldg(indptr_info.data + offset + + indptr_info.strides[indptr_info.dims - 1]); + scalar_t val = __ldg(src_data + row_idx); + + offset = (row_idx / (indptr_info.sizes[indptr_info.dims - 1] - 1)) * E; + for (int out_idx = row_start + lane_idx; out_idx < row_end; out_idx += TB) { + out_data[offset + out_idx] = val; // "Mostly" coalesced. + } + } +} + +template +__global__ void gather_csr_broadcast_kernel( + const scalar_t *src_data, + const at::cuda::detail::TensorInfo indptr_info, + scalar_t *out_data, size_t N, size_t K, size_t E) { + + int thread_idx = blockIdx.x * blockDim.x + threadIdx.x; + int row_idx = thread_idx / K; + int lane_idx = thread_idx % K; + + if (thread_idx < N * K) { + int offset = IndexPtrToOffset::get(row_idx, indptr_info); + int row_start = __ldg(indptr_info.data + offset); + int row_end = __ldg(indptr_info.data + offset + + indptr_info.strides[indptr_info.dims - 1]); + + scalar_t val = src_data[thread_idx]; // Coalesced. + + offset = (row_idx / (indptr_info.sizes[indptr_info.dims - 1] - 1)) * E * K; + for (int out_idx = row_start; out_idx < row_end; out_idx++) { + out_data[offset + K * out_idx + lane_idx] = val; // "Mostly" coalesced. + } + } +} + +torch::Tensor gather_csr_cuda(torch::Tensor src, torch::Tensor indptr, + torch::optional optional_out) { + CHECK_CUDA(src); + CHECK_CUDA(indptr); + if (optional_out.has_value()) + CHECK_CUDA(optional_out.value()); + cudaSetDevice(src.get_device()); + + CHECK_INPUT(src.dim() >= indptr.dim()); + + auto sizes = indptr.sizes().vec(); + for (auto i = 0; i < indptr.dim() - 1; i++) + sizes[i] = src.size(i); + indptr = indptr.expand(sizes); + + auto dim = indptr.dim() - 1; + CHECK_INPUT(src.size(dim) == indptr.size(dim) - 1); + + src = src.contiguous(); + + torch::Tensor out; + if (optional_out.has_value()) { + out = optional_out.value().contiguous(); + for (auto i = 0; i < out.dim(); i++) + if (i != dim) + CHECK_INPUT(src.size(i) == out.size(i)); + } else { + auto d_gather_size = indptr.flatten()[-1].data_ptr(); + auto h_gather_size = (int64_t *)malloc(sizeof(int64_t)); + cudaMemcpy(h_gather_size, d_gather_size, sizeof(int64_t), + cudaMemcpyDeviceToHost); + + auto sizes = src.sizes().vec(); + sizes[dim] = *h_gather_size; + out = torch::empty(sizes, src.options()); + } + + auto N = src.size(dim) * (indptr.numel() / indptr.size(-1)); + auto K = src.numel() / N; + auto E = out.size(dim); + + auto indptr_info = at::cuda::detail::getTensorInfo(indptr); + auto stream = at::cuda::getCurrentCUDAStream(); + AT_DISPATCH_ALL_TYPES(src.scalar_type(), "gather_csr_kernel", [&] { + auto src_data = src.data_ptr(); + auto out_data = out.data_ptr(); + + if (K == 1) + gather_csr_kernel<<>>( + src_data, indptr_info, out_data, N, E); + else + gather_csr_broadcast_kernel + <<>>(src_data, indptr_info, + out_data, N, K, E); + }); + + return out; +} diff --git a/csrc/cuda/segment_csr_cuda.h b/csrc/cuda/segment_csr_cuda.h new file mode 100644 index 00000000..5f8bd40e --- /dev/null +++ b/csrc/cuda/segment_csr_cuda.h @@ -0,0 +1,11 @@ +#pragma once + +#include + +std::tuple> +segment_csr_cuda(torch::Tensor src, torch::Tensor indptr, + torch::optional optional_out, + std::string reduce); + +torch::Tensor gather_csr_cuda(torch::Tensor src, torch::Tensor indptr, + torch::optional optional_out); diff --git a/csrc/cuda/utils.cuh b/csrc/cuda/utils.cuh new file mode 100644 index 00000000..000ee0fe --- /dev/null +++ b/csrc/cuda/utils.cuh @@ -0,0 +1,7 @@ +#pragma once + +#include + +#define CHECK_CUDA(x) \ + AT_ASSERTM(x.device().is_cuda(), #x " must be CUDA tensor") +#define CHECK_INPUT(x) AT_ASSERTM(x, "Input mismatch") diff --git a/csrc/segment_csr.cpp b/csrc/segment_csr.cpp new file mode 100644 index 00000000..45eeecf5 --- /dev/null +++ b/csrc/segment_csr.cpp @@ -0,0 +1,218 @@ +#include + +#include "cpu/segment_csr_cpu.h" + +#ifdef WITH_CUDA +#include "cuda/segment_csr_cuda.h" +#endif + +std::tuple> +segment_csr_fw(torch::Tensor src, torch::Tensor indptr, + torch::optional optional_out, + std::string reduce) { + if (src.device().is_cuda()) { +#ifdef WITH_CUDA + return segment_csr_cuda(src, indptr, optional_out, reduce); +#else + AT_ERROR("Not compiled with CUDA support"); +#endif + } else { + return segment_csr_cpu(src, indptr, optional_out, reduce); + } +} + +torch::Tensor gather_csr_fw(torch::Tensor src, torch::Tensor indptr, + torch::optional optional_out) { + if (src.device().is_cuda()) { +#ifdef WITH_CUDA + return gather_csr_cuda(src, indptr, optional_out); +#else + AT_ERROR("Not compiled with CUDA support"); +#endif + } else { + return gather_csr_cpu(src, indptr, optional_out); + } +} + +using torch::autograd::AutogradContext; +using torch::autograd::Variable; +using torch::autograd::variable_list; + +class SegmentSumCSR : public torch::autograd::Function { +public: + static variable_list forward(AutogradContext *ctx, Variable src, + Variable indptr, + torch::optional optional_out) { + ctx->saved_data["src_shape"] = src.sizes(); + auto out = std::get<0>(segment_csr_fw(src, indptr, optional_out, "sum")); + ctx->save_for_backward({indptr}); + if (optional_out.has_value()) + ctx->mark_dirty({optional_out.value()}); + return {out}; + } + + static variable_list backward(AutogradContext *ctx, variable_list grad_outs) { + auto grad_out = grad_outs[0]; + auto saved = ctx->get_saved_variables(); + auto indptr = saved[0]; + auto src_shape = ctx->saved_data["src_shape"].toIntVector(); + auto grad_in = torch::empty(src_shape, grad_out.options()); + gather_csr_fw(grad_out, indptr, grad_in); + return {grad_in, Variable(), Variable()}; + } +}; + +class SegmentMeanCSR : public torch::autograd::Function { +public: + static variable_list forward(AutogradContext *ctx, Variable src, + Variable indptr, + torch::optional optional_out) { + ctx->saved_data["src_shape"] = src.sizes(); + auto out = std::get<0>(segment_csr_fw(src, indptr, optional_out, "mean")); + ctx->save_for_backward({indptr}); + if (optional_out.has_value()) + ctx->mark_dirty({optional_out.value()}); + return {out}; + } + + static variable_list backward(AutogradContext *ctx, variable_list grad_outs) { + auto grad_out = grad_outs[0]; + auto saved = ctx->get_saved_variables(); + auto indptr = saved[0]; + auto src_shape = ctx->saved_data["src_shape"].toIntVector(); + auto grad_in = torch::empty(src_shape, grad_out.options()); + gather_csr_fw(grad_out, indptr, grad_in); + auto indptr1 = indptr.narrow(-1, 0, indptr.size(-1) - 1); + auto indptr2 = indptr.narrow(-1, 1, indptr.size(-1) - 1); + auto count = (indptr2 - indptr1).to(grad_in.options()); + count = gather_csr_fw(count, indptr, torch::nullopt); + for (auto i = 0; i < grad_out.dim() - indptr.dim(); i++) + count = count.unsqueeze(-1); + grad_in.div_(count); + return {grad_in, Variable(), Variable()}; + } +}; + +class SegmentMinCSR : public torch::autograd::Function { +public: + static variable_list forward(AutogradContext *ctx, Variable src, + Variable indptr, + torch::optional optional_out) { + ctx->saved_data["src_shape"] = src.sizes(); + auto result = segment_csr_fw(src, indptr, optional_out, "min"); + auto out = std::get<0>(result); + auto arg_out = std::get<1>(result).value(); + ctx->save_for_backward({indptr, arg_out}); + ctx->mark_non_differentiable({arg_out}); + if (optional_out.has_value()) + ctx->mark_dirty({optional_out.value()}); + return {out, arg_out}; + } + + static variable_list backward(AutogradContext *ctx, variable_list grad_outs) { + auto grad_out = grad_outs[0]; + auto saved = ctx->get_saved_variables(); + auto indptr = saved[0]; + auto arg_out = saved[1]; + auto src_shape = ctx->saved_data["src_shape"].toIntVector(); + src_shape[indptr.dim() - 1] += 1; + auto grad_in = torch::zeros(src_shape, grad_out.options()); + grad_in.scatter_(indptr.dim() - 1, arg_out, grad_out); + grad_in = + grad_in.narrow(indptr.dim() - 1, 0, src_shape[indptr.dim() - 1] - 1); + return {grad_in, Variable(), Variable()}; + } +}; + +class SegmentMaxCSR : public torch::autograd::Function { +public: + static variable_list forward(AutogradContext *ctx, Variable src, + Variable indptr, + torch::optional optional_out) { + ctx->saved_data["src_shape"] = src.sizes(); + auto result = segment_csr_fw(src, indptr, optional_out, "max"); + auto out = std::get<0>(result); + auto arg_out = std::get<1>(result).value(); + ctx->save_for_backward({indptr, arg_out}); + ctx->mark_non_differentiable({arg_out}); + if (optional_out.has_value()) + ctx->mark_dirty({optional_out.value()}); + return {out, arg_out}; + } + + static variable_list backward(AutogradContext *ctx, variable_list grad_outs) { + auto grad_out = grad_outs[0]; + auto saved = ctx->get_saved_variables(); + auto indptr = saved[0]; + auto arg_out = saved[1]; + auto src_shape = ctx->saved_data["src_shape"].toIntVector(); + src_shape[indptr.dim() - 1] += 1; + auto grad_in = torch::zeros(src_shape, grad_out.options()); + grad_in.scatter_(indptr.dim() - 1, arg_out, grad_out); + grad_in = + grad_in.narrow(indptr.dim() - 1, 0, src_shape[indptr.dim() - 1] - 1); + return {grad_in, Variable(), Variable()}; + } +}; + +class GatherCSR : public torch::autograd::Function { +public: + static variable_list forward(AutogradContext *ctx, Variable src, + Variable indptr, + torch::optional optional_out) { + ctx->saved_data["src_shape"] = src.sizes(); + auto out = gather_csr_fw(src, indptr, optional_out); + ctx->save_for_backward({indptr}); + if (optional_out.has_value()) + ctx->mark_dirty({optional_out.value()}); + return {out}; + } + + static variable_list backward(AutogradContext *ctx, variable_list grad_outs) { + auto grad_out = grad_outs[0]; + auto saved = ctx->get_saved_variables(); + auto indptr = saved[0]; + auto src_shape = ctx->saved_data["src_shape"].toIntVector(); + + auto grad_in = torch::empty(src_shape, grad_out.options()); + segment_csr_fw(grad_out, indptr, grad_in, "sum"); + return {grad_in, Variable(), Variable()}; + } +}; + +torch::Tensor segment_sum_csr(torch::Tensor src, torch::Tensor indptr, + torch::optional optional_out) { + return SegmentSumCSR::apply(src, indptr, optional_out)[0]; +} + +torch::Tensor segment_mean_csr(torch::Tensor src, torch::Tensor indptr, + torch::optional optional_out) { + return SegmentMeanCSR::apply(src, indptr, optional_out)[0]; +} + +std::tuple +segment_min_csr(torch::Tensor src, torch::Tensor indptr, + torch::optional optional_out) { + auto result = SegmentMinCSR::apply(src, indptr, optional_out); + return std::make_tuple(result[0], result[1]); +} + +std::tuple +segment_max_csr(torch::Tensor src, torch::Tensor indptr, + torch::optional optional_out) { + auto result = SegmentMaxCSR::apply(src, indptr, optional_out); + return std::make_tuple(result[0], result[1]); +} + +torch::Tensor gather_csr(torch::Tensor src, torch::Tensor indptr, + torch::optional optional_out) { + return GatherCSR::apply(src, indptr, optional_out)[0]; +} + +static auto registry = + torch::RegisterOperators() + .op("torch_scatter::segment_sum_csr", &segment_sum_csr) + .op("torch_scatter::segment_mean_csr", &segment_mean_csr) + .op("torch_scatter::segment_min_csr", &segment_min_csr) + .op("torch_scatter::segment_max_csr", &segment_max_csr) + .op("torch_scatter::gather_csr", &gather_csr); diff --git a/setup.py b/setup.py index 31a30d8a..3267df18 100644 --- a/setup.py +++ b/setup.py @@ -1,59 +1,57 @@ -import platform +import os import os.path as osp -from glob import glob +import sys +import glob from setuptools import setup, find_packages -from sys import argv import torch from torch.utils.cpp_extension import BuildExtension from torch.utils.cpp_extension import CppExtension, CUDAExtension, CUDA_HOME -# Windows users: Edit both of these to contain your VS include path, i.e.: -# cxx_extra_compile_args = ['-I{VISUAL_STUDIO_DIR}\\include'] -# nvcc_extra_compile_args = [..., '-I{VISUAL_STUDIO_DIR}\\include'] -cxx_extra_compile_args = [] -nvcc_extra_compile_args = ['-arch=sm_35', '--expt-relaxed-constexpr'] -# Windows users: Edit both of these to contain your VS library path, i.e.: -# cxx_extra_link_args = ['/LIBPATH:{VISUAL_STUDIO_DIR}\\lib\\{x86|x64}'] -# nvcc_extra_link_args = ['/LIBPATH:{VISUAL_STUDIO_DIR}\\lib\\{x86|x64}'] -cxx_extra_link_args = [] -nvcc_extra_link_args = [] +def get_extensions(): + this_dir = osp.dirname(osp.abspath(__file__)) + extensions_dir = osp.join(this_dir, 'csrc') -if platform.system() != 'Windows': - cxx_extra_compile_args += ['-Wno-unused-variable'] -TORCH_MAJOR = int(torch.__version__.split('.')[0]) -TORCH_MINOR = int(torch.__version__.split('.')[1]) -if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 2): - cxx_extra_compile_args += ['-DVERSION_GE_1_3'] - nvcc_extra_compile_args += ['-DVERSION_GE_1_3'] -cmdclass = { - 'build_ext': BuildExtension.with_options(no_python_abi_suffix=True) -} + main_files = glob.glob(osp.join(extensions_dir, '*.cpp')) + cpu_files = glob.glob(osp.join(extensions_dir, 'cpu', '*.cpp')) + cuda_files = glob.glob(osp.join(extensions_dir, 'cuda', '*.cu')) -ext_modules = [] -exts = [e.split(osp.sep)[-1][:-4] for e in glob(osp.join('cpu', '*.cpp'))] -ext_modules += [ - CppExtension(f'torch_scatter.{ext}_cpu', [f'cpu/{ext}.cpp'], - extra_compile_args=cxx_extra_compile_args, - extra_link_args=cxx_extra_link_args) for ext in exts -] + Extension = CppExtension + sources = main_files + cpu_files -if CUDA_HOME is not None and '--cpu' not in argv: - exts = [e.split(osp.sep)[-1][:-4] for e in glob(osp.join('cuda', '*.cpp'))] - ext_modules += [ - CUDAExtension( - f'torch_scatter.{ext}_cuda', - [f'cuda/{ext}.cpp', f'cuda/{ext}_kernel.cu'], extra_compile_args={ - 'cxx': cxx_extra_compile_args, - 'nvcc': nvcc_extra_compile_args, - }, extra_link_args=nvcc_extra_link_args) for ext in exts + define_macros = [] + extra_compile_args = {'cxx': [], 'nvcc': []} + # Windows users: Edit both of these to contain your VS include path, i.e.: + # extra_compile_args['cxx'] += ['-I{VISUAL_STUDIO_DIR}\\include'] + # extra_compile_args['nvcc'] += ['-I{VISUAL_STUDIO_DIR}\\include'] + + if (torch.cuda.is_available() and CUDA_HOME is not None) or os.getenv( + 'FORCE_CUDA', '0') == '1': + + Extension = CUDAExtension + sources += cuda_files + define_macros += [('WITH_CUDA', None)] + + nvcc_flags = os.getenv('NVCC_FLAGS', '') + nvcc_flags = [] if nvcc_flags == '' else nvcc_flags.split(' ') + nvcc_flags += ['-arch=sm_35', '--expt-relaxed-constexpr'] + extra_compile_args['cxx'] += ['-O0'] + extra_compile_args['nvcc'] += nvcc_flags + + if sys.platform == 'win32': + extra_compile_args['cxx'] += ['/MP'] + + return [ + Extension( + 'torch_scatter._C', + sources, + include_dirs=[extensions_dir], + define_macros=define_macros, + extra_compile_args=extra_compile_args, + ) ] -if '--cpu' in argv: - argv.remove('--cpu') -__version__ = '1.5.0' -url = 'https://github.com/rusty1s/pytorch_scatter' install_requires = [] setup_requires = ['pytest-runner'] @@ -61,17 +59,19 @@ setup( name='torch_scatter', - version=__version__, - description='PyTorch Extension Library of Optimized Scatter Operations', + version='1.5.0', author='Matthias Fey', author_email='matthias.fey@tu-dortmund.de', - url=url, - download_url='{}/archive/{}.tar.gz'.format(url, __version__), - keywords=['pytorch', 'scatter', 'segment'], + url='https://github.com/rusty1s/pytorch_scatter', + description='PyTorch Extension Library of Optimized Scatter Operations', + keywords=['pytorch', 'scatter', 'segment', 'gather'], + license='MIT', install_requires=install_requires, setup_requires=setup_requires, tests_require=tests_require, - ext_modules=ext_modules, - cmdclass=cmdclass, + ext_modules=get_extensions(), + cmdclass={ + 'build_ext': BuildExtension.with_options(no_python_abi_suffix=True) + }, packages=find_packages(), ) diff --git a/test/test_segment.py b/test/test_segment.py index ca613950..7f68f62b 100644 --- a/test/test_segment.py +++ b/test/test_segment.py @@ -3,7 +3,7 @@ import pytest import torch from torch.autograd import gradcheck -from torch_scatter import segment_coo, segment_csr +from torch_scatter import segment_csr from .utils import tensor, dtypes, devices @@ -88,12 +88,12 @@ def test_forward(test, reduce, dtype, device): indptr = tensor(test['indptr'], torch.long, device) expected = tensor(test[reduce], dtype, device) - out = segment_coo(src, index, reduce=reduce) - if isinstance(out, tuple): - out, arg_out = out - arg_expected = tensor(test[f'arg_{reduce}'], torch.long, device) - assert torch.all(arg_out == arg_expected) - assert torch.all(out == expected) + # out = segment_coo(src, index, reduce=reduce) + # if isinstance(out, tuple): + # out, arg_out = out + # arg_expected = tensor(test[f'arg_{reduce}'], torch.long, device) + # assert torch.all(arg_out == arg_expected) + # assert torch.all(out == expected) out = segment_csr(src, indptr, reduce=reduce) if isinstance(out, tuple): @@ -111,7 +111,7 @@ def test_backward(test, reduce, device): index = tensor(test['index'], torch.long, device) indptr = tensor(test['indptr'], torch.long, device) - assert gradcheck(segment_coo, (src, index, None, None, reduce)) + # assert gradcheck(segment_coo, (src, index, None, None, reduce)) assert gradcheck(segment_csr, (src, indptr, None, reduce)) @@ -130,22 +130,22 @@ def test_segment_out(test, reduce, dtype, device): segment_csr(src, indptr, out, reduce=reduce) assert torch.all(out == expected) - out.fill_(-2) + # out.fill_(-2) - segment_coo(src, index, out, reduce=reduce) + # segment_coo(src, index, out, reduce=reduce) - if reduce == 'sum': - expected = expected - 2 - elif reduce == 'mean': - expected = out # We can not really test this here. - elif reduce == 'min': - expected = expected.fill_(-2) - elif reduce == 'max': - expected[expected == 0] = -2 - else: - raise ValueError + # if reduce == 'sum': + # expected = expected - 2 + # elif reduce == 'mean': + # expected = out # We can not really test this here. + # elif reduce == 'min': + # expected = expected.fill_(-2) + # elif reduce == 'max': + # expected[expected == 0] = -2 + # else: + # raise ValueError - assert torch.all(out == expected) + # assert torch.all(out == expected) @pytest.mark.parametrize('test,reduce,dtype,device', @@ -163,12 +163,12 @@ def test_non_contiguous_segment(test, reduce, dtype, device): if indptr.dim() > 1: indptr = indptr.transpose(0, 1).contiguous().transpose(0, 1) - out = segment_coo(src, index, reduce=reduce) - if isinstance(out, tuple): - out, arg_out = out - arg_expected = tensor(test[f'arg_{reduce}'], torch.long, device) - assert torch.all(arg_out == arg_expected) - assert torch.all(out == expected) + # out = segment_coo(src, index, reduce=reduce) + # if isinstance(out, tuple): + # out, arg_out = out + # arg_expected = tensor(test[f'arg_{reduce}'], torch.long, device) + # assert torch.all(arg_out == arg_expected) + # assert torch.all(out == expected) out = segment_csr(src, indptr, reduce=reduce) if isinstance(out, tuple): diff --git a/torch_scatter/__init__.py b/torch_scatter/__init__.py index 74f0648d..9f3e38de 100644 --- a/torch_scatter/__init__.py +++ b/torch_scatter/__init__.py @@ -1,47 +1,17 @@ -import torch +from .segment_csr import (segment_sum_csr, segment_add_csr, segment_mean_csr, + segment_min_csr, segment_max_csr, segment_csr, + gather_csr) -from .add import scatter_add -from .sub import scatter_sub -from .mul import scatter_mul -from .div import scatter_div -from .mean import scatter_mean -from .std import scatter_std -from .max import scatter_max -from .min import scatter_min -from .logsumexp import scatter_logsumexp - -from .segment import segment_coo, segment_csr -from .gather import gather_coo, gather_csr - -import torch_scatter.composite - -torch.ops.load_library('torch_scatter/scatter_cpu.so') -torch.ops.load_library('torch_scatter/segment_csr_cpu.so') -torch.ops.load_library('torch_scatter/segment_coo_cpu.so') - -try: - torch.ops.load_library('torch_scatter/scatter_cuda.so') - # torch.ops.load_library('torch_scatter/segment_csr_cuda.so') - # torch.ops.load_library('torch_scatter/segment_coo_cuda.so') -except OSError as e: - if torch.cuda.is_available(): - raise e - -__version__ = '1.4.0' +__version__ = '1.5.0' __all__ = [ - 'scatter_add', - 'scatter_sub', - 'scatter_mul', - 'scatter_div', - 'scatter_mean', - 'scatter_std', - 'scatter_max', - 'scatter_min', - 'scatter_logsumexp', - 'segment_coo', + 'segment_sum_csr', + 'segment_add_csr', + 'segment_mean_csr', + 'segment_min_csr', + 'segment_max_csr', + 'segment_max_csr', 'segment_csr', - 'gather_coo', 'gather_csr', 'torch_scatter', '__version__', diff --git a/torch_scatter/segment_csr.py b/torch_scatter/segment_csr.py new file mode 100644 index 00000000..88a4faca --- /dev/null +++ b/torch_scatter/segment_csr.py @@ -0,0 +1,59 @@ +from typing import Optional, Tuple + +import torch + +torch.ops.load_library('torch_scatter/_C.so') + + +@torch.jit.script +def segment_sum_csr(src: torch.Tensor, indptr: torch.Tensor, + out: Optional[torch.Tensor] = None) -> torch.Tensor: + return torch.ops.torch_scatter.segment_sum_csr(src, indptr, out) + + +@torch.jit.script +def segment_add_csr(src: torch.Tensor, indptr: torch.Tensor, + out: Optional[torch.Tensor] = None) -> torch.Tensor: + return torch.ops.torch_scatter.segment_sum_csr(src, indptr, out) + + +@torch.jit.script +def segment_mean_csr(src: torch.Tensor, indptr: torch.Tensor, + out: Optional[torch.Tensor] = None) -> torch.Tensor: + return torch.ops.torch_scatter.segment_mean_csr(src, indptr, out) + + +@torch.jit.script +def segment_min_csr(src: torch.Tensor, indptr: torch.Tensor, + out: Optional[torch.Tensor] = None + ) -> Tuple[torch.Tensor, torch.Tensor]: + return torch.ops.torch_scatter.segment_min_csr(src, indptr, out) + + +@torch.jit.script +def segment_max_csr(src: torch.Tensor, indptr: torch.Tensor, + out: Optional[torch.Tensor] = None + ) -> Tuple[torch.Tensor, torch.Tensor]: + return torch.ops.torch_scatter.segment_max_csr(src, indptr, out) + + +@torch.jit.script +def segment_csr(src: torch.Tensor, indptr: torch.Tensor, + out: Optional[torch.Tensor] = None, + reduce: str = "sum") -> torch.Tensor: + if reduce == 'sum' or reduce == 'add': + return segment_sum_csr(src, indptr, out) + elif reduce == 'mean': + return segment_mean_csr(src, indptr, out) + elif reduce == 'min': + return segment_min_csr(src, indptr, out)[0] + elif reduce == 'max': + return segment_max_csr(src, indptr, out)[0] + else: + raise ValueError + + +@torch.jit.script +def gather_csr(src: torch.Tensor, indptr: torch.Tensor, + out: Optional[torch.Tensor] = None) -> torch.Tensor: + return torch.ops.torch_scatter.gather_csr(src, indptr, out) From bb87ec65409e332110ed0ee3c7085699644c559e Mon Sep 17 00:00:00 2001 From: rusty1s Date: Wed, 29 Jan 2020 16:06:55 +0100 Subject: [PATCH 03/12] coo cpu implementation --- cpu/segment_coo_impl.h | 2 +- csrc/cpu/segment_coo_cpu.cpp | 198 +++++++++++++++++++++++++++++ csrc/cpu/segment_coo_cpu.h | 11 ++ csrc/cpu/segment_csr_cpu.cpp | 4 +- csrc/cuda/segment_coo_cuda.cu | 13 ++ csrc/cuda/segment_coo_cuda.h | 11 ++ csrc/segment_coo.cpp | 227 ++++++++++++++++++++++++++++++++++ test/test_gather.py | 18 +-- test/test_segment.py | 67 +++++----- torch_scatter/__init__.py | 17 ++- torch_scatter/segment_coo.py | 63 ++++++++++ torch_scatter/segment_csr.py | 2 - 12 files changed, 585 insertions(+), 48 deletions(-) create mode 100644 csrc/cpu/segment_coo_cpu.cpp create mode 100644 csrc/cpu/segment_coo_cpu.h create mode 100644 csrc/cuda/segment_coo_cuda.cu create mode 100644 csrc/cuda/segment_coo_cuda.h create mode 100644 csrc/segment_coo.cpp create mode 100644 torch_scatter/segment_coo.py diff --git a/cpu/segment_coo_impl.h b/cpu/segment_coo_impl.h index 147134b3..99ea81d9 100644 --- a/cpu/segment_coo_impl.h +++ b/cpu/segment_coo_impl.h @@ -166,7 +166,7 @@ torch::Tensor gather_coo(torch::Tensor src, torch::Tensor index, if (e < E - 1) { next_idx = index_info.data[offset + (e + 1) * stride]; - CHECK_INPUT(idx < E && idx <= next_idx); + CHECK_INPUT(idx <= next_idx); if (idx != next_idx) { idx = next_idx; diff --git a/csrc/cpu/segment_coo_cpu.cpp b/csrc/cpu/segment_coo_cpu.cpp new file mode 100644 index 00000000..a04ff3bc --- /dev/null +++ b/csrc/cpu/segment_coo_cpu.cpp @@ -0,0 +1,198 @@ +#include "segment_coo_cpu.h" + +#include "index_info.h" +#include "reducer.h" +#include "utils.h" + +std::tuple> +segment_coo_cpu(torch::Tensor src, torch::Tensor index, + torch::optional optional_out, + torch::optional dim_size, std::string reduce) { + CHECK_CPU(src); + CHECK_CPU(index); + if (optional_out.has_value()) + CHECK_CPU(optional_out.value()); + + CHECK_INPUT(src.dim() >= index.dim()); + + auto sizes = index.sizes().vec(); + for (int i = 0; i < index.dim(); i++) + sizes[i] = src.size(i); + index = index.expand(sizes); + + auto dim = index.dim() - 1; + + src = src.contiguous(); + + torch::Tensor out; + if (optional_out.has_value()) { + out = optional_out.value().contiguous(); + for (int i = 0; i < out.dim(); i++) + if (i != dim) + CHECK_INPUT(src.size(i) == out.size(i)); + } else { + sizes = src.sizes().vec(); + if (dim_size.has_value()) + sizes[dim] = dim_size.value(); + else + sizes[dim] = 1 + *index.max().data_ptr(); + out = torch::empty(sizes, src.options()); + } + + torch::optional arg_out = torch::nullopt; + int64_t *arg_out_data = nullptr; + if (reduce2REDUCE.at(reduce) == MIN || reduce2REDUCE.at(reduce) == MAX) { + arg_out = torch::full_like(out, src.size(dim), index.options()); + arg_out_data = arg_out.value().data_ptr(); + } + + torch::optional count = torch::nullopt; + if (reduce2REDUCE.at(reduce) == MEAN) { + auto sizes = index.sizes().vec(); + sizes[dim] = out.size(dim); + count = torch::zeros(sizes, out.options()); + } + + auto B = index.numel() / src.size(dim); + auto E = src.size(dim); + auto K = src.numel() / index.numel(); + auto N = out.size(dim); + + auto index_info = getTensorInfo(index); + auto stride = index_info.strides[index_info.dims - 1]; + std::vector args(K); + AT_DISPATCH_ALL_TYPES(src.scalar_type(), "segment_coo", [&] { + auto src_data = src.data_ptr(); + auto out_data = out.data_ptr(); + scalar_t *count_data = nullptr; + + std::vector vals(K); + int64_t idx, next_idx, row_start; + AT_DISPATCH_REDUCTION_TYPES(reduce, [&] { + if (!optional_out.has_value()) + out.fill_(Reducer::init()); + if (REDUCE == MEAN) + count_data = count.value().data_ptr(); + + for (auto b = 0; b < B; b++) { + auto offset = IndexToOffset::get(b * E, index_info); + idx = index_info.data[offset]; + + for (auto k = 0; k < K; k++) + vals[k] = out_data[b * N * K + k]; + + row_start = 0; + for (auto e = 0; e < E; e++) { + + for (auto k = 0; k < K; k++) + Reducer::update( + &vals[k], src_data[b * E * K + e * K + k], &args[k], e); + + if (e == E - 1) { + for (auto k = 0; k < K; k++) + Reducer::write( + out_data + b * N * K + idx * K + k, vals[k], + arg_out_data + b * N * K + idx * K + k, args[k], + e + 1 - row_start); + if (REDUCE == MEAN) + count_data[b * N + idx] = (scalar_t)(e + 1 - row_start); + } else { + next_idx = index_info.data[offset + (e + 1) * stride]; + assert(idx <= next_idx); + + if (idx != next_idx) { + for (auto k = 0; k < K; k++) { + Reducer::write( + out_data + b * N * K + idx * K + k, vals[k], + arg_out_data + b * N * K + idx * K + k, args[k], + e + 1 - row_start); + + vals[k] = out_data[b * N * K + next_idx * K + k]; + } + if (REDUCE == MEAN) + count_data[b * N + idx] = (scalar_t)(e + 1 - row_start); + row_start = e + 1; + } + + idx = next_idx; + } + } + } + if (!optional_out.has_value() && (REDUCE == MIN || REDUCE == MAX)) + out.masked_fill_(out == Reducer::init(), (scalar_t)0); + + if (REDUCE == MEAN) + arg_out = count; + }); + }); + + return std::make_tuple(out, arg_out); +} + +torch::Tensor gather_coo_cpu(torch::Tensor src, torch::Tensor index, + torch::optional optional_out) { + CHECK_CPU(src); + CHECK_CPU(index); + if (optional_out.has_value()) + CHECK_CPU(optional_out.value()); + + CHECK_INPUT(src.dim() >= index.dim()); + for (auto i = 0; i < index.dim() - 1; i++) + CHECK_INPUT(src.size(i) == index.size(i)); + + auto dim = index.dim() - 1; + + src = src.contiguous(); + + torch::Tensor out; + if (optional_out.has_value()) { + out = optional_out.value().contiguous(); + for (auto i = 0; i < src.dim(); i++) + if (i != dim) + CHECK_INPUT(src.size(i) == out.size(i)); + } else { + auto sizes = src.sizes().vec(); + sizes[dim] = index.size(dim); + out = torch::empty(sizes, src.options()); + } + + auto B = index.numel() / out.size(dim); + auto E = index.size(dim); + auto K = out.numel() / index.numel(); + auto N = src.size(dim); + + auto index_info = getTensorInfo(index); + auto stride = index_info.strides[index_info.dims - 1]; + AT_DISPATCH_ALL_TYPES(src.scalar_type(), "gather_coo", [&] { + auto src_data = src.data_ptr(); + auto out_data = out.data_ptr(); + + std::vector vals(K); + int64_t idx, next_idx; + for (auto b = 0; b < B; b++) { + auto offset = IndexToOffset::get(b * E, index_info); + idx = index_info.data[offset]; + + for (auto k = 0; k < K; k++) + vals[k] = src_data[b * N * K + idx * K + k]; + + for (auto e = 0; e < E; e++) { + for (auto k = 0; k < K; k++) + out_data[b * E * K + e * K + k] = vals[k]; + + if (e < E - 1) { + next_idx = index_info.data[offset + (e + 1) * stride]; + CHECK_INPUT(idx <= next_idx); + + if (idx != next_idx) { + idx = next_idx; + for (auto k = 0; k < K; k++) + vals[k] = src_data[b * N * K + idx * K + k]; + } + } + } + } + }); + + return out; +} diff --git a/csrc/cpu/segment_coo_cpu.h b/csrc/cpu/segment_coo_cpu.h new file mode 100644 index 00000000..feb7a827 --- /dev/null +++ b/csrc/cpu/segment_coo_cpu.h @@ -0,0 +1,11 @@ +#pragma once + +#include + +std::tuple> +segment_coo_cpu(torch::Tensor src, torch::Tensor index, + torch::optional optional_out, + torch::optional dim_size, std::string reduce); + +torch::Tensor gather_coo_cpu(torch::Tensor src, torch::Tensor index, + torch::optional optional_out); diff --git a/csrc/cpu/segment_csr_cpu.cpp b/csrc/cpu/segment_csr_cpu.cpp index 00790128..efc6955c 100644 --- a/csrc/cpu/segment_csr_cpu.cpp +++ b/csrc/cpu/segment_csr_cpu.cpp @@ -67,12 +67,10 @@ segment_csr_cpu(torch::Tensor src, torch::Tensor indptr, for (auto k = 0; k < K; k++) vals[k] = Reducer::init(); - for (auto e = row_start; e < row_end; e++) { - CHECK_INPUT(e < E); + for (auto e = row_start; e < row_end; e++) for (auto k = 0; k < K; k++) Reducer::update( &vals[k], src_data[offset + e * K + k], &args[k], e); - } for (auto k = 0; k < K; k++) Reducer::write(out_data + n * K + k, vals[k], diff --git a/csrc/cuda/segment_coo_cuda.cu b/csrc/cuda/segment_coo_cuda.cu new file mode 100644 index 00000000..a72b363b --- /dev/null +++ b/csrc/cuda/segment_coo_cuda.cu @@ -0,0 +1,13 @@ +#include "segment_coo_cuda.h" + +std::tuple> +segment_coo_cuda(torch::Tensor src, torch::Tensor index, + torch::optional optional_out, + torch::optional dim_size, std::string reduce) { + return std::make_tuple(src, optional_out); +} + +torch::Tensor gather_coo_cuda(torch::Tensor src, torch::Tensor index, + torch::optional optional_out) { + return src; +} diff --git a/csrc/cuda/segment_coo_cuda.h b/csrc/cuda/segment_coo_cuda.h new file mode 100644 index 00000000..68154775 --- /dev/null +++ b/csrc/cuda/segment_coo_cuda.h @@ -0,0 +1,11 @@ +#pragma once + +#include + +std::tuple> +segment_coo_cuda(torch::Tensor src, torch::Tensor index, + torch::optional optional_out, + torch::optional dim_size, std::string reduce); + +torch::Tensor gather_coo_cuda(torch::Tensor src, torch::Tensor index, + torch::optional optional_out); diff --git a/csrc/segment_coo.cpp b/csrc/segment_coo.cpp new file mode 100644 index 00000000..40352ff8 --- /dev/null +++ b/csrc/segment_coo.cpp @@ -0,0 +1,227 @@ +#include + +#include "cpu/segment_coo_cpu.h" + +#ifdef WITH_CUDA +#include "cuda/segment_coo_cuda.h" +#endif + +std::tuple> +segment_coo_fw(torch::Tensor src, torch::Tensor index, + torch::optional optional_out, + torch::optional dim_size, std::string reduce) { + if (src.device().is_cuda()) { +#ifdef WITH_CUDA + return segment_coo_cuda(src, index, optional_out, dim_size, reduce); +#else + AT_ERROR("Not compiled with CUDA support"); +#endif + } else { + return segment_coo_cpu(src, index, optional_out, dim_size, reduce); + } +} + +torch::Tensor gather_coo_fw(torch::Tensor src, torch::Tensor index, + torch::optional optional_out) { + if (src.device().is_cuda()) { +#ifdef WITH_CUDA + return gather_coo_cuda(src, index, optional_out); +#else + AT_ERROR("Not compiled with CUDA support"); +#endif + } else { + return gather_coo_cpu(src, index, optional_out); + } +} + +using torch::autograd::AutogradContext; +using torch::autograd::Variable; +using torch::autograd::variable_list; + +class SegmentSumCOO : public torch::autograd::Function { +public: + static variable_list forward(AutogradContext *ctx, Variable src, + Variable index, + torch::optional optional_out, + torch::optional dim_size) { + ctx->saved_data["src_shape"] = src.sizes(); + auto result = segment_coo_fw(src, index, optional_out, dim_size, "sum"); + auto out = std::get<0>(result); + ctx->save_for_backward({index}); + if (optional_out.has_value()) + ctx->mark_dirty({optional_out.value()}); + return {out}; + } + + static variable_list backward(AutogradContext *ctx, variable_list grad_outs) { + auto grad_out = grad_outs[0]; + auto saved = ctx->get_saved_variables(); + auto index = saved[0]; + auto src_shape = ctx->saved_data["src_shape"].toIntVector(); + auto grad_in = torch::empty(src_shape, grad_out.options()); + gather_coo_fw(grad_out, index, grad_in); + return {grad_in, Variable(), Variable(), Variable()}; + } +}; + +class SegmentMeanCOO : public torch::autograd::Function { +public: + static variable_list forward(AutogradContext *ctx, Variable src, + Variable index, + torch::optional optional_out, + torch::optional dim_size) { + ctx->saved_data["src_shape"] = src.sizes(); + auto result = segment_coo_fw(src, index, optional_out, dim_size, "mean"); + auto out = std::get<0>(result); + auto count = std::get<1>(result).value(); + ctx->save_for_backward({index, count}); + if (optional_out.has_value()) + ctx->mark_dirty({optional_out.value()}); + return {out}; + } + + static variable_list backward(AutogradContext *ctx, variable_list grad_outs) { + auto grad_out = grad_outs[0]; + auto saved = ctx->get_saved_variables(); + auto index = saved[0]; + auto count = saved[1]; + auto src_shape = ctx->saved_data["src_shape"].toIntVector(); + auto grad_in = torch::empty(src_shape, grad_out.options()); + gather_coo_fw(grad_out, index, grad_in); + count = gather_coo_fw(count, index, torch::nullopt); + for (auto i = 0; i < grad_out.dim() - index.dim(); i++) + count = count.unsqueeze(-1); + grad_in.div_(count); + return {grad_in, Variable(), Variable(), Variable()}; + } +}; + +class SegmentMinCOO : public torch::autograd::Function { +public: + static variable_list forward(AutogradContext *ctx, Variable src, + Variable index, + torch::optional optional_out, + torch::optional dim_size) { + ctx->saved_data["src_shape"] = src.sizes(); + auto result = segment_coo_fw(src, index, optional_out, dim_size, "min"); + auto out = std::get<0>(result); + auto arg_out = std::get<1>(result).value(); + ctx->save_for_backward({index, arg_out}); + ctx->mark_non_differentiable({arg_out}); + if (optional_out.has_value()) + ctx->mark_dirty({optional_out.value()}); + return {out, arg_out}; + } + + static variable_list backward(AutogradContext *ctx, variable_list grad_outs) { + auto grad_out = grad_outs[0]; + auto saved = ctx->get_saved_variables(); + auto index = saved[0]; + auto arg_out = saved[1]; + auto src_shape = ctx->saved_data["src_shape"].toIntVector(); + src_shape[index.dim() - 1] += 1; + auto grad_in = torch::zeros(src_shape, grad_out.options()); + grad_in.scatter_(index.dim() - 1, arg_out, grad_out); + grad_in = + grad_in.narrow(index.dim() - 1, 0, src_shape[index.dim() - 1] - 1); + return {grad_in, Variable(), Variable(), Variable()}; + } +}; + +class SegmentMaxCOO : public torch::autograd::Function { +public: + static variable_list forward(AutogradContext *ctx, Variable src, + Variable index, + torch::optional optional_out, + torch::optional dim_size) { + ctx->saved_data["src_shape"] = src.sizes(); + auto result = segment_coo_fw(src, index, optional_out, dim_size, "max"); + auto out = std::get<0>(result); + auto arg_out = std::get<1>(result).value(); + ctx->save_for_backward({index, arg_out}); + ctx->mark_non_differentiable({arg_out}); + if (optional_out.has_value()) + ctx->mark_dirty({optional_out.value()}); + return {out, arg_out}; + } + + static variable_list backward(AutogradContext *ctx, variable_list grad_outs) { + auto grad_out = grad_outs[0]; + auto saved = ctx->get_saved_variables(); + auto index = saved[0]; + auto arg_out = saved[1]; + auto src_shape = ctx->saved_data["src_shape"].toIntVector(); + src_shape[index.dim() - 1] += 1; + auto grad_in = torch::zeros(src_shape, grad_out.options()); + grad_in.scatter_(index.dim() - 1, arg_out, grad_out); + grad_in = + grad_in.narrow(index.dim() - 1, 0, src_shape[index.dim() - 1] - 1); + return {grad_in, Variable(), Variable(), Variable()}; + } +}; + +class GatherCOO : public torch::autograd::Function { +public: + static variable_list forward(AutogradContext *ctx, Variable src, + Variable index, + torch::optional optional_out) { + ctx->saved_data["src_shape"] = src.sizes(); + auto out = gather_coo_fw(src, index, optional_out); + ctx->save_for_backward({index}); + if (optional_out.has_value()) + ctx->mark_dirty({optional_out.value()}); + return {out}; + } + + static variable_list backward(AutogradContext *ctx, variable_list grad_outs) { + auto grad_out = grad_outs[0]; + auto saved = ctx->get_saved_variables(); + auto index = saved[0]; + auto src_shape = ctx->saved_data["src_shape"].toIntVector(); + + auto grad_in = torch::zeros(src_shape, grad_out.options()); + segment_coo_fw(grad_out, index, grad_in, torch::nullopt, "sum"); + return {grad_in, Variable(), Variable()}; + } +}; + +torch::Tensor segment_sum_coo(torch::Tensor src, torch::Tensor index, + torch::optional optional_out, + torch::optional dim_size) { + return SegmentSumCOO::apply(src, index, optional_out, dim_size)[0]; +} + +torch::Tensor segment_mean_coo(torch::Tensor src, torch::Tensor index, + torch::optional optional_out, + torch::optional dim_size) { + return SegmentMeanCOO::apply(src, index, optional_out, dim_size)[0]; +} + +std::tuple +segment_min_coo(torch::Tensor src, torch::Tensor index, + torch::optional optional_out, + torch::optional dim_size) { + auto result = SegmentMinCOO::apply(src, index, optional_out, dim_size); + return std::make_tuple(result[0], result[1]); +} + +std::tuple +segment_max_coo(torch::Tensor src, torch::Tensor index, + torch::optional optional_out, + torch::optional dim_size) { + auto result = SegmentMaxCOO::apply(src, index, optional_out, dim_size); + return std::make_tuple(result[0], result[1]); +} + +torch::Tensor gather_coo(torch::Tensor src, torch::Tensor index, + torch::optional optional_out) { + return GatherCOO::apply(src, index, optional_out)[0]; +} + +static auto registry = + torch::RegisterOperators() + .op("torch_scatter::segment_sum_coo", &segment_sum_coo) + .op("torch_scatter::segment_mean_coo", &segment_mean_coo) + .op("torch_scatter::segment_min_coo", &segment_min_coo) + .op("torch_scatter::segment_max_coo", &segment_max_coo) + .op("torch_scatter::gather_coo", &gather_coo); diff --git a/test/test_gather.py b/test/test_gather.py index bf92e1ef..3a5ade55 100644 --- a/test/test_gather.py +++ b/test/test_gather.py @@ -3,10 +3,12 @@ import pytest import torch from torch.autograd import gradcheck -from torch_scatter import gather_coo, gather_csr +from torch_scatter import gather_csr, gather_coo from .utils import tensor, dtypes, devices +devices = ['cpu'] + tests = [ { 'src': [1, 2, 3, 4], @@ -54,10 +56,10 @@ def test_forward(test, dtype, device): indptr = tensor(test['indptr'], torch.long, device) expected = tensor(test['expected'], dtype, device) - out = gather_coo(src, index) + out = gather_csr(src, indptr) assert torch.all(out == expected) - out = gather_csr(src, indptr) + out = gather_coo(src, index) assert torch.all(out == expected) @@ -68,8 +70,8 @@ def test_backward(test, device): index = tensor(test['index'], torch.long, device) indptr = tensor(test['indptr'], torch.long, device) - assert gradcheck(gather_coo, (src, index, None)) is True assert gradcheck(gather_csr, (src, indptr, None)) is True + assert gradcheck(gather_coo, (src, index, None)) is True @pytest.mark.parametrize('test,dtype,device', product(tests, dtypes, devices)) @@ -83,12 +85,12 @@ def test_gather_out(test, dtype, device): size[index.dim() - 1] = index.size(-1) out = src.new_full(size, -2) - gather_coo(src, index, out) + gather_csr(src, indptr, out) assert torch.all(out == expected) out.fill_(-2) - gather_csr(src, indptr, out) + gather_coo(src, index, out) assert torch.all(out == expected) @@ -106,8 +108,8 @@ def test_non_contiguous_segment(test, dtype, device): if indptr.dim() > 1: indptr = indptr.transpose(0, 1).contiguous().transpose(0, 1) - out = gather_coo(src, index) + out = gather_csr(src, indptr) assert torch.all(out == expected) - out = gather_csr(src, indptr) + out = gather_coo(src, index) assert torch.all(out == expected) diff --git a/test/test_segment.py b/test/test_segment.py index 7f68f62b..22cc0959 100644 --- a/test/test_segment.py +++ b/test/test_segment.py @@ -3,12 +3,12 @@ import pytest import torch from torch.autograd import gradcheck -from torch_scatter import segment_csr +import torch_scatter from .utils import tensor, dtypes, devices reductions = ['sum', 'mean', 'min', 'max'] -grad_reductions = ['sum', 'mean'] +devices = ['cpu'] tests = [ { @@ -88,14 +88,14 @@ def test_forward(test, reduce, dtype, device): indptr = tensor(test['indptr'], torch.long, device) expected = tensor(test[reduce], dtype, device) - # out = segment_coo(src, index, reduce=reduce) - # if isinstance(out, tuple): - # out, arg_out = out - # arg_expected = tensor(test[f'arg_{reduce}'], torch.long, device) - # assert torch.all(arg_out == arg_expected) - # assert torch.all(out == expected) + out = getattr(torch_scatter, f'segment_{reduce}_csr')(src, indptr) + if isinstance(out, tuple): + out, arg_out = out + arg_expected = tensor(test[f'arg_{reduce}'], torch.long, device) + assert torch.all(arg_out == arg_expected) + assert torch.all(out == expected) - out = segment_csr(src, indptr, reduce=reduce) + out = getattr(torch_scatter, f'segment_{reduce}_coo')(src, index) if isinstance(out, tuple): out, arg_out = out arg_expected = tensor(test[f'arg_{reduce}'], torch.long, device) @@ -104,15 +104,16 @@ def test_forward(test, reduce, dtype, device): @pytest.mark.parametrize('test,reduce,device', - product(tests, grad_reductions, devices)) + product(tests, reductions, devices)) def test_backward(test, reduce, device): src = tensor(test['src'], torch.double, device) src.requires_grad_() index = tensor(test['index'], torch.long, device) indptr = tensor(test['indptr'], torch.long, device) - # assert gradcheck(segment_coo, (src, index, None, None, reduce)) - assert gradcheck(segment_csr, (src, indptr, None, reduce)) + assert gradcheck(torch_scatter.segment_csr, (src, indptr, None, reduce)) + assert gradcheck(torch_scatter.segment_coo, + (src, index, None, None, reduce)) @pytest.mark.parametrize('test,reduce,dtype,device', @@ -127,25 +128,25 @@ def test_segment_out(test, reduce, dtype, device): size[indptr.dim() - 1] = indptr.size(-1) - 1 out = src.new_full(size, -2) - segment_csr(src, indptr, out, reduce=reduce) + getattr(torch_scatter, f'segment_{reduce}_csr')(src, indptr, out) assert torch.all(out == expected) - # out.fill_(-2) + out.fill_(-2) - # segment_coo(src, index, out, reduce=reduce) + getattr(torch_scatter, f'segment_{reduce}_coo')(src, index, out) - # if reduce == 'sum': - # expected = expected - 2 - # elif reduce == 'mean': - # expected = out # We can not really test this here. - # elif reduce == 'min': - # expected = expected.fill_(-2) - # elif reduce == 'max': - # expected[expected == 0] = -2 - # else: - # raise ValueError + if reduce == 'sum': + expected = expected - 2 + elif reduce == 'mean': + expected = out # We can not really test this here. + elif reduce == 'min': + expected = expected.fill_(-2) + elif reduce == 'max': + expected[expected == 0] = -2 + else: + raise ValueError - # assert torch.all(out == expected) + assert torch.all(out == expected) @pytest.mark.parametrize('test,reduce,dtype,device', @@ -163,14 +164,14 @@ def test_non_contiguous_segment(test, reduce, dtype, device): if indptr.dim() > 1: indptr = indptr.transpose(0, 1).contiguous().transpose(0, 1) - # out = segment_coo(src, index, reduce=reduce) - # if isinstance(out, tuple): - # out, arg_out = out - # arg_expected = tensor(test[f'arg_{reduce}'], torch.long, device) - # assert torch.all(arg_out == arg_expected) - # assert torch.all(out == expected) + out = getattr(torch_scatter, f'segment_{reduce}_csr')(src, indptr) + if isinstance(out, tuple): + out, arg_out = out + arg_expected = tensor(test[f'arg_{reduce}'], torch.long, device) + assert torch.all(arg_out == arg_expected) + assert torch.all(out == expected) - out = segment_csr(src, indptr, reduce=reduce) + out = getattr(torch_scatter, f'segment_{reduce}_coo')(src, index) if isinstance(out, tuple): out, arg_out = out arg_expected = tensor(test[f'arg_{reduce}'], torch.long, device) diff --git a/torch_scatter/__init__.py b/torch_scatter/__init__.py index 9f3e38de..9fcc124c 100644 --- a/torch_scatter/__init__.py +++ b/torch_scatter/__init__.py @@ -1,6 +1,13 @@ +import torch + +torch.ops.load_library('torch_scatter/_C.so') + from .segment_csr import (segment_sum_csr, segment_add_csr, segment_mean_csr, segment_min_csr, segment_max_csr, segment_csr, - gather_csr) + gather_csr) # noqa +from .segment_coo import (segment_sum_coo, segment_add_coo, segment_mean_coo, + segment_min_coo, segment_max_coo, segment_coo, + gather_coo) # noqa __version__ = '1.5.0' @@ -13,6 +20,14 @@ 'segment_max_csr', 'segment_csr', 'gather_csr', + 'segment_sum_coo', + 'segment_add_coo', + 'segment_mean_coo', + 'segment_min_coo', + 'segment_max_coo', + 'segment_max_coo', + 'segment_coo', + 'gather_coo', 'torch_scatter', '__version__', ] diff --git a/torch_scatter/segment_coo.py b/torch_scatter/segment_coo.py new file mode 100644 index 00000000..bbb7d00c --- /dev/null +++ b/torch_scatter/segment_coo.py @@ -0,0 +1,63 @@ +from typing import Optional, Tuple + +import torch + + +@torch.jit.script +def segment_sum_coo(src: torch.Tensor, index: torch.Tensor, + out: Optional[torch.Tensor] = None, + dim_size: Optional[int] = None) -> torch.Tensor: + return torch.ops.torch_scatter.segment_sum_coo(src, index, out, dim_size) + + +@torch.jit.script +def segment_add_coo(src: torch.Tensor, index: torch.Tensor, + out: Optional[torch.Tensor] = None, + dim_size: Optional[int] = None) -> torch.Tensor: + return torch.ops.torch_scatter.segment_sum_coo(src, index, out, dim_size) + + +@torch.jit.script +def segment_mean_coo(src: torch.Tensor, index: torch.Tensor, + out: Optional[torch.Tensor] = None, + dim_size: Optional[int] = None) -> torch.Tensor: + return torch.ops.torch_scatter.segment_mean_coo(src, index, out, dim_size) + + +@torch.jit.script +def segment_min_coo(src: torch.Tensor, index: torch.Tensor, + out: Optional[torch.Tensor] = None, + dim_size: Optional[int] = None + ) -> Tuple[torch.Tensor, torch.Tensor]: + return torch.ops.torch_scatter.segment_min_coo(src, index, out, dim_size) + + +@torch.jit.script +def segment_max_coo(src: torch.Tensor, index: torch.Tensor, + out: Optional[torch.Tensor] = None, + dim_size: Optional[int] = None + ) -> Tuple[torch.Tensor, torch.Tensor]: + return torch.ops.torch_scatter.segment_max_coo(src, index, out, dim_size) + + +@torch.jit.script +def segment_coo(src: torch.Tensor, index: torch.Tensor, + out: Optional[torch.Tensor] = None, + dim_size: Optional[int] = None, + reduce: str = "sum") -> torch.Tensor: + if reduce == 'sum' or reduce == 'add': + return segment_sum_coo(src, index, out, dim_size) + elif reduce == 'mean': + return segment_mean_coo(src, index, out, dim_size) + elif reduce == 'min': + return segment_min_coo(src, index, out, dim_size)[0] + elif reduce == 'max': + return segment_max_coo(src, index, out, dim_size)[0] + else: + raise ValueError + + +@torch.jit.script +def gather_coo(src: torch.Tensor, index: torch.Tensor, + out: Optional[torch.Tensor] = None) -> torch.Tensor: + return torch.ops.torch_scatter.gather_coo(src, index, out) diff --git a/torch_scatter/segment_csr.py b/torch_scatter/segment_csr.py index 88a4faca..5ae84e79 100644 --- a/torch_scatter/segment_csr.py +++ b/torch_scatter/segment_csr.py @@ -2,8 +2,6 @@ import torch -torch.ops.load_library('torch_scatter/_C.so') - @torch.jit.script def segment_sum_csr(src: torch.Tensor, indptr: torch.Tensor, From d0f5005f47bdc69b0dfc4e00874c5cf1500d9100 Mon Sep 17 00:00:00 2001 From: rusty1s Date: Wed, 29 Jan 2020 16:10:02 +0100 Subject: [PATCH 04/12] fix link --- torch_scatter/__init__.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/torch_scatter/__init__.py b/torch_scatter/__init__.py index 9fcc124c..9e959cdc 100644 --- a/torch_scatter/__init__.py +++ b/torch_scatter/__init__.py @@ -1,6 +1,8 @@ +import os.path as osp + import torch -torch.ops.load_library('torch_scatter/_C.so') +torch.ops.load_library(osp.join(osp.dirname(osp.abspath(__file__)), '_C.so')) from .segment_csr import (segment_sum_csr, segment_add_csr, segment_mean_csr, segment_min_csr, segment_max_csr, segment_csr, From 64772d759e0da7bac78f43c898a942e366402f41 Mon Sep 17 00:00:00 2001 From: rusty1s Date: Thu, 30 Jan 2020 09:57:20 +0100 Subject: [PATCH 05/12] segment coo done --- csrc/cuda/segment_coo_cuda.cu | 360 +++++++++++++++++++++++++++++++++- csrc/cuda/segment_csr_cuda.cu | 10 +- setup.py | 40 ++-- test/test_gather.py | 2 - test/test_segment.py | 1 - torch_scatter/__init__.py | 12 +- torch_scatter/segment_coo.py | 4 + torch_scatter/segment_csr.py | 4 + 8 files changed, 394 insertions(+), 39 deletions(-) diff --git a/csrc/cuda/segment_coo_cuda.cu b/csrc/cuda/segment_coo_cuda.cu index a72b363b..2ceed10b 100644 --- a/csrc/cuda/segment_coo_cuda.cu +++ b/csrc/cuda/segment_coo_cuda.cu @@ -1,13 +1,369 @@ #include "segment_coo_cuda.h" +#include +#include +#include + +#include "reducer.cuh" +#include "utils.cuh" + +#define THREADS 256 +#define BLOCKS(TB, N) (TB * N + THREADS - 1) / THREADS +#define FULL_MASK 0xffffffff + +template +__global__ void +segment_coo_kernel(const scalar_t *src_data, + const at::cuda::detail::TensorInfo index_info, + scalar_t *out_data, size_t E, size_t N) { + + // Each thread processes exactly one entry. Within a warp, we perform a + // parallel reduction across equal indices, and write the intermediate + // result via atomics. + + int row_idx = blockIdx.x * blockDim.x + threadIdx.x; + int lane_idx = row_idx & (32 - 1); + int D = index_info.sizes[index_info.dims - 1]; + + if (row_idx < E) { + int offset = at::cuda::detail::IndexToOffset::get( + row_idx, index_info); + int64_t idx = index_info.data[offset], next_idx; + int out_idx = (row_idx / D) * N + idx; + + scalar_t val = HAS_VAL ? src_data[row_idx] : (scalar_t)1, tmp; + +#pragma unroll + for (int i = 1; i < 32; i *= 2) { + // Parallel reduction inside a single warp. + tmp = __shfl_up_sync(FULL_MASK, val, i); + next_idx = __shfl_up_sync(FULL_MASK, idx, i); + if (lane_idx >= i && row_idx / D == (row_idx - i) / D) { + assert(idx >= next_idx); + if (idx == next_idx) + Reducer::update(&val, tmp); + } + } + + next_idx = __shfl_down_sync(FULL_MASK, idx, 1); + if (lane_idx == 32 - 1 || row_idx / D != (row_idx + 1) / D || + idx != next_idx) + Reducer::atomic_write(out_data + out_idx, val); + } +} + +template +__global__ void segment_coo_arg_kernel( + const scalar_t *src_data, + const at::cuda::detail::TensorInfo index_info, + scalar_t *out_data, int64_t *arg_out_data, size_t E, size_t N) { + + int row_idx = blockIdx.x * blockDim.x + threadIdx.x; + int D = index_info.sizes[index_info.dims - 1]; + + if (row_idx < E) { + int offset = at::cuda::detail::IndexToOffset::get( + row_idx, index_info); + int64_t idx = index_info.data[offset]; + int out_idx = (row_idx / D) * N + idx; + + scalar_t val = __ldg(out_data + out_idx); + if (src_data[row_idx] == val) + arg_out_data[out_idx] = row_idx % D; + } +} + +template +__global__ void segment_coo_broadcast_kernel( + const scalar_t *src_data, + const at::cuda::detail::TensorInfo index_info, + scalar_t *out_data, size_t E, size_t K, size_t N) { + + // Each thread processes a single column and `TB` index entries. Coalesced + // read and write is performed in column-major order. The intermediate + // results are written via atomics. + + int D = index_info.sizes[index_info.dims - 1]; + int E_1 = E / D; + int E_2 = D + TB - (D % TB); + + int row_idx = blockIdx.x * blockDim.y + threadIdx.y; + int col_idx = blockIdx.y * blockDim.x + threadIdx.x; + + int dim_start = (row_idx * TB) / E_2; + int row_start = (row_idx * TB) % E_2; + + if (dim_start < E_1 && col_idx < K) { + + int offset = at::cuda::detail::IndexToOffset::get( + dim_start * D + row_start, index_info); + int idx1 = __ldg(index_info.data + offset), idx2; + + scalar_t val = src_data[K * (dim_start * D + row_start) + col_idx]; + +#pragma unroll + for (int i = 1; i < TB; i++) { + if (row_start + i >= D) + break; + + idx2 = __ldg(index_info.data + offset + + i * index_info.strides[index_info.dims - 1]); + assert(idx1 <= idx2); + if (idx1 == idx2) { + Reducer::update( + &val, src_data[K * (dim_start * D + row_start + i) + col_idx]); + } else { + Reducer::atomic_write( + out_data + (dim_start * N + idx1) * K + col_idx, val); + val = src_data[K * (dim_start * D + row_start + i) + col_idx]; + } + + idx1 = idx2; + } + + Reducer::atomic_write( + out_data + (dim_start * N + idx1) * K + col_idx, val); + } +} + +template +__global__ void segment_coo_arg_broadcast_kernel( + const scalar_t *src_data, + const at::cuda::detail::TensorInfo index_info, + scalar_t *out_data, int64_t *arg_out_data, size_t E, size_t K, size_t N) { + + int thread_idx = blockIdx.x * blockDim.x + threadIdx.x; + int row_idx = thread_idx / K; + int col_idx = thread_idx % K; + int D = index_info.sizes[index_info.dims - 1]; + + if (row_idx < E && col_idx < K) { + int offset = at::cuda::detail::IndexToOffset::get( + row_idx, index_info); + int idx = __ldg(index_info.data + offset); + int out_idx = ((row_idx / D) * N + idx) * K + col_idx; + + scalar_t val = __ldg(out_data + out_idx); + if (src_data[thread_idx] == val) + arg_out_data[out_idx] = row_idx % D; + } +} + std::tuple> segment_coo_cuda(torch::Tensor src, torch::Tensor index, torch::optional optional_out, torch::optional dim_size, std::string reduce) { - return std::make_tuple(src, optional_out); + CHECK_CUDA(src); + CHECK_CUDA(index); + if (optional_out.has_value()) + CHECK_CUDA(optional_out.value()); + cudaSetDevice(src.get_device()); + + CHECK_INPUT(src.dim() >= index.dim()); + + auto sizes = index.sizes().vec(); + for (int i = 0; i < index.dim(); i++) { + sizes[i] = src.size(i); + } + index = index.expand(sizes); + + auto dim = index.dim() - 1; + + src = src.contiguous(); + + torch::Tensor out; + if (optional_out.has_value()) { + out = optional_out.value().contiguous(); + for (int i = 0; i < out.dim(); i++) + if (i != dim) + CHECK_INPUT(src.size(i) == out.size(i)); + } else { + sizes = src.sizes().vec(); + if (dim_size.has_value()) + sizes[dim] = dim_size.value(); + else { + auto d_size = index.max().data_ptr(); + auto h_size = (int64_t *)malloc(sizeof(int64_t)); + cudaMemcpy(h_size, d_size, sizeof(int64_t), cudaMemcpyDeviceToHost); + sizes[dim] = 1 + *h_size; + } + out = torch::zeros(sizes, src.options()); + } + + torch::optional arg_out = torch::nullopt; + int64_t *arg_out_data = nullptr; + if (reduce2REDUCE.at(reduce) == MIN || reduce2REDUCE.at(reduce) == MAX) { + arg_out = torch::full_like(out, src.size(dim), index.options()); + arg_out_data = arg_out.value().data_ptr(); + } + + auto E = index.numel(); + auto E_2 = index.size(dim); + auto E_1 = index.numel() / E_2; + auto K = src.numel() / E; + auto N = out.size(dim); + auto avg_len = (float)E_2 / (float)N; + + auto index_info = at::cuda::detail::getTensorInfo(index); + auto stream = at::cuda::getCurrentCUDAStream(); + AT_DISPATCH_ALL_TYPES(src.scalar_type(), "segment_coo_kernel", [&] { + auto src_data = src.data_ptr(); + auto out_data = out.data_ptr(); + + AT_DISPATCH_REDUCTION_TYPES(reduce, [&] { + if (!optional_out.has_value()) + out.fill_(Reducer::init()); + + if (K == 1) + segment_coo_kernel + <<>>(src_data, index_info, + out_data, E, N); + else if (avg_len <= 8) + segment_coo_broadcast_kernel + <<>>(src_data, index_info, out_data, E, K, + N); + else if (avg_len <= 16) + segment_coo_broadcast_kernel + <<>>(src_data, index_info, out_data, E, K, + N); + else if (avg_len <= 32) + segment_coo_broadcast_kernel + <<>>(src_data, index_info, out_data, E, K, + N); + else + segment_coo_broadcast_kernel + <<>>(src_data, index_info, out_data, E, K, + N); + + if (!optional_out.has_value() && (REDUCE == MIN || REDUCE == MAX)) + out.masked_fill_(out == Reducer::init(), (scalar_t)0); + + if (REDUCE == MIN || REDUCE == MAX) { + if (K == 1) + segment_coo_arg_kernel + <<>>( + src_data, index_info, out_data, arg_out_data, E, N); + else + segment_coo_arg_broadcast_kernel + <<>>( + src_data, index_info, out_data, arg_out_data, E, K, N); + } + + if (REDUCE == MEAN) { + auto sizes = index.sizes().vec(); + sizes[dim] = out.size(dim); + auto count = torch::zeros(sizes, out.options()); + auto count_data = count.data_ptr(); + segment_coo_kernel + <<>>(nullptr, index_info, + count_data, E, N); + arg_out = count; + for (int i = dim + 1; i < out.dim(); i++) + count = count.unsqueeze(-1); + out.div_(count.clamp_(1)); + } + }); + }); + + return std::make_tuple(out, arg_out); +} + +template +__global__ void +gather_coo_kernel(const scalar_t *src_data, + const at::cuda::detail::TensorInfo index_info, + scalar_t *out_data, size_t E, size_t N) { + + int row_idx = blockIdx.x * blockDim.x + threadIdx.x; + + if (row_idx < E) { + int offset = at::cuda::detail::IndexToOffset::get( + row_idx, index_info); + int row = index_info.data[offset]; + + offset = (row_idx / index_info.sizes[index_info.dims - 1]) * N; + scalar_t val = __ldg(src_data + offset + row); + + out_data[row_idx] = val; + } +} + +template +__global__ void gather_coo_broadcast_kernel( + const scalar_t *src_data, + const at::cuda::detail::TensorInfo index_info, + scalar_t *out_data, size_t E, size_t K, size_t N) { + + int thread_idx = blockIdx.x * blockDim.x + threadIdx.x; + int row_idx = thread_idx / K; + int col_idx = thread_idx % K; + + if (thread_idx < E * K) { + int offset = at::cuda::detail::IndexToOffset::get( + row_idx, index_info); + int row = index_info.data[offset]; + + offset = (row_idx / index_info.sizes[index_info.dims - 1]) * N * K; + scalar_t val = __ldg(src_data + offset + K * row + col_idx); + + out_data[thread_idx] = val; + } } torch::Tensor gather_coo_cuda(torch::Tensor src, torch::Tensor index, torch::optional optional_out) { - return src; + CHECK_CUDA(src); + CHECK_CUDA(index); + if (optional_out.has_value()) + CHECK_CUDA(optional_out.value()); + cudaSetDevice(src.get_device()); + + CHECK_INPUT(src.dim() >= index.dim()); + + auto sizes = index.sizes().vec(); + for (auto i = 0; i < index.dim() - 1; i++) + sizes[i] = src.size(i); + index = index.expand(sizes); + + auto dim = index.dim() - 1; + + src = src.contiguous(); + + torch::Tensor out; + if (optional_out.has_value()) { + out = optional_out.value().contiguous(); + for (auto i = 0; i < src.dim(); i++) + if (i != dim) + CHECK_INPUT(src.size(i) == out.size(i)); + CHECK_INPUT(index.size(dim) == out.size(dim)); + } else { + auto sizes = src.sizes().vec(); + sizes[dim] = index.size(dim); + out = torch::empty(sizes, src.options()); + } + + auto E = index.numel(); + auto K = out.numel() / E; + auto N = src.size(dim); + + auto index_info = at::cuda::detail::getTensorInfo(index); + auto stream = at::cuda::getCurrentCUDAStream(); + AT_DISPATCH_ALL_TYPES(src.scalar_type(), "gather_coo_kernel", [&] { + auto src_data = src.data_ptr(); + auto out_data = out.data_ptr(); + + if (K == 1) + gather_coo_kernel<<>>( + src_data, index_info, out_data, E, N); + else + gather_coo_broadcast_kernel + <<>>(src_data, index_info, + out_data, E, K, N); + }); + + return out; } diff --git a/csrc/cuda/segment_csr_cuda.cu b/csrc/cuda/segment_csr_cuda.cu index 1bf6cf5f..f22fd764 100644 --- a/csrc/cuda/segment_csr_cuda.cu +++ b/csrc/cuda/segment_csr_cuda.cu @@ -237,13 +237,11 @@ torch::Tensor gather_csr_cuda(torch::Tensor src, torch::Tensor indptr, if (i != dim) CHECK_INPUT(src.size(i) == out.size(i)); } else { - auto d_gather_size = indptr.flatten()[-1].data_ptr(); - auto h_gather_size = (int64_t *)malloc(sizeof(int64_t)); - cudaMemcpy(h_gather_size, d_gather_size, sizeof(int64_t), - cudaMemcpyDeviceToHost); - + auto d_size = indptr.flatten()[-1].data_ptr(); + auto h_size = (int64_t *)malloc(sizeof(int64_t)); + cudaMemcpy(h_size, d_size, sizeof(int64_t), cudaMemcpyDeviceToHost); auto sizes = src.sizes().vec(); - sizes[dim] = *h_gather_size; + sizes[dim] = *h_size; out = torch::empty(sizes, src.options()); } diff --git a/setup.py b/setup.py index 3267df18..d803ad16 100644 --- a/setup.py +++ b/setup.py @@ -8,31 +8,22 @@ from torch.utils.cpp_extension import BuildExtension from torch.utils.cpp_extension import CppExtension, CUDAExtension, CUDA_HOME +WITH_CUDA = torch.cuda.is_available() and CUDA_HOME is not None +WITH_CUDA = WITH_CUDA or os.getenv('FORCE_CUDA', '0') == '1' -def get_extensions(): - this_dir = osp.dirname(osp.abspath(__file__)) - extensions_dir = osp.join(this_dir, 'csrc') - - main_files = glob.glob(osp.join(extensions_dir, '*.cpp')) - cpu_files = glob.glob(osp.join(extensions_dir, 'cpu', '*.cpp')) - cuda_files = glob.glob(osp.join(extensions_dir, 'cuda', '*.cu')) +def get_extensions(): Extension = CppExtension - sources = main_files + cpu_files - define_macros = [] extra_compile_args = {'cxx': [], 'nvcc': []} + # Windows users: Edit both of these to contain your VS include path, i.e.: # extra_compile_args['cxx'] += ['-I{VISUAL_STUDIO_DIR}\\include'] # extra_compile_args['nvcc'] += ['-I{VISUAL_STUDIO_DIR}\\include'] - if (torch.cuda.is_available() and CUDA_HOME is not None) or os.getenv( - 'FORCE_CUDA', '0') == '1': - + if WITH_CUDA: Extension = CUDAExtension - sources += cuda_files define_macros += [('WITH_CUDA', None)] - nvcc_flags = os.getenv('NVCC_FLAGS', '') nvcc_flags = [] if nvcc_flags == '' else nvcc_flags.split(' ') nvcc_flags += ['-arch=sm_35', '--expt-relaxed-constexpr'] @@ -42,15 +33,26 @@ def get_extensions(): if sys.platform == 'win32': extra_compile_args['cxx'] += ['/MP'] - return [ - Extension( - 'torch_scatter._C', + extensions_dir = osp.join(osp.dirname(osp.abspath(__file__)), 'csrc') + main_files = glob.glob(osp.join(extensions_dir, '*.cpp')) + extensions = [] + for main in main_files: + name = main.split(os.sep)[-1][:-4] + + sources = [main, osp.join(extensions_dir, 'cpu', f'{name}_cpu.cpp')] + if WITH_CUDA: + sources += [osp.join(extensions_dir, 'cuda', f'{name}_cuda.cu')] + + extension = Extension( + f'torch_scatter._{name}', sources, include_dirs=[extensions_dir], define_macros=define_macros, extra_compile_args=extra_compile_args, ) - ] + extensions += [extension] + + return extensions install_requires = [] @@ -59,7 +61,7 @@ def get_extensions(): setup( name='torch_scatter', - version='1.5.0', + version='2.0.0', author='Matthias Fey', author_email='matthias.fey@tu-dortmund.de', url='https://github.com/rusty1s/pytorch_scatter', diff --git a/test/test_gather.py b/test/test_gather.py index 3a5ade55..5875af3f 100644 --- a/test/test_gather.py +++ b/test/test_gather.py @@ -7,8 +7,6 @@ from .utils import tensor, dtypes, devices -devices = ['cpu'] - tests = [ { 'src': [1, 2, 3, 4], diff --git a/test/test_segment.py b/test/test_segment.py index 22cc0959..1576f944 100644 --- a/test/test_segment.py +++ b/test/test_segment.py @@ -8,7 +8,6 @@ from .utils import tensor, dtypes, devices reductions = ['sum', 'mean', 'min', 'max'] -devices = ['cpu'] tests = [ { diff --git a/torch_scatter/__init__.py b/torch_scatter/__init__.py index 9e959cdc..71b2fc96 100644 --- a/torch_scatter/__init__.py +++ b/torch_scatter/__init__.py @@ -1,17 +1,11 @@ -import os.path as osp - -import torch - -torch.ops.load_library(osp.join(osp.dirname(osp.abspath(__file__)), '_C.so')) - from .segment_csr import (segment_sum_csr, segment_add_csr, segment_mean_csr, segment_min_csr, segment_max_csr, segment_csr, - gather_csr) # noqa + gather_csr) from .segment_coo import (segment_sum_coo, segment_add_coo, segment_mean_coo, segment_min_coo, segment_max_coo, segment_coo, - gather_coo) # noqa + gather_coo) -__version__ = '1.5.0' +__version__ = '2.0.0' __all__ = [ 'segment_sum_csr', diff --git a/torch_scatter/segment_coo.py b/torch_scatter/segment_coo.py index bbb7d00c..8e33d825 100644 --- a/torch_scatter/segment_coo.py +++ b/torch_scatter/segment_coo.py @@ -1,7 +1,11 @@ +import os.path as osp from typing import Optional, Tuple import torch +torch.ops.load_library( + osp.join(osp.dirname(osp.abspath(__file__)), '_segment_coo.so')) + @torch.jit.script def segment_sum_coo(src: torch.Tensor, index: torch.Tensor, diff --git a/torch_scatter/segment_csr.py b/torch_scatter/segment_csr.py index 5ae84e79..ac00183e 100644 --- a/torch_scatter/segment_csr.py +++ b/torch_scatter/segment_csr.py @@ -1,7 +1,11 @@ +import os.path as osp from typing import Optional, Tuple import torch +torch.ops.load_library( + osp.join(osp.dirname(osp.abspath(__file__)), '_segment_csr.so')) + @torch.jit.script def segment_sum_csr(src: torch.Tensor, indptr: torch.Tensor, From 5e2d0f1fc579b0249d3d9646b093762ecbebdd30 Mon Sep 17 00:00:00 2001 From: rusty1s Date: Thu, 30 Jan 2020 12:07:11 +0100 Subject: [PATCH 06/12] scatter cpu: --- README.md | 4 +- csrc/cpu/scatter_cpu.cpp | 82 ++++++++++++++ csrc/cpu/scatter_cpu.h | 8 ++ csrc/cpu/segment_coo_cpu.cpp | 4 +- csrc/cpu/segment_csr_cpu.cpp | 4 +- csrc/cuda/reducer.cuh | 4 +- csrc/cuda/scatter_cuda.cu | 8 ++ csrc/cuda/scatter_cuda.h | 8 ++ csrc/scatter.cpp | 213 +++++++++++++++++++++++++++++++++++ test/test_gather.py | 4 +- test/test_scatter.py | 162 ++++++++++++++++++++++++++ test/test_segment.py | 18 +-- torch_scatter/__init__.py | 10 +- torch_scatter/scatter.py | 60 ++++++++++ 14 files changed, 570 insertions(+), 19 deletions(-) create mode 100644 csrc/cpu/scatter_cpu.cpp create mode 100644 csrc/cpu/scatter_cpu.h create mode 100644 csrc/cuda/scatter_cuda.cu create mode 100644 csrc/cuda/scatter_cuda.h create mode 100644 csrc/scatter.cpp create mode 100644 test/test_scatter.py create mode 100644 torch_scatter/scatter.py diff --git a/README.md b/README.md index 47b8f151..49b55f2c 100644 --- a/README.md +++ b/README.md @@ -45,11 +45,11 @@ All included operations are broadcastable, work on varying data types, and are i ## Installation -Ensure that at least PyTorch 1.1.0 is installed and verify that `cuda/bin` and `cuda/include` are in your `$PATH` and `$CPATH` respectively, *e.g.*: +Ensure that at least PyTorch 1.3.0 is installed and verify that `cuda/bin` and `cuda/include` are in your `$PATH` and `$CPATH` respectively, *e.g.*: ``` $ python -c "import torch; print(torch.__version__)" ->>> 1.1.0 +>>> 1.3.0 $ echo $PATH >>> /usr/local/cuda/bin:... diff --git a/csrc/cpu/scatter_cpu.cpp b/csrc/cpu/scatter_cpu.cpp new file mode 100644 index 00000000..f145143d --- /dev/null +++ b/csrc/cpu/scatter_cpu.cpp @@ -0,0 +1,82 @@ +#include "scatter_cpu.h" + +#include "index_info.h" +#include "reducer.h" +#include "utils.h" + +std::tuple> +scatter_cpu(torch::Tensor src, torch::Tensor index, int64_t dim, + torch::optional optional_out, + torch::optional dim_size, std::string reduce) { + CHECK_CPU(src); + CHECK_CPU(index); + if (optional_out.has_value()) + CHECK_CPU(optional_out.value()); + + CHECK_INPUT(src.dim() == index.dim()); + for (auto i = 0; i < index.dim() - 1; i++) + CHECK_INPUT(src.size(i) >= index.size(i)); + + if (dim < 0) + dim = src.dim() + dim; + + src = src.contiguous(); + + torch::Tensor out; + if (optional_out.has_value()) { + out = optional_out.value().contiguous(); + for (auto i = 0; i < out.dim(); i++) + if (i != dim) + CHECK_INPUT(src.size(i) == out.size(i)); + } else { + auto sizes = src.sizes().vec(); + if (dim_size.has_value()) + sizes[dim] = dim_size.value(); + else + sizes[dim] = 1 + *index.max().data_ptr(); + out = torch::empty(sizes, src.options()); + } + + torch::optional arg_out = torch::nullopt; + int64_t *arg_out_data = nullptr; + if (reduce2REDUCE.at(reduce) == MIN || reduce2REDUCE.at(reduce) == MAX) { + arg_out = torch::full_like(out, src.size(dim), index.options()); + arg_out_data = arg_out.value().data_ptr(); + } + + auto B = 1; + for (auto i = 0; i < dim; i++) + B *= src.size(i); + auto E = src.size(dim); + auto K = src.numel() / (B * E); + auto N = out.size(dim); + + auto index_info = getTensorInfo(index); + AT_DISPATCH_ALL_TYPES(src.scalar_type(), "scatter", [&] { + auto src_data = src.data_ptr(); + auto out_data = out.data_ptr(); + + int64_t i, idx; + AT_DISPATCH_REDUCTION_TYPES(reduce, [&] { + if (!optional_out.has_value()) + out.fill_(Reducer::init()); + + for (auto b = 0; b < B; b++) { + for (auto e = 0; e < E; e++) { + for (auto k = 0; k < K; k++) { + i = b * E * K + e * K + k; + idx = index_info.data[IndexToOffset::get(i, index_info)]; + Reducer::update( + out_data + b * N * K + idx * K + k, src_data[i], + arg_out_data + b * N * K + idx * K + k, e); + } + } + } + + if (!optional_out.has_value() && (REDUCE == MIN || REDUCE == MAX)) + out.masked_fill_(out == Reducer::init(), (scalar_t)0); + }); + }); + + return std::make_tuple(out, arg_out); +} diff --git a/csrc/cpu/scatter_cpu.h b/csrc/cpu/scatter_cpu.h new file mode 100644 index 00000000..25122e70 --- /dev/null +++ b/csrc/cpu/scatter_cpu.h @@ -0,0 +1,8 @@ +#pragma once + +#include + +std::tuple> +scatter_cpu(torch::Tensor src, torch::Tensor index, int64_t dim, + torch::optional optional_out, + torch::optional dim_size, std::string reduce); diff --git a/csrc/cpu/segment_coo_cpu.cpp b/csrc/cpu/segment_coo_cpu.cpp index a04ff3bc..cb84c28d 100644 --- a/csrc/cpu/segment_coo_cpu.cpp +++ b/csrc/cpu/segment_coo_cpu.cpp @@ -16,7 +16,7 @@ segment_coo_cpu(torch::Tensor src, torch::Tensor index, CHECK_INPUT(src.dim() >= index.dim()); auto sizes = index.sizes().vec(); - for (int i = 0; i < index.dim(); i++) + for (auto i = 0; i < index.dim(); i++) sizes[i] = src.size(i); index = index.expand(sizes); @@ -27,7 +27,7 @@ segment_coo_cpu(torch::Tensor src, torch::Tensor index, torch::Tensor out; if (optional_out.has_value()) { out = optional_out.value().contiguous(); - for (int i = 0; i < out.dim(); i++) + for (auto i = 0; i < out.dim(); i++) if (i != dim) CHECK_INPUT(src.size(i) == out.size(i)); } else { diff --git a/csrc/cpu/segment_csr_cpu.cpp b/csrc/cpu/segment_csr_cpu.cpp index efc6955c..726df0ec 100644 --- a/csrc/cpu/segment_csr_cpu.cpp +++ b/csrc/cpu/segment_csr_cpu.cpp @@ -27,7 +27,7 @@ segment_csr_cpu(torch::Tensor src, torch::Tensor indptr, torch::Tensor out; if (optional_out.has_value()) { out = optional_out.value().contiguous(); - for (int i = 0; i < out.dim(); i++) + for (auto i = 0; i < out.dim(); i++) if (i != dim) CHECK_INPUT(src.size(i) == out.size(i)); CHECK_INPUT(out.size(dim) == indptr.size(dim) - 1); @@ -126,7 +126,7 @@ torch::Tensor gather_csr_cpu(torch::Tensor src, torch::Tensor indptr, std::vector vals(K); int64_t row_start, row_end; - for (int n = 0; n < N; n++) { + for (auto n = 0; n < N; n++) { auto offset = IndexPtrToOffset::get(n, indptr_info); row_start = indptr_info.data[offset]; row_end = indptr_info.data[offset + stride]; diff --git a/csrc/cuda/reducer.cuh b/csrc/cuda/reducer.cuh index 1e126e6a..8b318958 100644 --- a/csrc/cuda/reducer.cuh +++ b/csrc/cuda/reducer.cuh @@ -106,9 +106,9 @@ template struct Reducer { atomMul(address, val); else if (REDUCE == DIV) atomDiv(address, val); - else if (REDUCE == MIN && val < *address) + else if (REDUCE == MIN) atomMin(address, val); - else if (REDUCE == MAX && val > *address) + else if (REDUCE == MAX) atomMax(address, val); } }; diff --git a/csrc/cuda/scatter_cuda.cu b/csrc/cuda/scatter_cuda.cu new file mode 100644 index 00000000..76b9f8c1 --- /dev/null +++ b/csrc/cuda/scatter_cuda.cu @@ -0,0 +1,8 @@ +#include "scatter_cuda.h" + +std::tuple> +scatter_cuda(torch::Tensor src, torch::Tensor index, int64_t dim, + torch::optional optional_out, + torch::optional dim_size, std::string reduce) { + return std::make_tuple(src, optional_out); +} diff --git a/csrc/cuda/scatter_cuda.h b/csrc/cuda/scatter_cuda.h new file mode 100644 index 00000000..95c80642 --- /dev/null +++ b/csrc/cuda/scatter_cuda.h @@ -0,0 +1,8 @@ +#pragma once + +#include + +std::tuple> +scatter_cuda(torch::Tensor src, torch::Tensor index, int64_t dim, + torch::optional optional_out, + torch::optional dim_size, std::string reduce); diff --git a/csrc/scatter.cpp b/csrc/scatter.cpp new file mode 100644 index 00000000..016bf8fc --- /dev/null +++ b/csrc/scatter.cpp @@ -0,0 +1,213 @@ +#include + +#include "cpu/scatter_cpu.h" + +#ifdef WITH_CUDA +#include "cuda/scatter_cuda.h" +#endif + +torch::Tensor broadcast(torch::Tensor src, torch::Tensor other, int64_t dim) { + if (dim < 0) + dim = other.dim() + dim; + if (src.dim() == 1) + for (auto i = 0; i < dim; i++) + src = src.unsqueeze(0); + for (auto i = src.dim(); i < other.dim(); i++) + src = src.unsqueeze(-1); + src = src.expand(other.sizes().vec()); + return src; +} + +std::tuple> +scatter_fw(torch::Tensor src, torch::Tensor index, int64_t dim, + torch::optional optional_out, + torch::optional dim_size, std::string reduce) { + if (src.device().is_cuda()) { +#ifdef WITH_CUDA + return scatter_cuda(src, index, dim, optional_out, dim_size, reduce); +#else + AT_ERROR("Not compiled with CUDA support"); +#endif + } else { + return scatter_cpu(src, index, dim, optional_out, dim_size, reduce); + } +} +using torch::autograd::AutogradContext; +using torch::autograd::Variable; +using torch::autograd::variable_list; + +class ScatterSum : public torch::autograd::Function { +public: + static variable_list forward(AutogradContext *ctx, Variable src, + Variable index, int64_t dim, + torch::optional optional_out, + torch::optional dim_size) { + ctx->saved_data["dim"] = dim; + ctx->saved_data["src_shape"] = src.sizes(); + index = broadcast(index, src, dim); + auto result = scatter_fw(src, index, dim, optional_out, dim_size, "sum"); + auto out = std::get<0>(result); + ctx->save_for_backward({index}); + if (optional_out.has_value()) + ctx->mark_dirty({optional_out.value()}); + return {out}; + } + + static variable_list backward(AutogradContext *ctx, variable_list grad_outs) { + auto grad_out = grad_outs[0]; + auto saved = ctx->get_saved_variables(); + auto index = saved[0]; + auto dim = ctx->saved_data["dim"].toInt(); + auto src_shape = ctx->saved_data["src_shape"].toIntVector(); + auto grad_in = torch::gather(grad_out, dim, index, false); + return {grad_in, Variable(), Variable(), Variable(), Variable()}; + } +}; + +class ScatterMean : public torch::autograd::Function { +public: + static variable_list forward(AutogradContext *ctx, Variable src, + Variable index, int64_t dim, + torch::optional optional_out, + torch::optional dim_size) { + ctx->saved_data["dim"] = dim; + ctx->saved_data["src_shape"] = src.sizes(); + + auto old_index = index; + + index = broadcast(index, src, dim); + auto result = scatter_fw(src, index, dim, optional_out, dim_size, "sum"); + auto out = std::get<0>(result); + + auto ones = torch::ones(old_index.sizes(), src.options()); + result = scatter_fw(ones, old_index, + old_index.dim() <= dim ? old_index.dim() - 1 : dim, + torch::nullopt, out.size(dim), "sum"); + auto count = std::get<0>(result); + count.clamp_(1); + count = broadcast(count, out, dim); + out.div_(count); + + ctx->save_for_backward({index, count}); + if (optional_out.has_value()) + ctx->mark_dirty({optional_out.value()}); + return {out}; + } + + static variable_list backward(AutogradContext *ctx, variable_list grad_outs) { + auto grad_out = grad_outs[0]; + auto saved = ctx->get_saved_variables(); + auto index = saved[0]; + auto count = saved[1]; + auto dim = ctx->saved_data["dim"].toInt(); + auto src_shape = ctx->saved_data["src_shape"].toIntVector(); + count = torch::gather(count, dim, index, false); + auto grad_in = torch::gather(grad_out, dim, index, false); + grad_in.div_(count); + return {grad_in, Variable(), Variable(), Variable(), Variable()}; + } +}; + +class ScatterMin : public torch::autograd::Function { +public: + static variable_list forward(AutogradContext *ctx, Variable src, + Variable index, int64_t dim, + torch::optional optional_out, + torch::optional dim_size) { + ctx->saved_data["dim"] = dim; + ctx->saved_data["src_shape"] = src.sizes(); + + index = broadcast(index, src, dim); + auto result = scatter_fw(src, index, dim, optional_out, dim_size, "min"); + auto out = std::get<0>(result); + auto arg_out = std::get<1>(result).value(); + ctx->save_for_backward({index, arg_out}); + ctx->mark_non_differentiable({arg_out}); + if (optional_out.has_value()) + ctx->mark_dirty({optional_out.value()}); + return {out, arg_out}; + } + + static variable_list backward(AutogradContext *ctx, variable_list grad_outs) { + auto grad_out = grad_outs[0]; + auto saved = ctx->get_saved_variables(); + auto index = saved[0]; + auto arg_out = saved[1]; + auto dim = ctx->saved_data["dim"].toInt(); + auto src_shape = ctx->saved_data["src_shape"].toIntVector(); + src_shape[dim] += 1; + auto grad_in = torch::zeros(src_shape, grad_out.options()); + grad_in.scatter_(dim, arg_out, grad_out); + grad_in = grad_in.narrow(dim, 0, src_shape[dim] - 1); + return {grad_in, Variable(), Variable(), Variable(), Variable()}; + } +}; + +class ScatterMax : public torch::autograd::Function { +public: + static variable_list forward(AutogradContext *ctx, Variable src, + Variable index, int64_t dim, + torch::optional optional_out, + torch::optional dim_size) { + ctx->saved_data["dim"] = dim; + ctx->saved_data["src_shape"] = src.sizes(); + + index = broadcast(index, src, dim); + auto result = scatter_fw(src, index, dim, optional_out, dim_size, "max"); + auto out = std::get<0>(result); + auto arg_out = std::get<1>(result).value(); + ctx->save_for_backward({index, arg_out}); + ctx->mark_non_differentiable({arg_out}); + if (optional_out.has_value()) + ctx->mark_dirty({optional_out.value()}); + return {out, arg_out}; + } + + static variable_list backward(AutogradContext *ctx, variable_list grad_outs) { + auto grad_out = grad_outs[0]; + auto saved = ctx->get_saved_variables(); + auto index = saved[0]; + auto arg_out = saved[1]; + auto dim = ctx->saved_data["dim"].toInt(); + auto src_shape = ctx->saved_data["src_shape"].toIntVector(); + src_shape[dim] += 1; + auto grad_in = torch::zeros(src_shape, grad_out.options()); + grad_in.scatter_(dim, arg_out, grad_out); + grad_in = grad_in.narrow(dim, 0, src_shape[dim] - 1); + return {grad_in, Variable(), Variable(), Variable(), Variable()}; + } +}; + +torch::Tensor scatter_sum(torch::Tensor src, torch::Tensor index, int64_t dim, + torch::optional optional_out, + torch::optional dim_size) { + return ScatterSum::apply(src, index, dim, optional_out, dim_size)[0]; +} + +torch::Tensor scatter_mean(torch::Tensor src, torch::Tensor index, int64_t dim, + torch::optional optional_out, + torch::optional dim_size) { + return ScatterMean::apply(src, index, dim, optional_out, dim_size)[0]; +} + +std::tuple +scatter_min(torch::Tensor src, torch::Tensor index, int64_t dim, + torch::optional optional_out, + torch::optional dim_size) { + auto result = ScatterMin::apply(src, index, dim, optional_out, dim_size); + return std::make_tuple(result[0], result[1]); +} + +std::tuple +scatter_max(torch::Tensor src, torch::Tensor index, int64_t dim, + torch::optional optional_out, + torch::optional dim_size) { + auto result = ScatterMax::apply(src, index, dim, optional_out, dim_size); + return std::make_tuple(result[0], result[1]); +} + +static auto registry = torch::RegisterOperators() + .op("torch_scatter::scatter_sum", &scatter_sum) + .op("torch_scatter::scatter_mean", &scatter_mean) + .op("torch_scatter::scatter_min", &scatter_min) + .op("torch_scatter::scatter_max", &scatter_max); diff --git a/test/test_gather.py b/test/test_gather.py index 5875af3f..8d0d100f 100644 --- a/test/test_gather.py +++ b/test/test_gather.py @@ -73,7 +73,7 @@ def test_backward(test, device): @pytest.mark.parametrize('test,dtype,device', product(tests, dtypes, devices)) -def test_gather_out(test, dtype, device): +def test_out(test, dtype, device): src = tensor(test['src'], dtype, device) index = tensor(test['index'], torch.long, device) indptr = tensor(test['indptr'], torch.long, device) @@ -93,7 +93,7 @@ def test_gather_out(test, dtype, device): @pytest.mark.parametrize('test,dtype,device', product(tests, dtypes, devices)) -def test_non_contiguous_segment(test, dtype, device): +def test_non_contiguous(test, dtype, device): src = tensor(test['src'], dtype, device) index = tensor(test['index'], torch.long, device) indptr = tensor(test['indptr'], torch.long, device) diff --git a/test/test_scatter.py b/test/test_scatter.py new file mode 100644 index 00000000..aeaa3a66 --- /dev/null +++ b/test/test_scatter.py @@ -0,0 +1,162 @@ +from itertools import product + +import pytest +import torch +from torch.autograd import gradcheck +import torch_scatter + +from .utils import tensor, dtypes, devices + +devices = ['cpu'] + +reductions = ['sum', 'add', 'mean', 'min', 'max'] + +tests = [ + { + 'src': [1, 3, 2, 4, 5, 6], + 'index': [0, 1, 0, 1, 1, 3], + 'dim': 0, + 'sum': [3, 12, 0, 6], + 'add': [3, 12, 0, 6], + 'mean': [1.5, 4, 0, 6], + 'min': [1, 3, 0, 6], + 'arg_min': [0, 1, 6, 5], + 'max': [2, 5, 0, 6], + 'arg_max': [2, 4, 6, 5], + }, + { + 'src': [[1, 2], [5, 6], [3, 4], [7, 8], [9, 10], [11, 12]], + 'index': [0, 1, 0, 1, 1, 3], + 'dim': 0, + 'sum': [[4, 6], [21, 24], [0, 0], [11, 12]], + 'add': [[4, 6], [21, 24], [0, 0], [11, 12]], + 'mean': [[2, 3], [7, 8], [0, 0], [11, 12]], + 'min': [[1, 2], [5, 6], [0, 0], [11, 12]], + 'arg_min': [[0, 0], [1, 1], [6, 6], [5, 5]], + 'max': [[3, 4], [9, 10], [0, 0], [11, 12]], + 'arg_max': [[2, 2], [4, 4], [6, 6], [5, 5]], + }, + { + 'src': [[1, 5, 3, 7, 9, 11], [2, 4, 8, 6, 10, 12]], + 'index': [[0, 1, 0, 1, 1, 3], [0, 0, 1, 0, 1, 2]], + 'dim': 1, + 'sum': [[4, 21, 0, 11], [12, 18, 12, 0]], + 'add': [[4, 21, 0, 11], [12, 18, 12, 0]], + 'mean': [[2, 7, 0, 11], [4, 9, 12, 0]], + 'min': [[1, 5, 0, 11], [2, 8, 12, 0]], + 'arg_min': [[0, 1, 6, 5], [0, 2, 5, 6]], + 'max': [[3, 9, 0, 11], [6, 10, 12, 0]], + 'arg_max': [[2, 4, 6, 5], [3, 4, 5, 6]], + }, + { + 'src': [[[1, 2], [5, 6], [3, 4]], [[10, 11], [7, 9], [12, 13]]], + 'index': [[0, 1, 0], [2, 0, 2]], + 'dim': 1, + 'sum': [[[4, 6], [5, 6], [0, 0]], [[7, 9], [0, 0], [22, 24]]], + 'add': [[[4, 6], [5, 6], [0, 0]], [[7, 9], [0, 0], [22, 24]]], + 'mean': [[[2, 3], [5, 6], [0, 0]], [[7, 9], [0, 0], [11, 12]]], + 'min': [[[1, 2], [5, 6], [0, 0]], [[7, 9], [0, 0], [10, 11]]], + 'arg_min': [[[0, 0], [1, 1], [3, 3]], [[1, 1], [3, 3], [0, 0]]], + 'max': [[[3, 4], [5, 6], [0, 0]], [[7, 9], [0, 0], [12, 13]]], + 'arg_max': [[[2, 2], [1, 1], [3, 3]], [[1, 1], [3, 3], [2, 2]]], + }, + { + 'src': [[1, 3], [2, 4]], + 'index': [[0, 0], [0, 0]], + 'dim': 1, + 'sum': [[4], [6]], + 'add': [[4], [6]], + 'mean': [[2], [3]], + 'min': [[1], [2]], + 'arg_min': [[0], [0]], + 'max': [[3], [4]], + 'arg_max': [[1], [1]], + }, + { + 'src': [[[1, 1], [3, 3]], [[2, 2], [4, 4]]], + 'index': [[0, 0], [0, 0]], + 'dim': 1, + 'sum': [[[4, 4]], [[6, 6]]], + 'add': [[[4, 4]], [[6, 6]]], + 'mean': [[[2, 2]], [[3, 3]]], + 'min': [[[1, 1]], [[2, 2]]], + 'arg_min': [[[0, 0]], [[0, 0]]], + 'max': [[[3, 3]], [[4, 4]]], + 'arg_max': [[[1, 1]], [[1, 1]]], + }, +] + + +@pytest.mark.parametrize('test,reduce,dtype,device', + product(tests, reductions, dtypes, devices)) +def test_forward(test, reduce, dtype, device): + src = tensor(test['src'], dtype, device) + index = tensor(test['index'], torch.long, device) + dim = test['dim'] + expected = tensor(test[reduce], dtype, device) + + out = getattr(torch_scatter, f'scatter_{reduce}')(src, index, dim) + if isinstance(out, tuple): + out, arg_out = out + arg_expected = tensor(test[f'arg_{reduce}'], torch.long, device) + assert torch.all(arg_out == arg_expected) + assert torch.all(out == expected) + + +@pytest.mark.parametrize('test,reduce,device', + product(tests, reductions, devices)) +def test_backward(test, reduce, device): + src = tensor(test['src'], torch.double, device) + src.requires_grad_() + index = tensor(test['index'], torch.long, device) + dim = test['dim'] + + assert gradcheck(torch_scatter.scatter, + (src, index, dim, None, None, reduce)) + + +@pytest.mark.parametrize('test,reduce,dtype,device', + product(tests, reductions, dtypes, devices)) +def test_out(test, reduce, dtype, device): + src = tensor(test['src'], dtype, device) + index = tensor(test['index'], torch.long, device) + dim = test['dim'] + expected = tensor(test[reduce], dtype, device) + + out = torch.full_like(expected, -2) + + getattr(torch_scatter, f'scatter_{reduce}')(src, index, dim, out) + + if reduce == 'sum' or reduce == 'add': + expected = expected - 2 + elif reduce == 'mean': + expected = out # We can not really test this here. + elif reduce == 'min': + expected = expected.fill_(-2) + elif reduce == 'max': + expected[expected == 0] = -2 + else: + raise ValueError + + assert torch.all(out == expected) + + +@pytest.mark.parametrize('test,reduce,dtype,device', + product(tests, reductions, dtypes, devices)) +def test_non_contiguous(test, reduce, dtype, device): + src = tensor(test['src'], dtype, device) + index = tensor(test['index'], torch.long, device) + dim = test['dim'] + expected = tensor(test[reduce], dtype, device) + + if src.dim() > 1: + src = src.transpose(0, 1).contiguous().transpose(0, 1) + if index.dim() > 1: + index = index.transpose(0, 1).contiguous().transpose(0, 1) + + out = getattr(torch_scatter, f'scatter_{reduce}')(src, index, dim) + if isinstance(out, tuple): + out, arg_out = out + arg_expected = tensor(test[f'arg_{reduce}'], torch.long, device) + assert torch.all(arg_out == arg_expected) + assert torch.all(out == expected) diff --git a/test/test_segment.py b/test/test_segment.py index 1576f944..e93cac3f 100644 --- a/test/test_segment.py +++ b/test/test_segment.py @@ -7,7 +7,7 @@ from .utils import tensor, dtypes, devices -reductions = ['sum', 'mean', 'min', 'max'] +reductions = ['sum', 'add', 'mean', 'min', 'max'] tests = [ { @@ -15,6 +15,7 @@ 'index': [0, 0, 1, 1, 1, 3], 'indptr': [0, 2, 5, 5, 6], 'sum': [3, 12, 0, 6], + 'add': [3, 12, 0, 6], 'mean': [1.5, 4, 0, 6], 'min': [1, 3, 0, 6], 'arg_min': [0, 2, 6, 5], @@ -26,6 +27,7 @@ 'index': [0, 0, 1, 1, 1, 3], 'indptr': [0, 2, 5, 5, 6], 'sum': [[4, 6], [21, 24], [0, 0], [11, 12]], + 'add': [[4, 6], [21, 24], [0, 0], [11, 12]], 'mean': [[2, 3], [7, 8], [0, 0], [11, 12]], 'min': [[1, 2], [5, 6], [0, 0], [11, 12]], 'arg_min': [[0, 0], [2, 2], [6, 6], [5, 5]], @@ -37,6 +39,7 @@ 'index': [[0, 0, 1, 1, 1, 3], [0, 0, 0, 1, 1, 2]], 'indptr': [[0, 2, 5, 5, 6], [0, 3, 5, 6, 6]], 'sum': [[4, 21, 0, 11], [12, 18, 12, 0]], + 'add': [[4, 21, 0, 11], [12, 18, 12, 0]], 'mean': [[2, 7, 0, 11], [4, 9, 12, 0]], 'min': [[1, 5, 0, 11], [2, 8, 12, 0]], 'arg_min': [[0, 2, 6, 5], [0, 3, 5, 6]], @@ -48,6 +51,7 @@ 'index': [[0, 0, 1], [0, 2, 2]], 'indptr': [[0, 2, 3, 3], [0, 1, 1, 3]], 'sum': [[[4, 6], [5, 6], [0, 0]], [[7, 9], [0, 0], [22, 24]]], + 'add': [[[4, 6], [5, 6], [0, 0]], [[7, 9], [0, 0], [22, 24]]], 'mean': [[[2, 3], [5, 6], [0, 0]], [[7, 9], [0, 0], [11, 12]]], 'min': [[[1, 2], [5, 6], [0, 0]], [[7, 9], [0, 0], [10, 11]]], 'arg_min': [[[0, 0], [2, 2], [3, 3]], [[0, 0], [3, 3], [1, 1]]], @@ -59,6 +63,7 @@ 'index': [[0, 0], [0, 0]], 'indptr': [[0, 2], [0, 2]], 'sum': [[4], [6]], + 'add': [[4], [6]], 'mean': [[2], [3]], 'min': [[1], [2]], 'arg_min': [[0], [0]], @@ -70,6 +75,7 @@ 'index': [[0, 0], [0, 0]], 'indptr': [[0, 2], [0, 2]], 'sum': [[[4, 4]], [[6, 6]]], + 'add': [[[4, 4]], [[6, 6]]], 'mean': [[[2, 2]], [[3, 3]]], 'min': [[[1, 1]], [[2, 2]]], 'arg_min': [[[0, 0]], [[0, 0]]], @@ -117,15 +123,13 @@ def test_backward(test, reduce, device): @pytest.mark.parametrize('test,reduce,dtype,device', product(tests, reductions, dtypes, devices)) -def test_segment_out(test, reduce, dtype, device): +def test_out(test, reduce, dtype, device): src = tensor(test['src'], dtype, device) index = tensor(test['index'], torch.long, device) indptr = tensor(test['indptr'], torch.long, device) expected = tensor(test[reduce], dtype, device) - size = list(src.size()) - size[indptr.dim() - 1] = indptr.size(-1) - 1 - out = src.new_full(size, -2) + out = torch.full_like(expected, -2) getattr(torch_scatter, f'segment_{reduce}_csr')(src, indptr, out) assert torch.all(out == expected) @@ -134,7 +138,7 @@ def test_segment_out(test, reduce, dtype, device): getattr(torch_scatter, f'segment_{reduce}_coo')(src, index, out) - if reduce == 'sum': + if reduce == 'sum' or reduce == 'add': expected = expected - 2 elif reduce == 'mean': expected = out # We can not really test this here. @@ -150,7 +154,7 @@ def test_segment_out(test, reduce, dtype, device): @pytest.mark.parametrize('test,reduce,dtype,device', product(tests, reductions, dtypes, devices)) -def test_non_contiguous_segment(test, reduce, dtype, device): +def test_non_contiguous(test, reduce, dtype, device): src = tensor(test['src'], dtype, device) index = tensor(test['index'], torch.long, device) indptr = tensor(test['indptr'], torch.long, device) diff --git a/torch_scatter/__init__.py b/torch_scatter/__init__.py index 71b2fc96..29e53596 100644 --- a/torch_scatter/__init__.py +++ b/torch_scatter/__init__.py @@ -1,3 +1,5 @@ +from .scatter import (scatter_sum, scatter_add, scatter_mean, scatter_min, + scatter_max, scatter) from .segment_csr import (segment_sum_csr, segment_add_csr, segment_mean_csr, segment_min_csr, segment_max_csr, segment_csr, gather_csr) @@ -8,12 +10,17 @@ __version__ = '2.0.0' __all__ = [ + 'scatter_sum', + 'scatter_add', + 'scatter_mean', + 'scatter_min', + 'scatter_max', + 'scatter', 'segment_sum_csr', 'segment_add_csr', 'segment_mean_csr', 'segment_min_csr', 'segment_max_csr', - 'segment_max_csr', 'segment_csr', 'gather_csr', 'segment_sum_coo', @@ -21,7 +28,6 @@ 'segment_mean_coo', 'segment_min_coo', 'segment_max_coo', - 'segment_max_coo', 'segment_coo', 'gather_coo', 'torch_scatter', diff --git a/torch_scatter/scatter.py b/torch_scatter/scatter.py new file mode 100644 index 00000000..4e6c4e4f --- /dev/null +++ b/torch_scatter/scatter.py @@ -0,0 +1,60 @@ +import os.path as osp +from typing import Optional, Tuple + +import torch + +torch.ops.load_library( + osp.join(osp.dirname(osp.abspath(__file__)), '_scatter.so')) + + +@torch.jit.script +def scatter_sum(src: torch.Tensor, index: torch.Tensor, dim: int = -1, + out: Optional[torch.Tensor] = None, + dim_size: Optional[int] = None) -> torch.Tensor: + return torch.ops.torch_scatter.scatter_sum(src, index, dim, out, dim_size) + + +@torch.jit.script +def scatter_add(src: torch.Tensor, index: torch.Tensor, dim: int = -1, + out: Optional[torch.Tensor] = None, + dim_size: Optional[int] = None) -> torch.Tensor: + return torch.ops.torch_scatter.scatter_sum(src, index, dim, out, dim_size) + + +@torch.jit.script +def scatter_mean(src: torch.Tensor, index: torch.Tensor, dim: int = -1, + out: Optional[torch.Tensor] = None, + dim_size: Optional[int] = None) -> torch.Tensor: + return torch.ops.torch_scatter.scatter_mean(src, index, dim, out, dim_size) + + +@torch.jit.script +def scatter_min(src: torch.Tensor, index: torch.Tensor, dim: int = -1, + out: Optional[torch.Tensor] = None, + dim_size: Optional[int] = None + ) -> Tuple[torch.Tensor, torch.Tensor]: + return torch.ops.torch_scatter.scatter_min(src, index, dim, out, dim_size) + + +@torch.jit.script +def scatter_max(src: torch.Tensor, index: torch.Tensor, dim: int = -1, + out: Optional[torch.Tensor] = None, + dim_size: Optional[int] = None + ) -> Tuple[torch.Tensor, torch.Tensor]: + return torch.ops.torch_scatter.scatter_max(src, index, dim, out, dim_size) + + +@torch.jit.script +def scatter(src: torch.Tensor, index: torch.Tensor, dim: int = -1, + out: Optional[torch.Tensor] = None, dim_size: Optional[int] = None, + reduce: str = "sum") -> torch.Tensor: + if reduce == 'sum' or reduce == 'add': + return scatter_sum(src, index, dim, out, dim_size) + elif reduce == 'mean': + return scatter_mean(src, index, dim, out, dim_size) + elif reduce == 'min': + return scatter_min(src, index, dim, out, dim_size)[0] + elif reduce == 'max': + return scatter_max(src, index, dim, out, dim_size)[0] + else: + raise ValueError From 5be6d63a2c13e51e287e819fde895c4ae9c333cb Mon Sep 17 00:00:00 2001 From: rusty1s Date: Thu, 30 Jan 2020 12:38:43 +0100 Subject: [PATCH 07/12] scatter kernel done --- cpu/compat.h | 5 - cpu/dim_apply.h | 120 --------- cpu/index_info.h | 65 ----- cpu/reducer.h | 68 ------ cpu/scatter.cpp | 84 ------- cpu/segment_coo.cpp | 7 - cpu/segment_coo_impl.h | 182 -------------- cpu/segment_csr.cpp | 41 ---- cpu/segment_csr_impl.h | 146 ----------- cpu/utils.h | 6 - csrc/cuda/scatter_cuda.cu | 130 +++++++++- cuda/atomics.cuh | 230 ------------------ cuda/compat.cuh | 5 - cuda/gather.cpp | 31 --- cuda/gather_kernel.cu | 202 ---------------- cuda/index.cuh | 108 --------- cuda/indptr.cuh | 20 -- cuda/scatter.cpp | 55 ----- cuda/scatter_kernel.cu | 163 ------------- cuda/segment.cpp | 34 --- cuda/segment_kernel.cu | 494 -------------------------------------- test/test_jit.py | 31 --- test/test_scatter.py | 2 - 23 files changed, 129 insertions(+), 2100 deletions(-) delete mode 100644 cpu/compat.h delete mode 100644 cpu/dim_apply.h delete mode 100644 cpu/index_info.h delete mode 100644 cpu/reducer.h delete mode 100644 cpu/scatter.cpp delete mode 100644 cpu/segment_coo.cpp delete mode 100644 cpu/segment_coo_impl.h delete mode 100644 cpu/segment_csr.cpp delete mode 100644 cpu/segment_csr_impl.h delete mode 100644 cpu/utils.h delete mode 100644 cuda/atomics.cuh delete mode 100644 cuda/compat.cuh delete mode 100644 cuda/gather.cpp delete mode 100644 cuda/gather_kernel.cu delete mode 100644 cuda/index.cuh delete mode 100644 cuda/indptr.cuh delete mode 100644 cuda/scatter.cpp delete mode 100644 cuda/scatter_kernel.cu delete mode 100644 cuda/segment.cpp delete mode 100644 cuda/segment_kernel.cu delete mode 100644 test/test_jit.py diff --git a/cpu/compat.h b/cpu/compat.h deleted file mode 100644 index 1be09913..00000000 --- a/cpu/compat.h +++ /dev/null @@ -1,5 +0,0 @@ -#ifdef VERSION_GE_1_3 -#define DATA_PTR data_ptr -#else -#define DATA_PTR data -#endif diff --git a/cpu/dim_apply.h b/cpu/dim_apply.h deleted file mode 100644 index e7a720d8..00000000 --- a/cpu/dim_apply.h +++ /dev/null @@ -1,120 +0,0 @@ -#pragma once - -#include - -#include "compat.h" - -#define DIM_APPLY3(TYPE1, TENSOR1, TYPE2, TENSOR2, TYPE3, TENSOR3, DIM, CODE) \ - [&] { \ - TYPE1 *TENSOR1##_data = TENSOR1.DATA_PTR(); \ - auto TENSOR1##_size = TENSOR1.size(DIM); \ - auto TENSOR1##_stride = TENSOR1.stride(DIM); \ - \ - TYPE2 *TENSOR2##_data = TENSOR2.DATA_PTR(); \ - auto TENSOR2##_size = TENSOR2.size(DIM); \ - auto TENSOR2##_stride = TENSOR2.stride(DIM); \ - \ - TYPE3 *TENSOR3##_data = TENSOR3.DATA_PTR(); \ - auto TENSOR3##_size = TENSOR3.size(DIM); \ - auto TENSOR3##_stride = TENSOR3.stride(DIM); \ - \ - auto dims = TENSOR1.dim(); \ - auto zeros = torch::zeros(dims, TENSOR1.options().dtype(torch::kLong)); \ - auto counter = zeros.DATA_PTR(); \ - bool has_finished = false; \ - \ - while (!has_finished) { \ - CODE; \ - if (dims == 1) \ - break; \ - \ - for (int64_t cur_dim = 0; cur_dim < dims; cur_dim++) { \ - if (cur_dim == DIM) { \ - if (cur_dim == dims - 1) { \ - has_finished = true; \ - break; \ - } \ - continue; \ - } \ - \ - counter[cur_dim]++; \ - TENSOR1##_data += TENSOR1.stride(cur_dim); \ - TENSOR2##_data += TENSOR2.stride(cur_dim); \ - TENSOR3##_data += TENSOR3.stride(cur_dim); \ - \ - if (counter[cur_dim] == TENSOR1.size(cur_dim)) { \ - if (cur_dim == dims - 1) { \ - has_finished = true; \ - break; \ - } else { \ - TENSOR1##_data -= counter[cur_dim] * TENSOR1.stride(cur_dim); \ - TENSOR2##_data -= counter[cur_dim] * TENSOR2.stride(cur_dim); \ - TENSOR3##_data -= counter[cur_dim] * TENSOR3.stride(cur_dim); \ - counter[cur_dim] = 0; \ - } \ - } else \ - break; \ - } \ - } \ - }() - -#define DIM_APPLY4(TYPE1, TENSOR1, TYPE2, TENSOR2, TYPE3, TENSOR3, TYPE4, \ - TENSOR4, DIM, CODE) \ - [&] { \ - TYPE1 *TENSOR1##_data = TENSOR1.DATA_PTR(); \ - auto TENSOR1##_size = TENSOR1.size(DIM); \ - auto TENSOR1##_stride = TENSOR1.stride(DIM); \ - \ - TYPE2 *TENSOR2##_data = TENSOR2.DATA_PTR(); \ - auto TENSOR2##_size = TENSOR2.size(DIM); \ - auto TENSOR2##_stride = TENSOR2.stride(DIM); \ - \ - TYPE3 *TENSOR3##_data = TENSOR3.DATA_PTR(); \ - auto TENSOR3##_size = TENSOR3.size(DIM); \ - auto TENSOR3##_stride = TENSOR3.stride(DIM); \ - \ - TYPE4 *TENSOR4##_data = TENSOR4.DATA_PTR(); \ - auto TENSOR4##_size = TENSOR4.size(DIM); \ - auto TENSOR4##_stride = TENSOR4.stride(DIM); \ - \ - auto dims = TENSOR1.dim(); \ - auto zeros = torch::zeros(dims, TENSOR1.options().dtype(torch::kLong)); \ - auto counter = zeros.DATA_PTR(); \ - bool has_finished = false; \ - \ - while (!has_finished) { \ - CODE; \ - if (dims == 1) \ - break; \ - \ - for (int64_t cur_dim = 0; cur_dim < dims; cur_dim++) { \ - if (cur_dim == DIM) { \ - if (cur_dim == dims - 1) { \ - has_finished = true; \ - break; \ - } \ - continue; \ - } \ - \ - counter[cur_dim]++; \ - TENSOR1##_data += TENSOR1.stride(cur_dim); \ - TENSOR2##_data += TENSOR2.stride(cur_dim); \ - TENSOR3##_data += TENSOR3.stride(cur_dim); \ - TENSOR4##_data += TENSOR4.stride(cur_dim); \ - \ - if (counter[cur_dim] == TENSOR1.size(cur_dim)) { \ - if (cur_dim == dims - 1) { \ - has_finished = true; \ - break; \ - } else { \ - TENSOR1##_data -= counter[cur_dim] * TENSOR1.stride(cur_dim); \ - TENSOR2##_data -= counter[cur_dim] * TENSOR2.stride(cur_dim); \ - TENSOR3##_data -= counter[cur_dim] * TENSOR3.stride(cur_dim); \ - TENSOR4##_data -= counter[cur_dim] * TENSOR4.stride(cur_dim); \ - counter[cur_dim] = 0; \ - } \ - } else \ - break; \ - } \ - } \ - }() diff --git a/cpu/index_info.h b/cpu/index_info.h deleted file mode 100644 index 06362ce0..00000000 --- a/cpu/index_info.h +++ /dev/null @@ -1,65 +0,0 @@ -#pragma once - -#include - -#include "compat.h" - -#define MAX_TENSORINFO_DIMS 25 - -template struct TensorInfo { - TensorInfo(scalar_t *p, int dim, int sz[MAX_TENSORINFO_DIMS], - int st[MAX_TENSORINFO_DIMS]) { - data = p; - dims = dim; - AT_ASSERT(dims < MAX_TENSORINFO_DIMS); - - for (int i = 0; i < dim; ++i) { - sizes[i] = sz[i]; - strides[i] = st[i]; - } - } - - scalar_t *data; - int dims; - int sizes[MAX_TENSORINFO_DIMS]; - int strides[MAX_TENSORINFO_DIMS]; -}; - -template -TensorInfo getTensorInfo(const torch::Tensor &tensor) { - int sizes[MAX_TENSORINFO_DIMS]; - int strides[MAX_TENSORINFO_DIMS]; - - int dims = tensor.dim(); - for (int i = 0; i < dims; ++i) { - sizes[i] = tensor.size(i); - strides[i] = tensor.stride(i); - } - - return TensorInfo(tensor.DATA_PTR(), dims, sizes, - strides); -} - -template struct IndexToOffset { - static inline int get(int idx, const TensorInfo &info) { - int offset = 0; - for (int i = info.dims - 1; i >= 0; --i) { - offset += (idx % info.sizes[i]) * info.strides[i]; - idx /= info.sizes[i]; - } - return offset; - } -}; - -template struct IndexPtrToOffset { - static inline int get(int idx, const TensorInfo &info) { - int offset = idx % (info.sizes[info.dims - 1] - 1); - offset *= info.strides[info.dims - 1]; - idx /= info.sizes[info.dims - 1] - 1; - for (int i = info.dims - 2; i >= 0; --i) { - offset += (idx % info.sizes[i]) * info.strides[i]; - idx /= info.sizes[i]; - } - return offset; - } -}; diff --git a/cpu/reducer.h b/cpu/reducer.h deleted file mode 100644 index 91d730ee..00000000 --- a/cpu/reducer.h +++ /dev/null @@ -1,68 +0,0 @@ -#pragma once - -#include - -enum ReductionType { SUM, MEAN, MIN, MAX }; - -const std::map reduce2REDUCE = { - {"sum", SUM}, {"add", SUM}, {"mean", MEAN}, {"min", MIN}, {"max", MAX}, -}; - -#define AT_DISPATCH_REDUCTION_TYPES(reduce, ...) \ - [&] { \ - switch (reduce2REDUCE.at(reduce)) { \ - case SUM: { \ - const ReductionType REDUCE = SUM; \ - return __VA_ARGS__(); \ - } \ - case MEAN: { \ - const ReductionType REDUCE = MEAN; \ - return __VA_ARGS__(); \ - } \ - case MIN: { \ - const ReductionType REDUCE = MIN; \ - return __VA_ARGS__(); \ - } \ - case MAX: { \ - const ReductionType REDUCE = MAX; \ - return __VA_ARGS__(); \ - } \ - } \ - }() - -template struct Reducer { - static inline scalar_t init() { - if (REDUCE == MIN) - return std::numeric_limits::max(); - else if (REDUCE == MAX) - return std::numeric_limits::lowest(); - else - return (scalar_t)0; - } - - static inline void update(scalar_t *val, scalar_t new_val, int64_t *arg, - int64_t new_arg) { - if (REDUCE == SUM || REDUCE == MEAN) - *val = *val + new_val; - else if ((REDUCE == MIN && new_val < *val) || - (REDUCE == MAX && new_val > *val)) { - *val = new_val; - *arg = new_arg; - } - } - - static inline void write(scalar_t *address, scalar_t val, - int64_t *arg_address, int64_t arg, int count) { - if (REDUCE == SUM) - *address = val; - else if (REDUCE == MEAN) - *address = val / (count > 0 ? count : (scalar_t)1); - else if (REDUCE == MIN || REDUCE == MAX) { - if (count > 0) { - *address = val; - *arg_address = arg; - } else - *address = (scalar_t)0; - } - } -}; diff --git a/cpu/scatter.cpp b/cpu/scatter.cpp deleted file mode 100644 index a2fdc250..00000000 --- a/cpu/scatter.cpp +++ /dev/null @@ -1,84 +0,0 @@ -#include - -#include "dim_apply.h" - -#define CHECK_CPU(x) AT_ASSERTM(x.device().is_cpu(), #x " must be CPU tensor") - -void scatter_mul(torch::Tensor src, torch::Tensor index, torch::Tensor out, - int64_t dim) { - CHECK_CPU(src); - CHECK_CPU(index); - CHECK_CPU(out); - int64_t elems_per_row = index.size(dim), i, idx; - AT_DISPATCH_ALL_TYPES(src.scalar_type(), "scatter_mul", [&] { - DIM_APPLY3(scalar_t, src, int64_t, index, scalar_t, out, dim, { - for (i = 0; i < elems_per_row; i++) { - idx = index_data[i * index_stride]; - out_data[idx * out_stride] *= src_data[i * src_stride]; - } - }); - }); -} - -void scatter_div(torch::Tensor src, torch::Tensor index, torch::Tensor out, - int64_t dim) { - CHECK_CPU(src); - CHECK_CPU(index); - CHECK_CPU(out); - int64_t elems_per_row = index.size(dim), i, idx; - AT_DISPATCH_ALL_TYPES(src.scalar_type(), "scatter_div", [&] { - DIM_APPLY3(scalar_t, src, int64_t, index, scalar_t, out, dim, { - for (i = 0; i < elems_per_row; i++) { - idx = index_data[i * index_stride]; - out_data[idx * out_stride] /= src_data[i * src_stride]; - } - }); - }); -} - -void scatter_max(torch::Tensor src, torch::Tensor index, torch::Tensor out, - torch::Tensor arg, int64_t dim) { - CHECK_CPU(src); - CHECK_CPU(index); - CHECK_CPU(out); - int64_t elems_per_row = index.size(dim), i, idx; - AT_DISPATCH_ALL_TYPES(src.scalar_type(), "scatter_max", [&] { - DIM_APPLY4(scalar_t, src, int64_t, index, scalar_t, out, int64_t, arg, dim, - { - for (i = 0; i < elems_per_row; i++) { - idx = index_data[i * index_stride]; - if (src_data[i * src_stride] >= out_data[idx * out_stride]) { - out_data[idx * out_stride] = src_data[i * src_stride]; - arg_data[idx * arg_stride] = i; - } - } - }); - }); -} - -void scatter_min(torch::Tensor src, torch::Tensor index, torch::Tensor out, - torch::Tensor arg, int64_t dim) { - CHECK_CPU(src); - CHECK_CPU(index); - CHECK_CPU(out); - CHECK_CPU(arg); - int64_t elems_per_row = index.size(dim), i, idx; - AT_DISPATCH_ALL_TYPES(src.scalar_type(), "scatter_min", [&] { - DIM_APPLY4(scalar_t, src, int64_t, index, scalar_t, out, int64_t, arg, dim, - { - for (i = 0; i < elems_per_row; i++) { - idx = index_data[i * index_stride]; - if (src_data[i * src_stride] <= out_data[idx * out_stride]) { - out_data[idx * out_stride] = src_data[i * src_stride]; - arg_data[idx * arg_stride] = i; - } - } - }); - }); -} - -static auto registry = - torch::RegisterOperators("torch_scatter_cpu::scatter_mul", &scatter_mul) - .op("torch_scatter_cpu::scatter_div", &scatter_div) - .op("torch_scatter_cpu::scatter_max", &scatter_max) - .op("torch_scatter_cpu::scatter_min", &scatter_min); diff --git a/cpu/segment_coo.cpp b/cpu/segment_coo.cpp deleted file mode 100644 index 9d10938f..00000000 --- a/cpu/segment_coo.cpp +++ /dev/null @@ -1,7 +0,0 @@ -#include - -#include "segment_coo_impl.h" - -static auto registry = - torch::RegisterOperators("torch_scatter_cpu::segment_coo", &segment_coo) - .op("torch_scatter_cpu::gather_coo", &gather_coo); diff --git a/cpu/segment_coo_impl.h b/cpu/segment_coo_impl.h deleted file mode 100644 index 99ea81d9..00000000 --- a/cpu/segment_coo_impl.h +++ /dev/null @@ -1,182 +0,0 @@ -#pragma once - -#include - -#include "compat.h" -#include "index_info.h" -#include "reducer.h" -#include "utils.h" - -std::tuple> -segment_coo(torch::Tensor src, torch::Tensor index, - torch::optional optional_out, std::string reduce) { - CHECK_CPU(src); - CHECK_CPU(index); - if (optional_out.has_value()) - CHECK_CPU(optional_out.value()); - - CHECK_INPUT(src.dim() >= index.dim()); - - // Broadcasting `index` via `expand`. - auto sizes = index.sizes().vec(); - for (int i = 0; i < index.dim(); i++) - sizes[i] = src.size(i); - index = index.expand(sizes); - - auto dim = index.dim() - 1; - - src = src.contiguous(); - - torch::Tensor out; - if (optional_out.has_value()) { - out = optional_out.value().contiguous(); - for (int i = 0; i < out.dim(); i++) - if (i != dim) - CHECK_INPUT(src.size(i) == out.size(i)); - } else { - sizes = src.sizes().vec(); - sizes[dim] = *index.max().DATA_PTR(); - out = torch::empty(sizes, src.options()); - } - - torch::optional arg_out = torch::nullopt; - int64_t *arg_out_data = nullptr; - if (reduce2REDUCE.at(reduce) == MIN || reduce2REDUCE.at(reduce) == MAX) { - arg_out = torch::full_like(out, src.size(dim), index.options()); - arg_out_data = arg_out.value().DATA_PTR(); - } - - auto B = index.numel() / src.size(dim); - auto E = src.size(dim); - auto K = src.numel() / index.numel(); - auto N = out.size(dim); - - auto index_info = getTensorInfo(index); - auto stride = index_info.strides[index_info.dims - 1]; - std::vector args(K); - AT_DISPATCH_ALL_TYPES(src.scalar_type(), "segment_coo", [&] { - auto src_data = src.DATA_PTR(); - auto out_data = out.DATA_PTR(); - - std::vector vals(K); - int64_t idx, next_idx, row_start; - AT_DISPATCH_REDUCTION_TYPES(reduce, [&] { - if (!optional_out.has_value()) - out.fill_(Reducer::init()); - - for (auto b = 0; b < B; b++) { - auto offset = IndexToOffset::get(b * E, index_info); - idx = index_info.data[offset]; - - for (auto k = 0; k < K; k++) - vals[k] = out_data[b * N * K + k]; - - row_start = 0; - for (auto e = 0; e < E; e++) { - - for (auto k = 0; k < K; k++) - Reducer::update( - &vals[k], src_data[b * E * K + e * K + k], &args[k], e); - - if (e == E - 1) { - for (auto k = 0; k < K; k++) - Reducer::write( - out_data + b * N * K + idx * K + k, vals[k], - arg_out_data + b * N * K + idx * K + k, args[k], - e + 1 - row_start); - } else { - next_idx = index_info.data[offset + (e + 1) * stride]; - assert(idx <= next_idx); - - if (idx != next_idx) { - for (auto k = 0; k < K; k++) { - Reducer::write( - out_data + b * N * K + idx * K + k, vals[k], - arg_out_data + b * N * K + idx * K + k, args[k], - e + 1 - row_start); - - vals[k] = out_data[b * N * K + next_idx * K + k]; - } - row_start = e + 1; - } - - idx = next_idx; - } - } - } - if (!optional_out.has_value() && (REDUCE == MIN || REDUCE == MAX)) { - out.masked_fill_(out == Reducer::init(), (scalar_t)0); - } - }); - }); - - return std::make_tuple(out, arg_out); -} - -torch::Tensor gather_coo(torch::Tensor src, torch::Tensor index, - torch::optional optional_out) { - CHECK_CPU(src); - CHECK_CPU(index); - if (optional_out.has_value()) - CHECK_CPU(optional_out.value()); - - CHECK_INPUT(src.dim() >= index.dim()); - for (auto i = 0; i < index.dim() - 1; i++) - CHECK_INPUT(src.size(i) == index.size(i)); - - auto dim = index.dim() - 1; - - src = src.contiguous(); - - torch::Tensor out; - if (optional_out.has_value()) { - out = optional_out.value().contiguous(); - for (auto i = 0; i < src.dim(); i++) - if (i != dim) - CHECK_INPUT(src.size(i) == out.size(i)); - } else { - auto sizes = src.sizes().vec(); - sizes[dim] = index.size(dim); - out = torch::empty(sizes, src.options()); - } - - auto B = index.numel() / out.size(dim); - auto E = index.size(dim); - auto K = out.numel() / index.numel(); - auto N = src.size(dim); - - auto index_info = getTensorInfo(index); - auto stride = index_info.strides[index_info.dims - 1]; - AT_DISPATCH_ALL_TYPES(src.scalar_type(), "gather_coo", [&] { - auto src_data = src.DATA_PTR(); - auto out_data = out.DATA_PTR(); - - std::vector vals(K); - int64_t idx, next_idx; - for (auto b = 0; b < B; b++) { - auto offset = IndexToOffset::get(b * E, index_info); - idx = index_info.data[offset]; - - for (auto k = 0; k < K; k++) - vals[k] = src_data[b * N * K + idx * K + k]; - - for (auto e = 0; e < E; e++) { - for (auto k = 0; k < K; k++) - out_data[b * E * K + e * K + k] = vals[k]; - - if (e < E - 1) { - next_idx = index_info.data[offset + (e + 1) * stride]; - CHECK_INPUT(idx <= next_idx); - - if (idx != next_idx) { - idx = next_idx; - for (auto k = 0; k < K; k++) - vals[k] = src_data[b * N * K + idx * K + k]; - } - } - } - } - }); - - return out; -} diff --git a/cpu/segment_csr.cpp b/cpu/segment_csr.cpp deleted file mode 100644 index c39dd829..00000000 --- a/cpu/segment_csr.cpp +++ /dev/null @@ -1,41 +0,0 @@ -#include - -#include "segment_csr_impl.h" - -using torch::autograd::AutogradContext; -using torch::autograd::Variable; -using torch::autograd::variable_list; - -class SegmentSumCSR : public torch::autograd::Function { -public: - static variable_list forward(AutogradContext *ctx, Variable src, - Variable indptr, - torch::optional optional_out) { - ctx->saved_data["src_shape"] = src.sizes(); - auto result = segment_csr(src, indptr, optional_out, "sum"); - auto out = std::get<0>(result); - ctx->save_for_backward({indptr}); - return {out}; - } - - static variable_list backward(AutogradContext *ctx, variable_list grad_outs) { - auto grad_out = grad_outs[0]; - auto saved = ctx->get_saved_variables(); - auto indptr = saved[0]; - auto src_shape = ctx->saved_data["src_shape"].toIntVector(); - auto grad_in = torch::empty(src_shape, grad_out.options()); - gather_csr(grad_out, indptr, grad_in); - - return {grad_in, Variable(), Variable()}; - } -}; - -torch::Tensor segment_sum_csr(torch::Tensor src, torch::Tensor indptr, - torch::optional optional_out) { - return SegmentSumCSR::apply(src, indptr, optional_out)[0]; -} - -static auto registry = - torch::RegisterOperators("torch_scatter_cpu::segment_csr", &segment_csr) - .op("torch_scatter_cpu::gather_csr", &gather_csr) - .op("torch_scatter_cpu::segment_sum_csr", &segment_sum_csr); diff --git a/cpu/segment_csr_impl.h b/cpu/segment_csr_impl.h deleted file mode 100644 index 8823c654..00000000 --- a/cpu/segment_csr_impl.h +++ /dev/null @@ -1,146 +0,0 @@ -#pragma once - -#include - -#include "compat.h" -#include "index_info.h" -#include "reducer.h" -#include "utils.h" - -std::tuple> -segment_csr(torch::Tensor src, torch::Tensor indptr, - torch::optional optional_out, std::string reduce) { - CHECK_CPU(src); - CHECK_CPU(indptr); - if (optional_out.has_value()) - CHECK_CPU(optional_out.value()); - - CHECK_INPUT(src.dim() >= indptr.dim()); - - auto sizes = indptr.sizes().vec(); - for (auto i = 0; i < indptr.dim() - 1; i++) - sizes[i] = src.size(i); - indptr = indptr.expand(sizes); - - auto dim = indptr.dim() - 1; - - src = src.contiguous(); - - torch::Tensor out; - if (optional_out.has_value()) { - out = optional_out.value().contiguous(); - for (int i = 0; i < out.dim(); i++) - if (i != dim) - CHECK_INPUT(src.size(i) == out.size(i)); - CHECK_INPUT(out.size(dim) == indptr.size(dim) - 1); - } else { - sizes = src.sizes().vec(); - sizes[dim] = indptr.size(dim) - 1; - out = torch::empty(sizes, src.options()); - } - - torch::optional arg_out = torch::nullopt; - int64_t *arg_out_data = nullptr; - if (reduce2REDUCE.at(reduce) == MIN || reduce2REDUCE.at(reduce) == MAX) { - arg_out = torch::full(out.sizes(), src.size(dim), indptr.options()); - arg_out_data = arg_out.value().DATA_PTR(); - } - - auto N = out.size(dim) * (indptr.numel() / indptr.size(-1)); - auto K = out.numel() / N; - auto E = src.size(dim); - - auto indptr_info = getTensorInfo(indptr); - auto stride = indptr_info.strides[indptr_info.dims - 1]; - std::vector args(K); - AT_DISPATCH_ALL_TYPES(src.scalar_type(), "segment_csr", [&] { - auto src_data = src.DATA_PTR(); - auto out_data = out.DATA_PTR(); - - std::vector vals(K); - int64_t row_start, row_end; - AT_DISPATCH_REDUCTION_TYPES(reduce, [&] { - for (auto n = 0; n < N; n++) { - auto offset = IndexPtrToOffset::get(n, indptr_info); - row_start = indptr_info.data[offset]; - row_end = indptr_info.data[offset + stride]; - - offset = (n / (indptr.size(-1) - 1)) * E * K; - for (auto k = 0; k < K; k++) - vals[k] = Reducer::init(); - - for (auto e = row_start; e < row_end; e++) { - CHECK_INPUT(e < E); - for (auto k = 0; k < K; k++) - Reducer::update( - &vals[k], src_data[offset + e * K + k], &args[k], e); - } - - for (auto k = 0; k < K; k++) - Reducer::write(out_data + n * K + k, vals[k], - arg_out_data + n * K + k, args[k], - row_end - row_start); - } - }); - }); - - return std::make_tuple(out, arg_out); -} - -torch::Tensor gather_csr(torch::Tensor src, torch::Tensor indptr, - torch::optional optional_out) { - CHECK_CPU(src); - CHECK_CPU(indptr); - if (optional_out.has_value()) - CHECK_CPU(optional_out.value()); - - CHECK_INPUT(src.dim() >= indptr.dim()); - for (auto i = 0; i < indptr.dim() - 1; i++) - CHECK_INPUT(src.size(i) == indptr.size(i)); - - auto dim = indptr.dim() - 1; - CHECK_INPUT(src.size(dim) == indptr.size(dim) - 1); - - src = src.contiguous(); - - torch::Tensor out; - if (optional_out.has_value()) { - out = optional_out.value().contiguous(); - for (auto i = 0; i < out.dim(); i++) - if (i != dim) - CHECK_INPUT(src.size(i) == out.size(i)); - } else { - auto sizes = src.sizes().vec(); - sizes[dim] = *indptr.flatten()[-1].DATA_PTR(); - out = torch::empty(sizes, src.options()); - } - - auto N = src.size(dim) * (indptr.numel() / indptr.size(-1)); - auto K = src.numel() / N; - auto E = out.size(dim); - - auto indptr_info = getTensorInfo(indptr); - auto stride = indptr_info.strides[indptr_info.dims - 1]; - AT_DISPATCH_ALL_TYPES(src.scalar_type(), "gather_csr", [&] { - auto src_data = src.DATA_PTR(); - auto out_data = out.DATA_PTR(); - - std::vector vals(K); - int64_t row_start, row_end; - for (int n = 0; n < N; n++) { - auto offset = IndexPtrToOffset::get(n, indptr_info); - row_start = indptr_info.data[offset]; - row_end = indptr_info.data[offset + stride]; - - for (auto k = 0; k < K; k++) - vals[k] = src_data[n * K + k]; - - offset = (n / (indptr.size(-1) - 1)) * E * K; - for (auto e = row_start; e < row_end; e++) - for (auto k = 0; k < K; k++) - out_data[offset + e * K + k] = vals[k]; - } - }); - - return out; -} diff --git a/cpu/utils.h b/cpu/utils.h deleted file mode 100644 index 40dfb344..00000000 --- a/cpu/utils.h +++ /dev/null @@ -1,6 +0,0 @@ -#pragma once - -#include - -#define CHECK_CPU(x) AT_ASSERTM(x.device().is_cpu(), #x " must be CPU tensor") -#define CHECK_INPUT(x) AT_ASSERTM(x, "Input mismatch") diff --git a/csrc/cuda/scatter_cuda.cu b/csrc/cuda/scatter_cuda.cu index 76b9f8c1..7a0dc175 100644 --- a/csrc/cuda/scatter_cuda.cu +++ b/csrc/cuda/scatter_cuda.cu @@ -1,8 +1,136 @@ #include "scatter_cuda.h" +#include +#include +#include + +#include "reducer.cuh" +#include "utils.cuh" + +#define THREADS 1024 +#define BLOCKS(N) (N + THREADS - 1) / THREADS + +template +__global__ void +scatter_kernel(const scalar_t *src_data, + const at::cuda::detail::TensorInfo index_info, + scalar_t *out_data, int E, int K, int N, int numel) { + + int thread_idx = blockIdx.x * blockDim.x + threadIdx.x; + + int b = thread_idx / (E * K); + int k = thread_idx % K; + + if (thread_idx < numel) { + int offset = at::cuda::detail::IndexToOffset::get( + thread_idx, index_info); + int64_t idx = index_info.data[offset]; + + Reducer::atomic_write(out_data + b * N * K + idx * K + k, + src_data[thread_idx]); + } +} + +template +__global__ void +scatter_arg_kernel(const scalar_t *src_data, + const at::cuda::detail::TensorInfo index_info, + const scalar_t *out_data, int64_t *arg_out_data, int E, + int K, int N, int numel) { + + int thread_idx = blockIdx.x * blockDim.x + threadIdx.x; + + int b = thread_idx / (E * K); + int e = (thread_idx / K) % E; + int k = thread_idx % K; + + if (thread_idx < numel) { + int offset = at::cuda::detail::IndexToOffset::get( + thread_idx, index_info); + int64_t idx = index_info.data[offset]; + + if (src_data[thread_idx] == out_data[b * N * K + idx * K + k]) { + arg_out_data[b * N * K + idx * K + k] = e; + } + } +} + std::tuple> scatter_cuda(torch::Tensor src, torch::Tensor index, int64_t dim, torch::optional optional_out, torch::optional dim_size, std::string reduce) { - return std::make_tuple(src, optional_out); + CHECK_CUDA(src); + CHECK_CUDA(index); + if (optional_out.has_value()) + CHECK_CUDA(optional_out.value()); + cudaSetDevice(src.get_device()); + + CHECK_INPUT(src.dim() == index.dim()); + for (auto i = 0; i < index.dim() - 1; i++) + CHECK_INPUT(src.size(i) >= index.size(i)); + + if (dim < 0) + dim = src.dim() + dim; + + src = src.contiguous(); + + torch::Tensor out; + if (optional_out.has_value()) { + out = optional_out.value().contiguous(); + for (auto i = 0; i < out.dim(); i++) + if (i != dim) + CHECK_INPUT(src.size(i) == out.size(i)); + } else { + auto sizes = src.sizes().vec(); + if (dim_size.has_value()) + sizes[dim] = dim_size.value(); + else { + auto d_size = index.max().data_ptr(); + auto h_size = (int64_t *)malloc(sizeof(int64_t)); + cudaMemcpy(h_size, d_size, sizeof(int64_t), cudaMemcpyDeviceToHost); + sizes[dim] = 1 + *h_size; + } + out = torch::empty(sizes, src.options()); + } + + torch::optional arg_out = torch::nullopt; + int64_t *arg_out_data = nullptr; + if (reduce2REDUCE.at(reduce) == MIN || reduce2REDUCE.at(reduce) == MAX) { + arg_out = torch::full_like(out, src.size(dim), index.options()); + arg_out_data = arg_out.value().data_ptr(); + } + + auto B = 1; + for (auto i = 0; i < dim; i++) + B *= src.size(i); + auto E = src.size(dim); + auto K = src.numel() / (B * E); + auto N = out.size(dim); + + auto index_info = at::cuda::detail::getTensorInfo(index); + auto stream = at::cuda::getCurrentCUDAStream(); + AT_DISPATCH_ALL_TYPES(src.scalar_type(), "scatter", [&] { + auto src_data = src.data_ptr(); + auto out_data = out.data_ptr(); + + AT_DISPATCH_REDUCTION_TYPES(reduce, [&] { + if (!optional_out.has_value()) + out.fill_(Reducer::init()); + + scatter_kernel + <<>>( + src_data, index_info, out_data, E, K, N, src.numel()); + + if (!optional_out.has_value() && (REDUCE == MIN || REDUCE == MAX)) + out.masked_fill_(out == Reducer::init(), (scalar_t)0); + + if (REDUCE == MIN || REDUCE == MAX) + scatter_arg_kernel + <<>>( + src_data, index_info, out_data, arg_out_data, E, K, N, + src.numel()); + }); + }); + + return std::make_tuple(out, arg_out); } diff --git a/cuda/atomics.cuh b/cuda/atomics.cuh deleted file mode 100644 index 32427eac..00000000 --- a/cuda/atomics.cuh +++ /dev/null @@ -1,230 +0,0 @@ -#pragma once - -#define ATOMIC(NAME) \ - template struct Atomic##NAME##IntegerImpl; \ - \ - template struct Atomic##NAME##IntegerImpl { \ - inline __device__ void operator()(scalar *address, scalar val) { \ - uint32_t *address_as_ui = (uint32_t *)(address - ((size_t)address & 3)); \ - uint32_t old = *address_as_ui; \ - uint32_t shift = ((size_t)address & 3) * 8; \ - uint32_t sum; \ - uint32_t assumed; \ - \ - do { \ - assumed = old; \ - sum = OP(val, scalar((old >> shift) & 0xff)); \ - old = (old & ~(0x000000ff << shift)) | (sum << shift); \ - old = atomicCAS(address_as_ui, assumed, old); \ - } while (assumed != old); \ - } \ - }; \ - \ - template struct Atomic##NAME##IntegerImpl { \ - inline __device__ void operator()(scalar *address, scalar val) { \ - uint32_t *address_as_ui = \ - (uint32_t *)((char *)address - ((size_t)address & 2)); \ - uint32_t old = *address_as_ui; \ - uint32_t sum; \ - uint32_t newval; \ - uint32_t assumed; \ - \ - do { \ - assumed = old; \ - sum = OP(val, (size_t)address & 2 ? scalar(old >> 16) \ - : scalar(old & 0xffff)); \ - newval = (size_t)address & 2 ? (old & 0xffff) | (sum << 16) \ - : (old & 0xffff0000) | sum; \ - old = atomicCAS(address_as_ui, assumed, newval); \ - } while (assumed != old); \ - } \ - }; \ - \ - template struct Atomic##NAME##IntegerImpl { \ - inline __device__ void operator()(scalar *address, scalar val) { \ - uint32_t *address_as_ui = (uint32_t *)address; \ - uint32_t old = *address_as_ui; \ - uint32_t assumed; \ - \ - do { \ - assumed = old; \ - old = atomicCAS(address_as_ui, assumed, OP(val, (scalar)old)); \ - } while (assumed != old); \ - } \ - }; \ - \ - template struct Atomic##NAME##IntegerImpl { \ - inline __device__ void operator()(scalar *address, scalar val) { \ - unsigned long long *address_as_ull = (unsigned long long *)address; \ - unsigned long long old = *address_as_ull; \ - unsigned long long assumed; \ - \ - do { \ - assumed = old; \ - old = atomicCAS(address_as_ull, assumed, OP(val, (scalar)old)); \ - } while (assumed != old); \ - } \ - }; \ - \ - template struct Atomic##NAME##DecimalImpl; \ - \ - template struct Atomic##NAME##DecimalImpl { \ - inline __device__ void operator()(scalar *address, scalar val) { \ - int *address_as_i = (int *)address; \ - int old = *address_as_i; \ - int assumed; \ - \ - do { \ - assumed = old; \ - old = atomicCAS(address_as_i, assumed, \ - __float_as_int(OP(val, __int_as_float(assumed)))); \ - } while (assumed != old); \ - } \ - }; \ - \ - template struct Atomic##NAME##DecimalImpl { \ - inline __device__ void operator()(scalar *address, scalar val) { \ - unsigned long long int *address_as_ull = \ - (unsigned long long int *)address; \ - unsigned long long int old = *address_as_ull; \ - unsigned long long int assumed; \ - \ - do { \ - assumed = old; \ - old = atomicCAS( \ - address_as_ull, assumed, \ - __double_as_longlong(OP(val, __longlong_as_double(assumed)))); \ - } while (assumed != old); \ - } \ - }; - -#define OP(X, Y) Y + X -ATOMIC(Add) -#undef OP -static inline __device__ void atomAdd(uint8_t *address, uint8_t val) { - AtomicAddIntegerImpl()(address, val); -} -static inline __device__ void atomAdd(int8_t *address, int8_t val) { - AtomicAddIntegerImpl()(address, val); -} -static inline __device__ void atomAdd(int16_t *address, int16_t val) { - AtomicAddIntegerImpl()(address, val); -} -static inline __device__ void atomAdd(int32_t *address, int32_t val) { - atomicAdd(address, val); -} -static inline __device__ void atomAdd(int64_t *address, int64_t val) { - AtomicAddIntegerImpl()(address, val); -} -static inline __device__ void atomAdd(float *address, float val) { - atomicAdd(address, val); -} -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 600 || CUDA_VERSION < 8000) -static inline __device__ void atomAdd(double *address, double val) { - AtomicAddDecimalImpl()(address, val); -} -#else -static inline __device__ void atomAdd(double *address, double val) { - atomicAdd(address, val); -} -#endif - -#define OP(X, Y) Y *X -ATOMIC(Mul) -#undef OP -static inline __device__ void atomMul(uint8_t *address, uint8_t val) { - AtomicMulIntegerImpl()(address, val); -} -static inline __device__ void atomMul(int8_t *address, int8_t val) { - AtomicMulIntegerImpl()(address, val); -} -static inline __device__ void atomMul(int16_t *address, int16_t val) { - AtomicMulIntegerImpl()(address, val); -} -static inline __device__ void atomMul(int32_t *address, int32_t val) { - AtomicMulIntegerImpl()(address, val); -} -static inline __device__ void atomMul(int64_t *address, int64_t val) { - AtomicMulIntegerImpl()(address, val); -} -static inline __device__ void atomMul(float *address, float val) { - AtomicMulDecimalImpl()(address, val); -} -static inline __device__ void atomMul(double *address, double val) { - AtomicMulDecimalImpl()(address, val); -} - -#define OP(X, Y) Y / X -ATOMIC(Div) -#undef OP -static inline __device__ void atomDiv(uint8_t *address, uint8_t val) { - AtomicDivIntegerImpl()(address, val); -} -static inline __device__ void atomDiv(int8_t *address, int8_t val) { - AtomicDivIntegerImpl()(address, val); -} -static inline __device__ void atomDiv(int16_t *address, int16_t val) { - AtomicDivIntegerImpl()(address, val); -} -static inline __device__ void atomDiv(int32_t *address, int32_t val) { - AtomicDivIntegerImpl()(address, val); -} -static inline __device__ void atomDiv(int64_t *address, int64_t val) { - AtomicDivIntegerImpl()(address, val); -} -static inline __device__ void atomDiv(float *address, float val) { - AtomicDivDecimalImpl()(address, val); -} -static inline __device__ void atomDiv(double *address, double val) { - AtomicDivDecimalImpl()(address, val); -} - -#define OP(X, Y) max(Y, X) -ATOMIC(Max) -#undef OP -static inline __device__ void atomMax(uint8_t *address, uint8_t val) { - AtomicMaxIntegerImpl()(address, val); -} -static inline __device__ void atomMax(int8_t *address, int8_t val) { - AtomicMaxIntegerImpl()(address, val); -} -static inline __device__ void atomMax(int16_t *address, int16_t val) { - AtomicMaxIntegerImpl()(address, val); -} -static inline __device__ void atomMax(int32_t *address, int32_t val) { - atomicMax(address, val); -} -static inline __device__ void atomMax(int64_t *address, int64_t val) { - AtomicMaxIntegerImpl()(address, val); -} -static inline __device__ void atomMax(float *address, float val) { - AtomicMaxDecimalImpl()(address, val); -} -static inline __device__ void atomMax(double *address, double val) { - AtomicMaxDecimalImpl()(address, val); -} - -#define OP(X, Y) min(Y, X) -ATOMIC(Min) -#undef OP -static inline __device__ void atomMin(uint8_t *address, uint8_t val) { - AtomicMinIntegerImpl()(address, val); -} -static inline __device__ void atomMin(int8_t *address, int8_t val) { - AtomicMinIntegerImpl()(address, val); -} -static inline __device__ void atomMin(int16_t *address, int16_t val) { - AtomicMinIntegerImpl()(address, val); -} -static inline __device__ void atomMin(int32_t *address, int32_t val) { - atomicMin(address, val); -} -static inline __device__ void atomMin(int64_t *address, int64_t val) { - AtomicMinIntegerImpl()(address, val); -} -static inline __device__ void atomMin(float *address, float val) { - AtomicMinDecimalImpl()(address, val); -} -static inline __device__ void atomMin(double *address, double val) { - AtomicMinDecimalImpl()(address, val); -} diff --git a/cuda/compat.cuh b/cuda/compat.cuh deleted file mode 100644 index 1be09913..00000000 --- a/cuda/compat.cuh +++ /dev/null @@ -1,5 +0,0 @@ -#ifdef VERSION_GE_1_3 -#define DATA_PTR data_ptr -#else -#define DATA_PTR data -#endif diff --git a/cuda/gather.cpp b/cuda/gather.cpp deleted file mode 100644 index e80f70ec..00000000 --- a/cuda/gather.cpp +++ /dev/null @@ -1,31 +0,0 @@ -#include - -#define CHECK_CUDA(x) \ - AT_ASSERTM(x.device().is_cuda(), #x " must be CUDA tensor") - -torch::Tensor gather_csr_cuda(torch::Tensor src, torch::Tensor indptr, - torch::optional out_opt); -torch::Tensor gather_coo_cuda(torch::Tensor src, torch::Tensor index, - torch::optional out_opt); - -torch::Tensor gather_csr(torch::Tensor src, torch::Tensor indptr, - torch::optional out_opt) { - CHECK_CUDA(src); - CHECK_CUDA(indptr); - if (out_opt.has_value()) - CHECK_CUDA(out_opt.value()); - return gather_csr_cuda(src, indptr, out_opt); -} - -torch::Tensor gather_coo(torch::Tensor src, torch::Tensor index, - torch::optional out_opt) { - CHECK_CUDA(src); - CHECK_CUDA(index); - if (out_opt.has_value()) - CHECK_CUDA(out_opt.value()); - return gather_coo_cuda(src, index, out_opt); -} - -static auto registry = - torch::RegisterOperators("torch_scatter_cuda::gather_csr", &gather_csr) - .op("torch_scatter_cuda::gather_coo", &gather_coo); diff --git a/cuda/gather_kernel.cu b/cuda/gather_kernel.cu deleted file mode 100644 index 68dff864..00000000 --- a/cuda/gather_kernel.cu +++ /dev/null @@ -1,202 +0,0 @@ -#include -#include -#include -#include - -#include "compat.cuh" -#include "indptr.cuh" - -#define THREADS 256 -#define BLOCKS(TB, N) (TB * N + THREADS - 1) / THREADS - -template -__global__ void -gather_csr_kernel(const scalar_t *src_data, - const at::cuda::detail::TensorInfo indptr_info, - scalar_t *out_data, size_t N, size_t E) { - - int thread_idx = blockIdx.x * blockDim.x + threadIdx.x; - int row_idx = thread_idx / TB; - int lane_idx = thread_idx % TB; - - if (row_idx < N) { - int offset = IndexPtrToOffset::get(row_idx, indptr_info); - int row_start = __ldg(indptr_info.data + offset); - int row_end = __ldg(indptr_info.data + offset + - indptr_info.strides[indptr_info.dims - 1]); - scalar_t val = __ldg(src_data + row_idx); - - offset = (row_idx / (indptr_info.sizes[indptr_info.dims - 1] - 1)) * E; - for (int out_idx = row_start + lane_idx; out_idx < row_end; out_idx += TB) { - out_data[offset + out_idx] = val; // "Mostly" coalesced. - } - } -} - -template -__global__ void gather_csr_broadcast_kernel( - const scalar_t *src_data, - const at::cuda::detail::TensorInfo indptr_info, - scalar_t *out_data, size_t N, size_t K, size_t E) { - - int thread_idx = blockIdx.x * blockDim.x + threadIdx.x; - int row_idx = thread_idx / K; - int lane_idx = thread_idx % K; - - if (thread_idx < N * K) { - int offset = IndexPtrToOffset::get(row_idx, indptr_info); - int row_start = __ldg(indptr_info.data + offset); - int row_end = __ldg(indptr_info.data + offset + - indptr_info.strides[indptr_info.dims - 1]); - - scalar_t val = src_data[thread_idx]; // Coalesced. - - offset = (row_idx / (indptr_info.sizes[indptr_info.dims - 1] - 1)) * E * K; - for (int out_idx = row_start; out_idx < row_end; out_idx++) { - out_data[offset + K * out_idx + lane_idx] = val; // "Mostly" coalesced. - } - } -} - -torch::Tensor gather_csr_cuda(torch::Tensor src, torch::Tensor indptr, - torch::optional out_opt) { - - cudaSetDevice(src.get_device()); - AT_ASSERTM(src.dim() >= indptr.dim(), "Input mismatch"); - for (int i = 0; i < indptr.dim() - 1; i++) - AT_ASSERTM(src.size(i) == indptr.size(i), "Input mismatch"); - - src = src.contiguous(); - auto gather_dim = indptr.dim() - 1; - AT_ASSERTM(src.size(gather_dim) == indptr.size(gather_dim) - 1, - "Input mismatch"); - - torch::Tensor out; - if (out_opt.has_value()) { - out = out_opt.value().contiguous(); - for (int i = 0; i < out.dim(); i++) - if (i != gather_dim) - AT_ASSERTM(src.size(i) == out.size(i), "Input mismatch"); - } else { - auto d_gather_size = indptr.flatten()[-1].DATA_PTR(); - auto h_gather_size = (int64_t *)malloc(sizeof(int64_t)); - cudaMemcpy(h_gather_size, d_gather_size, sizeof(int64_t), - cudaMemcpyDeviceToHost); - - auto sizes = src.sizes().vec(); - sizes[gather_dim] = *h_gather_size; - out = at::empty(sizes, src.options()); - } - - auto N = src.size(gather_dim) * (indptr.numel() / indptr.size(-1)); - auto K = src.numel() / N; - auto E = out.size(gather_dim); - - auto indptr_info = at::cuda::detail::getTensorInfo(indptr); - auto stream = at::cuda::getCurrentCUDAStream(); - AT_DISPATCH_ALL_TYPES(src.scalar_type(), "gather_csr_kernel", [&] { - auto src_data = src.DATA_PTR(); - auto out_data = out.DATA_PTR(); - - if (K == 1) { - gather_csr_kernel<<>>( - src_data, indptr_info, out_data, N, E); - } else { - gather_csr_broadcast_kernel - <<>>(src_data, indptr_info, - out_data, N, K, E); - } - }); - - return out; -} - -template -__global__ void -gather_coo_kernel(const scalar_t *src_data, - const at::cuda::detail::TensorInfo index_info, - scalar_t *out_data, size_t E, size_t N) { - - int row_idx = blockIdx.x * blockDim.x + threadIdx.x; - - if (row_idx < E) { - int offset = at::cuda::detail::IndexToOffset::get( - row_idx, index_info); - int row = index_info.data[offset]; - - offset = (row_idx / index_info.sizes[index_info.dims - 1]) * N; - scalar_t val = __ldg(src_data + offset + row); - - out_data[row_idx] = val; - } -} - -template -__global__ void gather_coo_broadcast_kernel( - const scalar_t *src_data, - const at::cuda::detail::TensorInfo index_info, - scalar_t *out_data, size_t E, size_t K, size_t N) { - - int thread_idx = blockIdx.x * blockDim.x + threadIdx.x; - int row_idx = thread_idx / K; - int col_idx = thread_idx % K; - - if (thread_idx < E * K) { - int offset = at::cuda::detail::IndexToOffset::get( - row_idx, index_info); - int row = index_info.data[offset]; - - offset = (row_idx / index_info.sizes[index_info.dims - 1]) * N * K; - scalar_t val = __ldg(src_data + offset + K * row + col_idx); - - out_data[thread_idx] = val; - } -} - -torch::Tensor gather_coo_cuda(torch::Tensor src, torch::Tensor index, - torch::optional out_opt) { - - cudaSetDevice(src.get_device()); - - AT_ASSERTM(src.dim() >= index.dim(), "Input mismatch"); - for (int i = 0; i < index.dim() - 1; i++) - AT_ASSERTM(src.size(i) == index.size(i), "Input mismatch"); - - src = src.contiguous(); - auto gather_dim = index.dim() - 1; - - torch::Tensor out; - if (out_opt.has_value()) { - out = out_opt.value().contiguous(); - for (int i = 0; i < index.dim(); i++) - AT_ASSERTM(out.size(i) == index.size(i), "Input mismatch"); - for (int i = index.dim() + 1; i < src.dim(); i++) - AT_ASSERTM(out.size(i) == src.size(i), "Input mismatch"); - } else { - auto sizes = src.sizes().vec(); - sizes[gather_dim] = index.size(gather_dim); - out = torch::empty(sizes, src.options()); - } - - auto E = index.numel(); - auto K = out.numel() / E; - auto N = src.size(gather_dim); - - auto index_info = at::cuda::detail::getTensorInfo(index); - auto stream = at::cuda::getCurrentCUDAStream(); - AT_DISPATCH_ALL_TYPES(src.scalar_type(), "gather_coo_kernel", [&] { - auto src_data = src.DATA_PTR(); - auto out_data = out.DATA_PTR(); - - if (K == 1) { - gather_coo_kernel<<>>( - src_data, index_info, out_data, E, N); - } else { - gather_coo_broadcast_kernel - <<>>(src_data, index_info, - out_data, E, K, N); - } - }); - - return out; -} diff --git a/cuda/index.cuh b/cuda/index.cuh deleted file mode 100644 index dab1493a..00000000 --- a/cuda/index.cuh +++ /dev/null @@ -1,108 +0,0 @@ -#pragma once - -#include -#include - -template -struct IndexToScatterOffsets3 { - static __device__ void - compute(int64_t i, const int64_t dim, - const at::cuda::detail::TensorInfo &index, - int64_t *indexOffset, - const at::cuda::detail::TensorInfo &t1, - int64_t *t1Offset, - const at::cuda::detail::TensorInfo &t2, - int64_t *t2Offset) { - for (int64_t d = Dims - 1; d >= 0; d--) { - int64_t curDimIndex = i % index.sizes[d]; - *indexOffset += curDimIndex * index.strides[d]; - *t1Offset += curDimIndex * t1.strides[d]; - if (d != dim) { - *t2Offset += curDimIndex * t2.strides[d]; - } - i /= index.sizes[d]; - } - int64_t indexValue = index.data[*indexOffset]; - *t2Offset += indexValue * t2.strides[dim]; - } -}; - -template -struct IndexToScatterOffsets3 { - static __device__ void - compute(int64_t i, const int64_t dim, - const at::cuda::detail::TensorInfo &index, - int64_t *indexOffset, - const at::cuda::detail::TensorInfo &t1, - int64_t *t1Offset, - const at::cuda::detail::TensorInfo &t2, - int64_t *t2Offset) { - for (int64_t d = index.dims - 1; d >= 0; d--) { - int64_t curDimIndex = i % index.sizes[d]; - *indexOffset += curDimIndex * index.strides[d]; - *t1Offset += curDimIndex * t1.strides[d]; - if (d != dim) { - *t2Offset += curDimIndex * t2.strides[d]; - } - i /= index.sizes[d]; - } - int64_t indexValue = index.data[*indexOffset]; - *t2Offset += indexValue * t2.strides[dim]; - } -}; - -template -struct IndexToScatterOffsets4 { - static __device__ void - compute(int64_t i, const int64_t dim, - const at::cuda::detail::TensorInfo &index, - int64_t *indexOffset, - const at::cuda::detail::TensorInfo &t1, - int64_t *t1Offset, - const at::cuda::detail::TensorInfo &t2, - int64_t *t2Offset, - const at::cuda::detail::TensorInfo &t3, - int64_t *t3Offset) { - for (int64_t d = Dims - 1; d >= 0; d--) { - int64_t curDimIndex = i % index.sizes[d]; - *indexOffset += curDimIndex * index.strides[d]; - *t1Offset += curDimIndex * t1.strides[d]; - if (d != dim) { - *t2Offset += curDimIndex * t2.strides[d]; - *t3Offset += curDimIndex * t3.strides[d]; - } - i /= index.sizes[d]; - } - int64_t indexValue = index.data[*indexOffset]; - *t2Offset += indexValue * t2.strides[dim]; - *t3Offset += indexValue * t3.strides[dim]; - } -}; - -template -struct IndexToScatterOffsets4 { - static __device__ void - compute(int64_t i, const int64_t dim, - const at::cuda::detail::TensorInfo &index, - int64_t *indexOffset, - const at::cuda::detail::TensorInfo &t1, - int64_t *t1Offset, - const at::cuda::detail::TensorInfo &t2, - int64_t *t2Offset, - const at::cuda::detail::TensorInfo &t3, - int64_t *t3Offset) { - for (int64_t d = index.dims - 1; d >= 0; d--) { - int64_t curDimIndex = i % index.sizes[d]; - *indexOffset += curDimIndex * index.strides[d]; - *t1Offset += curDimIndex * t1.strides[d]; - if (d != dim) { - *t2Offset += curDimIndex * t2.strides[d]; - *t3Offset += curDimIndex * t3.strides[d]; - } - i /= index.sizes[d]; - } - int64_t indexValue = index.data[*indexOffset]; - *t2Offset += indexValue * t2.strides[dim]; - *t3Offset += indexValue * t3.strides[dim]; - } -}; diff --git a/cuda/indptr.cuh b/cuda/indptr.cuh deleted file mode 100644 index a0860995..00000000 --- a/cuda/indptr.cuh +++ /dev/null @@ -1,20 +0,0 @@ -#pragma once - -#include -#include - -// We need our own `IndexToOffset` implementation since we do not want to -// access the last element of the `indexptr`. -template struct IndexPtrToOffset { - static inline __host__ __device__ int - get(int idx, const at::cuda::detail::TensorInfo &info) { - int offset = idx % (info.sizes[info.dims - 1] - 1); - offset *= info.strides[info.dims - 1]; - idx /= info.sizes[info.dims - 1] - 1; - for (int i = info.dims - 2; i >= 0; --i) { - offset += (idx % info.sizes[i]) * info.strides[i]; - idx /= info.sizes[i]; - } - return offset; - } -}; diff --git a/cuda/scatter.cpp b/cuda/scatter.cpp deleted file mode 100644 index 652bbacf..00000000 --- a/cuda/scatter.cpp +++ /dev/null @@ -1,55 +0,0 @@ -#include - -#define CHECK_CUDA(x) \ - AT_ASSERTM(x.device().is_cuda(), #x " must be CUDA tensor") - -void scatter_mul_cuda(torch::Tensor src, torch::Tensor index, torch::Tensor out, - int64_t dim); -void scatter_div_cuda(torch::Tensor src, torch::Tensor index, torch::Tensor out, - int64_t dim); -void scatter_max_cuda(torch::Tensor src, torch::Tensor index, torch::Tensor out, - torch::Tensor arg, int64_t dim); -void scatter_min_cuda(torch::Tensor src, torch::Tensor index, torch::Tensor out, - torch::Tensor arg, int64_t dim); -void index_backward_cuda(torch::Tensor grad, torch::Tensor index, - torch::Tensor arg, torch::Tensor out, int64_t dim); - -void scatter_mul(torch::Tensor src, torch::Tensor index, torch::Tensor out, - int64_t dim) { - CHECK_CUDA(src); - CHECK_CUDA(index); - CHECK_CUDA(out); - scatter_mul_cuda(src, index, out, dim); -} - -void scatter_div(torch::Tensor src, torch::Tensor index, torch::Tensor out, - int64_t dim) { - CHECK_CUDA(src); - CHECK_CUDA(index); - CHECK_CUDA(out); - scatter_div_cuda(src, index, out, dim); -} - -void scatter_max(torch::Tensor src, torch::Tensor index, torch::Tensor out, - torch::Tensor arg, int64_t dim) { - CHECK_CUDA(src); - CHECK_CUDA(index); - CHECK_CUDA(out); - CHECK_CUDA(arg); - scatter_max_cuda(src, index, out, arg, dim); -} - -void scatter_min(torch::Tensor src, torch::Tensor index, torch::Tensor out, - torch::Tensor arg, int64_t dim) { - CHECK_CUDA(src); - CHECK_CUDA(index); - CHECK_CUDA(out); - CHECK_CUDA(arg); - scatter_min_cuda(src, index, out, arg, dim); -} - -static auto registry = - torch::RegisterOperators("torch_scatter_cuda::scatter_mul", &scatter_mul) - .op("torch_scatter_cuda::scatter_div", &scatter_div) - .op("torch_scatter_cuda::scatter_max", &scatter_max) - .op("torch_scatter_cuda::scatter_min", &scatter_min); diff --git a/cuda/scatter_kernel.cu b/cuda/scatter_kernel.cu deleted file mode 100644 index 1122d330..00000000 --- a/cuda/scatter_kernel.cu +++ /dev/null @@ -1,163 +0,0 @@ -#include -#include -#include -#include - -#include "atomics.cuh" -#include "index.cuh" - -#define THREADS 1024 -#define BLOCKS(N) (N + THREADS - 1) / THREADS - -#define KERNEL_RUN(NAME, DIMS, N, ...) \ - [&] { \ - auto stream = at::cuda::getCurrentCUDAStream(); \ - switch (DIMS) { \ - case 1: \ - NAME<<>>(__VA_ARGS__, N); \ - break; \ - case 2: \ - NAME<<>>(__VA_ARGS__, N); \ - break; \ - case 3: \ - NAME<<>>(__VA_ARGS__, N); \ - break; \ - default: \ - NAME<<>>(__VA_ARGS__, N); \ - } \ - }() - -template -__global__ void -scatter_mul_kernel(at::cuda::detail::TensorInfo src, - at::cuda::detail::TensorInfo index, - at::cuda::detail::TensorInfo out, - int64_t dim, size_t numel) { - const size_t idx = blockIdx.x * blockDim.x + threadIdx.x; - const size_t stride = blockDim.x * gridDim.x; - for (ptrdiff_t i = idx; i < numel; i += stride) { - int64_t srcOffset = 0, indexOffset = 0, outOffset = 0; - IndexToScatterOffsets3::compute( - i, dim, index, &indexOffset, src, &srcOffset, out, &outOffset); - atomMul(&out.data[outOffset], src.data[srcOffset]); - } -} - -void scatter_mul_cuda(torch::Tensor src, torch::Tensor index, torch::Tensor out, - int64_t dim) { - cudaSetDevice(src.get_device()); - AT_DISPATCH_ALL_TYPES(src.scalar_type(), "scatter_mul_kernel", [&] { - KERNEL_RUN(scatter_mul_kernel, index.dim(), index.numel(), - at::cuda::detail::getTensorInfo(src), - at::cuda::detail::getTensorInfo(index), - at::cuda::detail::getTensorInfo(out), dim); - }); -} - -template -__global__ void -scatter_div_kernel(at::cuda::detail::TensorInfo src, - at::cuda::detail::TensorInfo index, - at::cuda::detail::TensorInfo out, - int64_t dim, size_t numel) { - const size_t idx = blockIdx.x * blockDim.x + threadIdx.x; - const size_t stride = blockDim.x * gridDim.x; - for (ptrdiff_t i = idx; i < numel; i += stride) { - int64_t srcOffset = 0, indexOffset = 0, outOffset = 0; - IndexToScatterOffsets3::compute( - i, dim, index, &indexOffset, src, &srcOffset, out, &outOffset); - atomDiv(&out.data[outOffset], src.data[srcOffset]); - } -} - -void scatter_div_cuda(torch::Tensor src, torch::Tensor index, torch::Tensor out, - int64_t dim) { - cudaSetDevice(src.get_device()); - AT_DISPATCH_ALL_TYPES(src.scalar_type(), "scatter_div_kernel", [&] { - KERNEL_RUN(scatter_div_kernel, index.dim(), index.numel(), - at::cuda::detail::getTensorInfo(src), - at::cuda::detail::getTensorInfo(index), - at::cuda::detail::getTensorInfo(out), dim); - }); -} - -template -__global__ void arg_kernel(at::cuda::detail::TensorInfo src, - at::cuda::detail::TensorInfo index, - at::cuda::detail::TensorInfo out, - at::cuda::detail::TensorInfo arg, - int64_t dim, size_t numel) { - const size_t idx = blockIdx.x * blockDim.x + threadIdx.x; - const size_t stride = blockDim.x * gridDim.x; - for (ptrdiff_t i = idx; i < numel; i += stride) { - int64_t srcOffset = 0, indexOffset = 0, outOffset = 0, argOffset = 0; - IndexToScatterOffsets4::compute( - i, dim, index, &indexOffset, src, &srcOffset, out, &outOffset, arg, - &argOffset); - if (src.data[srcOffset] == out.data[outOffset]) { - arg.data[argOffset] = (srcOffset / src.strides[dim]) % src.sizes[dim]; - } - } -} - -template -__global__ void -scatter_max_kernel(at::cuda::detail::TensorInfo src, - at::cuda::detail::TensorInfo index, - at::cuda::detail::TensorInfo out, - int64_t dim, size_t numel) { - const size_t idx = blockIdx.x * blockDim.x + threadIdx.x; - const size_t stride = blockDim.x * gridDim.x; - for (ptrdiff_t i = idx; i < numel; i += stride) { - int64_t srcOffset = 0, indexOffset = 0, outOffset = 0; - IndexToScatterOffsets3::compute( - i, dim, index, &indexOffset, src, &srcOffset, out, &outOffset); - atomMax(&out.data[outOffset], src.data[srcOffset]); - } -} - -void scatter_max_cuda(torch::Tensor src, torch::Tensor index, torch::Tensor out, - torch::Tensor arg, int64_t dim) { - cudaSetDevice(src.get_device()); - AT_DISPATCH_ALL_TYPES(src.scalar_type(), "scatter_max_kernel", [&] { - auto src_info = at::cuda::detail::getTensorInfo(src); - auto index_info = at::cuda::detail::getTensorInfo(index); - auto out_info = at::cuda::detail::getTensorInfo(out); - KERNEL_RUN(scatter_max_kernel, index.dim(), index.numel(), src_info, - index_info, out_info, dim); - KERNEL_RUN(arg_kernel, index.dim(), index.numel(), src_info, index_info, - out_info, at::cuda::detail::getTensorInfo(arg), - dim); - }); -} - -template -__global__ void -scatter_min_kernel(at::cuda::detail::TensorInfo src, - at::cuda::detail::TensorInfo index, - at::cuda::detail::TensorInfo out, - int64_t dim, size_t numel) { - const size_t idx = blockIdx.x * blockDim.x + threadIdx.x; - const size_t stride = blockDim.x * gridDim.x; - for (ptrdiff_t i = idx; i < numel; i += stride) { - int64_t srcOffset = 0, indexOffset = 0, outOffset = 0; - IndexToScatterOffsets3::compute( - i, dim, index, &indexOffset, src, &srcOffset, out, &outOffset); - atomMin(&out.data[outOffset], src.data[srcOffset]); - } -} - -void scatter_min_cuda(torch::Tensor src, torch::Tensor index, torch::Tensor out, - torch::Tensor arg, int64_t dim) { - cudaSetDevice(src.get_device()); - AT_DISPATCH_ALL_TYPES(src.scalar_type(), "scatter_min_kernel", [&] { - auto src_info = at::cuda::detail::getTensorInfo(src); - auto index_info = at::cuda::detail::getTensorInfo(index); - auto out_info = at::cuda::detail::getTensorInfo(out); - KERNEL_RUN(scatter_min_kernel, index.dim(), index.numel(), src_info, - index_info, out_info, dim); - KERNEL_RUN(arg_kernel, index.dim(), index.numel(), src_info, index_info, - out_info, at::cuda::detail::getTensorInfo(arg), - dim); - }); -} diff --git a/cuda/segment.cpp b/cuda/segment.cpp deleted file mode 100644 index 883807b9..00000000 --- a/cuda/segment.cpp +++ /dev/null @@ -1,34 +0,0 @@ -#include - -#define CHECK_CUDA(x) \ - AT_ASSERTM(x.device().is_cuda(), #x " must be CUDA tensor") - -std::tuple> -segment_csr_cuda(torch::Tensor src, torch::Tensor indptr, - torch::optional out_opt, std::string reduce); -std::tuple> -segment_coo_cuda(torch::Tensor src, torch::Tensor index, torch::Tensor out, - std::string reduce); - -std::tuple> -segment_csr(torch::Tensor src, torch::Tensor indptr, - torch::optional out_opt, std::string reduce) { - CHECK_CUDA(src); - CHECK_CUDA(indptr); - if (out_opt.has_value()) - CHECK_CUDA(out_opt.value()); - return segment_csr_cuda(src, indptr, out_opt, reduce); -} - -std::tuple> -segment_coo(torch::Tensor src, torch::Tensor index, torch::Tensor out, - std::string reduce) { - CHECK_CUDA(src); - CHECK_CUDA(index); - CHECK_CUDA(out); - return segment_coo_cuda(src, index, out, reduce); -} - -static auto registry = - torch::RegisterOperators("torch_scatter_cuda::segment_csr", &segment_csr) - .op("torch_scatter_cuda::segment_coo", &segment_coo); diff --git a/cuda/segment_kernel.cu b/cuda/segment_kernel.cu deleted file mode 100644 index f9e0f694..00000000 --- a/cuda/segment_kernel.cu +++ /dev/null @@ -1,494 +0,0 @@ -#include -#include -#include -#include - -#include "atomics.cuh" -#include "compat.cuh" -#include "indptr.cuh" - -#define THREADS 256 -#define BLOCKS(TB, N) (TB * N + THREADS - 1) / THREADS -#define FULL_MASK 0xffffffff - -enum ReductionType { SUM, MEAN, MIN, MAX }; - -const std::map reduce2REDUCE = { - {"sum", SUM}, {"add", SUM}, {"mean", MEAN}, {"min", MIN}, {"max", MAX}, -}; - -#define AT_DISPATCH_REDUCTION_TYPES(reduce, ...) \ - [&] { \ - switch (reduce2REDUCE.at(reduce)) { \ - case SUM: { \ - const ReductionType REDUCE = SUM; \ - return __VA_ARGS__(); \ - } \ - case MEAN: { \ - const ReductionType REDUCE = MEAN; \ - return __VA_ARGS__(); \ - } \ - case MIN: { \ - const ReductionType REDUCE = MIN; \ - return __VA_ARGS__(); \ - } \ - case MAX: { \ - const ReductionType REDUCE = MAX; \ - return __VA_ARGS__(); \ - } \ - } \ - }() - -template struct Reducer { - static inline __host__ __device__ scalar_t init() { - if (REDUCE == MIN) { - return std::numeric_limits::max(); - } else if (REDUCE == MAX) { - return std::numeric_limits::lowest(); - } else { - return (scalar_t)0; - } - } - - static inline __host__ __device__ void update(scalar_t *val, - scalar_t new_val) { - if (REDUCE == SUM || REDUCE == MEAN) { - *val = *val + new_val; - } else if ((REDUCE == MIN && new_val < *val) || - (REDUCE == MAX && new_val > *val)) { - *val = new_val; - } - } - - static inline __host__ __device__ void update(scalar_t *val, scalar_t new_val, - int64_t *arg, int64_t new_arg) { - if (REDUCE == SUM || REDUCE == MEAN) { - *val = *val + new_val; - } else if ((REDUCE == MIN && new_val < *val) || - (REDUCE == MAX && new_val > *val)) { - *val = new_val; - *arg = new_arg; - } - } - - static inline __host__ __device__ void write(scalar_t *address, scalar_t val, - int64_t *arg_address, - int64_t arg, int count) { - if (REDUCE == SUM) { - *address = val; - } else if (REDUCE == MEAN) { - *address = val / (scalar_t)max(count, 1); - } else if (REDUCE == MIN || REDUCE == MAX) { - if (count > 0) { - *address = val; - *arg_address = arg; - } else { - *address = (scalar_t)0; - } - } - } - - static inline __device__ void atomic_write(scalar_t *address, scalar_t val) { - if (REDUCE == SUM || REDUCE == MEAN) { - atomAdd(address, val); - } else if (REDUCE == MIN && val < *address) { - atomMin(address, val); - } else if (REDUCE == MAX && val > *address) { - atomMax(address, val); - } - } -}; - -template -__global__ void -segment_csr_kernel(const scalar_t *src_data, - const at::cuda::detail::TensorInfo indptr_info, - scalar_t *out_data, int64_t *arg_out_data, size_t N, - size_t E) { - - // Each warp processes exactly `32/TB` rows and aggregates all row values - // via a parallel reduction. - - int thread_idx = blockIdx.x * blockDim.x + threadIdx.x; - int row_idx = thread_idx / TB; - int lane_idx = thread_idx & (TB - 1); - - if (row_idx < N) { - int offset = IndexPtrToOffset::get(row_idx, indptr_info); - int64_t row_start = __ldg(indptr_info.data + offset); - int64_t row_end = __ldg(indptr_info.data + offset + - indptr_info.strides[indptr_info.dims - 1]); - - scalar_t val = Reducer::init(); - int64_t arg, arg_tmp; - - offset = (row_idx / (indptr_info.sizes[indptr_info.dims - 1] - 1)) * E; - for (int64_t src_idx = row_start + lane_idx; src_idx < row_end; - src_idx += TB) { - Reducer::update(&val, src_data[offset + src_idx], &arg, - src_idx); - } - -#pragma unroll - for (int i = TB / 2; i > 0; i /= 2) { - // Parallel reduction inside a single warp. - if (REDUCE == MIN || REDUCE == MAX) - arg_tmp = __shfl_down_sync(FULL_MASK, arg, i); - Reducer::update( - &val, __shfl_down_sync(FULL_MASK, val, i), &arg, arg_tmp); - } - - if (lane_idx == 0) { - Reducer::write(out_data + row_idx, val, - arg_out_data + row_idx, arg, - row_end - row_start); - } - } -} - -template -__global__ void segment_csr_broadcast_kernel( - const scalar_t *src_data, - const at::cuda::detail::TensorInfo indptr_info, - scalar_t *out_data, int64_t *arg_out_data, size_t N, size_t K, size_t E) { - - // Each thread processes exactly one row. It turned out that is more - // efficient than using shared memory due to avoiding synchronization - // barriers. - - int thread_idx = blockIdx.x * blockDim.x + threadIdx.x; - int row_idx = thread_idx / K; - int lane_idx = thread_idx % K; - - if (thread_idx < N * K) { - int offset = IndexPtrToOffset::get(row_idx, indptr_info); - int64_t row_start = __ldg(indptr_info.data + offset); - int64_t row_end = __ldg(indptr_info.data + offset + - indptr_info.strides[indptr_info.dims - 1]); - - scalar_t val = Reducer::init(); - int64_t arg; - - offset = (row_idx / (indptr_info.sizes[indptr_info.dims - 1] - 1)) * E * K; - for (int64_t src_idx = row_start; src_idx < row_end; src_idx++) { - Reducer::update( - &val, src_data[offset + K * src_idx + lane_idx], &arg, src_idx); - } - - Reducer::write(out_data + thread_idx, val, - arg_out_data + thread_idx, arg, - row_end - row_start); - } -} - -std::tuple> -segment_csr_cuda(torch::Tensor src, torch::Tensor indptr, - torch::optional out_opt, std::string reduce) { - - cudaSetDevice(src.get_device()); - - AT_ASSERTM(src.dim() >= indptr.dim(), "Input mismatch"); - - // Broadcasting `indptr` via `expand`. - auto sizes = indptr.sizes().vec(); - for (int i = 0; i < indptr.dim() - 1; i++) { - sizes[i] = src.size(i); - } - indptr = indptr.expand(sizes); - - src = src.contiguous(); - auto reduce_dim = indptr.dim() - 1; - - torch::Tensor out; - if (out_opt.has_value()) { - out = out_opt.value().contiguous(); - for (int i = 0; i < out.dim(); i++) - if (i != reduce_dim) - AT_ASSERTM(src.size(i) == out.size(i), "Input mismatch"); - AT_ASSERTM(out.size(reduce_dim) == indptr.size(reduce_dim) - 1, - "Input mismatch"); - } else { - sizes = src.sizes().vec(); - sizes[reduce_dim] = indptr.size(reduce_dim) - 1; - out = torch::empty(sizes, src.options()); - } - - torch::optional arg_out = torch::nullopt; - int64_t *arg_out_data = nullptr; - if (reduce2REDUCE.at(reduce) == MIN || reduce2REDUCE.at(reduce) == MAX) { - arg_out = torch::full_like(out, src.size(reduce_dim), indptr.options()); - arg_out_data = arg_out.value().DATA_PTR(); - } - - auto N = out.size(reduce_dim) * (indptr.numel() / indptr.size(-1)); - auto K = out.numel() / N; - auto E = src.size(reduce_dim); - - auto indptr_info = at::cuda::detail::getTensorInfo(indptr); - auto stream = at::cuda::getCurrentCUDAStream(); - AT_DISPATCH_ALL_TYPES(src.scalar_type(), "segment_csr_kernel", [&] { - auto src_data = src.DATA_PTR(); - auto out_data = out.DATA_PTR(); - - AT_DISPATCH_REDUCTION_TYPES(reduce, [&] { - if (K == 1) { - segment_csr_kernel - <<>>( - src_data, indptr_info, out_data, arg_out_data, N, E); - } else { - segment_csr_broadcast_kernel - <<>>( - src_data, indptr_info, out_data, arg_out_data, N, K, E); - } - }); - }); - - return std::make_tuple(out, arg_out); -} - -template -__global__ void -segment_coo_kernel(const scalar_t *src_data, - const at::cuda::detail::TensorInfo index_info, - scalar_t *out_data, size_t E, size_t N) { - - // Each thread processes exactly one entry. Within a warp, we perform a - // parallel reduction across equal indices, and write the intermediate - // result via atomics. - - int row_idx = blockIdx.x * blockDim.x + threadIdx.x; - int lane_idx = row_idx & (32 - 1); - int D = index_info.sizes[index_info.dims - 1]; - - if (row_idx < E) { - int offset = at::cuda::detail::IndexToOffset::get( - row_idx, index_info); - int64_t idx = index_info.data[offset], next_idx; - int out_idx = (row_idx / D) * N + idx; - - scalar_t val = HAS_VAL ? src_data[row_idx] : (scalar_t)1, tmp; - -#pragma unroll - for (int i = 1; i < 32; i *= 2) { - // Parallel reduction inside a single warp. - tmp = __shfl_up_sync(FULL_MASK, val, i); - next_idx = __shfl_up_sync(FULL_MASK, idx, i); - if (lane_idx >= i && row_idx / D == (row_idx - i) / D) { - assert(idx >= next_idx); - if (idx == next_idx) - Reducer::update(&val, tmp); - } - } - - next_idx = __shfl_down_sync(FULL_MASK, idx, 1); - if (lane_idx == 32 - 1 || row_idx / D != (row_idx + 1) / D || - idx != next_idx) - Reducer::atomic_write(out_data + out_idx, val); - } -} - -template -__global__ void segment_coo_arg_kernel( - const scalar_t *src_data, - const at::cuda::detail::TensorInfo index_info, - scalar_t *out_data, int64_t *arg_out_data, size_t E, size_t N) { - - int row_idx = blockIdx.x * blockDim.x + threadIdx.x; - int D = index_info.sizes[index_info.dims - 1]; - - if (row_idx < E) { - int offset = at::cuda::detail::IndexToOffset::get( - row_idx, index_info); - int64_t idx = index_info.data[offset]; - int out_idx = (row_idx / D) * N + idx; - - scalar_t val = __ldg(out_data + out_idx); - if (src_data[row_idx] == val) - arg_out_data[out_idx] = row_idx % D; - } -} - -template -__global__ void segment_coo_broadcast_kernel( - const scalar_t *src_data, - const at::cuda::detail::TensorInfo index_info, - scalar_t *out_data, size_t E, size_t K, size_t N) { - - // Each thread processes a single column and `TB` index entries. Coalesced - // read and write is performed in column-major order. The intermediate - // results are written via atomics. - - int D = index_info.sizes[index_info.dims - 1]; - int E_1 = E / D; - int E_2 = D + TB - (D % TB); - - int row_idx = blockIdx.x * blockDim.y + threadIdx.y; - int col_idx = blockIdx.y * blockDim.x + threadIdx.x; - - int dim_start = (row_idx * TB) / E_2; - int row_start = (row_idx * TB) % E_2; - - if (dim_start < E_1 && col_idx < K) { - - int offset = at::cuda::detail::IndexToOffset::get( - dim_start * D + row_start, index_info); - int idx1 = __ldg(index_info.data + offset), idx2; - - scalar_t val = src_data[K * (dim_start * D + row_start) + col_idx]; - -#pragma unroll - for (int i = 1; i < TB; i++) { - if (row_start + i >= D) - break; - - idx2 = __ldg(index_info.data + offset + - i * index_info.strides[index_info.dims - 1]); - assert(idx1 <= idx2); - if (idx1 == idx2) { - Reducer::update( - &val, src_data[K * (dim_start * D + row_start + i) + col_idx]); - } else { - Reducer::atomic_write( - out_data + (dim_start * N + idx1) * K + col_idx, val); - val = src_data[K * (dim_start * D + row_start + i) + col_idx]; - } - - idx1 = idx2; - } - - Reducer::atomic_write( - out_data + (dim_start * N + idx1) * K + col_idx, val); - } -} - -template -__global__ void segment_coo_arg_broadcast_kernel( - const scalar_t *src_data, - const at::cuda::detail::TensorInfo index_info, - scalar_t *out_data, int64_t *arg_out_data, size_t E, size_t K, size_t N) { - - int thread_idx = blockIdx.x * blockDim.x + threadIdx.x; - int row_idx = thread_idx / K; - int col_idx = thread_idx % K; - int D = index_info.sizes[index_info.dims - 1]; - - if (row_idx < E && col_idx < K) { - int offset = at::cuda::detail::IndexToOffset::get( - row_idx, index_info); - int idx = __ldg(index_info.data + offset); - int out_idx = ((row_idx / D) * N + idx) * K + col_idx; - - scalar_t val = __ldg(out_data + out_idx); - if (src_data[thread_idx] == val) - arg_out_data[out_idx] = row_idx % D; - } -} - -std::tuple> -segment_coo_cuda(torch::Tensor src, torch::Tensor index, torch::Tensor out, - std::string reduce) { - - cudaSetDevice(src.get_device()); - - AT_ASSERTM(src.dim() >= index.dim(), "Input mismatch"); - - // Broadcasting `index` via `expand`. - auto sizes = index.sizes().vec(); - for (int i = 0; i < index.dim(); i++) { - sizes[i] = src.size(i); - } - index = index.expand(sizes); - - src = src.contiguous(); - out = out.contiguous(); - auto reduce_dim = index.dim() - 1; - - for (int i = 0; i < out.dim(); i++) - if (i != reduce_dim) - AT_ASSERTM(src.size(i) == out.size(i), "Input mismatch"); - - torch::optional arg_out = torch::nullopt; - int64_t *arg_out_data = nullptr; - if (reduce2REDUCE.at(reduce) == MIN || reduce2REDUCE.at(reduce) == MAX) { - arg_out = torch::full_like(out, src.size(reduce_dim), index.options()); - arg_out_data = arg_out.value().DATA_PTR(); - } - - auto E = index.numel(); - auto E_2 = index.size(reduce_dim); - auto E_1 = index.numel() / E_2; - auto K = src.numel() / E; - auto N = out.size(reduce_dim); - auto avg_len = (float)E_2 / (float)N; - - auto index_info = at::cuda::detail::getTensorInfo(index); - auto stream = at::cuda::getCurrentCUDAStream(); - AT_DISPATCH_ALL_TYPES(src.scalar_type(), "segment_coo_kernel", [&] { - auto src_data = src.DATA_PTR(); - auto out_data = out.DATA_PTR(); - - AT_DISPATCH_REDUCTION_TYPES(reduce, [&] { - if (K == 1) { - segment_coo_kernel - <<>>(src_data, index_info, - out_data, E, N); - } else if (avg_len <= 8) { - segment_coo_broadcast_kernel - <<>>(src_data, index_info, out_data, E, K, - N); - } else if (avg_len <= 16) { - segment_coo_broadcast_kernel - <<>>(src_data, index_info, out_data, E, K, - N); - } else if (avg_len <= 32) { - segment_coo_broadcast_kernel - <<>>(src_data, index_info, out_data, E, K, - N); - } else { - segment_coo_broadcast_kernel - <<>>(src_data, index_info, out_data, E, K, - N); - } - - if (REDUCE == MIN || REDUCE == MAX) { - if (K == 1) { - segment_coo_arg_kernel - <<>>( - src_data, index_info, out_data, arg_out_data, E, N); - } else { - segment_coo_arg_broadcast_kernel - <<>>( - src_data, index_info, out_data, arg_out_data, E, K, N); - } - } - }); - }); - - if (reduce2REDUCE.at(reduce) == MEAN) { - auto sizes = index.sizes().vec(); - sizes[reduce_dim] = out.size(reduce_dim); - auto count = torch::zeros(sizes, out.options()); - - AT_DISPATCH_ALL_TYPES(out.scalar_type(), "count_kernel", [&] { - auto count_data = count.DATA_PTR(); - segment_coo_kernel - <<>>(nullptr, index_info, - count_data, E, N); - }); - - count.clamp_(1); - arg_out = count; - - for (int i = reduce_dim + 1; i < out.dim(); i++) { - count = count.unsqueeze(-1); - } - - out.div_(count); - } - - return std::make_tuple(out, arg_out); -} diff --git a/test/test_jit.py b/test/test_jit.py deleted file mode 100644 index 74f9d815..00000000 --- a/test/test_jit.py +++ /dev/null @@ -1,31 +0,0 @@ -from typing import Optional - -import torch -import torch_scatter - - -@torch.jit.script -def segment_csr(src: torch.Tensor, indptr: torch.Tensor, - out: Optional[torch.Tensor] = None, reduce: str = "sum"): - return torch.ops.torch_scatter_cpu.segment_sum_csr(src, indptr, out) - - -def test_jit(): - # op = torch.ops.torch_scatter_cpu.segment_sum_csr - - src = torch.randn(8, 4) - src.requires_grad_() - indptr = torch.tensor([0, 2, 4, 6, 8]) - - out = segment_csr(src, indptr) - print(out) - - print(src.grad) - out.backward(torch.randn_like(out)) - print(src.grad) - - # op = torch.ops.torch_scatter_cpu.segment_csr - # out = op(src, indptr, None, "sum") - # print(out) - - # traced_cell = torch.jit.script(op) diff --git a/test/test_scatter.py b/test/test_scatter.py index aeaa3a66..1168f0ad 100644 --- a/test/test_scatter.py +++ b/test/test_scatter.py @@ -7,8 +7,6 @@ from .utils import tensor, dtypes, devices -devices = ['cpu'] - reductions = ['sum', 'add', 'mean', 'min', 'max'] tests = [ From 7fd9091c4c432100a702cf011a25c45d69bed287 Mon Sep 17 00:00:00 2001 From: rusty1s Date: Thu, 30 Jan 2020 15:09:29 +0100 Subject: [PATCH 08/12] update code and tests --- .coveragerc | 7 +- test/composite/test_logsumexp.py | 18 ++ test/composite/test_softmax.py | 53 ++--- test/composite/test_std.py | 12 + test/test_backward.py | 54 ----- test/test_broadcasting.py | 10 - test/test_forward.py | 127 ----------- test/test_logsumexp.py | 24 -- test/test_max_min.py | 22 -- test/test_multi_gpu.py | 41 +++- test/test_scatter.py | 4 +- test/test_segment.py | 4 +- test/test_std.py | 30 --- test/utils.py | 2 + torch_scatter/__init__.py | 6 + torch_scatter/add.py | 75 ------- torch_scatter/composite/__init__.py | 4 + torch_scatter/composite/logsumexp.py | 40 ++++ torch_scatter/composite/softmax.py | 85 ++----- torch_scatter/composite/std.py | 41 ++++ torch_scatter/composite/utils.py | 14 ++ torch_scatter/div.py | 94 -------- torch_scatter/gather.py | 67 ------ torch_scatter/helpers.py | 15 -- torch_scatter/logsumexp.py | 54 ----- torch_scatter/max.py | 109 --------- torch_scatter/mean.py | 70 ------ torch_scatter/min.py | 112 ---------- torch_scatter/mul.py | 93 -------- torch_scatter/scatter.py | 69 ++++++ torch_scatter/segment.py | 319 --------------------------- torch_scatter/segment_coo.py | 88 ++++++++ torch_scatter/segment_csr.py | 67 ++++++ torch_scatter/std.py | 65 ------ torch_scatter/sub.py | 64 ------ torch_scatter/utils/__init__.py | 0 torch_scatter/utils/gen.py | 54 ----- 37 files changed, 438 insertions(+), 1575 deletions(-) create mode 100644 test/composite/test_logsumexp.py create mode 100644 test/composite/test_std.py delete mode 100644 test/test_backward.py delete mode 100644 test/test_forward.py delete mode 100644 test/test_logsumexp.py delete mode 100644 test/test_max_min.py delete mode 100644 test/test_std.py delete mode 100644 torch_scatter/add.py create mode 100644 torch_scatter/composite/logsumexp.py create mode 100644 torch_scatter/composite/std.py create mode 100644 torch_scatter/composite/utils.py delete mode 100644 torch_scatter/div.py delete mode 100644 torch_scatter/gather.py delete mode 100644 torch_scatter/helpers.py delete mode 100644 torch_scatter/logsumexp.py delete mode 100644 torch_scatter/max.py delete mode 100644 torch_scatter/mean.py delete mode 100644 torch_scatter/min.py delete mode 100644 torch_scatter/mul.py delete mode 100644 torch_scatter/segment.py delete mode 100644 torch_scatter/std.py delete mode 100644 torch_scatter/sub.py delete mode 100644 torch_scatter/utils/__init__.py delete mode 100644 torch_scatter/utils/gen.py diff --git a/.coveragerc b/.coveragerc index b93d9733..b1b88d2e 100644 --- a/.coveragerc +++ b/.coveragerc @@ -3,10 +3,5 @@ source=torch_scatter [report] exclude_lines = pragma: no cover - cuda - forward - backward - apply + torch.jit.script raise - min_value - max_value diff --git a/test/composite/test_logsumexp.py b/test/composite/test_logsumexp.py new file mode 100644 index 00000000..fd97fa04 --- /dev/null +++ b/test/composite/test_logsumexp.py @@ -0,0 +1,18 @@ +import torch +from torch_scatter import scatter_logsumexp + + +def test_logsumexp(): + src = torch.tensor([0.5, 0, 0.5, -2.1, 3.2, 7, -1, -100]) + index = torch.tensor([0, 1, 0, 1, 1, 2, 4, 4]) + + out = scatter_logsumexp(src, index) + + out0 = torch.logsumexp(torch.tensor([0.5, 0.5]), dim=-1) + out1 = torch.logsumexp(torch.tensor([0, -2.1, 3.2]), dim=-1) + out2 = torch.logsumexp(torch.tensor(7, dtype=torch.float), dim=-1) + out3 = torch.logsumexp(torch.tensor([], dtype=torch.float), dim=-1) + out4 = torch.tensor(-1, dtype=torch.float) + + expected = torch.stack([out0, out1, out2, out3, out4], dim=0) + assert torch.allclose(out, expected) diff --git a/test/composite/test_softmax.py b/test/composite/test_softmax.py index c13ab97d..25f1c3f1 100644 --- a/test/composite/test_softmax.py +++ b/test/composite/test_softmax.py @@ -1,57 +1,38 @@ -from itertools import product - -import pytest import torch -from torch_scatter.composite import scatter_log_softmax, scatter_softmax - -from test.utils import devices, tensor, grad_dtypes +from torch_scatter import scatter_log_softmax, scatter_softmax -@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices)) -def test_softmax(dtype, device): - src = tensor([0.2, 0, 0.2, -2.1, 3.2, 7, -1, float('-inf')], dtype, device) - index = tensor([0, 1, 0, 1, 1, 2, 4, 4], torch.long, device) +def test_softmax(): + src = torch.tensor([0.2, 0, 0.2, -2.1, 3.2, 7, -1, float('-inf')]) + index = torch.tensor([0, 1, 0, 1, 1, 2, 4, 4]) out = scatter_softmax(src, index) - out0 = torch.softmax(torch.tensor([0.2, 0.2], dtype=dtype), dim=-1) - out1 = torch.softmax(torch.tensor([0, -2.1, 3.2], dtype=dtype), dim=-1) - out2 = torch.softmax(torch.tensor([7], dtype=dtype), dim=-1) - out4 = torch.softmax(torch.tensor([-1, float('-inf')], dtype=dtype), - dim=-1) + out0 = torch.softmax(torch.tensor([0.2, 0.2]), dim=-1) + out1 = torch.softmax(torch.tensor([0, -2.1, 3.2]), dim=-1) + out2 = torch.softmax(torch.tensor([7], dtype=torch.float), dim=-1) + out4 = torch.softmax(torch.tensor([-1, float('-inf')]), dim=-1) expected = torch.stack([ out0[0], out1[0], out0[1], out1[1], out1[2], out2[0], out4[0], out4[1] - ], dim=0).to(device) + ], dim=0) assert torch.allclose(out, expected) -@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices)) -def test_softmax_broadcasting(dtype, device): - src = torch.randn(10, 5, dtype=dtype, device=device) - index = tensor([0, 0, 1, 1, 2, 2, 3, 3, 4, 4], torch.long, device) - - out = scatter_softmax(src, index, dim=0).view(5, 2, 5) - out = out.sum(dim=1) - assert torch.allclose(out, torch.ones_like(out)) - - -@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices)) -def test_log_softmax(dtype, device): - src = tensor([0.2, 0, 0.2, -2.1, 3.2, 7, -1, float('-inf')], dtype, device) - index = tensor([0, 1, 0, 1, 1, 2, 4, 4], torch.long, device) +def test_log_softmax(): + src = torch.tensor([0.2, 0, 0.2, -2.1, 3.2, 7, -1, float('-inf')]) + index = torch.tensor([0, 1, 0, 1, 1, 2, 4, 4]) out = scatter_log_softmax(src, index) - out0 = torch.log_softmax(torch.tensor([0.2, 0.2], dtype=dtype), dim=-1) - out1 = torch.log_softmax(torch.tensor([0, -2.1, 3.2], dtype=dtype), dim=-1) - out2 = torch.log_softmax(torch.tensor([7], dtype=dtype), dim=-1) - out4 = torch.log_softmax(torch.tensor([-1, float('-inf')], dtype=dtype), - dim=-1) + out0 = torch.log_softmax(torch.tensor([0.2, 0.2]), dim=-1) + out1 = torch.log_softmax(torch.tensor([0, -2.1, 3.2]), dim=-1) + out2 = torch.log_softmax(torch.tensor([7], dtype=torch.float), dim=-1) + out4 = torch.log_softmax(torch.tensor([-1, float('-inf')]), dim=-1) expected = torch.stack([ out0[0], out1[0], out0[1], out1[1], out1[2], out2[0], out4[0], out4[1] - ], dim=0).to(device) + ], dim=0) assert torch.allclose(out, expected) diff --git a/test/composite/test_std.py b/test/composite/test_std.py new file mode 100644 index 00000000..3b3c7ad5 --- /dev/null +++ b/test/composite/test_std.py @@ -0,0 +1,12 @@ +import torch +from torch_scatter import scatter_std + + +def test_std(): + src = torch.tensor([[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]], dtype=torch.float) + index = torch.tensor([[0, 0, 0, 0, 0], [1, 1, 1, 1, 1]], dtype=torch.long) + + out = scatter_std(src, index, dim=-1, unbiased=True) + std = src.std(dim=-1, unbiased=True)[0] + expected = torch.tensor([[std, 0], [0, std]]) + assert torch.allclose(out, expected) diff --git a/test/test_backward.py b/test/test_backward.py deleted file mode 100644 index d9e03fa8..00000000 --- a/test/test_backward.py +++ /dev/null @@ -1,54 +0,0 @@ -from itertools import product - -import pytest -import torch -from torch.autograd import gradcheck -import torch_scatter - -from .utils import grad_dtypes as dtypes, devices, tensor - -funcs = ['add', 'sub', 'mul', 'div', 'mean'] -indices = [2, 0, 1, 1, 0] - - -@pytest.mark.parametrize('func,device', product(funcs, devices)) -def test_backward(func, device): - index = torch.tensor(indices, dtype=torch.long, device=device) - src = torch.rand((index.size(0), 2), dtype=torch.double, device=device) - src.requires_grad_() - - op = getattr(torch_scatter, 'scatter_{}'.format(func)) - data = (src, index, 0) - assert gradcheck(op, data, eps=1e-6, atol=1e-4) is True - - -tests = [{ - 'name': 'max', - 'src': [[1, 1], [2, 2], [3, 3], [4, 4], [5, 5]], - 'index': [2, 0, 1, 1, 0], - 'dim': 0, - 'fill_value': 0, - 'grad': [[4, 4], [8, 8], [6, 6]], - 'expected': [[6, 6], [0, 0], [0, 0], [8, 8], [4, 4]], -}, { - 'name': 'min', - 'src': [[1, 1], [2, 2], [3, 3], [4, 4], [5, 5]], - 'index': [2, 0, 1, 1, 0], - 'dim': 0, - 'fill_value': 3, - 'grad': [[4, 4], [8, 8], [6, 6]], - 'expected': [[6, 6], [4, 4], [8, 8], [0, 0], [0, 0]], -}] - - -@pytest.mark.parametrize('test,dtype,device', product(tests, dtypes, devices)) -def test_arg_backward(test, dtype, device): - src = tensor(test['src'], dtype, device) - src.requires_grad_() - index = tensor(test['index'], torch.long, device) - grad = tensor(test['grad'], dtype, device) - - op = getattr(torch_scatter, 'scatter_{}'.format(test['name'])) - out, _ = op(src, index, test['dim'], fill_value=test['fill_value']) - out.backward(grad) - assert src.grad.tolist() == test['expected'] diff --git a/test/test_broadcasting.py b/test/test_broadcasting.py index 2b68e4a7..efe36103 100644 --- a/test/test_broadcasting.py +++ b/test/test_broadcasting.py @@ -14,16 +14,6 @@ def test_broadcasting(device): out = scatter_add(src, index, dim=2, dim_size=H) assert out.size() == (B, C, H, W) - src = torch.randn((B, 1, H, W), device=device) - index = torch.randint(0, H, (B, C, H, W)).to(device, torch.long) - out = scatter_add(src, index, dim=2, dim_size=H) - assert out.size() == (B, C, H, W) - - src = torch.randn((B, 1, H, W), device=device) - index = torch.randint(0, H, (B, 1, H, W)).to(device, torch.long) - out = scatter_add(src, index, dim=2, dim_size=H) - assert out.size() == (B, 1, H, W) - src = torch.randn((B, C, H, W), device=device) index = torch.randint(0, H, (H, )).to(device, torch.long) out = scatter_add(src, index, dim=2, dim_size=H) diff --git a/test/test_forward.py b/test/test_forward.py deleted file mode 100644 index 03feab1f..00000000 --- a/test/test_forward.py +++ /dev/null @@ -1,127 +0,0 @@ -from itertools import product - -import pytest -import torch -import torch_scatter - -from .utils import dtypes, devices, tensor - -tests = [{ - 'name': 'add', - 'src': [[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]], - 'index': [[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]], - 'dim': -1, - 'fill_value': 0, - 'expected': [[0, 0, 4, 3, 3, 0], [2, 4, 4, 0, 0, 0]], -}, { - 'name': 'add', - 'src': [[5, 2], [2, 5], [4, 3], [1, 3]], - 'index': [0, 1, 1, 0], - 'dim': 0, - 'fill_value': 0, - 'expected': [[6, 5], [6, 8]], -}, { - 'name': 'sub', - 'src': [[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]], - 'index': [[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]], - 'dim': -1, - 'fill_value': 9, - 'expected': [[9, 9, 5, 6, 6, 9], [7, 5, 5, 9, 9, 9]], -}, { - 'name': 'sub', - 'src': [[5, 2], [2, 2], [4, 2], [1, 3]], - 'index': [0, 1, 1, 0], - 'dim': 0, - 'fill_value': 9, - 'expected': [[3, 4], [3, 5]], -}, { - 'name': 'mul', - 'src': [[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]], - 'index': [[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]], - 'dim': -1, - 'fill_value': 1, - 'expected': [[1, 1, 4, 3, 2, 0], [0, 4, 3, 1, 1, 1]], -}, { - 'name': 'mul', - 'src': [[5, 2], [2, 5], [4, 3], [1, 3]], - 'index': [0, 1, 1, 0], - 'dim': 0, - 'fill_value': 1, - 'expected': [[5, 6], [8, 15]], -}, { - 'name': 'div', - 'src': [[2, 1, 1, 4, 2], [1, 2, 1, 2, 4]], - 'index': [[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]], - 'dim': -1, - 'fill_value': 1, - 'expected': [[1, 1, 0.25, 0.5, 0.5, 1], [0.5, 0.25, 0.5, 1, 1, 1]], -}, { - 'name': 'div', - 'src': [[4, 2], [2, 1], [4, 2], [1, 2]], - 'index': [0, 1, 1, 0], - 'dim': 0, - 'fill_value': 1, - 'expected': [[0.25, 0.25], [0.125, 0.5]], -}, { - 'name': 'mean', - 'src': [[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]], - 'index': [[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]], - 'dim': -1, - 'fill_value': 0, - 'expected': [[0, 0, 4, 3, 1.5, 0], [1, 4, 2, 0, 0, 0]], -}, { - 'name': 'mean', - 'src': [[5, 2], [2, 5], [4, 3], [1, 3]], - 'index': [0, 1, 1, 0], - 'dim': 0, - 'fill_value': 0, - 'expected': [[3, 2.5], [3, 4]], -}, { - 'name': 'max', - 'src': [[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]], - 'index': [[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]], - 'dim': -1, - 'fill_value': 0, - 'expected': [[0, 0, 4, 3, 2, 0], [2, 4, 3, 0, 0, 0]], - 'expected_arg': [[-1, -1, 3, 4, 0, 1], [1, 4, 3, -1, -1, -1]], -}, { - 'name': 'max', - 'src': [[5, 2], [2, 5], [4, 3], [1, 3]], - 'index': [0, 1, 1, 0], - 'dim': 0, - 'fill_value': 0, - 'expected': [[5, 3], [4, 5]], - 'expected_arg': [[0, 3], [2, 1]], -}, { - 'name': 'min', - 'src': [[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]], - 'index': [[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]], - 'dim': -1, - 'fill_value': 9, - 'expected': [[9, 9, 4, 3, 1, 0], [0, 4, 1, 9, 9, 9]], - 'expected_arg': [[-1, -1, 3, 4, 2, 1], [0, 4, 2, -1, -1, -1]], -}, { - 'name': 'min', - 'src': [[5, 2], [2, 5], [4, 3], [1, 3]], - 'index': [0, 1, 1, 0], - 'dim': 0, - 'fill_value': 9, - 'expected': [[1, 2], [2, 3]], - 'expected_arg': [[3, 0], [1, 2]], -}] - - -@pytest.mark.parametrize('test,dtype,device', product(tests, dtypes, devices)) -def test_forward(test, dtype, device): - src = tensor(test['src'], dtype, device) - index = tensor(test['index'], torch.long, device) - expected = tensor(test['expected'], dtype, device) - - op = getattr(torch_scatter, 'scatter_{}'.format(test['name'])) - out = op(src, index, test['dim'], fill_value=test['fill_value']) - - if isinstance(out, tuple): - assert out[0].tolist() == expected.tolist() - assert out[1].tolist() == test['expected_arg'] - else: - assert out.tolist() == expected.tolist() diff --git a/test/test_logsumexp.py b/test/test_logsumexp.py deleted file mode 100644 index 1a03adac..00000000 --- a/test/test_logsumexp.py +++ /dev/null @@ -1,24 +0,0 @@ -from itertools import product - -import torch -import pytest -from torch_scatter import scatter_logsumexp - -from .utils import devices, tensor, grad_dtypes - - -@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices)) -def test_logsumexp(dtype, device): - src = tensor([0.5, 0, 0.5, -2.1, 3.2, 7, -1, float('-inf')], dtype, device) - index = tensor([0, 1, 0, 1, 1, 2, 4, 4], torch.long, device) - - out = scatter_logsumexp(src, index) - - out0 = torch.logsumexp(torch.tensor([0.5, 0.5], dtype=dtype), dim=-1) - out1 = torch.logsumexp(torch.tensor([0, -2.1, 3.2], dtype=dtype), dim=-1) - out2 = torch.logsumexp(torch.tensor(7, dtype=dtype), dim=-1) - out3 = torch.tensor(torch.finfo(dtype).min, dtype=dtype) - out4 = torch.tensor(-1, dtype=dtype) - - expected = torch.stack([out0, out1, out2, out3, out4], dim=0).to(device) - assert torch.allclose(out, expected) diff --git a/test/test_max_min.py b/test/test_max_min.py deleted file mode 100644 index 6086f4a1..00000000 --- a/test/test_max_min.py +++ /dev/null @@ -1,22 +0,0 @@ -import torch -from torch_scatter import scatter_max, scatter_min - - -def test_max_fill_value(): - src = torch.Tensor([[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]]) - index = torch.tensor([[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]]) - - out, _ = scatter_max(src, index) - - v = torch.finfo(torch.float).min - assert out.tolist() == [[v, v, 4, 3, 2, 0], [2, 4, 3, v, v, v]] - - -def test_min_fill_value(): - src = torch.Tensor([[-2, 0, -1, -4, -3], [0, -2, -1, -3, -4]]) - index = torch.tensor([[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]]) - - out, _ = scatter_min(src, index) - - v = torch.finfo(torch.float).max - assert out.tolist() == [[v, v, -4, -3, -2, 0], [-2, -4, -3, v, v, v]] diff --git a/test/test_multi_gpu.py b/test/test_multi_gpu.py index dfb54655..cdaf893e 100644 --- a/test/test_multi_gpu.py +++ b/test/test_multi_gpu.py @@ -1,12 +1,43 @@ +from itertools import product + import pytest import torch -from torch_scatter import scatter_max +import torch_scatter + +from .utils import reductions, tensor, dtypes + +tests = [ + { + 'src': [1, 2, 3, 4, 5, 6], + 'index': [0, 0, 1, 1, 1, 3], + 'indptr': [0, 2, 5, 5, 6], + 'dim': 0, + 'sum': [3, 12, 0, 6], + 'add': [3, 12, 0, 6], + 'mean': [1.5, 4, 0, 6], + 'min': [1, 3, 0, 6], + 'max': [2, 5, 0, 6], + }, +] @pytest.mark.skipif(not torch.cuda.is_available(), reason='CUDA not available') @pytest.mark.skipif(torch.cuda.device_count() < 2, reason='No multiple GPUS') -def test_multi_gpu(): +@pytest.mark.parametrize('test,reduce,dtype', product(tests, reductions, + dtypes)) +def test_forward(test, reduce, dtype): device = torch.device('cuda:1') - src = torch.tensor([2.0, 3.0, 4.0, 5.0], device=device) - index = torch.tensor([0, 0, 1, 1], device=device) - assert scatter_max(src, index)[0].tolist() == [3, 5] + src = tensor(test['src'], dtype, device) + index = tensor(test['index'], torch.long, device) + indptr = tensor(test['indptr'], torch.long, device) + dim = test['dim'] + expected = tensor(test[reduce], dtype, device) + + out = torch_scatter.scatter(src, index, dim, reduce=reduce) + assert torch.all(out == expected) + + out = torch_scatter.segment_coo(src, index, reduce=reduce) + assert torch.all(out == expected) + + out = torch_scatter.segment_csr(src, indptr, reduce=reduce) + assert torch.all(out == expected) diff --git a/test/test_scatter.py b/test/test_scatter.py index 1168f0ad..8e43844b 100644 --- a/test/test_scatter.py +++ b/test/test_scatter.py @@ -5,9 +5,7 @@ from torch.autograd import gradcheck import torch_scatter -from .utils import tensor, dtypes, devices - -reductions = ['sum', 'add', 'mean', 'min', 'max'] +from .utils import reductions, tensor, dtypes, devices tests = [ { diff --git a/test/test_segment.py b/test/test_segment.py index e93cac3f..7b6bb39e 100644 --- a/test/test_segment.py +++ b/test/test_segment.py @@ -5,9 +5,7 @@ from torch.autograd import gradcheck import torch_scatter -from .utils import tensor, dtypes, devices - -reductions = ['sum', 'add', 'mean', 'min', 'max'] +from .utils import reductions, tensor, dtypes, devices tests = [ { diff --git a/test/test_std.py b/test/test_std.py deleted file mode 100644 index 95eff085..00000000 --- a/test/test_std.py +++ /dev/null @@ -1,30 +0,0 @@ -from itertools import product - -import pytest -import torch -from torch_scatter import scatter_std - -from .utils import grad_dtypes as dtypes, devices, tensor - -biases = [True, False] - - -@pytest.mark.parametrize('dtype,device,bias', product(dtypes, devices, biases)) -def test_std(dtype, device, bias): - src = tensor([[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]], dtype, device) - index = tensor([[0, 0, 0, 0, 0], [1, 1, 1, 1, 1]], torch.long, device) - - out = scatter_std(src, index, dim=-1, unbiased=bias) - std = src.std(dim=-1, unbiased=bias)[0].item() - expected = tensor([[std, 0], [0, std]], dtype, device) - assert torch.allclose(out, expected) - - -@pytest.mark.parametrize('dtype,device', product(dtypes, devices)) -def test_empty_std(dtype, device): - out = torch.zeros(1, 5, dtype=dtype, device=device) - src = tensor([], dtype, device).view(0, 5) - index = tensor([], torch.long, device).view(0, 5) - - out = scatter_std(src, index, dim=0, out=out) - assert out.tolist() == [[0, 0, 0, 0, 0]] diff --git a/test/utils.py b/test/utils.py index 9bb241bb..1eb352b2 100644 --- a/test/utils.py +++ b/test/utils.py @@ -1,5 +1,7 @@ import torch +reductions = ['sum', 'add', 'mean', 'min', 'max'] + dtypes = [torch.float, torch.double, torch.int, torch.long] grad_dtypes = [torch.float, torch.double] diff --git a/torch_scatter/__init__.py b/torch_scatter/__init__.py index 29e53596..0856e84d 100644 --- a/torch_scatter/__init__.py +++ b/torch_scatter/__init__.py @@ -6,6 +6,8 @@ from .segment_coo import (segment_sum_coo, segment_add_coo, segment_mean_coo, segment_min_coo, segment_max_coo, segment_coo, gather_coo) +from .composite import (scatter_std, scatter_logsumexp, scatter_softmax, + scatter_log_softmax) __version__ = '2.0.0' @@ -30,6 +32,10 @@ 'segment_max_coo', 'segment_coo', 'gather_coo', + 'scatter_std', + 'scatter_logsumexp', + 'scatter_softmax', + 'scatter_log_softmax', 'torch_scatter', '__version__', ] diff --git a/torch_scatter/add.py b/torch_scatter/add.py deleted file mode 100644 index 51a0e3db..00000000 --- a/torch_scatter/add.py +++ /dev/null @@ -1,75 +0,0 @@ -from torch_scatter.utils.gen import gen - - -def scatter_add(src, index, dim=-1, out=None, dim_size=None, fill_value=0): - r""" - | - - .. image:: https://raw.githubusercontent.com/rusty1s/pytorch_scatter/ - master/docs/source/_figures/add.svg?sanitize=true - :align: center - :width: 400px - - | - - Sums all values from the :attr:`src` tensor into :attr:`out` at the indices - specified in the :attr:`index` tensor along a given axis :attr:`dim`. For - each value in :attr:`src`, its output index is specified by its index in - :attr:`src` for dimensions outside of :attr:`dim` and by the - corresponding value in :attr:`index` for dimension :attr:`dim`. If - multiple indices reference the same location, their **contributions add**. - - Formally, if :attr:`src` and :attr:`index` are n-dimensional tensors with - size :math:`(x_0, ..., x_{i-1}, x_i, x_{i+1}, ..., x_{n-1})` and - :attr:`dim` = `i`, then :attr:`out` must be an n-dimensional tensor with - size :math:`(x_0, ..., x_{i-1}, y, x_{i+1}, ..., x_{n-1})`. Moreover, the - values of :attr:`index` must be between `0` and `out.size(dim) - 1`. - Both :attr:`src` and :attr:`index` are broadcasted in case their dimensions - do not match. - - For one-dimensional tensors, the operation computes - - .. math:: - \mathrm{out}_i = \mathrm{out}_i + \sum_j \mathrm{src}_j - - where :math:`\sum_j` is over :math:`j` such that - :math:`\mathrm{index}_j = i`. - - Args: - src (Tensor): The source tensor. - index (LongTensor): The indices of elements to scatter. - dim (int, optional): The axis along which to index. - (default: :obj:`-1`) - out (Tensor, optional): The destination tensor. (default: :obj:`None`) - dim_size (int, optional): If :attr:`out` is not given, automatically - create output with size :attr:`dim_size` at dimension :attr:`dim`. - If :attr:`dim_size` is not given, a minimal sized output tensor is - returned. (default: :obj:`None`) - fill_value (int, optional): If :attr:`out` is not given, automatically - fill output tensor with :attr:`fill_value`. (default: :obj:`0`) - - :rtype: :class:`Tensor` - - .. testsetup:: - - import torch - - .. testcode:: - - from torch_scatter import scatter_add - - src = torch.Tensor([[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]]) - index = torch.tensor([[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]]) - out = src.new_zeros((2, 6)) - - out = scatter_add(src, index, out=out) - - print(out) - - .. testoutput:: - - tensor([[0., 0., 4., 3., 3., 0.], - [2., 4., 4., 0., 0., 0.]]) - """ - src, out, index, dim = gen(src, index, dim, out, dim_size, fill_value) - return out.scatter_add_(dim, index, src) diff --git a/torch_scatter/composite/__init__.py b/torch_scatter/composite/__init__.py index 74cdb7db..2acc5081 100644 --- a/torch_scatter/composite/__init__.py +++ b/torch_scatter/composite/__init__.py @@ -1,6 +1,10 @@ +from .std import scatter_std +from .logsumexp import scatter_logsumexp from .softmax import scatter_log_softmax, scatter_softmax __all__ = [ + 'scatter_std', + 'scatter_logsumexp', 'scatter_softmax', 'scatter_log_softmax', ] diff --git a/torch_scatter/composite/logsumexp.py b/torch_scatter/composite/logsumexp.py new file mode 100644 index 00000000..35d01460 --- /dev/null +++ b/torch_scatter/composite/logsumexp.py @@ -0,0 +1,40 @@ +from typing import Optional + +import torch +from torch_scatter import scatter_sum, scatter_max + +from .utils import broadcast + + +@torch.jit.script +def scatter_logsumexp(src: torch.Tensor, index: torch.Tensor, dim: int = -1, + out: Optional[torch.Tensor] = None, + dim_size: Optional[int] = None, + eps: float = 1e-12) -> torch.Tensor: + if not torch.is_floating_point(src): + raise ValueError('`scatter_logsumexp` can only be computed over ' + 'tensors with floating point data types.') + + index = broadcast(index, src, dim) + + if out is not None: + dim_size = out.size(dim) + else: + if dim_size is None: + dim_size = int(index.max().item() + 1) + + size = src.size() + size[dim] = dim_size + max_value_per_index = torch.full(size, float('-inf'), dtype=src.dtype, + device=src.device) + scatter_max(src, index, dim, max_value_per_index, dim_size)[0] + max_per_src_element = max_value_per_index.gather(dim, index) + recentered_scores = src - max_per_src_element + + if out is not None: + out = out.sub_(max_per_src_element).exp_() + + sum_per_index = scatter_sum(recentered_scores.exp_(), index, dim, out, + dim_size) + + return sum_per_index.add_(eps).log_().add_(max_value_per_index) diff --git a/torch_scatter/composite/softmax.py b/torch_scatter/composite/softmax.py index 7963b7da..0a05139f 100644 --- a/torch_scatter/composite/softmax.py +++ b/torch_scatter/composite/softmax.py @@ -1,89 +1,46 @@ import torch -from torch_scatter import scatter_add, scatter_max -from torch_scatter.utils.gen import broadcast +from torch_scatter import scatter_sum, scatter_max +from .utils import broadcast -def scatter_softmax(src, index, dim=-1, eps=1e-12): - r""" - Softmax operation over all values in :attr:`src` tensor that share indices - specified in the :attr:`index` tensor along a given axis :attr:`dim`. - For one-dimensional tensors, the operation computes - - .. math:: - \mathrm{out}_i = {\textrm{softmax}(\mathrm{src})}_i = - \frac{\exp(\mathrm{src}_i)}{\sum_j \exp(\mathrm{src}_j)} - - where :math:`\sum_j` is over :math:`j` such that - :math:`\mathrm{index}_j = i`. - - Args: - src (Tensor): The source tensor. - index (LongTensor): The indices of elements to scatter. - dim (int, optional): The axis along which to index. - (default: :obj:`-1`) - eps (float, optional): Small value to ensure numerical stability. - (default: :obj:`1e-12`) - - :rtype: :class:`Tensor` - """ +@torch.jit.script +def scatter_softmax(src: torch.Tensor, index: torch.Tensor, dim: int = -1, + eps: float = 1e-12) -> torch.Tensor: if not torch.is_floating_point(src): raise ValueError('`scatter_softmax` can only be computed over tensors ' 'with floating point data types.') - src, index = broadcast(src, index, dim) - max_value_per_index, _ = scatter_max(src, index, dim=dim, fill_value=0) + index = broadcast(index, src, dim) + + max_value_per_index = scatter_max(src, index, dim=dim)[0] max_per_src_element = max_value_per_index.gather(dim, index) recentered_scores = src - max_per_src_element - recentered_scores_exp = recentered_scores.exp() - - sum_per_index = scatter_add(recentered_scores_exp, index, dim=dim) - normalizing_constants = (sum_per_index + eps).gather(dim, index) + recentered_scores_exp = recentered_scores.exp_() - return recentered_scores_exp / normalizing_constants + sum_per_index = scatter_sum(recentered_scores_exp, index, dim) + normalizing_constants = sum_per_index.add_(eps).gather(dim, index) + return recentered_scores_exp.div_(normalizing_constants) -def scatter_log_softmax(src, index, dim=-1, eps=1e-12): - r""" - Log-softmax operation over all values in :attr:`src` tensor that share - indices specified in the :attr:`index` tensor along a given axis - :attr:`dim`. - For one-dimensional tensors, the operation computes - - .. math:: - \mathrm{out}_i = {\textrm{log_softmax}(\mathrm{src})}_i = - \log \left( \frac{\exp(\mathrm{src}_i)}{\sum_j \exp(\mathrm{src}_j)} - \right) - - where :math:`\sum_j` is over :math:`j` such that - :math:`\mathrm{index}_j = i`. - - Args: - src (Tensor): The source tensor. - index (LongTensor): The indices of elements to scatter. - dim (int, optional): The axis along which to index. - (default: :obj:`-1`) - eps (float, optional): Small value to ensure numerical stability. - (default: :obj:`1e-12`) - - :rtype: :class:`Tensor` - """ +@torch.jit.script +def scatter_log_softmax(src: torch.Tensor, index: torch.Tensor, dim: int = -1, + eps: float = 1e-12) -> torch.Tensor: if not torch.is_floating_point(src): raise ValueError('`scatter_log_softmax` can only be computed over ' 'tensors with floating point data types.') - src, index = broadcast(src, index, dim) - max_value_per_index, _ = scatter_max(src, index, dim=dim, fill_value=0) + index = broadcast(index, src, dim) + + max_value_per_index = scatter_max(src, index, dim=dim)[0] max_per_src_element = max_value_per_index.gather(dim, index) recentered_scores = src - max_per_src_element - sum_per_index = scatter_add(src=recentered_scores.exp(), index=index, - dim=dim) - - normalizing_constants = torch.log(sum_per_index + eps).gather(dim, index) + sum_per_index = scatter_sum(recentered_scores.exp(), index, dim) + normalizing_constants = sum_per_index.add_(eps).log_().gather(dim, index) - return recentered_scores - normalizing_constants + return recentered_scores.sub_(normalizing_constants) diff --git a/torch_scatter/composite/std.py b/torch_scatter/composite/std.py new file mode 100644 index 00000000..c646b222 --- /dev/null +++ b/torch_scatter/composite/std.py @@ -0,0 +1,41 @@ +from typing import Optional + +import torch +from torch_scatter import scatter_sum + +from .utils import broadcast + + +@torch.jit.script +def scatter_std(src: torch.Tensor, index: torch.Tensor, dim: int = -1, + out: Optional[torch.Tensor] = None, + dim_size: Optional[int] = None, + unbiased: bool = True) -> torch.Tensor: + + if out is not None: + dim_size = out.size(dim) + + if dim < 0: + dim = src.dim() + dim + + count_dim = dim + if index.dim() <= dim: + count_dim = index.dim() - 1 + + ones = torch.ones(index.size(), dtype=src.dtype, device=src.device) + count = scatter_sum(ones, index, count_dim, dim_size=dim_size) + + index = broadcast(index, src, dim) + tmp = scatter_sum(src, index, dim, dim_size=dim_size) + count = broadcast(count, tmp, dim).clamp_(1) + mean = tmp.div_(count) + + var = (src - mean.gather(dim, index)) + var = var * var + out = scatter_sum(var, index, dim, out, dim_size) + + if unbiased: + count.sub_(1).clamp_(1) + out.div_(count).sqrt_() + + return out diff --git a/torch_scatter/composite/utils.py b/torch_scatter/composite/utils.py new file mode 100644 index 00000000..de8879b2 --- /dev/null +++ b/torch_scatter/composite/utils.py @@ -0,0 +1,14 @@ +import torch + + +@torch.jit.script +def broadcast(src: torch.Tensor, other: torch.Tensor, dim: int): + if dim < 0: + dim = other.dim() + dim + if src.dim() == 1: + for _ in range(0, dim): + src = src.unsqueeze(0) + for _ in range(src.dim(), other.dim()): + src = src.unsqueeze(-1) + src = src.expand_as(other) + return src diff --git a/torch_scatter/div.py b/torch_scatter/div.py deleted file mode 100644 index b0897f6c..00000000 --- a/torch_scatter/div.py +++ /dev/null @@ -1,94 +0,0 @@ -import torch -from torch_scatter.utils.gen import gen - - -class ScatterDiv(torch.autograd.Function): - @staticmethod - def forward(ctx, out, src, index, dim): - if src.is_cuda: - torch.ops.torch_scatter_cuda.scatter_div(src, index, out, dim) - else: - torch.ops.torch_scatter_cpu.scatter_div(src, index, out, dim) - - ctx.mark_dirty(out) - ctx.save_for_backward(out, src, index) - ctx.dim = dim - - return out - - @staticmethod - def backward(ctx, grad_out): - out, src, index = ctx.saved_tensors - - grad_src = None - if ctx.needs_input_grad[1]: - grad_src = -(out * grad_out).gather(ctx.dim, index) / src - - return None, grad_src, None, None - - -def scatter_div(src, index, dim=-1, out=None, dim_size=None, fill_value=1): - r""" - | - - .. image:: https://raw.githubusercontent.com/rusty1s/pytorch_scatter/ - master/docs/source/_figures/div.svg?sanitize=true - :align: center - :width: 400px - - | - - Divides all values from the :attr:`src` tensor into :attr:`out` at the - indices specified in the :attr:`index` tensor along a given axis - :attr:`dim`.If multiple indices reference the same location, their - **contributions divide** (`cf.` :meth:`~torch_scatter.scatter_add`). - - For one-dimensional tensors, the operation computes - - .. math:: - \mathrm{out}_i = \mathrm{out}_i \cdot \prod_j - \frac{1}{\mathrm{src}_j} - - where :math:`\prod_j` is over :math:`j` such that - :math:`\mathrm{index}_j = i`. - - Args: - src (Tensor): The source tensor. - index (LongTensor): The indices of elements to scatter. - dim (int, optional): The axis along which to index. - (default: :obj:`-1`) - out (Tensor, optional): The destination tensor. (default: :obj:`None`) - dim_size (int, optional): If :attr:`out` is not given, automatically - create output with size :attr:`dim_size` at dimension :attr:`dim`. - If :attr:`dim_size` is not given, a minimal sized output tensor is - returned. (default: :obj:`None`) - fill_value (int, optional): If :attr:`out` is not given, automatically - fill output tensor with :attr:`fill_value`. (default: :obj:`1`) - - :rtype: :class:`Tensor` - - .. testsetup:: - - import torch - - .. testcode:: - - from torch_scatter import scatter_div - - src = torch.Tensor([[2, 1, 1, 4, 2], [1, 2, 1, 2, 4]]).float() - index = torch.tensor([[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]]) - out = src.new_ones((2, 6)) - - out = scatter_div(src, index, out=out) - - print(out) - - .. testoutput:: - - tensor([[1.0000, 1.0000, 0.2500, 0.5000, 0.5000, 1.0000], - [0.5000, 0.2500, 0.5000, 1.0000, 1.0000, 1.0000]]) - """ - src, out, index, dim = gen(src, index, dim, out, dim_size, fill_value) - if src.size(dim) == 0: # pragma: no cover - return out - return ScatterDiv.apply(out, src, index, dim) diff --git a/torch_scatter/gather.py b/torch_scatter/gather.py deleted file mode 100644 index 0dce91b2..00000000 --- a/torch_scatter/gather.py +++ /dev/null @@ -1,67 +0,0 @@ -import torch - - -class GatherCOO(torch.autograd.Function): - @staticmethod - def forward(ctx, src, index, out): - if out is not None: - ctx.mark_dirty(out) - ctx.src_size = list(src.size()) - ctx.save_for_backward(index) - - if src.is_cuda: - return torch.ops.torch_scatter_cuda.gather_coo(src, index, out) - else: - return torch.ops.torch_scatter_cpu.gather_coo(src, index, out) - - @staticmethod - def backward(ctx, grad_out): - (index, ), src_size = ctx.saved_tensors, ctx.src_size - - grad_src = None - if ctx.needs_input_grad[0]: - if grad_out.is_cuda: - grad_src, _ = torch.ops.torch_scatter_cuda.segment_coo( - grad_out, index, grad_out.new_zeros(src_size), 'sum') - else: - grad_src, _ = torch.ops.torch_scatter_cpu.segment_coo( - grad_out, index, grad_out.new_zeros(src_size), 'sum') - - return grad_src, None, None - - -class GatherCSR(torch.autograd.Function): - @staticmethod - def forward(ctx, src, indptr, out): - if out is not None: - ctx.mark_dirty(out) - ctx.src_size = list(src.size()) - ctx.save_for_backward(indptr) - - if src.is_cuda: - return torch.ops.torch_scatter_cuda.gather_csr(src, indptr, out) - else: - return torch.ops.torch_scatter_cpu.gather_csr(src, indptr, out) - - @staticmethod - def backward(ctx, grad_out): - (indptr, ), src_size = ctx.saved_tensors, ctx.src_size - - grad_src = None - if ctx.needs_input_grad[0]: - if grad_out.is_cuda: - grad_src, _ = torch.ops.torch_scatter_cuda.segment_csr( - grad_out, indptr, grad_out.new_empty(src_size), 'sum') - else: - grad_src, _ = torch.ops.torch_scatter_cpu.segment_csr( - grad_out, indptr, grad_out.new_empty(src_size), 'sum') - - return grad_src, None, None - - -def gather_coo(src, index, out=None): - return GatherCOO.apply(src, index, out) - - -def gather_csr(src, indptr, out=None): - return GatherCSR.apply(src, indptr, out) diff --git a/torch_scatter/helpers.py b/torch_scatter/helpers.py deleted file mode 100644 index dbb6f2a9..00000000 --- a/torch_scatter/helpers.py +++ /dev/null @@ -1,15 +0,0 @@ -import torch - - -def min_value(dtype): # pragma: no cover - try: - return torch.finfo(dtype).min - except TypeError: - return torch.iinfo(dtype).min - - -def max_value(dtype): # pragma: no cover - try: - return torch.finfo(dtype).max - except TypeError: - return torch.iinfo(dtype).max diff --git a/torch_scatter/logsumexp.py b/torch_scatter/logsumexp.py deleted file mode 100644 index 16e9d182..00000000 --- a/torch_scatter/logsumexp.py +++ /dev/null @@ -1,54 +0,0 @@ -import torch - -from . import scatter_add, scatter_max - - -def scatter_logsumexp(src, index, dim=-1, out=None, dim_size=None, - fill_value=None, eps=1e-12): - r"""Fills :attr:`out` with the log of summed exponentials of all values - from the :attr:`src` tensor at the indices specified in the :attr:`index` - tensor along a given axis :attr:`dim`. - If multiple indices reference the same location, their - **exponential contributions add** - (`cf.` :meth:`~torch_scatter.scatter_add`). - - For one-dimensional tensors, the operation computes - - .. math:: - \mathrm{out}_i = \log \, \left( \exp(\mathrm{out}_i) + \sum_j - \exp(\mathrm{src}_j) \right) - - where :math:`\sum_j` is over :math:`j` such that - :math:`\mathrm{index}_j = i`. - - Args: - src (Tensor): The source tensor. - index (LongTensor): The indices of elements to scatter. - dim (int, optional): The axis along which to index. - (default: :obj:`-1`) - out (Tensor, optional): The destination tensor. (default: :obj:`None`) - dim_size (int, optional): If :attr:`out` is not given, automatically - create output with size :attr:`dim_size` at dimension :attr:`dim`. - If :attr:`dim_size` is not given, a minimal sized output tensor is - returned. (default: :obj:`None`) - fill_value (int, optional): If :attr:`out` is not given, automatically - fill output tensor with :attr:`fill_value`. (default: :obj:`None`) - eps (float, optional): Small value to ensure numerical stability. - (default: :obj:`1e-12`) - - :rtype: :class:`Tensor` - """ - if not torch.is_floating_point(src): - raise ValueError('`scatter_logsumexp` can only be computed over ' - 'tensors with floating point data types.') - - max_value_per_index, _ = scatter_max(src, index, dim, out, dim_size, - fill_value) - max_per_src_element = max_value_per_index.gather(dim, index) - recentered_scores = src - max_per_src_element - out = (out - max_per_src_element).exp() if out is not None else None - - sum_per_index = scatter_add(recentered_scores.exp(), index, dim, out, - dim_size, fill_value=0) - - return torch.log(sum_per_index + eps) + max_value_per_index diff --git a/torch_scatter/max.py b/torch_scatter/max.py deleted file mode 100644 index 83a7f436..00000000 --- a/torch_scatter/max.py +++ /dev/null @@ -1,109 +0,0 @@ -import torch -from torch_scatter.utils.gen import gen - - -class ScatterMax(torch.autograd.Function): - @staticmethod - def forward(ctx, out, src, index, dim): - arg = index.new_full(out.size(), -1) - - if src.is_cuda: - torch.ops.torch_scatter_cuda.scatter_max(src, index, out, arg, dim) - else: - torch.ops.torch_scatter_cpu.scatter_max(src, index, out, arg, dim) - - ctx.mark_dirty(out) - ctx.dim = dim - ctx.save_for_backward(index, arg) - - return out, arg - - @staticmethod - def backward(ctx, grad_out, grad_arg): - index, arg = ctx.saved_tensors - - grad_src = None - if ctx.needs_input_grad[1]: - size = list(index.size()) - size[ctx.dim] += 1 - grad_src = grad_out.new_zeros(size) - grad_src.scatter_(ctx.dim, arg.detach() + 1, grad_out) - grad_src = grad_src.narrow(ctx.dim, 1, index.size(ctx.dim)) - - return None, grad_src, None, None - - -def scatter_max(src, index, dim=-1, out=None, dim_size=None, fill_value=None): - r""" - | - - .. image:: https://raw.githubusercontent.com/rusty1s/pytorch_scatter/ - master/docs/source/_figures/max.svg?sanitize=true - :align: center - :width: 400px - - | - - Maximizes all values from the :attr:`src` tensor into :attr:`out` at the - indices specified in the :attr:`index` tensor along a given axis - :attr:`dim`.If multiple indices reference the same location, their - **contributions maximize** (`cf.` :meth:`~torch_scatter.scatter_add`). - The second return tensor contains index location in :attr:`src` of each - maximum value (known as argmax). - - For one-dimensional tensors, the operation computes - - .. math:: - \mathrm{out}_i = \max(\mathrm{out}_i, \max_j(\mathrm{src}_j)) - - where :math:`\max_j` is over :math:`j` such that - :math:`\mathrm{index}_j = i`. - - Args: - src (Tensor): The source tensor. - index (LongTensor): The indices of elements to scatter. - dim (int, optional): The axis along which to index. - (default: :obj:`-1`) - out (Tensor, optional): The destination tensor. (default: :obj:`None`) - dim_size (int, optional): If :attr:`out` is not given, automatically - create output with size :attr:`dim_size` at dimension :attr:`dim`. - If :attr:`dim_size` is not given, a minimal sized output tensor is - returned. (default: :obj:`None`) - fill_value (int, optional): If :attr:`out` is not given, automatically - fill output tensor with :attr:`fill_value`. If set to :obj:`None`, - the output tensor is filled with the smallest possible value of - :obj:`src.dtype`. (default: :obj:`None`) - - :rtype: (:class:`Tensor`, :class:`LongTensor`) - - .. testsetup:: - - import torch - - .. testcode:: - - from torch_scatter import scatter_max - - src = torch.Tensor([[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]]) - index = torch.tensor([[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]]) - out = src.new_zeros((2, 6)) - - out, argmax = scatter_max(src, index, out=out) - - print(out) - print(argmax) - - .. testoutput:: - - tensor([[0., 0., 4., 3., 2., 0.], - [2., 4., 3., 0., 0., 0.]]) - tensor([[-1, -1, 3, 4, 0, 1], - [ 1, 4, 3, -1, -1, -1]]) - """ - if fill_value is None: - op = torch.finfo if torch.is_floating_point(src) else torch.iinfo - fill_value = op(src.dtype).min - src, out, index, dim = gen(src, index, dim, out, dim_size, fill_value) - if src.size(dim) == 0: # pragma: no cover - return out, index.new_full(out.size(), -1) - return ScatterMax.apply(out, src, index, dim) diff --git a/torch_scatter/mean.py b/torch_scatter/mean.py deleted file mode 100644 index 418ee2ba..00000000 --- a/torch_scatter/mean.py +++ /dev/null @@ -1,70 +0,0 @@ -import torch - -from torch_scatter import scatter_add - - -def scatter_mean(src, index, dim=-1, out=None, dim_size=None, fill_value=0): - r""" - | - - .. image:: https://raw.githubusercontent.com/rusty1s/pytorch_scatter/ - master/docs/source/_figures/mean.svg?sanitize=true - :align: center - :width: 400px - - | - - Averages all values from the :attr:`src` tensor into :attr:`out` at the - indices specified in the :attr:`index` tensor along a given axis - :attr:`dim`.If multiple indices reference the same location, their - **contributions average** (`cf.` :meth:`~torch_scatter.scatter_add`). - - For one-dimensional tensors, the operation computes - - .. math:: - \mathrm{out}_i = \mathrm{out}_i + \frac{1}{N_i} \cdot - \sum_j \mathrm{src}_j - - where :math:`\sum_j` is over :math:`j` such that - :math:`\mathrm{index}_j = i`. :math:`N_i` indicates the number of indices - referencing :math:`i`. - - Args: - src (Tensor): The source tensor. - index (LongTensor): The indices of elements to scatter. - dim (int, optional): The axis along which to index. - (default: :obj:`-1`) - out (Tensor, optional): The destination tensor. (default: :obj:`None`) - dim_size (int, optional): If :attr:`out` is not given, automatically - create output with size :attr:`dim_size` at dimension :attr:`dim`. - If :attr:`dim_size` is not given, a minimal sized output tensor is - returned. (default: :obj:`None`) - fill_value (int, optional): If :attr:`out` is not given, automatically - fill output tensor with :attr:`fill_value`. (default: :obj:`0`) - - :rtype: :class:`Tensor` - - .. testsetup:: - - import torch - - .. testcode:: - - from torch_scatter import scatter_mean - - src = torch.Tensor([[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]]) - index = torch.tensor([[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]]) - out = src.new_zeros((2, 6)) - - out = scatter_mean(src, index, out=out) - - print(out) - - .. testoutput:: - - tensor([[0.0000, 0.0000, 4.0000, 3.0000, 1.5000, 0.0000], - [1.0000, 4.0000, 2.0000, 0.0000, 0.0000, 0.0000]]) - """ - out = scatter_add(src, index, dim, out, dim_size, fill_value) - count = scatter_add(torch.ones_like(src), index, dim, None, out.size(dim)) - return out / count.clamp(min=1) diff --git a/torch_scatter/min.py b/torch_scatter/min.py deleted file mode 100644 index 6092a427..00000000 --- a/torch_scatter/min.py +++ /dev/null @@ -1,112 +0,0 @@ -import torch - -from torch_scatter.utils.gen import gen - - -class ScatterMin(torch.autograd.Function): - @staticmethod - def forward(ctx, out, src, index, dim): - arg = index.new_full(out.size(), -1) - - if src.is_cuda: - torch.ops.torch_scatter_cuda.scatter_min(src, index, out, arg, dim) - else: - torch.ops.torch_scatter_cpu.scatter_min(src, index, out, arg, dim) - - ctx.mark_dirty(out) - ctx.dim = dim - ctx.save_for_backward(index, arg) - - return out, arg - - @staticmethod - def backward(ctx, grad_out, grad_arg): - index, arg = ctx.saved_tensors - - grad_src = None - if ctx.needs_input_grad[1]: - size = list(index.size()) - size[ctx.dim] += 1 - grad_src = grad_out.new_zeros(size) - grad_src.scatter_(ctx.dim, arg.detach() + 1, grad_out) - grad_src = grad_src.narrow(ctx.dim, 1, index.size(ctx.dim)) - - return None, grad_src, None, None - - -def scatter_min(src, index, dim=-1, out=None, dim_size=None, fill_value=None): - r""" - | - - .. image:: https://raw.githubusercontent.com/rusty1s/pytorch_scatter/ - master/docs/source/_figures/min.svg?sanitize=true - :align: center - :width: 400px - - | - - Minimizes all values from the :attr:`src` tensor into :attr:`out` at the - indices specified in the :attr:`index` tensor along a given axis - :attr:`dim`.If multiple indices reference the same location, their - **contributions minimize** (`cf.` :meth:`~torch_scatter.scatter_add`). - The second return tensor contains index location in :attr:`src` of each - minimum value (known as argmin). - - For one-dimensional tensors, the operation computes - - .. math:: - \mathrm{out}_i = \min(\mathrm{out}_i, \min_j(\mathrm{src}_j)) - - where :math:`\min_j` is over :math:`j` such that - :math:`\mathrm{index}_j = i`. - - Args: - src (Tensor): The source tensor. - index (LongTensor): The indices of elements to scatter. - dim (int, optional): The axis along which to index. - (default: :obj:`-1`) - out (Tensor, optional): The destination tensor. (default: :obj:`None`) - dim_size (int, optional): If :attr:`out` is not given, automatically - create output with size :attr:`dim_size` at dimension :attr:`dim`. - If :attr:`dim_size` is not given, a minimal sized output tensor is - returned. (default: :obj:`None`) - fill_value (int, optional): If :attr:`out` is not given, automatically - fill output tensor with :attr:`fill_value`. (default: :obj:`None`) - fill_value (int, optional): If :attr:`out` is not given, automatically - fill output tensor with :attr:`fill_value`. If set to :obj:`None`, - the output tensor is filled with the greatest possible value of - :obj:`src.dtype`. (default: :obj:`None`) - - :rtype: (:class:`Tensor`, :class:`LongTensor`) - - .. testsetup:: - - import torch - - .. testcode:: - - from torch_scatter import scatter_min - - src = torch.Tensor([[-2, 0, -1, -4, -3], [0, -2, -1, -3, -4]]) - index = torch.tensor([[ 4, 5, 4, 2, 3], [0, 0, 2, 2, 1]]) - out = src.new_zeros((2, 6)) - - out, argmin = scatter_min(src, index, out=out) - - print(out) - print(argmin) - - .. testoutput:: - - tensor([[ 0., 0., -4., -3., -2., 0.], - [-2., -4., -3., 0., 0., 0.]]) - tensor([[-1, -1, 3, 4, 0, 1], - [ 1, 4, 3, -1, -1, -1]]) - """ - if fill_value is None: - op = torch.finfo if torch.is_floating_point(src) else torch.iinfo - fill_value = op(src.dtype).max - src, out, index, dim = gen(src, index, dim, out, dim_size, fill_value) - if src.size(dim) == 0: # pragma: no cover - return out, index.new_full(out.size(), -1) - return ScatterMin.apply(out, src, index, dim) diff --git a/torch_scatter/mul.py b/torch_scatter/mul.py deleted file mode 100644 index 807c8334..00000000 --- a/torch_scatter/mul.py +++ /dev/null @@ -1,93 +0,0 @@ -import torch -from torch_scatter.utils.gen import gen - - -class ScatterMul(torch.autograd.Function): - @staticmethod - def forward(ctx, out, src, index, dim): - if src.is_cuda: - torch.ops.torch_scatter_cuda.scatter_mul(src, index, out, dim) - else: - torch.ops.torch_scatter_cpu.scatter_mul(src, index, out, dim) - - ctx.mark_dirty(out) - ctx.save_for_backward(out, src, index) - ctx.dim = dim - - return out - - @staticmethod - def backward(ctx, grad_out): - out, src, index = ctx.saved_tensors - - grad_src = None - if ctx.needs_input_grad[1]: - grad_src = (grad_out * out).gather(ctx.dim, index) / src - - return None, grad_src, None, None - - -def scatter_mul(src, index, dim=-1, out=None, dim_size=None, fill_value=1): - r""" - | - - .. image:: https://raw.githubusercontent.com/rusty1s/pytorch_scatter/ - master/docs/source/_figures/mul.svg?sanitize=true - :align: center - :width: 400px - - | - - Multiplies all values from the :attr:`src` tensor into :attr:`out` at the - indices specified in the :attr:`index` tensor along a given axis - :attr:`dim`.If multiple indices reference the same location, their - **contributions multiply** (`cf.` :meth:`~torch_scatter.scatter_add`). - - For one-dimensional tensors, the operation computes - - .. math:: - \mathrm{out}_i = \mathrm{out}_i \cdot \prod_j \mathrm{src}_j - - where :math:`\prod_j` is over :math:`j` such that - :math:`\mathrm{index}_j = i`. - - Args: - src (Tensor): The source tensor. - index (LongTensor): The indices of elements to scatter. - dim (int, optional): The axis along which to index. - (default: :obj:`-1`) - out (Tensor, optional): The destination tensor. (default: :obj:`None`) - dim_size (int, optional): If :attr:`out` is not given, automatically - create output with size :attr:`dim_size` at dimension :attr:`dim`. - If :attr:`dim_size` is not given, a minimal sized output tensor is - returned. (default: :obj:`None`) - fill_value (int, optional): If :attr:`out` is not given, automatically - fill output tensor with :attr:`fill_value`. (default: :obj:`1`) - - :rtype: :class:`Tensor` - - .. testsetup:: - - import torch - - .. testcode:: - - from torch_scatter import scatter_mul - - src = torch.Tensor([[2, 0, 3, 4, 3], [2, 3, 4, 2, 4]]) - index = torch.tensor([[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]]) - out = src.new_ones((2, 6)) - - out = scatter_mul(src, index, out=out) - - print(out) - - .. testoutput:: - - tensor([[1., 1., 4., 3., 6., 0.], - [6., 4., 8., 1., 1., 1.]]) - """ - src, out, index, dim = gen(src, index, dim, out, dim_size, fill_value) - if src.size(dim) == 0: # pragma: no cover - return out - return ScatterMul.apply(out, src, index, dim) diff --git a/torch_scatter/scatter.py b/torch_scatter/scatter.py index 4e6c4e4f..aa315347 100644 --- a/torch_scatter/scatter.py +++ b/torch_scatter/scatter.py @@ -48,6 +48,75 @@ def scatter_max(src: torch.Tensor, index: torch.Tensor, dim: int = -1, def scatter(src: torch.Tensor, index: torch.Tensor, dim: int = -1, out: Optional[torch.Tensor] = None, dim_size: Optional[int] = None, reduce: str = "sum") -> torch.Tensor: + r""" + | + + .. image:: https://raw.githubusercontent.com/rusty1s/pytorch_scatter/ + master/docs/source/_figures/add.svg?sanitize=true + :align: center + :width: 400px + + | + + Sums all values from the :attr:`src` tensor into :attr:`out` at the indices + specified in the :attr:`index` tensor along a given axis :attr:`dim`. For + each value in :attr:`src`, its output index is specified by its index in + :attr:`src` for dimensions outside of :attr:`dim` and by the + corresponding value in :attr:`index` for dimension :attr:`dim`. If + multiple indices reference the same location, their **contributions add**. + + Formally, if :attr:`src` and :attr:`index` are n-dimensional tensors with + size :math:`(x_0, ..., x_{i-1}, x_i, x_{i+1}, ..., x_{n-1})` and + :attr:`dim` = `i`, then :attr:`out` must be an n-dimensional tensor with + size :math:`(x_0, ..., x_{i-1}, y, x_{i+1}, ..., x_{n-1})`. Moreover, the + values of :attr:`index` must be between `0` and `out.size(dim) - 1`. + Both :attr:`src` and :attr:`index` are broadcasted in case their dimensions + do not match. + + For one-dimensional tensors, the operation computes + + .. math:: + \mathrm{out}_i = \mathrm{out}_i + \sum_j \mathrm{src}_j + + where :math:`\sum_j` is over :math:`j` such that + :math:`\mathrm{index}_j = i`. + + Args: + src (Tensor): The source tensor. + index (LongTensor): The indices of elements to scatter. + dim (int, optional): The axis along which to index. + (default: :obj:`-1`) + out (Tensor, optional): The destination tensor. (default: :obj:`None`) + dim_size (int, optional): If :attr:`out` is not given, automatically + create output with size :attr:`dim_size` at dimension :attr:`dim`. + If :attr:`dim_size` is not given, a minimal sized output tensor is + returned. (default: :obj:`None`) + fill_value (int, optional): If :attr:`out` is not given, automatically + fill output tensor with :attr:`fill_value`. (default: :obj:`0`) + + :rtype: :class:`Tensor` + + .. testsetup:: + + import torch + + .. testcode:: + + from torch_scatter import scatter_add + + src = torch.Tensor([[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]]) + index = torch.tensor([[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]]) + out = src.new_zeros((2, 6)) + + out = scatter_add(src, index, out=out) + + print(out) + + .. testoutput:: + + tensor([[0., 0., 4., 3., 3., 0.], + [2., 4., 4., 0., 0., 0.]]) + """ if reduce == 'sum' or reduce == 'add': return scatter_sum(src, index, dim, out, dim_size) elif reduce == 'mean': diff --git a/torch_scatter/segment.py b/torch_scatter/segment.py deleted file mode 100644 index 65c91b95..00000000 --- a/torch_scatter/segment.py +++ /dev/null @@ -1,319 +0,0 @@ -import torch -from torch_scatter.helpers import min_value, max_value - - -class SegmentCOO(torch.autograd.Function): - @staticmethod - def forward(ctx, src, index, out, dim_size, reduce): - assert reduce in ['sum', 'add', 'mean', 'min', 'max'] - if out is not None: - ctx.mark_dirty(out) - ctx.reduce = reduce - ctx.src_size = list(src.size()) - - fill_value = 0 - if out is None: - dim_size = index.max().item() + 1 if dim_size is None else dim_size - size = list(src.size()) - size[index.dim() - 1] = dim_size - - if reduce == 'min': - fill_value = max_value(src.dtype) - elif reduce == 'max': - fill_value = min_value(src.dtype) - - out = src.new_full(size, fill_value) - - if src.is_cuda: - out, arg_out = torch.ops.torch_scatter_cuda.segment_coo( - src, index, out, reduce) - else: - out, arg_out = torch.ops.torch_scatter_cpu.segment_coo( - src, index, out, reduce) - - if fill_value != 0: - out.masked_fill_(out == fill_value, 0) - - ctx.save_for_backward(index, arg_out) - - if reduce == 'min' or reduce == 'max': - return out, arg_out - else: - return out - - @staticmethod - def backward(ctx, grad_out, *args): - (index, arg_out), src_size = ctx.saved_tensors, ctx.src_size - - grad_src = None - if ctx.needs_input_grad[0]: - if ctx.reduce == 'sum' or ctx.reduce == 'add': - if grad_out.is_cuda: - grad_src = torch.ops.torch_scatter_cuda.gather_coo( - grad_out, index, grad_out.new_empty(src_size)) - else: - grad_src = torch.ops.torch_scatter_cpu.gather_coo( - grad_out, index, grad_out.new_empty(src_size)) - - elif ctx.reduce == 'mean': - if grad_out.is_cuda: - grad_src = torch.ops.torch_scatter_cuda.gather_coo( - grad_out, index, grad_out.new_empty(src_size)) - else: - grad_src = torch.ops.torch_scatter_cpu.gather_coo( - grad_out, index, grad_out.new_empty(src_size)) - - count = arg_out # Gets pre-computed on GPU but not on CPU. - if count is None: - size = list(index.size()) - size[-1] = grad_out.size(index.dim() - 1) - count = torch.ops.torch_scatter_cpu.segment_coo( - torch.ones_like(index, dtype=grad_out.dtype), index, - grad_out.new_zeros(size), 'sum')[0].clamp_(min=1) - - if grad_out.is_cuda: - count = torch.ops.torch_scatter_cuda.gather_coo( - count, index, count.new_empty(src_size[:index.dim()])) - else: - count = torch.ops.torch_scatter_cpu.gather_coo( - count, index, count.new_empty(src_size[:index.dim()])) - for _ in range(grad_out.dim() - index.dim()): - count = count.unsqueeze(-1) - grad_src.div_(count) - - elif ctx.reduce == 'min' or ctx.reduce == 'max': - src_size[index.dim() - 1] += 1 - grad_src = grad_out.new_zeros(src_size).scatter_( - index.dim() - 1, arg_out, grad_out) - grad_src = grad_src.narrow(index.dim() - 1, 0, - src_size[index.dim() - 1] - 1) - - return grad_src, None, None, None, None - - -class SegmentCSR(torch.autograd.Function): - @staticmethod - def forward(ctx, src, indptr, out, reduce): - assert reduce in ['sum', 'add', 'mean', 'min', 'max'] - - if out is not None: - ctx.mark_dirty(out) - ctx.reduce = reduce - ctx.src_size = list(src.size()) - - if src.is_cuda: - out, arg_out = torch.ops.torch_scatter_cuda.segment_csr( - src, indptr, out, reduce) - else: - out, arg_out = torch.ops.torch_scatter_cpu.segment_csr( - src, indptr, out, reduce) - - ctx.save_for_backward(indptr, arg_out) - return out if arg_out is None else (out, arg_out) - - @staticmethod - def backward(ctx, grad_out, *args): - (indptr, arg_out), src_size = ctx.saved_tensors, ctx.src_size - - grad_src = None - if ctx.needs_input_grad[0]: - if ctx.reduce == 'sum' or ctx.reduce == 'add': - if grad_out.is_cuda: - grad_src = torch.ops.torch_scatter_cuda.gather_csr( - grad_out, indptr, grad_out.new_empty(src_size)) - else: - grad_src = torch.ops.torch_scatter_cpu.gather_csr( - grad_out, indptr, grad_out.new_empty(src_size)) - - elif ctx.reduce == 'mean': - if grad_out.is_cuda: - grad_src = torch.ops.torch_scatter_cuda.gather_csr( - grad_out, indptr, grad_out.new_empty(src_size)) - else: - grad_src = torch.ops.torch_scatter_cpu.gather_csr( - grad_out, indptr, grad_out.new_empty(src_size)) - indptr1 = indptr.narrow(-1, 0, indptr.size(-1) - 1) - indptr2 = indptr.narrow(-1, 1, indptr.size(-1) - 1) - count = (indptr2 - indptr1).to(grad_src.dtype) - if grad_out.is_cuda: - count = torch.ops.torch_scatter_cuda.gather_csr( - count, indptr, - count.new_empty(src_size[:indptr.dim()])) - else: - count = torch.ops.torch_scatter_cpu.gather_csr( - count, indptr, - count.new_empty(src_size[:indptr.dim()])) - for _ in range(grad_out.dim() - indptr.dim()): - count = count.unsqueeze(-1) - grad_src.div_(count) - elif ctx.reduce == 'min' or ctx.reduce == 'max': - src_size[indptr.dim() - 1] += 1 - grad_src = grad_out.new_zeros(src_size).scatter_( - indptr.dim() - 1, arg_out, grad_out) - grad_src = grad_src.narrow(indptr.dim() - 1, 0, - src_size[indptr.dim() - 1] - 1) - - return grad_src, None, None, None - - -def segment_coo(src, index, out=None, dim_size=None, reduce="sum"): - r""" - | - - .. image:: https://raw.githubusercontent.com/rusty1s/pytorch_scatter/ - master/docs/source/_figures/segment_coo.svg?sanitize=true - :align: center - :width: 400px - - | - - Reduces all values from the :attr:`src` tensor into :attr:`out` at the - indices specified in the :attr:`index` tensor along the last dimension of - :attr:`index`. - For each value in :attr:`src`, its output index is specified by its index - in :attr:`src` for dimensions outside of :obj:`index.dim() - 1` and by the - corresponding value in :attr:`index` for dimension :obj:`index.dim() - 1`. - The applied reduction is defined via the :attr:`reduce` argument. - - Formally, if :attr:`src` and :attr:`index` are :math:`n`-dimensional and - :math:`m`-dimensional tensors with - size :math:`(x_0, ..., x_{m-1}, x_m, x_{m+1}, ..., x_{n-1})` and - :math:`(x_0, ..., x_{m-1}, x_m)`, respectively, then :attr:`out` must be an - :math:`n`-dimensional tensor with size - :math:`(x_0, ..., x_{m-1}, y, x_{m+1}, ..., x_{n-1})`. - Moreover, the values of :attr:`index` must be between :math:`0` and - :math:`y - 1` in ascending order. - The :attr:`index` tensor supports broadcasting in case its dimensions do - not match with :attr:`src`. - For one-dimensional tensors with :obj:`reduce="sum"`, the operation - computes - - .. math:: - \mathrm{out}_i = \mathrm{out}_i + \sum_j~\mathrm{src}_j - - where :math:`\sum_j` is over :math:`j` such that - :math:`\mathrm{index}_j = i`. - - In contrast to :meth:`scatter`, this method expects values in :attr:`index` - **to be sorted** along dimension :obj:`index.dim() - 1`. - Due to the use of sorted indices, :meth:`segment_coo` is usually faster - than the more general :meth:`scatter` operation. - - For reductions :obj:`"min"` and :obj:`"max"`, this operation returns a - second tensor representing the :obj:`argmin` and :obj:`argmax`, - respectively. - - .. note:: - - This operation is implemented via atomic operations on the GPU and is - therefore **non-deterministic** since the order of parallel operations - to the same value is undetermined. - For floating-point variables, this results in a source of variance in - the result. - - Args: - src (Tensor): The source tensor. - index (LongTensor): The sorted indices of elements to segment. - The number of dimensions of :attr:`index` needs to be less than or - equal to :attr:`src`. - out (Tensor, optional): The destination tensor. (default: :obj:`None`) - dim_size (int, optional): If :attr:`out` is not given, automatically - create output with size :attr:`dim_size` at dimension - :obj:`index.dim() - 1`. - If :attr:`dim_size` is not given, a minimal sized output tensor - according to :obj:`index.max() + 1` is returned. - (default: :obj:`None`) - reduce (string, optional): The reduce operation (:obj:`"sum"`, - :obj:`"mean"`, :obj:`"min"` or :obj:`"max"`). - (default: :obj:`"sum"`) - - :rtype: :class:`Tensor`, :class:`LongTensor` *(optional)* - - .. code-block:: python - - from torch_scatter import segment_coo - - src = torch.randn(10, 6, 64) - index = torch.tensor([0, 0, 1, 1, 1, 2]) - index = index.view(1, -1) # Broadcasting in the first and last dim. - - out = segment_coo(src, index, reduce="sum") - - print(out.size()) - - .. code-block:: - - torch.Size([10, 3, 64]) - """ - return SegmentCOO.apply(src, index, out, dim_size, reduce) - - -def segment_csr(src, indptr, out=None, reduce="sum"): - r""" - Reduces all values from the :attr:`src` tensor into :attr:`out` within the - ranges specified in the :attr:`indptr` tensor along the last dimension of - :attr:`indptr`. - For each value in :attr:`src`, its output index is specified by its index - in :attr:`src` for dimensions outside of :obj:`indptr.dim() - 1` and by the - corresponding range index in :attr:`indptr` for dimension - :obj:`indptr.dim() - 1`. - The applied reduction is defined via the :attr:`reduce` argument. - - Formally, if :attr:`src` and :attr:`indptr` are :math:`n`-dimensional and - :math:`m`-dimensional tensors with - size :math:`(x_0, ..., x_{m-1}, x_m, x_{m+1}, ..., x_{n-1})` and - :math:`(x_0, ..., x_{m-1}, y)`, respectively, then :attr:`out` must be an - :math:`n`-dimensional tensor with size - :math:`(x_0, ..., x_{m-1}, y - 1, x_{m+1}, ..., x_{n-1})`. - Moreover, the values of :attr:`indptr` must be between :math:`0` and - :math:`x_m` in ascending order. - The :attr:`indptr` tensor supports broadcasting in case its dimensions do - not match with :attr:`src`. - For one-dimensional tensors with :obj:`reduce="sum"`, the operation - computes - - .. math:: - \mathrm{out}_i = - \sum_{j = \mathrm{indptr}[i]}^{\mathrm{indptr}[i+i]}~\mathrm{src}_j. - - Due to the use of index pointers, :meth:`segment_csr` is the fastest - method to apply for grouped reductions. - - For reductions :obj:`"min"` and :obj:`"max"`, this operation returns a - second tensor representing the :obj:`argmin` and :obj:`argmax`, - respectively. - - .. note:: - - In contrast to :meth:`scatter()` and :meth:`segment_coo`, this - operation is **fully-deterministic**. - - Args: - src (Tensor): The source tensor. - indptr (LongTensor): The index pointers between elements to segment. - The number of dimensions of :attr:`index` needs to be less than or - equal to :attr:`src`. - out (Tensor, optional): The destination tensor. (default: :obj:`None`) - reduce (string, optional): The reduce operation (:obj:`"sum"`, - :obj:`"mean"`, :obj:`"min"` or :obj:`"max"`). - (default: :obj:`"sum"`) - - :rtype: :class:`Tensor`, :class:`LongTensor` *(optional)* - - .. code-block:: python - - from torch_scatter import segment_csr - - src = torch.randn(10, 6, 64) - indptr = torch.tensor([0, 2, 5, 6]) - indptr = indptr.view(1, -1) # Broadcasting in the first and last dim. - - out = segment_csr(src, indptr, reduce="sum") - - print(out.size()) - - .. code-block:: - - torch.Size([10, 3, 64]) - """ - return SegmentCSR.apply(src, indptr, out, reduce) diff --git a/torch_scatter/segment_coo.py b/torch_scatter/segment_coo.py index 8e33d825..5fbd089c 100644 --- a/torch_scatter/segment_coo.py +++ b/torch_scatter/segment_coo.py @@ -49,6 +49,94 @@ def segment_coo(src: torch.Tensor, index: torch.Tensor, out: Optional[torch.Tensor] = None, dim_size: Optional[int] = None, reduce: str = "sum") -> torch.Tensor: + r""" + | + + .. image:: https://raw.githubusercontent.com/rusty1s/pytorch_scatter/ + master/docs/source/_figures/segment_coo.svg?sanitize=true + :align: center + :width: 400px + + | + + Reduces all values from the :attr:`src` tensor into :attr:`out` at the + indices specified in the :attr:`index` tensor along the last dimension of + :attr:`index`. + For each value in :attr:`src`, its output index is specified by its index + in :attr:`src` for dimensions outside of :obj:`index.dim() - 1` and by the + corresponding value in :attr:`index` for dimension :obj:`index.dim() - 1`. + The applied reduction is defined via the :attr:`reduce` argument. + + Formally, if :attr:`src` and :attr:`index` are :math:`n`-dimensional and + :math:`m`-dimensional tensors with + size :math:`(x_0, ..., x_{m-1}, x_m, x_{m+1}, ..., x_{n-1})` and + :math:`(x_0, ..., x_{m-1}, x_m)`, respectively, then :attr:`out` must be an + :math:`n`-dimensional tensor with size + :math:`(x_0, ..., x_{m-1}, y, x_{m+1}, ..., x_{n-1})`. + Moreover, the values of :attr:`index` must be between :math:`0` and + :math:`y - 1` in ascending order. + The :attr:`index` tensor supports broadcasting in case its dimensions do + not match with :attr:`src`. + For one-dimensional tensors with :obj:`reduce="sum"`, the operation + computes + + .. math:: + \mathrm{out}_i = \mathrm{out}_i + \sum_j~\mathrm{src}_j + + where :math:`\sum_j` is over :math:`j` such that + :math:`\mathrm{index}_j = i`. + + In contrast to :meth:`scatter`, this method expects values in :attr:`index` + **to be sorted** along dimension :obj:`index.dim() - 1`. + Due to the use of sorted indices, :meth:`segment_coo` is usually faster + than the more general :meth:`scatter` operation. + + For reductions :obj:`"min"` and :obj:`"max"`, this operation returns a + second tensor representing the :obj:`argmin` and :obj:`argmax`, + respectively. + + .. note:: + + This operation is implemented via atomic operations on the GPU and is + therefore **non-deterministic** since the order of parallel operations + to the same value is undetermined. + For floating-point variables, this results in a source of variance in + the result. + + Args: + src (Tensor): The source tensor. + index (LongTensor): The sorted indices of elements to segment. + The number of dimensions of :attr:`index` needs to be less than or + equal to :attr:`src`. + out (Tensor, optional): The destination tensor. (default: :obj:`None`) + dim_size (int, optional): If :attr:`out` is not given, automatically + create output with size :attr:`dim_size` at dimension + :obj:`index.dim() - 1`. + If :attr:`dim_size` is not given, a minimal sized output tensor + according to :obj:`index.max() + 1` is returned. + (default: :obj:`None`) + reduce (string, optional): The reduce operation (:obj:`"sum"`, + :obj:`"mean"`, :obj:`"min"` or :obj:`"max"`). + (default: :obj:`"sum"`) + + :rtype: :class:`Tensor`, :class:`LongTensor` *(optional)* + + .. code-block:: python + + from torch_scatter import segment_coo + + src = torch.randn(10, 6, 64) + index = torch.tensor([0, 0, 1, 1, 1, 2]) + index = index.view(1, -1) # Broadcasting in the first and last dim. + + out = segment_coo(src, index, reduce="sum") + + print(out.size()) + + .. code-block:: + + torch.Size([10, 3, 64]) + """ if reduce == 'sum' or reduce == 'add': return segment_sum_coo(src, index, out, dim_size) elif reduce == 'mean': diff --git a/torch_scatter/segment_csr.py b/torch_scatter/segment_csr.py index ac00183e..861ac998 100644 --- a/torch_scatter/segment_csr.py +++ b/torch_scatter/segment_csr.py @@ -43,6 +43,73 @@ def segment_max_csr(src: torch.Tensor, indptr: torch.Tensor, def segment_csr(src: torch.Tensor, indptr: torch.Tensor, out: Optional[torch.Tensor] = None, reduce: str = "sum") -> torch.Tensor: + r""" + Reduces all values from the :attr:`src` tensor into :attr:`out` within the + ranges specified in the :attr:`indptr` tensor along the last dimension of + :attr:`indptr`. + For each value in :attr:`src`, its output index is specified by its index + in :attr:`src` for dimensions outside of :obj:`indptr.dim() - 1` and by the + corresponding range index in :attr:`indptr` for dimension + :obj:`indptr.dim() - 1`. + The applied reduction is defined via the :attr:`reduce` argument. + + Formally, if :attr:`src` and :attr:`indptr` are :math:`n`-dimensional and + :math:`m`-dimensional tensors with + size :math:`(x_0, ..., x_{m-1}, x_m, x_{m+1}, ..., x_{n-1})` and + :math:`(x_0, ..., x_{m-1}, y)`, respectively, then :attr:`out` must be an + :math:`n`-dimensional tensor with size + :math:`(x_0, ..., x_{m-1}, y - 1, x_{m+1}, ..., x_{n-1})`. + Moreover, the values of :attr:`indptr` must be between :math:`0` and + :math:`x_m` in ascending order. + The :attr:`indptr` tensor supports broadcasting in case its dimensions do + not match with :attr:`src`. + For one-dimensional tensors with :obj:`reduce="sum"`, the operation + computes + + .. math:: + \mathrm{out}_i = + \sum_{j = \mathrm{indptr}[i]}^{\mathrm{indptr}[i+i]}~\mathrm{src}_j. + + Due to the use of index pointers, :meth:`segment_csr` is the fastest + method to apply for grouped reductions. + + For reductions :obj:`"min"` and :obj:`"max"`, this operation returns a + second tensor representing the :obj:`argmin` and :obj:`argmax`, + respectively. + + .. note:: + + In contrast to :meth:`scatter()` and :meth:`segment_coo`, this + operation is **fully-deterministic**. + + Args: + src (Tensor): The source tensor. + indptr (LongTensor): The index pointers between elements to segment. + The number of dimensions of :attr:`index` needs to be less than or + equal to :attr:`src`. + out (Tensor, optional): The destination tensor. (default: :obj:`None`) + reduce (string, optional): The reduce operation (:obj:`"sum"`, + :obj:`"mean"`, :obj:`"min"` or :obj:`"max"`). + (default: :obj:`"sum"`) + + :rtype: :class:`Tensor`, :class:`LongTensor` *(optional)* + + .. code-block:: python + + from torch_scatter import segment_csr + + src = torch.randn(10, 6, 64) + indptr = torch.tensor([0, 2, 5, 6]) + indptr = indptr.view(1, -1) # Broadcasting in the first and last dim. + + out = segment_csr(src, indptr, reduce="sum") + + print(out.size()) + + .. code-block:: + + torch.Size([10, 3, 64]) + """ if reduce == 'sum' or reduce == 'add': return segment_sum_csr(src, indptr, out) elif reduce == 'mean': diff --git a/torch_scatter/std.py b/torch_scatter/std.py deleted file mode 100644 index 500c5dcc..00000000 --- a/torch_scatter/std.py +++ /dev/null @@ -1,65 +0,0 @@ -import torch - -from torch_scatter import scatter_add -from torch_scatter.utils.gen import gen - - -def scatter_std(src, index, dim=-1, out=None, dim_size=None, unbiased=True): - r""" - | - - .. image:: https://raw.githubusercontent.com/rusty1s/pytorch_scatter/ - master/docs/source/_figures/std.svg?sanitize=true - :align: center - :width: 400px - - | - - Computes the standard-deviation from all values from the :attr:`src` tensor - into :attr:`out` at the indices specified in the :attr:`index` tensor along - a given axis :attr:`dim` (`cf.` :meth:`~torch_scatter.scatter_add`). - - For one-dimensional tensors, the operation computes - - .. math:: - \mathrm{out}_i = \sqrt{\frac{\sum_j {\left( x_j - \overline{x}_i - \right)}^2}{N_i - 1}} - - where :math:`\sum_j` is over :math:`j` such that - :math:`\mathrm{index}_j = i`. :math:`N_i` and :math:`\overline{x}_i` - indicate the number of indices referencing :math:`i` and their mean value, - respectively. - - Args: - src (Tensor): The source tensor. - index (LongTensor): The indices of elements to scatter. - dim (int, optional): The axis along which to index. - (default: :obj:`-1`) - out (Tensor, optional): The destination tensor. (default: :obj:`None`) - dim_size (int, optional): If :attr:`out` is not given, automatically - create output with size :attr:`dim_size` at dimension :attr:`dim`. - If :attr:`dim_size` is not given, a minimal sized output tensor is - returned. (default: :obj:`None`) - unbiased (bool, optional): If set to :obj:`False`, then the standard- - deviation will be calculated via the biased estimator. - (default: :obj:`True`) - - :rtype: :class:`Tensor` - """ - src, out, index, dim = gen(src, index, dim, out, dim_size, fill_value=0) - - tmp = None if out is None else out.clone().fill_(0) - tmp = scatter_add(src, index, dim, tmp, dim_size) - - count = None if out is None else out.clone().fill_(0) - count = scatter_add(torch.ones_like(src), index, dim, count, dim_size) - - mean = tmp / count.clamp(min=1) - - var = (src - mean.gather(dim, index)) - var = var * var - out = scatter_add(var, index, dim, out, dim_size) - out = out / (count - 1 if unbiased else count).clamp(min=1) - out = torch.sqrt(out) - - return out diff --git a/torch_scatter/sub.py b/torch_scatter/sub.py deleted file mode 100644 index 8e9af8bb..00000000 --- a/torch_scatter/sub.py +++ /dev/null @@ -1,64 +0,0 @@ -from torch_scatter import scatter_add - - -def scatter_sub(src, index, dim=-1, out=None, dim_size=None, fill_value=0): - r""" - | - - .. image:: https://raw.githubusercontent.com/rusty1s/pytorch_scatter/ - master/docs/source/_figures/sub.svg?sanitize=true - :align: center - :width: 400px - - | - - Subtracts all values from the :attr:`src` tensor into :attr:`out` at the - indices specified in the :attr:`index` tensor along a given axis - :attr:`dim`.If multiple indices reference the same location, their - **negated contributions add** (`cf.` :meth:`~torch_scatter.scatter_add`). - - For one-dimensional tensors, the operation computes - - .. math:: - \mathrm{out}_i = \mathrm{out}_i - \sum_j \mathrm{src}_j - - where :math:`\sum_j` is over :math:`j` such that - :math:`\mathrm{index}_j = i`. - - Args: - src (Tensor): The source tensor. - index (LongTensor): The indices of elements to scatter. - dim (int, optional): The axis along which to index. - (default: :obj:`-1`) - out (Tensor, optional): The destination tensor. (default: :obj:`None`) - dim_size (int, optional): If :attr:`out` is not given, automatically - create output with size :attr:`dim_size` at dimension :attr:`dim`. - If :attr:`dim_size` is not given, a minimal sized output tensor is - returned. (default: :obj:`None`) - fill_value (int, optional): If :attr:`out` is not given, automatically - fill output tensor with :attr:`fill_value`. (default: :obj:`0`) - - :rtype: :class:`Tensor` - - .. testsetup:: - - import torch - - .. testcode:: - - from torch_scatter import scatter_sub - - src = torch.Tensor([[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]]) - index = torch.tensor([[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]]) - out = src.new_zeros((2, 6)) - - out = scatter_sub(src, index, out=out) - - print(out) - - .. testoutput:: - - tensor([[ 0., 0., -4., -3., -3., 0.], - [-2., -4., -4., 0., 0., 0.]]) - """ - return scatter_add(src.neg(), index, dim, out, dim_size, fill_value) diff --git a/torch_scatter/utils/__init__.py b/torch_scatter/utils/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/torch_scatter/utils/gen.py b/torch_scatter/utils/gen.py deleted file mode 100644 index aa6a6221..00000000 --- a/torch_scatter/utils/gen.py +++ /dev/null @@ -1,54 +0,0 @@ -from __future__ import division - -from itertools import repeat - -import torch - - -def maybe_dim_size(index, dim_size=None): - if dim_size is not None: - return dim_size - dim = index.max().item() + 1 if index.numel() > 0 else 0 - return int(dim) - - -def broadcast(src, index, dim): - dim = range(src.dim())[dim] # Get real dim value. - - if index.dim() == 1: - index_size = list(repeat(1, src.dim())) - index_size[dim] = src.size(dim) - if index.numel() > 0: - index = index.view(index_size).expand_as(src) - else: # pragma: no cover - # PyTorch has a bug when view is used on zero-element tensors. - index = src.new_empty(index_size, dtype=torch.long) - - # Broadcasting capabilties: Expand dimensions to match. - if src.dim() != index.dim(): - raise ValueError( - ('Number of dimensions of src and index tensor do not match, ' - 'got {} and {}').format(src.dim(), index.dim())) - - expand_size = [] - for s, i in zip(src.size(), index.size()): - expand_size += [-1 if s == i and s != 1 and i != 1 else max(i, s)] - - src = src.expand(expand_size) - index = index.expand_as(src) - - return src, index - - -def gen(src, index, dim=-1, out=None, dim_size=None, fill_value=0): - src, index = broadcast(src, index, dim) - dim = range(src.dim())[dim] # Get real dim value. - - # Generate output tensor if not given. - if out is None: - out_size = list(src.size()) - dim_size = maybe_dim_size(index, dim_size) - out_size[dim] = dim_size - out = src.new_full(out_size, fill_value) - - return src, out, index, dim From 82838e1de2963606f53da0c91a7865362e57ed9c Mon Sep 17 00:00:00 2001 From: rusty1s Date: Thu, 30 Jan 2020 15:16:35 +0100 Subject: [PATCH 09/12] update readme --- README.md | 30 +++++++++++++----------------- 1 file changed, 13 insertions(+), 17 deletions(-) diff --git a/README.md b/README.md index 49b55f2c..45f95924 100644 --- a/README.md +++ b/README.md @@ -22,26 +22,22 @@ **[Documentation](https://pytorch-scatter.readthedocs.io)** -This package consists of a small extension library of highly optimized sparse update (scatter) operations for the use in [PyTorch](http://pytorch.org/), which are missing in the main package. -Scatter operations can be roughly described as reduce operations based on a given "group-index" tensor. +This package consists of a small extension library of highly optimized sparse update (scatter/segment) operations for the use in [PyTorch](http://pytorch.org/), which are missing in the main package. +Scatter and segment operations can be roughly described as reduce operations based on a given "group-index" tensor. The package consists of the following operations: -* [**Scatter Add**](https://pytorch-scatter.readthedocs.io/en/latest/functions/add.html) -* [**Scatter Sub**](https://pytorch-scatter.readthedocs.io/en/latest/functions/sub.html) -* [**Scatter Mul**](https://pytorch-scatter.readthedocs.io/en/latest/functions/mul.html) -* [**Scatter Div**](https://pytorch-scatter.readthedocs.io/en/latest/functions/div.html) -* [**Scatter Mean**](https://pytorch-scatter.readthedocs.io/en/latest/functions/mean.html) -* [**Scatter Std**](https://pytorch-scatter.readthedocs.io/en/latest/functions/std.html) -* [**Scatter Min**](https://pytorch-scatter.readthedocs.io/en/latest/functions/min.html) -* [**Scatter Max**](https://pytorch-scatter.readthedocs.io/en/latest/functions/max.html) -* [**Scatter LogSumExp**](https://pytorch-scatter.readthedocs.io/en/latest/functions/logsumexp.html) +* [**Scatter**](https://pytorch-scatter.readthedocs.io/en/latest/functions/add.html) +* [**SegmentCOO**](https://pytorch-scatter.readthedocs.io/en/latest/functions/add.html) +* [**SegmentCSR**](https://pytorch-scatter.readthedocs.io/en/latest/functions/add.html) In addition, we provide composite functions which make use of `scatter_*` operations under the hood: +* [**Scatter Std**](https://pytorch-scatter.readthedocs.io/en/latest/composite/softmax.html#torch_scatter.composite.scatter_std) +* [**Scatter LogSumExp**](https://pytorch-scatter.readthedocs.io/en/latest/composite/softmax.html#torch_scatter.composite.scatter_logsumexp) * [**Scatter Softmax**](https://pytorch-scatter.readthedocs.io/en/latest/composite/softmax.html#torch_scatter.composite.scatter_softmax) * [**Scatter LogSoftmax**](https://pytorch-scatter.readthedocs.io/en/latest/composite/softmax.html#torch_scatter.composite.scatter_log_softmax) -All included operations are broadcastable, work on varying data types, and are implemented both for CPU and GPU with corresponding backward implementations. +All included operations are broadcastable, work on varying data types, are implemented both for CPU and GPU with corresponding backward implementations, and are fully traceable via `@torch.jit.script`. ## Installation @@ -81,17 +77,17 @@ from torch_scatter import scatter_max src = torch.tensor([[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]]) index = torch.tensor([[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]]) -out, argmax = scatter_max(src, index, fill_value=0) +out, argmax = scatter_max(src, index, dim=-1) ``` ``` print(out) -tensor([[ 0, 0, 4, 3, 2, 0], - [ 2, 4, 3, 0, 0, 0]]) +tensor([[0, 0, 4, 3, 2, 0], + [2, 4, 3, 0, 0, 0]]) print(argmax) -tensor([[-1, -1, 3, 4, 0, 1] - [ 1, 4, 3, -1, -1, -1]]) +tensor([[5, 5, 3, 4, 0, 1] + [1, 4, 3, 5, 5, 5]]) ``` ## Running tests From 99db5b801d2ab866bf3ef8fa2f9ac1e9bb9588da Mon Sep 17 00:00:00 2001 From: rusty1s Date: Thu, 30 Jan 2020 15:26:08 +0100 Subject: [PATCH 10/12] benchmark fixes --- benchmark/scatter_segment.py | 35 +++++++++++++---------------------- 1 file changed, 13 insertions(+), 22 deletions(-) diff --git a/benchmark/scatter_segment.py b/benchmark/scatter_segment.py index e5b08c0b..d1fbe80e 100644 --- a/benchmark/scatter_segment.py +++ b/benchmark/scatter_segment.py @@ -7,9 +7,7 @@ import torch from scipy.io import loadmat -import torch_scatter -from torch_scatter import scatter_add, scatter_mean, scatter_min, scatter_max -from torch_scatter import segment_coo, segment_csr +from torch_scatter import scatter, segment_coo, segment_csr short_rows = [ ('DIMACS10', 'citationCiteseer'), @@ -47,34 +45,30 @@ def correctness(dataset): x = torch.randn((row.size(0), size), device=args.device) x = x.squeeze(-1) if size == 1 else x - out1 = scatter_add(x, row, dim=0, dim_size=dim_size) + out1 = scatter(x, row, dim=0, dim_size=dim_size, reduce='add') out2 = segment_coo(x, row, dim_size=dim_size, reduce='add') out3 = segment_csr(x, rowptr, reduce='add') assert torch.allclose(out1, out2, atol=1e-4) assert torch.allclose(out1, out3, atol=1e-4) - out1 = scatter_mean(x, row, dim=0, dim_size=dim_size) + out1 = scatter(x, row, dim=0, dim_size=dim_size, reduce='mean') out2 = segment_coo(x, row, dim_size=dim_size, reduce='mean') out3 = segment_csr(x, rowptr, reduce='mean') assert torch.allclose(out1, out2, atol=1e-4) assert torch.allclose(out1, out3, atol=1e-4) - x = x.abs_().mul_(-1) - - out1, _ = scatter_min(x, row, 0, torch.zeros_like(out1)) - out2, _ = segment_coo(x, row, reduce='min') - out3, _ = segment_csr(x, rowptr, reduce='min') + out1 = scatter(x, row, dim=0, dim_size=dim_size, reduce='min') + out2 = segment_coo(x, row, reduce='min') + out3 = segment_csr(x, rowptr, reduce='min') assert torch.allclose(out1, out2, atol=1e-4) assert torch.allclose(out1, out3, atol=1e-4) - x = x.abs_() - - out1, _ = scatter_max(x, row, 0, torch.zeros_like(out1)) - out2, _ = segment_coo(x, row, reduce='max') - out3, _ = segment_csr(x, rowptr, reduce='max') + out1 = scatter(x, row, dim=0, dim_size=dim_size, reduce='max') + out2 = segment_coo(x, row, reduce='max') + out3 = segment_csr(x, rowptr, reduce='max') assert torch.allclose(out1, out2, atol=1e-4) assert torch.allclose(out1, out3, atol=1e-4) @@ -117,17 +111,15 @@ def timing(dataset): mat = loadmat(f'{name}.mat')['Problem'][0][0][2].tocsr() rowptr = torch.from_numpy(mat.indptr).to(args.device, torch.long) row = torch.from_numpy(mat.tocoo().row).to(args.device, torch.long) - row_perm = row[torch.randperm(row.size(0))] + row2 = row[torch.randperm(row.size(0))] dim_size = rowptr.size(0) - 1 avg_row_len = row.size(0) / dim_size def sca_row(x): - op = getattr(torch_scatter, f'scatter_{args.scatter_reduce}') - return op(x, row, dim=0, dim_size=dim_size) + return scatter(x, row, dim=0, dim_size=dim_size, reduce=args.reduce) def sca_col(x): - op = getattr(torch_scatter, f'scatter_{args.scatter_reduce}') - return op(x, row_perm, dim=0, dim_size=dim_size) + return scatter(x, row2, dim=0, dim_size=dim_size, reduce=args.reduce) def seg_coo(x): return segment_coo(x, row, reduce=args.reduce) @@ -205,11 +197,10 @@ def dense2(x): if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--reduce', type=str, required=True, - choices=['sum', 'mean', 'min', 'max']) + choices=['sum', 'add', 'mean', 'min', 'max']) parser.add_argument('--with_backward', action='store_true') parser.add_argument('--device', type=str, default='cuda') args = parser.parse_args() - args.scatter_reduce = 'add' if args.reduce == 'sum' else args.reduce iters = 1 if args.device == 'cpu' else 20 sizes = [1, 16, 32, 64, 128, 256, 512] sizes = sizes[:3] if args.device == 'cpu' else sizes From 356d0fe868e69a002224ff4aaca374233a4479b4 Mon Sep 17 00:00:00 2001 From: rusty1s Date: Thu, 30 Jan 2020 15:57:31 +0100 Subject: [PATCH 11/12] update doc --- .travis.yml | 1 + README.md | 19 +++--- docs/requirements.txt | 1 + docs/source/composite/softmax.rst | 9 --- docs/source/conf.py | 1 + docs/source/functions/add.rst | 6 +- docs/source/functions/div.rst | 7 --- docs/source/functions/logsumexp.rst | 7 --- docs/source/functions/max.rst | 7 --- docs/source/functions/mean.rst | 7 --- docs/source/functions/min.rst | 6 -- docs/source/functions/mul.rst | 7 --- docs/source/functions/std.rst | 7 --- docs/source/functions/sub.rst | 7 --- docs/source/index.rst | 3 +- setup.py | 5 +- torch_scatter/scatter.py | 91 +++++++++++++++-------------- torch_scatter/segment_coo.py | 36 +++++------- torch_scatter/segment_csr.py | 26 ++++----- 19 files changed, 89 insertions(+), 164 deletions(-) delete mode 100644 docs/source/composite/softmax.rst delete mode 100644 docs/source/functions/div.rst delete mode 100644 docs/source/functions/logsumexp.rst delete mode 100644 docs/source/functions/max.rst delete mode 100644 docs/source/functions/mean.rst delete mode 100644 docs/source/functions/min.rst delete mode 100644 docs/source/functions/mul.rst delete mode 100644 docs/source/functions/std.rst delete mode 100644 docs/source/functions/sub.rst diff --git a/.travis.yml b/.travis.yml index 35b96ba9..1d8f83d6 100644 --- a/.travis.yml +++ b/.travis.yml @@ -39,6 +39,7 @@ install: - pip install codecov - pip install sphinx - pip install sphinx_rtd_theme + - pip install sphinx-autodoc-typehints script: - python -c "import torch; print(torch.__version__)" - pycodestyle . diff --git a/README.md b/README.md index 45f95924..a79852f7 100644 --- a/README.md +++ b/README.md @@ -22,22 +22,17 @@ **[Documentation](https://pytorch-scatter.readthedocs.io)** -This package consists of a small extension library of highly optimized sparse update (scatter/segment) operations for the use in [PyTorch](http://pytorch.org/), which are missing in the main package. +This package consists of a small extension library of highly optimized sparse update (scatter and segment) operations for the use in [PyTorch](http://pytorch.org/), which are missing in the main package. Scatter and segment operations can be roughly described as reduce operations based on a given "group-index" tensor. -The package consists of the following operations: +Segment operations require the "group-index" tensor to be sorted, whereas scatter operations are not subject to these requirements. -* [**Scatter**](https://pytorch-scatter.readthedocs.io/en/latest/functions/add.html) -* [**SegmentCOO**](https://pytorch-scatter.readthedocs.io/en/latest/functions/add.html) -* [**SegmentCSR**](https://pytorch-scatter.readthedocs.io/en/latest/functions/add.html) +The package consists of the following operations with reduction types `"sum"|"mean"|"min"|"max"`: -In addition, we provide composite functions which make use of `scatter_*` operations under the hood: +* [**scatter**](https://pytorch-scatter.readthedocs.io/en/latest/functions/segment.html) based on arbitrary indices +* [**segment_coo**](https://pytorch-scatter.readthedocs.io/en/latest/functions/segment_coo.html) based on sorted indices +* [**segment_csr**](https://pytorch-scatter.readthedocs.io/en/latest/functions/segment_csr.html) based on compressed indices via pointers -* [**Scatter Std**](https://pytorch-scatter.readthedocs.io/en/latest/composite/softmax.html#torch_scatter.composite.scatter_std) -* [**Scatter LogSumExp**](https://pytorch-scatter.readthedocs.io/en/latest/composite/softmax.html#torch_scatter.composite.scatter_logsumexp) -* [**Scatter Softmax**](https://pytorch-scatter.readthedocs.io/en/latest/composite/softmax.html#torch_scatter.composite.scatter_softmax) -* [**Scatter LogSoftmax**](https://pytorch-scatter.readthedocs.io/en/latest/composite/softmax.html#torch_scatter.composite.scatter_log_softmax) - -All included operations are broadcastable, work on varying data types, are implemented both for CPU and GPU with corresponding backward implementations, and are fully traceable via `@torch.jit.script`. +All included operations are broadcastable, work on varying data types, are implemented both for CPU and GPU with corresponding backward implementations, and are fully traceable. ## Installation diff --git a/docs/requirements.txt b/docs/requirements.txt index b660e11f..3b71a44c 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -3,3 +3,4 @@ numpy torch_nightly sphinx sphinx_rtd_theme +sphinx-autodoc-typehints diff --git a/docs/source/composite/softmax.rst b/docs/source/composite/softmax.rst deleted file mode 100644 index 4f03820d..00000000 --- a/docs/source/composite/softmax.rst +++ /dev/null @@ -1,9 +0,0 @@ -Scatter Softmax -=============== - -.. automodule:: torch_scatter.composite - :noindex: - -.. autofunction:: scatter_softmax - -.. autofunction:: scatter_log_softmax diff --git a/docs/source/conf.py b/docs/source/conf.py index eccdb5f1..0b9693e9 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -11,6 +11,7 @@ 'sphinx.ext.napoleon', 'sphinx.ext.viewcode', 'sphinx.ext.githubpages', + 'sphinx_autodoc_typehints', ] source_suffix = '.rst' diff --git a/docs/source/functions/add.rst b/docs/source/functions/add.rst index 0cc735b8..a6d03a32 100644 --- a/docs/source/functions/add.rst +++ b/docs/source/functions/add.rst @@ -1,7 +1,7 @@ -Scatter Add -=========== +Scatter +======= .. automodule:: torch_scatter :noindex: -.. autofunction:: scatter_add +.. autofunction:: scatter diff --git a/docs/source/functions/div.rst b/docs/source/functions/div.rst deleted file mode 100644 index 7216d662..00000000 --- a/docs/source/functions/div.rst +++ /dev/null @@ -1,7 +0,0 @@ -Scatter Div -=========== - -.. automodule:: torch_scatter - :noindex: - -.. autofunction:: scatter_div diff --git a/docs/source/functions/logsumexp.rst b/docs/source/functions/logsumexp.rst deleted file mode 100644 index 88aea63b..00000000 --- a/docs/source/functions/logsumexp.rst +++ /dev/null @@ -1,7 +0,0 @@ -Scatter LogSumExp -================= - -.. automodule:: torch_scatter - :noindex: - -.. autofunction:: scatter_logsumexp diff --git a/docs/source/functions/max.rst b/docs/source/functions/max.rst deleted file mode 100644 index da3d8b23..00000000 --- a/docs/source/functions/max.rst +++ /dev/null @@ -1,7 +0,0 @@ -Scatter Max -=========== - -.. automodule:: torch_scatter - :noindex: - -.. autofunction:: scatter_max diff --git a/docs/source/functions/mean.rst b/docs/source/functions/mean.rst deleted file mode 100644 index 6d409e7e..00000000 --- a/docs/source/functions/mean.rst +++ /dev/null @@ -1,7 +0,0 @@ -Scatter Mean -============ - -.. automodule:: torch_scatter - :noindex: - -.. autofunction:: scatter_mean diff --git a/docs/source/functions/min.rst b/docs/source/functions/min.rst deleted file mode 100644 index 5af8d292..00000000 --- a/docs/source/functions/min.rst +++ /dev/null @@ -1,6 +0,0 @@ -Scatter Min -=========== - -.. automodule:: torch_scatter - -.. autofunction:: scatter_min diff --git a/docs/source/functions/mul.rst b/docs/source/functions/mul.rst deleted file mode 100644 index 69c14fa4..00000000 --- a/docs/source/functions/mul.rst +++ /dev/null @@ -1,7 +0,0 @@ -Scatter Mul -=========== - -.. automodule:: torch_scatter - :noindex: - -.. autofunction:: scatter_mul diff --git a/docs/source/functions/std.rst b/docs/source/functions/std.rst deleted file mode 100644 index b3510a0b..00000000 --- a/docs/source/functions/std.rst +++ /dev/null @@ -1,7 +0,0 @@ -Scatter Std -=========== - -.. automodule:: torch_scatter - :noindex: - -.. autofunction:: scatter_std diff --git a/docs/source/functions/sub.rst b/docs/source/functions/sub.rst deleted file mode 100644 index 73310ccc..00000000 --- a/docs/source/functions/sub.rst +++ /dev/null @@ -1,7 +0,0 @@ -Scatter Sub -=========== - -.. automodule:: torch_scatter - :noindex: - -.. autofunction:: scatter_sub diff --git a/docs/source/index.rst b/docs/source/index.rst index b061d15d..51551643 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -7,7 +7,7 @@ This package consists of a small extension library of highly optimized sparse up Scatter and segment operations can be roughly described as reduce operations based on a given "group-index" tensor. Segment operations require the "group-index" tensor to be sorted, whereas scatter operations are not subject to these requirements. -All included operations are broadcastable, work on varying data types, and are implemented both for CPU and GPU with corresponding backward implementations. +All included operations are broadcastable, work on varying data types, are implemented both for CPU and GPU with corresponding backward implementations, and are fully traceable. .. toctree:: :glob: @@ -15,7 +15,6 @@ All included operations are broadcastable, work on varying data types, and are i :caption: Package reference functions/* - composite/* Indices and tables ================== diff --git a/setup.py b/setup.py index d803ad16..db214fe6 100644 --- a/setup.py +++ b/setup.py @@ -9,7 +9,10 @@ from torch.utils.cpp_extension import CppExtension, CUDAExtension, CUDA_HOME WITH_CUDA = torch.cuda.is_available() and CUDA_HOME is not None -WITH_CUDA = WITH_CUDA or os.getenv('FORCE_CUDA', '0') == '1' +if os.getenv('FORCE_CUDA', '0') == '1': + WITH_CUDA = True +if os.getenv('FORCE_NON_CUDA', '0') == '1': + WITH_CUDA = False def get_extensions(): diff --git a/torch_scatter/scatter.py b/torch_scatter/scatter.py index aa315347..73261437 100644 --- a/torch_scatter/scatter.py +++ b/torch_scatter/scatter.py @@ -44,7 +44,6 @@ def scatter_max(src: torch.Tensor, index: torch.Tensor, dim: int = -1, return torch.ops.torch_scatter.scatter_max(src, index, dim, out, dim_size) -@torch.jit.script def scatter(src: torch.Tensor, index: torch.Tensor, dim: int = -1, out: Optional[torch.Tensor] = None, dim_size: Optional[int] = None, reduce: str = "sum") -> torch.Tensor: @@ -58,64 +57,68 @@ def scatter(src: torch.Tensor, index: torch.Tensor, dim: int = -1, | - Sums all values from the :attr:`src` tensor into :attr:`out` at the indices - specified in the :attr:`index` tensor along a given axis :attr:`dim`. For - each value in :attr:`src`, its output index is specified by its index in - :attr:`src` for dimensions outside of :attr:`dim` and by the - corresponding value in :attr:`index` for dimension :attr:`dim`. If - multiple indices reference the same location, their **contributions add**. - - Formally, if :attr:`src` and :attr:`index` are n-dimensional tensors with - size :math:`(x_0, ..., x_{i-1}, x_i, x_{i+1}, ..., x_{n-1})` and - :attr:`dim` = `i`, then :attr:`out` must be an n-dimensional tensor with - size :math:`(x_0, ..., x_{i-1}, y, x_{i+1}, ..., x_{n-1})`. Moreover, the - values of :attr:`index` must be between `0` and `out.size(dim) - 1`. - Both :attr:`src` and :attr:`index` are broadcasted in case their dimensions - do not match. - - For one-dimensional tensors, the operation computes + Reduces all values from the :attr:`src` tensor into :attr:`out` at the + indices specified in the :attr:`index` tensor along a given axis + :attr:`dim`. + For each value in :attr:`src`, its output index is specified by its index + in :attr:`src` for dimensions outside of :attr:`dim` and by the + corresponding value in :attr:`index` for dimension :attr:`dim`. + The applied reduction is defined via the :attr:`reduce` argument. + + Formally, if :attr:`src` and :attr:`index` are :math:`n`-dimensional + tensors with size :math:`(x_0, ..., x_{i-1}, x_i, x_{i+1}, ..., x_{n-1})` + and :attr:`dim` = `i`, then :attr:`out` must be an :math:`n`-dimensional + tensor with size :math:`(x_0, ..., x_{i-1}, y, x_{i+1}, ..., x_{n-1})`. + Moreover, the values of :attr:`index` must be between :math:`0` and + :math:`y - 1` in ascending order. + The :attr:`index` tensor supports broadcasting in case its dimensions do + not match with :attr:`src`. + + For one-dimensional tensors with :obj:`reduce="sum"`, the operation + computes .. math:: - \mathrm{out}_i = \mathrm{out}_i + \sum_j \mathrm{src}_j + \mathrm{out}_i = \mathrm{out}_i + \sum_j~\mathrm{src}_j where :math:`\sum_j` is over :math:`j` such that :math:`\mathrm{index}_j = i`. - Args: - src (Tensor): The source tensor. - index (LongTensor): The indices of elements to scatter. - dim (int, optional): The axis along which to index. - (default: :obj:`-1`) - out (Tensor, optional): The destination tensor. (default: :obj:`None`) - dim_size (int, optional): If :attr:`out` is not given, automatically - create output with size :attr:`dim_size` at dimension :attr:`dim`. - If :attr:`dim_size` is not given, a minimal sized output tensor is - returned. (default: :obj:`None`) - fill_value (int, optional): If :attr:`out` is not given, automatically - fill output tensor with :attr:`fill_value`. (default: :obj:`0`) + .. note:: + + This operation is implemented via atomic operations on the GPU and is + therefore **non-deterministic** since the order of parallel operations + to the same value is undetermined. + For floating-point variables, this results in a source of variance in + the result. + + :param src: The source tensor. + :param index: The indices of elements to scatter. + :param dim: The axis along which to index. (default: :obj:`-1`) + :param out: The destination tensor. + :param dim_size: If :attr:`out` is not given, automatically create output + with size :attr:`dim_size` at dimension :attr:`dim`. + If :attr:`dim_size` is not given, a minimal sized output tensor + according to :obj:`index.max() + 1` is returned. + :param reduce: The reduce operation (:obj:`"sum"`, :obj:`"mean"`, + :obj:`"min"` or :obj:`"max"`). (default: :obj:`"sum"`) :rtype: :class:`Tensor` - .. testsetup:: - - import torch - - .. testcode:: + .. code-block:: python - from torch_scatter import scatter_add + from torch_scatter import scatter - src = torch.Tensor([[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]]) - index = torch.tensor([[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]]) - out = src.new_zeros((2, 6)) + src = torch.randn(10, 6, 64) + index = torch.tensor([0, 1, 0, 1, 2, 1]) - out = scatter_add(src, index, out=out) + # Broadcasting in the first and last dim. + out = scatter(src, index, dim=1, reduce="sum") - print(out) + print(out.size()) - .. testoutput:: + .. code-block:: - tensor([[0., 0., 4., 3., 3., 0.], - [2., 4., 4., 0., 0., 0.]]) + torch.Size([10, 3, 64]) """ if reduce == 'sum' or reduce == 'add': return scatter_sum(src, index, dim, out, dim_size) diff --git a/torch_scatter/segment_coo.py b/torch_scatter/segment_coo.py index 5fbd089c..ee8307f2 100644 --- a/torch_scatter/segment_coo.py +++ b/torch_scatter/segment_coo.py @@ -44,7 +44,6 @@ def segment_max_coo(src: torch.Tensor, index: torch.Tensor, return torch.ops.torch_scatter.segment_max_coo(src, index, out, dim_size) -@torch.jit.script def segment_coo(src: torch.Tensor, index: torch.Tensor, out: Optional[torch.Tensor] = None, dim_size: Optional[int] = None, @@ -77,6 +76,7 @@ def segment_coo(src: torch.Tensor, index: torch.Tensor, :math:`y - 1` in ascending order. The :attr:`index` tensor supports broadcasting in case its dimensions do not match with :attr:`src`. + For one-dimensional tensors with :obj:`reduce="sum"`, the operation computes @@ -91,10 +91,6 @@ def segment_coo(src: torch.Tensor, index: torch.Tensor, Due to the use of sorted indices, :meth:`segment_coo` is usually faster than the more general :meth:`scatter` operation. - For reductions :obj:`"min"` and :obj:`"max"`, this operation returns a - second tensor representing the :obj:`argmin` and :obj:`argmax`, - respectively. - .. note:: This operation is implemented via atomic operations on the GPU and is @@ -103,23 +99,19 @@ def segment_coo(src: torch.Tensor, index: torch.Tensor, For floating-point variables, this results in a source of variance in the result. - Args: - src (Tensor): The source tensor. - index (LongTensor): The sorted indices of elements to segment. - The number of dimensions of :attr:`index` needs to be less than or - equal to :attr:`src`. - out (Tensor, optional): The destination tensor. (default: :obj:`None`) - dim_size (int, optional): If :attr:`out` is not given, automatically - create output with size :attr:`dim_size` at dimension - :obj:`index.dim() - 1`. - If :attr:`dim_size` is not given, a minimal sized output tensor - according to :obj:`index.max() + 1` is returned. - (default: :obj:`None`) - reduce (string, optional): The reduce operation (:obj:`"sum"`, - :obj:`"mean"`, :obj:`"min"` or :obj:`"max"`). - (default: :obj:`"sum"`) - - :rtype: :class:`Tensor`, :class:`LongTensor` *(optional)* + :param src: The source tensor. + :param index: The sorted indices of elements to segment. + The number of dimensions of :attr:`index` needs to be less than or + equal to :attr:`src`. + :param out: The destination tensor. + :param dim_size: If :attr:`out` is not given, automatically create output + with size :attr:`dim_size` at dimension :obj:`index.dim() - 1`. + If :attr:`dim_size` is not given, a minimal sized output tensor + according to :obj:`index.max() + 1` is returned. + :param reduce: The reduce operation (:obj:`"sum"`, :obj:`"mean"`, + :obj:`"min"` or :obj:`"max"`). (default: :obj:`"sum"`) + + :rtype: :class:`Tensor` .. code-block:: python diff --git a/torch_scatter/segment_csr.py b/torch_scatter/segment_csr.py index 861ac998..8f7f2923 100644 --- a/torch_scatter/segment_csr.py +++ b/torch_scatter/segment_csr.py @@ -39,7 +39,6 @@ def segment_max_csr(src: torch.Tensor, indptr: torch.Tensor, return torch.ops.torch_scatter.segment_max_csr(src, indptr, out) -@torch.jit.script def segment_csr(src: torch.Tensor, indptr: torch.Tensor, out: Optional[torch.Tensor] = None, reduce: str = "sum") -> torch.Tensor: @@ -63,6 +62,7 @@ def segment_csr(src: torch.Tensor, indptr: torch.Tensor, :math:`x_m` in ascending order. The :attr:`indptr` tensor supports broadcasting in case its dimensions do not match with :attr:`src`. + For one-dimensional tensors with :obj:`reduce="sum"`, the operation computes @@ -73,26 +73,20 @@ def segment_csr(src: torch.Tensor, indptr: torch.Tensor, Due to the use of index pointers, :meth:`segment_csr` is the fastest method to apply for grouped reductions. - For reductions :obj:`"min"` and :obj:`"max"`, this operation returns a - second tensor representing the :obj:`argmin` and :obj:`argmax`, - respectively. - .. note:: In contrast to :meth:`scatter()` and :meth:`segment_coo`, this operation is **fully-deterministic**. - Args: - src (Tensor): The source tensor. - indptr (LongTensor): The index pointers between elements to segment. - The number of dimensions of :attr:`index` needs to be less than or - equal to :attr:`src`. - out (Tensor, optional): The destination tensor. (default: :obj:`None`) - reduce (string, optional): The reduce operation (:obj:`"sum"`, - :obj:`"mean"`, :obj:`"min"` or :obj:`"max"`). - (default: :obj:`"sum"`) - - :rtype: :class:`Tensor`, :class:`LongTensor` *(optional)* + :param src: The source tensor. + :param indptr: The index pointers between elements to segment. + The number of dimensions of :attr:`index` needs to be less than or + equal to :attr:`src`. + :param out: The destination tensor. + :param reduce: The reduce operation (:obj:`"sum"`, :obj:`"mean"`, + :obj:`"min"` or :obj:`"max"`). (default: :obj:`"sum"`) + + :rtype: :class:`Tensor` .. code-block:: python From 02a47c46d990cbc0ca0edef8b57c4d20090ed9de Mon Sep 17 00:00:00 2001 From: rusty1s Date: Thu, 30 Jan 2020 15:59:29 +0100 Subject: [PATCH 12/12] composite --- README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/README.md b/README.md index a79852f7..dd47a525 100644 --- a/README.md +++ b/README.md @@ -32,6 +32,8 @@ The package consists of the following operations with reduction types `"sum"|"me * [**segment_coo**](https://pytorch-scatter.readthedocs.io/en/latest/functions/segment_coo.html) based on sorted indices * [**segment_csr**](https://pytorch-scatter.readthedocs.io/en/latest/functions/segment_csr.html) based on compressed indices via pointers +In addition, we provide the following **composite functions** which make use of `scatter_*` operations under the hood: :`scatter_std`, `scatter_logsumexp`, `scatter_softmax` and `scatter_log_softmax`. + All included operations are broadcastable, work on varying data types, are implemented both for CPU and GPU with corresponding backward implementations, and are fully traceable. ## Installation