From ddd545ad8e3e57d69c4eb1c23ea209dc76f9e372 Mon Sep 17 00:00:00 2001 From: Jacob Zhong Date: Fri, 24 Apr 2020 11:46:01 -0400 Subject: [PATCH 1/9] compile with half --- csrc/cpu/reducer.h | 2 +- csrc/cpu/scatter_cpu.cpp | 2 +- csrc/cpu/segment_coo_cpu.cpp | 4 ++-- csrc/cpu/segment_csr_cpu.cpp | 4 ++-- csrc/cuda/atomics.cuh | 40 +++++++++++++++++++++++++++++++++++ csrc/cuda/reducer.cuh | 2 +- csrc/cuda/scatter_cuda.cu | 2 +- csrc/cuda/segment_coo_cuda.cu | 10 +++++---- csrc/cuda/segment_csr_cuda.cu | 7 +++--- 9 files changed, 58 insertions(+), 15 deletions(-) diff --git a/csrc/cpu/reducer.h b/csrc/cpu/reducer.h index a07033aa..b2267cdb 100644 --- a/csrc/cpu/reducer.h +++ b/csrc/cpu/reducer.h @@ -72,7 +72,7 @@ template struct Reducer { if (REDUCE == SUM || REDUCE == MUL || REDUCE == DIV) *address = val; else if (REDUCE == MEAN) - *address = val / (count > 0 ? count : (scalar_t)1); + *address = val / (scalar_t)(count > 0 ? count : 1); else if (REDUCE == MIN || REDUCE == MAX) { if (count > 0) { *address = val; diff --git a/csrc/cpu/scatter_cpu.cpp b/csrc/cpu/scatter_cpu.cpp index 5f9da470..59ee1311 100644 --- a/csrc/cpu/scatter_cpu.cpp +++ b/csrc/cpu/scatter_cpu.cpp @@ -54,7 +54,7 @@ scatter_cpu(torch::Tensor src, torch::Tensor index, int64_t dim, auto N = out.size(dim); auto index_info = getTensorInfo(index); - AT_DISPATCH_ALL_TYPES(src.scalar_type(), "scatter", [&] { + AT_DISPATCH_FLOATING_TYPES_AND(at::ScalarType::Half, src.scalar_type(), "scatter", [&] { auto src_data = src.data_ptr(); auto out_data = out.data_ptr(); diff --git a/csrc/cpu/segment_coo_cpu.cpp b/csrc/cpu/segment_coo_cpu.cpp index c59afd2b..27640c76 100644 --- a/csrc/cpu/segment_coo_cpu.cpp +++ b/csrc/cpu/segment_coo_cpu.cpp @@ -63,7 +63,7 @@ segment_coo_cpu(torch::Tensor src, torch::Tensor index, auto index_info = getTensorInfo(index); auto stride = index_info.strides[index_info.dims - 1]; std::vector args(K); - AT_DISPATCH_ALL_TYPES(src.scalar_type(), "segment_coo", [&] { + AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, src.scalar_type(), "segment_coo", [&] { auto src_data = src.data_ptr(); auto out_data = out.data_ptr(); scalar_t *count_data = nullptr; @@ -168,7 +168,7 @@ torch::Tensor gather_coo_cpu(torch::Tensor src, torch::Tensor index, auto index_info = getTensorInfo(index); auto stride = index_info.strides[index_info.dims - 1]; - AT_DISPATCH_ALL_TYPES(src.scalar_type(), "gather_coo", [&] { + AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, src.scalar_type(), "gather_coo", [&] { auto src_data = src.data_ptr(); auto out_data = out.data_ptr(); diff --git a/csrc/cpu/segment_csr_cpu.cpp b/csrc/cpu/segment_csr_cpu.cpp index 6dca23f6..68d70531 100644 --- a/csrc/cpu/segment_csr_cpu.cpp +++ b/csrc/cpu/segment_csr_cpu.cpp @@ -54,7 +54,7 @@ segment_csr_cpu(torch::Tensor src, torch::Tensor indptr, auto indptr_info = getTensorInfo(indptr); auto stride = indptr_info.strides[indptr_info.dims - 1]; std::vector args(K); - AT_DISPATCH_ALL_TYPES(src.scalar_type(), "segment_csr", [&] { + AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, src.scalar_type(), "segment_csr", [&] { auto src_data = src.data_ptr(); auto out_data = out.data_ptr(); @@ -129,7 +129,7 @@ torch::Tensor gather_csr_cpu(torch::Tensor src, torch::Tensor indptr, auto indptr_info = getTensorInfo(indptr); auto stride = indptr_info.strides[indptr_info.dims - 1]; - AT_DISPATCH_ALL_TYPES(src.scalar_type(), "gather_csr", [&] { + AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, src.scalar_type(), "gather_csr", [&] { auto src_data = src.data_ptr(); auto out_data = out.data_ptr(); diff --git a/csrc/cuda/atomics.cuh b/csrc/cuda/atomics.cuh index 32427eac..f5913388 100644 --- a/csrc/cuda/atomics.cuh +++ b/csrc/cuda/atomics.cuh @@ -68,6 +68,25 @@ \ template struct Atomic##NAME##DecimalImpl; \ \ + template struct Atomic##NAME##DecimalImpl { \ + inline __device__ void operator()(scalar *address, scalar val) { \ + unsigned int * address_as_ui = \ + (unsigned int *) ((char *)address - ((size_t)address & 2)); \ + unsigned int old = *address_as_ui; \ + unsigned int assumed; \ + \ + do { \ + assumed = old; \ + at::Half hsum; \ + hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff); \ + hsum = OP(hsum, val); \ + old = (size_t)address & 2 ? (old & 0xffff) | \ + (hsum.x << 16) : (old & 0xffff0000) | hsum.x; \ + old = atomicCAS(address_as_ui, assumed, old); \ + } while (assumed != old); \ + } \ + }; \ + \ template struct Atomic##NAME##DecimalImpl { \ inline __device__ void operator()(scalar *address, scalar val) { \ int *address_as_i = (int *)address; \ @@ -116,6 +135,15 @@ static inline __device__ void atomAdd(int32_t *address, int32_t val) { static inline __device__ void atomAdd(int64_t *address, int64_t val) { AtomicAddIntegerImpl()(address, val); } +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 700 || CUDA_VERSION < 10000) +static inline __device__ void atomAdd(at::Half *address, at::Half val) { + AtomicAddDecimalImpl()(address, val); +} +#else +static inline __device__ void atomAdd(at::Half *address, at::Half val) { + atomicAdd(reinterpret_cast<__half*>(address), val); +} +#endif static inline __device__ void atomAdd(float *address, float val) { atomicAdd(address, val); } @@ -150,6 +178,9 @@ static inline __device__ void atomMul(int64_t *address, int64_t val) { static inline __device__ void atomMul(float *address, float val) { AtomicMulDecimalImpl()(address, val); } +static inline __device__ void atomMul(at::Half *address, at::Half val) { + AtomicMulDecimalImpl()(address, val); +} static inline __device__ void atomMul(double *address, double val) { AtomicMulDecimalImpl()(address, val); } @@ -172,6 +203,9 @@ static inline __device__ void atomDiv(int32_t *address, int32_t val) { static inline __device__ void atomDiv(int64_t *address, int64_t val) { AtomicDivIntegerImpl()(address, val); } +static inline __device__ void atomDiv(at::Half *address, at::Half val) { + AtomicDivDecimalImpl()(address, val); +} static inline __device__ void atomDiv(float *address, float val) { AtomicDivDecimalImpl()(address, val); } @@ -197,6 +231,9 @@ static inline __device__ void atomMax(int32_t *address, int32_t val) { static inline __device__ void atomMax(int64_t *address, int64_t val) { AtomicMaxIntegerImpl()(address, val); } +static inline __device__ void atomMax(at::Half *address, at::Half val) { + AtomicMaxDecimalImpl()(address, val); +} static inline __device__ void atomMax(float *address, float val) { AtomicMaxDecimalImpl()(address, val); } @@ -222,6 +259,9 @@ static inline __device__ void atomMin(int32_t *address, int32_t val) { static inline __device__ void atomMin(int64_t *address, int64_t val) { AtomicMinIntegerImpl()(address, val); } +static inline __device__ void atomMin(at::Half *address, at::Half val) { + AtomicMinDecimalImpl()(address, val); +} static inline __device__ void atomMin(float *address, float val) { AtomicMinDecimalImpl()(address, val); } diff --git a/csrc/cuda/reducer.cuh b/csrc/cuda/reducer.cuh index 8b318958..8c851d20 100644 --- a/csrc/cuda/reducer.cuh +++ b/csrc/cuda/reducer.cuh @@ -89,7 +89,7 @@ template struct Reducer { if (REDUCE == SUM || REDUCE == MUL || REDUCE == DIV) *address = val; else if (REDUCE == MEAN) - *address = val / (count > 0 ? count : (scalar_t)1); + *address = val / (scalar_t)(count > 0 ? count : 1); else if (REDUCE == MIN || REDUCE == MAX) { if (count > 0) { *address = val; diff --git a/csrc/cuda/scatter_cuda.cu b/csrc/cuda/scatter_cuda.cu index 1f2191c8..05565e20 100644 --- a/csrc/cuda/scatter_cuda.cu +++ b/csrc/cuda/scatter_cuda.cu @@ -111,7 +111,7 @@ scatter_cuda(torch::Tensor src, torch::Tensor index, int64_t dim, auto index_info = at::cuda::detail::getTensorInfo(index); auto stream = at::cuda::getCurrentCUDAStream(); - AT_DISPATCH_ALL_TYPES(src.scalar_type(), "scatter", [&] { + AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, src.scalar_type(), "scatter", [&] { auto src_data = src.data_ptr(); auto out_data = out.data_ptr(); diff --git a/csrc/cuda/segment_coo_cuda.cu b/csrc/cuda/segment_coo_cuda.cu index 83851696..f6c6eaf7 100644 --- a/csrc/cuda/segment_coo_cuda.cu +++ b/csrc/cuda/segment_coo_cuda.cu @@ -1,5 +1,6 @@ #include "segment_coo_cuda.h" +#include #include #include #include @@ -34,9 +35,10 @@ segment_coo_kernel(const scalar_t *src_data, scalar_t val = HAS_VAL ? src_data[row_idx] : (scalar_t)1, tmp; #pragma unroll - for (int i = 1; i < 32; i *= 2) { + for (int i = 1; i < 32; i *= 2) { // use left shift? // Parallel reduction inside a single warp. - tmp = __shfl_up_sync(FULL_MASK, val, i); + using float_t = typename std::conditional::value, __half, scalar_t>::type; + tmp = __shfl_up_sync(FULL_MASK, (float_t)val, i); next_idx = __shfl_up_sync(FULL_MASK, idx, i); if (lane_idx >= i && row_idx / D == (row_idx - i) / D) { assert(idx >= next_idx); @@ -215,7 +217,7 @@ segment_coo_cuda(torch::Tensor src, torch::Tensor index, auto index_info = at::cuda::detail::getTensorInfo(index); auto stream = at::cuda::getCurrentCUDAStream(); - AT_DISPATCH_ALL_TYPES(src.scalar_type(), "segment_coo_kernel", [&] { + AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, src.scalar_type(), "segment_coo_kernel", [&] { auto src_data = src.data_ptr(); auto out_data = out.data_ptr(); @@ -362,7 +364,7 @@ torch::Tensor gather_coo_cuda(torch::Tensor src, torch::Tensor index, auto index_info = at::cuda::detail::getTensorInfo(index); auto stream = at::cuda::getCurrentCUDAStream(); - AT_DISPATCH_ALL_TYPES(src.scalar_type(), "gather_coo_kernel", [&] { + AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, src.scalar_type(), "gather_coo_kernel", [&] { auto src_data = src.data_ptr(); auto out_data = out.data_ptr(); diff --git a/csrc/cuda/segment_csr_cuda.cu b/csrc/cuda/segment_csr_cuda.cu index d08bdffd..57b23dfb 100644 --- a/csrc/cuda/segment_csr_cuda.cu +++ b/csrc/cuda/segment_csr_cuda.cu @@ -47,8 +47,9 @@ segment_csr_kernel(const scalar_t *src_data, // Parallel reduction inside a single warp. if (REDUCE == MIN || REDUCE == MAX) arg_tmp = __shfl_down_sync(FULL_MASK, arg, i); + using float_t = typename std::conditional::value, __half, scalar_t>::type; Reducer::update( - &val, __shfl_down_sync(FULL_MASK, val, i), &arg, arg_tmp); + &val, __shfl_down_sync(FULL_MASK, (float_t)val, i), &arg, arg_tmp); } if (lane_idx == 0) { @@ -144,7 +145,7 @@ segment_csr_cuda(torch::Tensor src, torch::Tensor indptr, auto indptr_info = at::cuda::detail::getTensorInfo(indptr); auto stream = at::cuda::getCurrentCUDAStream(); - AT_DISPATCH_ALL_TYPES(src.scalar_type(), "segment_csr_kernel", [&] { + AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, src.scalar_type(), "segment_csr_kernel", [&] { auto src_data = src.data_ptr(); auto out_data = out.data_ptr(); @@ -260,7 +261,7 @@ torch::Tensor gather_csr_cuda(torch::Tensor src, torch::Tensor indptr, auto indptr_info = at::cuda::detail::getTensorInfo(indptr); auto stream = at::cuda::getCurrentCUDAStream(); - AT_DISPATCH_ALL_TYPES(src.scalar_type(), "gather_csr_kernel", [&] { + AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, src.scalar_type(), "gather_csr_kernel", [&] { auto src_data = src.data_ptr(); auto out_data = out.data_ptr(); From 5a493fdcc479446654483a13da7d44f5dff29c33 Mon Sep 17 00:00:00 2001 From: Jacob Zhong Date: Tue, 9 Jun 2020 21:43:49 -0400 Subject: [PATCH 2/9] Fix --- csrc/cpu/scatter_cpu.cpp | 2 +- test/utils.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/csrc/cpu/scatter_cpu.cpp b/csrc/cpu/scatter_cpu.cpp index 59ee1311..50b82076 100644 --- a/csrc/cpu/scatter_cpu.cpp +++ b/csrc/cpu/scatter_cpu.cpp @@ -54,7 +54,7 @@ scatter_cpu(torch::Tensor src, torch::Tensor index, int64_t dim, auto N = out.size(dim); auto index_info = getTensorInfo(index); - AT_DISPATCH_FLOATING_TYPES_AND(at::ScalarType::Half, src.scalar_type(), "scatter", [&] { + AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, src.scalar_type(), "scatter", [&] { auto src_data = src.data_ptr(); auto out_data = out.data_ptr(); diff --git a/test/utils.py b/test/utils.py index 4decd25c..b6b5cc0a 100644 --- a/test/utils.py +++ b/test/utils.py @@ -2,7 +2,7 @@ reductions = ['sum', 'add', 'mean', 'min', 'max'] -dtypes = [torch.float, torch.double, torch.int, torch.long] +dtypes = [torch.half, torch.float, torch.double, torch.int, torch.long] grad_dtypes = [torch.float, torch.double] devices = [torch.device('cpu')] From 8ae11d40714723f32b39105df2c64968976d58b2 Mon Sep 17 00:00:00 2001 From: rusty1s Date: Mon, 19 Jul 2021 10:54:35 +0200 Subject: [PATCH 3/9] fix --- csrc/cpu/segment_coo_cpu.cpp | 7 +- csrc/cuda/segment_coo_cuda.cu | 21 +++-- csrc/cuda/segment_csr_cuda.cu | 63 ++++++++------- csrc/scatter.cpp | 6 +- test/test_segment.py | 144 +++++++++++++++++----------------- torch_scatter/scatter.py | 24 +++--- torch_scatter/segment_coo.py | 16 ++-- 7 files changed, 146 insertions(+), 135 deletions(-) diff --git a/csrc/cpu/segment_coo_cpu.cpp b/csrc/cpu/segment_coo_cpu.cpp index ce0ab79f..a16c56a2 100644 --- a/csrc/cpu/segment_coo_cpu.cpp +++ b/csrc/cpu/segment_coo_cpu.cpp @@ -66,7 +66,7 @@ segment_coo_cpu(torch::Tensor src, torch::Tensor index, auto index_info = getTensorInfo(index); auto stride = index_info.strides[index_info.dims - 1]; std::vector args(K); - AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, src.scalar_type(), "segment_coo", [&] { + AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, src.scalar_type(), "_", [&] { auto src_data = src.data_ptr(); auto out_data = out.data_ptr(); scalar_t *count_data = nullptr; @@ -127,7 +127,8 @@ segment_coo_cpu(torch::Tensor src, torch::Tensor index, out.masked_fill_(out == Reducer::init(), (scalar_t)0); if (REDUCE == MEAN) - arg_out.value().clamp_(1); + arg_out.value().masked_fill_(arg_out.value() < (scalar_t)1, + (scalar_t)1); }); }); @@ -174,7 +175,7 @@ torch::Tensor gather_coo_cpu(torch::Tensor src, torch::Tensor index, auto index_info = getTensorInfo(index); auto stride = index_info.strides[index_info.dims - 1]; - AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, src.scalar_type(), "gather_coo", [&] { + AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, src.scalar_type(), "_", [&] { auto src_data = src.data_ptr(); auto out_data = out.data_ptr(); diff --git a/csrc/cuda/segment_coo_cuda.cu b/csrc/cuda/segment_coo_cuda.cu index f6c6eaf7..e2657d60 100644 --- a/csrc/cuda/segment_coo_cuda.cu +++ b/csrc/cuda/segment_coo_cuda.cu @@ -1,9 +1,9 @@ #include "segment_coo_cuda.h" -#include #include #include #include +#include #include "reducer.cuh" #include "utils.cuh" @@ -35,9 +35,11 @@ segment_coo_kernel(const scalar_t *src_data, scalar_t val = HAS_VAL ? src_data[row_idx] : (scalar_t)1, tmp; #pragma unroll - for (int i = 1; i < 32; i *= 2) { // use left shift? + for (int i = 1; i < 32; i *= 2) { // Parallel reduction inside a single warp. - using float_t = typename std::conditional::value, __half, scalar_t>::type; + using float_t = + typename std::conditional::value, + __half, scalar_t>::type; tmp = __shfl_up_sync(FULL_MASK, (float_t)val, i); next_idx = __shfl_up_sync(FULL_MASK, idx, i); if (lane_idx >= i && row_idx / D == (row_idx - i) / D) { @@ -217,7 +219,7 @@ segment_coo_cuda(torch::Tensor src, torch::Tensor index, auto index_info = at::cuda::detail::getTensorInfo(index); auto stream = at::cuda::getCurrentCUDAStream(); - AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, src.scalar_type(), "segment_coo_kernel", [&] { + AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, src.scalar_type(), "_", [&] { auto src_data = src.data_ptr(); auto out_data = out.data_ptr(); @@ -269,11 +271,16 @@ segment_coo_cuda(torch::Tensor src, torch::Tensor index, segment_coo_kernel <<>>(nullptr, index_info, count_data, E, N); - arg_out.value().clamp_(1); + arg_out.value().masked_fill_(arg_out.value() < (scalar_t)1, + (scalar_t)1); auto count = arg_out.value(); for (int i = dim + 1; i < out.dim(); i++) count = count.unsqueeze(-1); - out.div_(count); + if (torch::is_floating_point(out)) { + out.div_(count); + } else { + out.div_(count, "floor"); + } } }); }); @@ -364,7 +371,7 @@ torch::Tensor gather_coo_cuda(torch::Tensor src, torch::Tensor index, auto index_info = at::cuda::detail::getTensorInfo(index); auto stream = at::cuda::getCurrentCUDAStream(); - AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, src.scalar_type(), "gather_coo_kernel", [&] { + AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, src.scalar_type(), "_", [&] { auto src_data = src.data_ptr(); auto out_data = out.data_ptr(); diff --git a/csrc/cuda/segment_csr_cuda.cu b/csrc/cuda/segment_csr_cuda.cu index c5b4ee71..0c83286d 100644 --- a/csrc/cuda/segment_csr_cuda.cu +++ b/csrc/cuda/segment_csr_cuda.cu @@ -47,7 +47,9 @@ segment_csr_kernel(const scalar_t *src_data, // Parallel reduction inside a single warp. if (REDUCE == MIN || REDUCE == MAX) arg_tmp = __shfl_down_sync(FULL_MASK, arg, i); - using float_t = typename std::conditional::value, __half, scalar_t>::type; + using float_t = + typename std::conditional::value, + __half, scalar_t>::type; Reducer::update( &val, __shfl_down_sync(FULL_MASK, (float_t)val, i), &arg, arg_tmp); } @@ -148,22 +150,23 @@ segment_csr_cuda(torch::Tensor src, torch::Tensor indptr, auto indptr_info = at::cuda::detail::getTensorInfo(indptr); auto stream = at::cuda::getCurrentCUDAStream(); - AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, src.scalar_type(), "segment_csr_kernel", [&] { - auto src_data = src.data_ptr(); - auto out_data = out.data_ptr(); - - AT_DISPATCH_REDUCTION_TYPES(reduce, [&] { - if (K == 1) { - segment_csr_kernel - <<>>( - src_data, indptr_info, out_data, arg_out_data, N, E); - } else { - segment_csr_broadcast_kernel - <<>>( - src_data, indptr_info, out_data, arg_out_data, N, K, E); - } - }); - }); + AT_DISPATCH_ALL_TYPES_AND( + at::ScalarType::Half, src.scalar_type(), "segment_csr_kernel", [&] { + auto src_data = src.data_ptr(); + auto out_data = out.data_ptr(); + + AT_DISPATCH_REDUCTION_TYPES(reduce, [&] { + if (K == 1) { + segment_csr_kernel + <<>>( + src_data, indptr_info, out_data, arg_out_data, N, E); + } else { + segment_csr_broadcast_kernel + <<>>( + src_data, indptr_info, out_data, arg_out_data, N, K, E); + } + }); + }); return std::make_tuple(out, arg_out); } @@ -267,18 +270,20 @@ torch::Tensor gather_csr_cuda(torch::Tensor src, torch::Tensor indptr, auto indptr_info = at::cuda::detail::getTensorInfo(indptr); auto stream = at::cuda::getCurrentCUDAStream(); - AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, src.scalar_type(), "gather_csr_kernel", [&] { - auto src_data = src.data_ptr(); - auto out_data = out.data_ptr(); - - if (K == 1) - gather_csr_kernel<<>>( - src_data, indptr_info, out_data, N, E); - else - gather_csr_broadcast_kernel - <<>>(src_data, indptr_info, - out_data, N, K, E); - }); + AT_DISPATCH_ALL_TYPES_AND( + at::ScalarType::Half, src.scalar_type(), "gather_csr_kernel", [&] { + auto src_data = src.data_ptr(); + auto out_data = out.data_ptr(); + + if (K == 1) + gather_csr_kernel + <<>>(src_data, indptr_info, + out_data, N, E); + else + gather_csr_broadcast_kernel + <<>>(src_data, indptr_info, + out_data, N, K, E); + }); return out; } diff --git a/csrc/scatter.cpp b/csrc/scatter.cpp index 25576c4c..f570e4d3 100644 --- a/csrc/scatter.cpp +++ b/csrc/scatter.cpp @@ -91,13 +91,13 @@ class ScatterMean : public torch::autograd::Function { old_index.dim() <= dim ? old_index.dim() - 1 : dim, torch::nullopt, out.size(dim), "sum"); auto count = std::get<0>(result); - count.clamp_(1); + count.clamp_(1); // TODO count = broadcast(count, out, dim); if (out.is_floating_point()) - out.true_divide_(count); + out.true_divide_(count); else - out.floor_divide_(count); + out.floor_divide_(count); ctx->save_for_backward({index, count}); if (optional_out.has_value()) diff --git a/test/test_segment.py b/test/test_segment.py index a5c28785..cc541fd5 100644 --- a/test/test_segment.py +++ b/test/test_segment.py @@ -82,6 +82,8 @@ }, ] +devices = ['cuda'] + @pytest.mark.parametrize('test,reduce,dtype,device', product(tests, reductions, dtypes, devices)) @@ -106,75 +108,73 @@ def test_forward(test, reduce, dtype, device): assert torch.all(out == expected) -@pytest.mark.parametrize('test,reduce,device', - product(tests, reductions, devices)) -def test_backward(test, reduce, device): - src = tensor(test['src'], torch.double, device) - src.requires_grad_() - index = tensor(test['index'], torch.long, device) - indptr = tensor(test['indptr'], torch.long, device) - - assert gradcheck(torch_scatter.segment_csr, (src, indptr, None, reduce)) - assert gradcheck(torch_scatter.segment_coo, - (src, index, None, None, reduce)) - - -@pytest.mark.parametrize('test,reduce,dtype,device', - product(tests, reductions, dtypes, devices)) -def test_out(test, reduce, dtype, device): - src = tensor(test['src'], dtype, device) - index = tensor(test['index'], torch.long, device) - indptr = tensor(test['indptr'], torch.long, device) - expected = tensor(test[reduce], dtype, device) - - out = torch.full_like(expected, -2) - - getattr(torch_scatter, 'segment_' + reduce + '_csr')(src, indptr, out) - assert torch.all(out == expected) - - out.fill_(-2) - - getattr(torch_scatter, 'segment_' + reduce + '_coo')(src, index, out) - - if reduce == 'sum' or reduce == 'add': - expected = expected - 2 - elif reduce == 'mean': - expected = out # We can not really test this here. - elif reduce == 'min': - expected = expected.fill_(-2) - elif reduce == 'max': - expected[expected == 0] = -2 - else: - raise ValueError - - assert torch.all(out == expected) - - -@pytest.mark.parametrize('test,reduce,dtype,device', - product(tests, reductions, dtypes, devices)) -def test_non_contiguous(test, reduce, dtype, device): - src = tensor(test['src'], dtype, device) - index = tensor(test['index'], torch.long, device) - indptr = tensor(test['indptr'], torch.long, device) - expected = tensor(test[reduce], dtype, device) - - if src.dim() > 1: - src = src.transpose(0, 1).contiguous().transpose(0, 1) - if index.dim() > 1: - index = index.transpose(0, 1).contiguous().transpose(0, 1) - if indptr.dim() > 1: - indptr = indptr.transpose(0, 1).contiguous().transpose(0, 1) - - out = getattr(torch_scatter, 'segment_' + reduce + '_csr')(src, indptr) - if isinstance(out, tuple): - out, arg_out = out - arg_expected = tensor(test['arg_' + reduce], torch.long, device) - assert torch.all(arg_out == arg_expected) - assert torch.all(out == expected) - - out = getattr(torch_scatter, 'segment_' + reduce + '_coo')(src, index) - if isinstance(out, tuple): - out, arg_out = out - arg_expected = tensor(test['arg_' + reduce], torch.long, device) - assert torch.all(arg_out == arg_expected) - assert torch.all(out == expected) +# @pytest.mark.parametrize('test,reduce,device', +# product(tests, reductions, devices)) +# def test_backward(test, reduce, device): +# src = tensor(test['src'], torch.double, device) +# src.requires_grad_() +# index = tensor(test['index'], torch.long, device) +# indptr = tensor(test['indptr'], torch.long, device) + +# assert gradcheck(torch_scatter.segment_csr, (src, indptr, None, reduce)) +# assert gradcheck(torch_scatter.segment_coo, +# (src, index, None, None, reduce)) + +# @pytest.mark.parametrize('test,reduce,dtype,device', +# product(tests, reductions, dtypes, devices)) +# def test_out(test, reduce, dtype, device): +# src = tensor(test['src'], dtype, device) +# index = tensor(test['index'], torch.long, device) +# indptr = tensor(test['indptr'], torch.long, device) +# expected = tensor(test[reduce], dtype, device) + +# out = torch.full_like(expected, -2) + +# getattr(torch_scatter, 'segment_' + reduce + '_csr')(src, indptr, out) +# assert torch.all(out == expected) + +# out.fill_(-2) + +# getattr(torch_scatter, 'segment_' + reduce + '_coo')(src, index, out) + +# if reduce == 'sum' or reduce == 'add': +# expected = expected - 2 +# elif reduce == 'mean': +# expected = out # We can not really test this here. +# elif reduce == 'min': +# expected = expected.fill_(-2) +# elif reduce == 'max': +# expected[expected == 0] = -2 +# else: +# raise ValueError + +# assert torch.all(out == expected) + +# @pytest.mark.parametrize('test,reduce,dtype,device', +# product(tests, reductions, dtypes, devices)) +# def test_non_contiguous(test, reduce, dtype, device): +# src = tensor(test['src'], dtype, device) +# index = tensor(test['index'], torch.long, device) +# indptr = tensor(test['indptr'], torch.long, device) +# expected = tensor(test[reduce], dtype, device) + +# if src.dim() > 1: +# src = src.transpose(0, 1).contiguous().transpose(0, 1) +# if index.dim() > 1: +# index = index.transpose(0, 1).contiguous().transpose(0, 1) +# if indptr.dim() > 1: +# indptr = indptr.transpose(0, 1).contiguous().transpose(0, 1) + +# out = getattr(torch_scatter, 'segment_' + reduce + '_csr')(src, indptr) +# if isinstance(out, tuple): +# out, arg_out = out +# arg_expected = tensor(test['arg_' + reduce], torch.long, device) +# assert torch.all(arg_out == arg_expected) +# assert torch.all(out == expected) + +# out = getattr(torch_scatter, 'segment_' + reduce + '_coo')(src, index) +# if isinstance(out, tuple): +# out, arg_out = out +# arg_expected = tensor(test['arg_' + reduce], torch.long, device) +# assert torch.all(arg_out == arg_expected) +# assert torch.all(out == expected) diff --git a/torch_scatter/scatter.py b/torch_scatter/scatter.py index 2aa38f5a..54ed0a80 100644 --- a/torch_scatter/scatter.py +++ b/torch_scatter/scatter.py @@ -47,28 +47,26 @@ def scatter_mean(src: torch.Tensor, index: torch.Tensor, dim: int = -1, ones = torch.ones(index.size(), dtype=src.dtype, device=src.device) count = scatter_sum(ones, index, index_dim, None, dim_size) - count.clamp_(1) + count[count < 1] = 1 count = broadcast(count, out, dim) - if torch.is_floating_point(out): - out.true_divide_(count) - else: - out.floor_divide_(count) + rounding_mode = None if torch.is_floating_point(out) else 'floor' + out.div_(count, rounding_mode=rounding_mode) return out @torch.jit.script -def scatter_min(src: torch.Tensor, index: torch.Tensor, dim: int = -1, - out: Optional[torch.Tensor] = None, - dim_size: Optional[int] = None - ) -> Tuple[torch.Tensor, torch.Tensor]: +def scatter_min( + src: torch.Tensor, index: torch.Tensor, dim: int = -1, + out: Optional[torch.Tensor] = None, + dim_size: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]: return torch.ops.torch_scatter.scatter_min(src, index, dim, out, dim_size) @torch.jit.script -def scatter_max(src: torch.Tensor, index: torch.Tensor, dim: int = -1, - out: Optional[torch.Tensor] = None, - dim_size: Optional[int] = None - ) -> Tuple[torch.Tensor, torch.Tensor]: +def scatter_max( + src: torch.Tensor, index: torch.Tensor, dim: int = -1, + out: Optional[torch.Tensor] = None, + dim_size: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]: return torch.ops.torch_scatter.scatter_max(src, index, dim, out, dim_size) diff --git a/torch_scatter/segment_coo.py b/torch_scatter/segment_coo.py index ff0b7a3d..533518e1 100644 --- a/torch_scatter/segment_coo.py +++ b/torch_scatter/segment_coo.py @@ -25,18 +25,18 @@ def segment_mean_coo(src: torch.Tensor, index: torch.Tensor, @torch.jit.script -def segment_min_coo(src: torch.Tensor, index: torch.Tensor, - out: Optional[torch.Tensor] = None, - dim_size: Optional[int] = None - ) -> Tuple[torch.Tensor, torch.Tensor]: +def segment_min_coo( + src: torch.Tensor, index: torch.Tensor, + out: Optional[torch.Tensor] = None, + dim_size: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]: return torch.ops.torch_scatter.segment_min_coo(src, index, out, dim_size) @torch.jit.script -def segment_max_coo(src: torch.Tensor, index: torch.Tensor, - out: Optional[torch.Tensor] = None, - dim_size: Optional[int] = None - ) -> Tuple[torch.Tensor, torch.Tensor]: +def segment_max_coo( + src: torch.Tensor, index: torch.Tensor, + out: Optional[torch.Tensor] = None, + dim_size: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]: return torch.ops.torch_scatter.segment_max_coo(src, index, out, dim_size) From 3a064e2ad8fb15eb9d458004241f91026756e8a4 Mon Sep 17 00:00:00 2001 From: rusty1s Date: Mon, 19 Jul 2021 10:56:16 +0200 Subject: [PATCH 4/9] rename --- csrc/cpu/scatter_cpu.cpp | 2 +- csrc/cpu/segment_csr_cpu.cpp | 4 ++-- csrc/cuda/scatter_cuda.cu | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/csrc/cpu/scatter_cpu.cpp b/csrc/cpu/scatter_cpu.cpp index a183572b..77a43969 100644 --- a/csrc/cpu/scatter_cpu.cpp +++ b/csrc/cpu/scatter_cpu.cpp @@ -57,7 +57,7 @@ scatter_cpu(torch::Tensor src, torch::Tensor index, int64_t dim, auto N = out.size(dim); auto index_info = getTensorInfo(index); - AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, src.scalar_type(), "scatter", [&] { + AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, src.scalar_type(), "_", [&] { auto src_data = src.data_ptr(); auto out_data = out.data_ptr(); diff --git a/csrc/cpu/segment_csr_cpu.cpp b/csrc/cpu/segment_csr_cpu.cpp index 9af599cf..a826192c 100644 --- a/csrc/cpu/segment_csr_cpu.cpp +++ b/csrc/cpu/segment_csr_cpu.cpp @@ -57,7 +57,7 @@ segment_csr_cpu(torch::Tensor src, torch::Tensor indptr, auto indptr_info = getTensorInfo(indptr); auto stride = indptr_info.strides[indptr_info.dims - 1]; std::vector args(K); - AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, src.scalar_type(), "segment_csr", [&] { + AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, src.scalar_type(), "_", [&] { auto src_data = src.data_ptr(); auto out_data = out.data_ptr(); @@ -135,7 +135,7 @@ torch::Tensor gather_csr_cpu(torch::Tensor src, torch::Tensor indptr, auto indptr_info = getTensorInfo(indptr); auto stride = indptr_info.strides[indptr_info.dims - 1]; - AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, src.scalar_type(), "gather_csr", [&] { + AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, src.scalar_type(), "_", [&] { auto src_data = src.data_ptr(); auto out_data = out.data_ptr(); diff --git a/csrc/cuda/scatter_cuda.cu b/csrc/cuda/scatter_cuda.cu index 6f0cd247..a594b615 100644 --- a/csrc/cuda/scatter_cuda.cu +++ b/csrc/cuda/scatter_cuda.cu @@ -114,7 +114,7 @@ scatter_cuda(torch::Tensor src, torch::Tensor index, int64_t dim, auto index_info = at::cuda::detail::getTensorInfo(index); auto stream = at::cuda::getCurrentCUDAStream(); - AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, src.scalar_type(), "scatter", [&] { + AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, src.scalar_type(), "_", [&] { auto src_data = src.data_ptr(); auto out_data = out.data_ptr(); From dacd4a1e4423e2e31eb44d1e4d8d28f2a5920dbe Mon Sep 17 00:00:00 2001 From: rusty1s Date: Mon, 19 Jul 2021 11:01:57 +0200 Subject: [PATCH 5/9] update --- csrc/cuda/segment_csr_cuda.cu | 59 +++++++------- csrc/scatter.cpp | 2 +- test/test_segment.py | 144 +++++++++++++++++----------------- 3 files changed, 101 insertions(+), 104 deletions(-) diff --git a/csrc/cuda/segment_csr_cuda.cu b/csrc/cuda/segment_csr_cuda.cu index 0c83286d..825a9b84 100644 --- a/csrc/cuda/segment_csr_cuda.cu +++ b/csrc/cuda/segment_csr_cuda.cu @@ -150,23 +150,22 @@ segment_csr_cuda(torch::Tensor src, torch::Tensor indptr, auto indptr_info = at::cuda::detail::getTensorInfo(indptr); auto stream = at::cuda::getCurrentCUDAStream(); - AT_DISPATCH_ALL_TYPES_AND( - at::ScalarType::Half, src.scalar_type(), "segment_csr_kernel", [&] { - auto src_data = src.data_ptr(); - auto out_data = out.data_ptr(); - - AT_DISPATCH_REDUCTION_TYPES(reduce, [&] { - if (K == 1) { - segment_csr_kernel - <<>>( - src_data, indptr_info, out_data, arg_out_data, N, E); - } else { - segment_csr_broadcast_kernel - <<>>( - src_data, indptr_info, out_data, arg_out_data, N, K, E); - } - }); - }); + AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, src.scalar_type(), "_", [&] { + auto src_data = src.data_ptr(); + auto out_data = out.data_ptr(); + + AT_DISPATCH_REDUCTION_TYPES(reduce, [&] { + if (K == 1) { + segment_csr_kernel + <<>>( + src_data, indptr_info, out_data, arg_out_data, N, E); + } else { + segment_csr_broadcast_kernel + <<>>( + src_data, indptr_info, out_data, arg_out_data, N, K, E); + } + }); + }); return std::make_tuple(out, arg_out); } @@ -270,20 +269,18 @@ torch::Tensor gather_csr_cuda(torch::Tensor src, torch::Tensor indptr, auto indptr_info = at::cuda::detail::getTensorInfo(indptr); auto stream = at::cuda::getCurrentCUDAStream(); - AT_DISPATCH_ALL_TYPES_AND( - at::ScalarType::Half, src.scalar_type(), "gather_csr_kernel", [&] { - auto src_data = src.data_ptr(); - auto out_data = out.data_ptr(); - - if (K == 1) - gather_csr_kernel - <<>>(src_data, indptr_info, - out_data, N, E); - else - gather_csr_broadcast_kernel - <<>>(src_data, indptr_info, - out_data, N, K, E); - }); + AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, src.scalar_type(), "_", [&] { + auto src_data = src.data_ptr(); + auto out_data = out.data_ptr(); + + if (K == 1) + gather_csr_kernel<<>>( + src_data, indptr_info, out_data, N, E); + else + gather_csr_broadcast_kernel + <<>>(src_data, indptr_info, + out_data, N, K, E); + }); return out; } diff --git a/csrc/scatter.cpp b/csrc/scatter.cpp index f570e4d3..30346da2 100644 --- a/csrc/scatter.cpp +++ b/csrc/scatter.cpp @@ -91,7 +91,7 @@ class ScatterMean : public torch::autograd::Function { old_index.dim() <= dim ? old_index.dim() - 1 : dim, torch::nullopt, out.size(dim), "sum"); auto count = std::get<0>(result); - count.clamp_(1); // TODO + count.masked_fill_(count.value() < 1, 1); count = broadcast(count, out, dim); if (out.is_floating_point()) diff --git a/test/test_segment.py b/test/test_segment.py index cc541fd5..a5c28785 100644 --- a/test/test_segment.py +++ b/test/test_segment.py @@ -82,8 +82,6 @@ }, ] -devices = ['cuda'] - @pytest.mark.parametrize('test,reduce,dtype,device', product(tests, reductions, dtypes, devices)) @@ -108,73 +106,75 @@ def test_forward(test, reduce, dtype, device): assert torch.all(out == expected) -# @pytest.mark.parametrize('test,reduce,device', -# product(tests, reductions, devices)) -# def test_backward(test, reduce, device): -# src = tensor(test['src'], torch.double, device) -# src.requires_grad_() -# index = tensor(test['index'], torch.long, device) -# indptr = tensor(test['indptr'], torch.long, device) - -# assert gradcheck(torch_scatter.segment_csr, (src, indptr, None, reduce)) -# assert gradcheck(torch_scatter.segment_coo, -# (src, index, None, None, reduce)) - -# @pytest.mark.parametrize('test,reduce,dtype,device', -# product(tests, reductions, dtypes, devices)) -# def test_out(test, reduce, dtype, device): -# src = tensor(test['src'], dtype, device) -# index = tensor(test['index'], torch.long, device) -# indptr = tensor(test['indptr'], torch.long, device) -# expected = tensor(test[reduce], dtype, device) - -# out = torch.full_like(expected, -2) - -# getattr(torch_scatter, 'segment_' + reduce + '_csr')(src, indptr, out) -# assert torch.all(out == expected) - -# out.fill_(-2) - -# getattr(torch_scatter, 'segment_' + reduce + '_coo')(src, index, out) - -# if reduce == 'sum' or reduce == 'add': -# expected = expected - 2 -# elif reduce == 'mean': -# expected = out # We can not really test this here. -# elif reduce == 'min': -# expected = expected.fill_(-2) -# elif reduce == 'max': -# expected[expected == 0] = -2 -# else: -# raise ValueError - -# assert torch.all(out == expected) - -# @pytest.mark.parametrize('test,reduce,dtype,device', -# product(tests, reductions, dtypes, devices)) -# def test_non_contiguous(test, reduce, dtype, device): -# src = tensor(test['src'], dtype, device) -# index = tensor(test['index'], torch.long, device) -# indptr = tensor(test['indptr'], torch.long, device) -# expected = tensor(test[reduce], dtype, device) - -# if src.dim() > 1: -# src = src.transpose(0, 1).contiguous().transpose(0, 1) -# if index.dim() > 1: -# index = index.transpose(0, 1).contiguous().transpose(0, 1) -# if indptr.dim() > 1: -# indptr = indptr.transpose(0, 1).contiguous().transpose(0, 1) - -# out = getattr(torch_scatter, 'segment_' + reduce + '_csr')(src, indptr) -# if isinstance(out, tuple): -# out, arg_out = out -# arg_expected = tensor(test['arg_' + reduce], torch.long, device) -# assert torch.all(arg_out == arg_expected) -# assert torch.all(out == expected) - -# out = getattr(torch_scatter, 'segment_' + reduce + '_coo')(src, index) -# if isinstance(out, tuple): -# out, arg_out = out -# arg_expected = tensor(test['arg_' + reduce], torch.long, device) -# assert torch.all(arg_out == arg_expected) -# assert torch.all(out == expected) +@pytest.mark.parametrize('test,reduce,device', + product(tests, reductions, devices)) +def test_backward(test, reduce, device): + src = tensor(test['src'], torch.double, device) + src.requires_grad_() + index = tensor(test['index'], torch.long, device) + indptr = tensor(test['indptr'], torch.long, device) + + assert gradcheck(torch_scatter.segment_csr, (src, indptr, None, reduce)) + assert gradcheck(torch_scatter.segment_coo, + (src, index, None, None, reduce)) + + +@pytest.mark.parametrize('test,reduce,dtype,device', + product(tests, reductions, dtypes, devices)) +def test_out(test, reduce, dtype, device): + src = tensor(test['src'], dtype, device) + index = tensor(test['index'], torch.long, device) + indptr = tensor(test['indptr'], torch.long, device) + expected = tensor(test[reduce], dtype, device) + + out = torch.full_like(expected, -2) + + getattr(torch_scatter, 'segment_' + reduce + '_csr')(src, indptr, out) + assert torch.all(out == expected) + + out.fill_(-2) + + getattr(torch_scatter, 'segment_' + reduce + '_coo')(src, index, out) + + if reduce == 'sum' or reduce == 'add': + expected = expected - 2 + elif reduce == 'mean': + expected = out # We can not really test this here. + elif reduce == 'min': + expected = expected.fill_(-2) + elif reduce == 'max': + expected[expected == 0] = -2 + else: + raise ValueError + + assert torch.all(out == expected) + + +@pytest.mark.parametrize('test,reduce,dtype,device', + product(tests, reductions, dtypes, devices)) +def test_non_contiguous(test, reduce, dtype, device): + src = tensor(test['src'], dtype, device) + index = tensor(test['index'], torch.long, device) + indptr = tensor(test['indptr'], torch.long, device) + expected = tensor(test[reduce], dtype, device) + + if src.dim() > 1: + src = src.transpose(0, 1).contiguous().transpose(0, 1) + if index.dim() > 1: + index = index.transpose(0, 1).contiguous().transpose(0, 1) + if indptr.dim() > 1: + indptr = indptr.transpose(0, 1).contiguous().transpose(0, 1) + + out = getattr(torch_scatter, 'segment_' + reduce + '_csr')(src, indptr) + if isinstance(out, tuple): + out, arg_out = out + arg_expected = tensor(test['arg_' + reduce], torch.long, device) + assert torch.all(arg_out == arg_expected) + assert torch.all(out == expected) + + out = getattr(torch_scatter, 'segment_' + reduce + '_coo')(src, index) + if isinstance(out, tuple): + out, arg_out = out + arg_expected = tensor(test['arg_' + reduce], torch.long, device) + assert torch.all(arg_out == arg_expected) + assert torch.all(out == expected) From 0cd6dfe843faea066acbaa70f715f8ccfeb331b3 Mon Sep 17 00:00:00 2001 From: rusty1s Date: Mon, 19 Jul 2021 11:05:09 +0200 Subject: [PATCH 6/9] update --- csrc/cuda/segment_coo_cuda.cu | 9 ++++----- torch_scatter/scatter.py | 2 -- torch_scatter/segment_coo.py | 2 -- 3 files changed, 4 insertions(+), 9 deletions(-) diff --git a/csrc/cuda/segment_coo_cuda.cu b/csrc/cuda/segment_coo_cuda.cu index e2657d60..71d6d636 100644 --- a/csrc/cuda/segment_coo_cuda.cu +++ b/csrc/cuda/segment_coo_cuda.cu @@ -276,11 +276,10 @@ segment_coo_cuda(torch::Tensor src, torch::Tensor index, auto count = arg_out.value(); for (int i = dim + 1; i < out.dim(); i++) count = count.unsqueeze(-1); - if (torch::is_floating_point(out)) { - out.div_(count); - } else { - out.div_(count, "floor"); - } + if (out.is_floating_point()) + out.true_divide_(count); + else + out.floor_divide_(count); } }); }); diff --git a/torch_scatter/scatter.py b/torch_scatter/scatter.py index 54ed0a80..c23b19fb 100644 --- a/torch_scatter/scatter.py +++ b/torch_scatter/scatter.py @@ -54,7 +54,6 @@ def scatter_mean(src: torch.Tensor, index: torch.Tensor, dim: int = -1, return out -@torch.jit.script def scatter_min( src: torch.Tensor, index: torch.Tensor, dim: int = -1, out: Optional[torch.Tensor] = None, @@ -62,7 +61,6 @@ def scatter_min( return torch.ops.torch_scatter.scatter_min(src, index, dim, out, dim_size) -@torch.jit.script def scatter_max( src: torch.Tensor, index: torch.Tensor, dim: int = -1, out: Optional[torch.Tensor] = None, diff --git a/torch_scatter/segment_coo.py b/torch_scatter/segment_coo.py index 533518e1..af049eb6 100644 --- a/torch_scatter/segment_coo.py +++ b/torch_scatter/segment_coo.py @@ -24,7 +24,6 @@ def segment_mean_coo(src: torch.Tensor, index: torch.Tensor, return torch.ops.torch_scatter.segment_mean_coo(src, index, out, dim_size) -@torch.jit.script def segment_min_coo( src: torch.Tensor, index: torch.Tensor, out: Optional[torch.Tensor] = None, @@ -32,7 +31,6 @@ def segment_min_coo( return torch.ops.torch_scatter.segment_min_coo(src, index, out, dim_size) -@torch.jit.script def segment_max_coo( src: torch.Tensor, index: torch.Tensor, out: Optional[torch.Tensor] = None, From b6a3d24fecd975906ec392677f1a75bc9e989fe6 Mon Sep 17 00:00:00 2001 From: rusty1s Date: Mon, 19 Jul 2021 11:17:50 +0200 Subject: [PATCH 7/9] update --- csrc/cuda/segment_coo_cuda.cu | 4 ++-- csrc/scatter.cpp | 7 +++---- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/csrc/cuda/segment_coo_cuda.cu b/csrc/cuda/segment_coo_cuda.cu index 71d6d636..9254a8f8 100644 --- a/csrc/cuda/segment_coo_cuda.cu +++ b/csrc/cuda/segment_coo_cuda.cu @@ -277,9 +277,9 @@ segment_coo_cuda(torch::Tensor src, torch::Tensor index, for (int i = dim + 1; i < out.dim(); i++) count = count.unsqueeze(-1); if (out.is_floating_point()) - out.true_divide_(count); + out.div_(count); else - out.floor_divide_(count); + out.div_(count, "floor"); } }); }); diff --git a/csrc/scatter.cpp b/csrc/scatter.cpp index 30346da2..c104f365 100644 --- a/csrc/scatter.cpp +++ b/csrc/scatter.cpp @@ -91,13 +91,12 @@ class ScatterMean : public torch::autograd::Function { old_index.dim() <= dim ? old_index.dim() - 1 : dim, torch::nullopt, out.size(dim), "sum"); auto count = std::get<0>(result); - count.masked_fill_(count.value() < 1, 1); + count.masked_fill_(count < 1, 1); count = broadcast(count, out, dim); - if (out.is_floating_point()) - out.true_divide_(count); + out.div_(count); else - out.floor_divide_(count); + out.div_(count, "floor"); ctx->save_for_backward({index, count}); if (optional_out.has_value()) From 5ce84f4c982770879d2bd84f725dcaa71a08df3f Mon Sep 17 00:00:00 2001 From: rusty1s Date: Mon, 19 Jul 2021 12:18:01 +0200 Subject: [PATCH 8/9] update --- csrc/cuda/atomics.cuh | 10 +++++----- csrc/cuda/segment_coo_cuda.cu | 12 ++++++++---- csrc/cuda/segment_csr_cuda.cu | 10 ++++++---- 3 files changed, 19 insertions(+), 13 deletions(-) diff --git a/csrc/cuda/atomics.cuh b/csrc/cuda/atomics.cuh index f5913388..8a7c4724 100644 --- a/csrc/cuda/atomics.cuh +++ b/csrc/cuda/atomics.cuh @@ -70,8 +70,8 @@ \ template struct Atomic##NAME##DecimalImpl { \ inline __device__ void operator()(scalar *address, scalar val) { \ - unsigned int * address_as_ui = \ - (unsigned int *) ((char *)address - ((size_t)address & 2)); \ + unsigned int *address_as_ui = \ + (unsigned int *)((char *)address - ((size_t)address & 2)); \ unsigned int old = *address_as_ui; \ unsigned int assumed; \ \ @@ -80,8 +80,8 @@ at::Half hsum; \ hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff); \ hsum = OP(hsum, val); \ - old = (size_t)address & 2 ? (old & 0xffff) | \ - (hsum.x << 16) : (old & 0xffff0000) | hsum.x; \ + old = (size_t)address & 2 ? (old & 0xffff) | (hsum.x << 16) \ + : (old & 0xffff0000) | hsum.x; \ old = atomicCAS(address_as_ui, assumed, old); \ } while (assumed != old); \ } \ @@ -141,7 +141,7 @@ static inline __device__ void atomAdd(at::Half *address, at::Half val) { } #else static inline __device__ void atomAdd(at::Half *address, at::Half val) { - atomicAdd(reinterpret_cast<__half*>(address), val); + atomicAdd(reinterpret_cast<__half *>(address), val); } #endif static inline __device__ void atomAdd(float *address, float val) { diff --git a/csrc/cuda/segment_coo_cuda.cu b/csrc/cuda/segment_coo_cuda.cu index 9254a8f8..00f3e299 100644 --- a/csrc/cuda/segment_coo_cuda.cu +++ b/csrc/cuda/segment_coo_cuda.cu @@ -26,6 +26,10 @@ segment_coo_kernel(const scalar_t *src_data, int lane_idx = row_idx & (32 - 1); int D = index_info.sizes[index_info.dims - 1]; + using cuda_scalar_t = + typename std::conditional::value, __half, + scalar_t>::type; + if (row_idx < E) { int offset = at::cuda::detail::IndexToOffset::get( row_idx, index_info); @@ -37,10 +41,7 @@ segment_coo_kernel(const scalar_t *src_data, #pragma unroll for (int i = 1; i < 32; i *= 2) { // Parallel reduction inside a single warp. - using float_t = - typename std::conditional::value, - __half, scalar_t>::type; - tmp = __shfl_up_sync(FULL_MASK, (float_t)val, i); + tmp = __shfl_up_sync(FULL_MASK, (cuda_scalar_t)val, i); next_idx = __shfl_up_sync(FULL_MASK, idx, i); if (lane_idx >= i && row_idx / D == (row_idx - i) / D) { assert(idx >= next_idx); @@ -220,6 +221,9 @@ segment_coo_cuda(torch::Tensor src, torch::Tensor index, auto index_info = at::cuda::detail::getTensorInfo(index); auto stream = at::cuda::getCurrentCUDAStream(); AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, src.scalar_type(), "_", [&] { + // if (std::is_same::value) + // scalar_t = typename __half; + auto src_data = src.data_ptr(); auto out_data = out.data_ptr(); diff --git a/csrc/cuda/segment_csr_cuda.cu b/csrc/cuda/segment_csr_cuda.cu index 825a9b84..15cbdd96 100644 --- a/csrc/cuda/segment_csr_cuda.cu +++ b/csrc/cuda/segment_csr_cuda.cu @@ -26,6 +26,10 @@ segment_csr_kernel(const scalar_t *src_data, int row_idx = thread_idx / TB; int lane_idx = thread_idx & (TB - 1); + using cuda_scalar_t = + typename std::conditional::value, __half, + scalar_t>::type; + if (row_idx < N) { int offset = IndexPtrToOffset::get(row_idx, indptr_info); int64_t row_start = __ldg(indptr_info.data + offset); @@ -47,11 +51,9 @@ segment_csr_kernel(const scalar_t *src_data, // Parallel reduction inside a single warp. if (REDUCE == MIN || REDUCE == MAX) arg_tmp = __shfl_down_sync(FULL_MASK, arg, i); - using float_t = - typename std::conditional::value, - __half, scalar_t>::type; Reducer::update( - &val, __shfl_down_sync(FULL_MASK, (float_t)val, i), &arg, arg_tmp); + &val, __shfl_down_sync(FULL_MASK, (cuda_scalar_t)val, i), &arg, + arg_tmp); } if (lane_idx == 0) { From 81afe36df6edcd80ca499eb5d1acdf0101df4607 Mon Sep 17 00:00:00 2001 From: rusty1s Date: Mon, 19 Jul 2021 12:19:00 +0200 Subject: [PATCH 9/9] typo --- csrc/cuda/segment_coo_cuda.cu | 3 --- 1 file changed, 3 deletions(-) diff --git a/csrc/cuda/segment_coo_cuda.cu b/csrc/cuda/segment_coo_cuda.cu index 00f3e299..56aab4b3 100644 --- a/csrc/cuda/segment_coo_cuda.cu +++ b/csrc/cuda/segment_coo_cuda.cu @@ -221,9 +221,6 @@ segment_coo_cuda(torch::Tensor src, torch::Tensor index, auto index_info = at::cuda::detail::getTensorInfo(index); auto stream = at::cuda::getCurrentCUDAStream(); AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, src.scalar_type(), "_", [&] { - // if (std::is_same::value) - // scalar_t = typename __half; - auto src_data = src.data_ptr(); auto out_data = out.data_ptr();