diff --git a/csrc/cpu/reducer.h b/csrc/cpu/reducer.h index a07033aa..b2267cdb 100644 --- a/csrc/cpu/reducer.h +++ b/csrc/cpu/reducer.h @@ -72,7 +72,7 @@ template struct Reducer { if (REDUCE == SUM || REDUCE == MUL || REDUCE == DIV) *address = val; else if (REDUCE == MEAN) - *address = val / (count > 0 ? count : (scalar_t)1); + *address = val / (scalar_t)(count > 0 ? count : 1); else if (REDUCE == MIN || REDUCE == MAX) { if (count > 0) { *address = val; diff --git a/csrc/cpu/scatter_cpu.cpp b/csrc/cpu/scatter_cpu.cpp index 21e5f327..77a43969 100644 --- a/csrc/cpu/scatter_cpu.cpp +++ b/csrc/cpu/scatter_cpu.cpp @@ -57,7 +57,7 @@ scatter_cpu(torch::Tensor src, torch::Tensor index, int64_t dim, auto N = out.size(dim); auto index_info = getTensorInfo(index); - AT_DISPATCH_ALL_TYPES(src.scalar_type(), "scatter", [&] { + AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, src.scalar_type(), "_", [&] { auto src_data = src.data_ptr(); auto out_data = out.data_ptr(); diff --git a/csrc/cpu/segment_coo_cpu.cpp b/csrc/cpu/segment_coo_cpu.cpp index bc23d3b9..8497301f 100644 --- a/csrc/cpu/segment_coo_cpu.cpp +++ b/csrc/cpu/segment_coo_cpu.cpp @@ -69,7 +69,7 @@ segment_coo_cpu(torch::Tensor src, torch::Tensor index, auto index_info = getTensorInfo(index); auto stride = index_info.strides[index_info.dims - 1]; std::vector args(K); - AT_DISPATCH_ALL_TYPES(src.scalar_type(), "segment_coo", [&] { + AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, src.scalar_type(), "_", [&] { auto src_data = src.data_ptr(); auto out_data = out.data_ptr(); scalar_t *count_data = nullptr; @@ -130,7 +130,8 @@ segment_coo_cpu(torch::Tensor src, torch::Tensor index, out.masked_fill_(out == Reducer::init(), (scalar_t)0); if (REDUCE == MEAN) - arg_out.value().clamp_(1); + arg_out.value().masked_fill_(arg_out.value() < (scalar_t)1, + (scalar_t)1); }); }); @@ -177,7 +178,7 @@ torch::Tensor gather_coo_cpu(torch::Tensor src, torch::Tensor index, auto index_info = getTensorInfo(index); auto stride = index_info.strides[index_info.dims - 1]; - AT_DISPATCH_ALL_TYPES(src.scalar_type(), "gather_coo", [&] { + AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, src.scalar_type(), "_", [&] { auto src_data = src.data_ptr(); auto out_data = out.data_ptr(); diff --git a/csrc/cpu/segment_csr_cpu.cpp b/csrc/cpu/segment_csr_cpu.cpp index ad258cec..a826192c 100644 --- a/csrc/cpu/segment_csr_cpu.cpp +++ b/csrc/cpu/segment_csr_cpu.cpp @@ -57,7 +57,7 @@ segment_csr_cpu(torch::Tensor src, torch::Tensor indptr, auto indptr_info = getTensorInfo(indptr); auto stride = indptr_info.strides[indptr_info.dims - 1]; std::vector args(K); - AT_DISPATCH_ALL_TYPES(src.scalar_type(), "segment_csr", [&] { + AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, src.scalar_type(), "_", [&] { auto src_data = src.data_ptr(); auto out_data = out.data_ptr(); @@ -135,7 +135,7 @@ torch::Tensor gather_csr_cpu(torch::Tensor src, torch::Tensor indptr, auto indptr_info = getTensorInfo(indptr); auto stride = indptr_info.strides[indptr_info.dims - 1]; - AT_DISPATCH_ALL_TYPES(src.scalar_type(), "gather_csr", [&] { + AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, src.scalar_type(), "_", [&] { auto src_data = src.data_ptr(); auto out_data = out.data_ptr(); diff --git a/csrc/cuda/atomics.cuh b/csrc/cuda/atomics.cuh index 32427eac..8a7c4724 100644 --- a/csrc/cuda/atomics.cuh +++ b/csrc/cuda/atomics.cuh @@ -68,6 +68,25 @@ \ template struct Atomic##NAME##DecimalImpl; \ \ + template struct Atomic##NAME##DecimalImpl { \ + inline __device__ void operator()(scalar *address, scalar val) { \ + unsigned int *address_as_ui = \ + (unsigned int *)((char *)address - ((size_t)address & 2)); \ + unsigned int old = *address_as_ui; \ + unsigned int assumed; \ + \ + do { \ + assumed = old; \ + at::Half hsum; \ + hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff); \ + hsum = OP(hsum, val); \ + old = (size_t)address & 2 ? (old & 0xffff) | (hsum.x << 16) \ + : (old & 0xffff0000) | hsum.x; \ + old = atomicCAS(address_as_ui, assumed, old); \ + } while (assumed != old); \ + } \ + }; \ + \ template struct Atomic##NAME##DecimalImpl { \ inline __device__ void operator()(scalar *address, scalar val) { \ int *address_as_i = (int *)address; \ @@ -116,6 +135,15 @@ static inline __device__ void atomAdd(int32_t *address, int32_t val) { static inline __device__ void atomAdd(int64_t *address, int64_t val) { AtomicAddIntegerImpl()(address, val); } +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 700 || CUDA_VERSION < 10000) +static inline __device__ void atomAdd(at::Half *address, at::Half val) { + AtomicAddDecimalImpl()(address, val); +} +#else +static inline __device__ void atomAdd(at::Half *address, at::Half val) { + atomicAdd(reinterpret_cast<__half *>(address), val); +} +#endif static inline __device__ void atomAdd(float *address, float val) { atomicAdd(address, val); } @@ -150,6 +178,9 @@ static inline __device__ void atomMul(int64_t *address, int64_t val) { static inline __device__ void atomMul(float *address, float val) { AtomicMulDecimalImpl()(address, val); } +static inline __device__ void atomMul(at::Half *address, at::Half val) { + AtomicMulDecimalImpl()(address, val); +} static inline __device__ void atomMul(double *address, double val) { AtomicMulDecimalImpl()(address, val); } @@ -172,6 +203,9 @@ static inline __device__ void atomDiv(int32_t *address, int32_t val) { static inline __device__ void atomDiv(int64_t *address, int64_t val) { AtomicDivIntegerImpl()(address, val); } +static inline __device__ void atomDiv(at::Half *address, at::Half val) { + AtomicDivDecimalImpl()(address, val); +} static inline __device__ void atomDiv(float *address, float val) { AtomicDivDecimalImpl()(address, val); } @@ -197,6 +231,9 @@ static inline __device__ void atomMax(int32_t *address, int32_t val) { static inline __device__ void atomMax(int64_t *address, int64_t val) { AtomicMaxIntegerImpl()(address, val); } +static inline __device__ void atomMax(at::Half *address, at::Half val) { + AtomicMaxDecimalImpl()(address, val); +} static inline __device__ void atomMax(float *address, float val) { AtomicMaxDecimalImpl()(address, val); } @@ -222,6 +259,9 @@ static inline __device__ void atomMin(int32_t *address, int32_t val) { static inline __device__ void atomMin(int64_t *address, int64_t val) { AtomicMinIntegerImpl()(address, val); } +static inline __device__ void atomMin(at::Half *address, at::Half val) { + AtomicMinDecimalImpl()(address, val); +} static inline __device__ void atomMin(float *address, float val) { AtomicMinDecimalImpl()(address, val); } diff --git a/csrc/cuda/reducer.cuh b/csrc/cuda/reducer.cuh index 8b318958..8c851d20 100644 --- a/csrc/cuda/reducer.cuh +++ b/csrc/cuda/reducer.cuh @@ -89,7 +89,7 @@ template struct Reducer { if (REDUCE == SUM || REDUCE == MUL || REDUCE == DIV) *address = val; else if (REDUCE == MEAN) - *address = val / (count > 0 ? count : (scalar_t)1); + *address = val / (scalar_t)(count > 0 ? count : 1); else if (REDUCE == MIN || REDUCE == MAX) { if (count > 0) { *address = val; diff --git a/csrc/cuda/scatter_cuda.cu b/csrc/cuda/scatter_cuda.cu index 4ee92684..bcf5d7eb 100644 --- a/csrc/cuda/scatter_cuda.cu +++ b/csrc/cuda/scatter_cuda.cu @@ -111,7 +111,7 @@ scatter_cuda(torch::Tensor src, torch::Tensor index, int64_t dim, auto index_info = at::cuda::detail::getTensorInfo(index); auto stream = at::cuda::getCurrentCUDAStream(); - AT_DISPATCH_ALL_TYPES(src.scalar_type(), "scatter", [&] { + AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, src.scalar_type(), "_", [&] { auto src_data = src.data_ptr(); auto out_data = out.data_ptr(); diff --git a/csrc/cuda/segment_coo_cuda.cu b/csrc/cuda/segment_coo_cuda.cu index 51c1458e..a95cccd7 100644 --- a/csrc/cuda/segment_coo_cuda.cu +++ b/csrc/cuda/segment_coo_cuda.cu @@ -3,6 +3,7 @@ #include #include #include +#include #include "reducer.cuh" #include "utils.cuh" @@ -25,6 +26,10 @@ segment_coo_kernel(const scalar_t *src_data, int lane_idx = row_idx & (32 - 1); int D = index_info.sizes[index_info.dims - 1]; + using cuda_scalar_t = + typename std::conditional::value, __half, + scalar_t>::type; + if (row_idx < E) { int offset = at::cuda::detail::IndexToOffset::get( row_idx, index_info); @@ -36,7 +41,7 @@ segment_coo_kernel(const scalar_t *src_data, #pragma unroll for (int i = 1; i < 32; i *= 2) { // Parallel reduction inside a single warp. - tmp = __shfl_up_sync(FULL_MASK, val, i); + tmp = __shfl_up_sync(FULL_MASK, (cuda_scalar_t)val, i); next_idx = __shfl_up_sync(FULL_MASK, idx, i); if (lane_idx >= i && row_idx / D == (row_idx - i) / D) { assert(idx >= next_idx); @@ -214,7 +219,7 @@ segment_coo_cuda(torch::Tensor src, torch::Tensor index, auto index_info = at::cuda::detail::getTensorInfo(index); auto stream = at::cuda::getCurrentCUDAStream(); - AT_DISPATCH_ALL_TYPES(src.scalar_type(), "segment_coo_kernel", [&] { + AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, src.scalar_type(), "_", [&] { auto src_data = src.data_ptr(); auto out_data = out.data_ptr(); @@ -266,14 +271,15 @@ segment_coo_cuda(torch::Tensor src, torch::Tensor index, segment_coo_kernel <<>>(nullptr, index_info, count_data, E, N); - arg_out.value().clamp_(1); + arg_out.value().masked_fill_(arg_out.value() < (scalar_t)1, + (scalar_t)1); auto count = arg_out.value(); for (int i = dim + 1; i < out.dim(); i++) count = count.unsqueeze(-1); if (out.is_floating_point()) - out.true_divide_(count); + out.div_(count); else - out.floor_divide_(count); + out.div_(count, "floor"); } }); }); @@ -364,7 +370,7 @@ torch::Tensor gather_coo_cuda(torch::Tensor src, torch::Tensor index, auto index_info = at::cuda::detail::getTensorInfo(index); auto stream = at::cuda::getCurrentCUDAStream(); - AT_DISPATCH_ALL_TYPES(src.scalar_type(), "gather_coo_kernel", [&] { + AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, src.scalar_type(), "_", [&] { auto src_data = src.data_ptr(); auto out_data = out.data_ptr(); diff --git a/csrc/cuda/segment_csr_cuda.cu b/csrc/cuda/segment_csr_cuda.cu index a22e99de..dafc98bb 100644 --- a/csrc/cuda/segment_csr_cuda.cu +++ b/csrc/cuda/segment_csr_cuda.cu @@ -26,6 +26,10 @@ segment_csr_kernel(const scalar_t *src_data, int row_idx = thread_idx / TB; int lane_idx = thread_idx & (TB - 1); + using cuda_scalar_t = + typename std::conditional::value, __half, + scalar_t>::type; + if (row_idx < N) { int offset = IndexPtrToOffset::get(row_idx, indptr_info); int64_t row_start = __ldg(indptr_info.data + offset); @@ -48,7 +52,8 @@ segment_csr_kernel(const scalar_t *src_data, if (REDUCE == MIN || REDUCE == MAX) arg_tmp = __shfl_down_sync(FULL_MASK, arg, i); Reducer::update( - &val, __shfl_down_sync(FULL_MASK, val, i), &arg, arg_tmp); + &val, __shfl_down_sync(FULL_MASK, (cuda_scalar_t)val, i), &arg, + arg_tmp); } if (lane_idx == 0) { @@ -147,7 +152,7 @@ segment_csr_cuda(torch::Tensor src, torch::Tensor indptr, auto indptr_info = at::cuda::detail::getTensorInfo(indptr); auto stream = at::cuda::getCurrentCUDAStream(); - AT_DISPATCH_ALL_TYPES(src.scalar_type(), "segment_csr_kernel", [&] { + AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, src.scalar_type(), "_", [&] { auto src_data = src.data_ptr(); auto out_data = out.data_ptr(); @@ -264,7 +269,7 @@ torch::Tensor gather_csr_cuda(torch::Tensor src, torch::Tensor indptr, auto indptr_info = at::cuda::detail::getTensorInfo(indptr); auto stream = at::cuda::getCurrentCUDAStream(); - AT_DISPATCH_ALL_TYPES(src.scalar_type(), "gather_csr_kernel", [&] { + AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, src.scalar_type(), "_", [&] { auto src_data = src.data_ptr(); auto out_data = out.data_ptr(); diff --git a/csrc/scatter.cpp b/csrc/scatter.cpp index 33ae35ad..ed3c3f97 100644 --- a/csrc/scatter.cpp +++ b/csrc/scatter.cpp @@ -127,13 +127,12 @@ class ScatterMean : public torch::autograd::Function { old_index.dim() <= dim ? old_index.dim() - 1 : dim, torch::nullopt, out.size(dim), "sum"); auto count = std::get<0>(result); - count.clamp_(1); + count.masked_fill_(count < 1, 1); count = broadcast(count, out, dim); - if (out.is_floating_point()) - out.true_divide_(count); + out.div_(count); else - out.floor_divide_(count); + out.div_(count, "floor"); ctx->save_for_backward({index, count}); if (optional_out.has_value()) diff --git a/test/utils.py b/test/utils.py index 4c3995d2..dcdb19c9 100644 --- a/test/utils.py +++ b/test/utils.py @@ -2,7 +2,7 @@ reductions = ['sum', 'add', 'mean', 'min', 'max'] -dtypes = [torch.float, torch.double, torch.int, torch.long] +dtypes = [torch.half, torch.float, torch.double, torch.int, torch.long] grad_dtypes = [torch.float, torch.double] devices = [torch.device('cpu')] diff --git a/torch_scatter/scatter.py b/torch_scatter/scatter.py index 87227361..615153ad 100644 --- a/torch_scatter/scatter.py +++ b/torch_scatter/scatter.py @@ -50,12 +50,10 @@ def scatter_mean(src: torch.Tensor, index: torch.Tensor, dim: int = -1, ones = torch.ones(index.size(), dtype=src.dtype, device=src.device) count = scatter_sum(ones, index, index_dim, None, dim_size) - count.clamp_(1) + count[count < 1] = 1 count = broadcast(count, out, dim) - if torch.is_floating_point(out): - out.true_divide_(count) - else: - out.floor_divide_(count) + rounding_mode = None if torch.is_floating_point(out) else 'floor' + out.div_(count, rounding_mode=rounding_mode) return out