From 89e1c2cb4fdac933186aa6228c8706d30d962040 Mon Sep 17 00:00:00 2001 From: Stefan Ivanov Date: Fri, 14 Feb 2020 00:42:42 +0000 Subject: [PATCH 1/2] Revert "potential windows fix" This reverts commit 0be33ffa3b8849b711a169bfab218302cb9b31c4. --- csrc/cpu/reducer.h | 13 ++++++------- csrc/cpu/scatter_cpu.cpp | 8 ++++---- csrc/cpu/segment_coo_cpu.cpp | 16 ++++++++-------- csrc/cpu/segment_csr_cpu.cpp | 12 ++++++------ 4 files changed, 24 insertions(+), 25 deletions(-) diff --git a/csrc/cpu/reducer.h b/csrc/cpu/reducer.h index 507aa408..edb0ec5f 100644 --- a/csrc/cpu/reducer.h +++ b/csrc/cpu/reducer.h @@ -40,8 +40,8 @@ const std::map reduce2REDUCE = { } \ }() -template struct Reducer { - static inline scalar_t init(ReductionType REDUCE) { +template struct Reducer { + static inline scalar_t init() { if (REDUCE == MUL || REDUCE == DIV) return (scalar_t)1; else if (REDUCE == MIN) @@ -52,8 +52,8 @@ template struct Reducer { return (scalar_t)0; } - static inline void update(ReductionType REDUCE, scalar_t *val, - scalar_t new_val, int64_t *arg, int64_t new_arg) { + static inline void update(scalar_t *val, scalar_t new_val, int64_t *arg, + int64_t new_arg) { if (REDUCE == SUM || REDUCE == MEAN) *val = *val + new_val; else if (REDUCE == MUL) @@ -67,9 +67,8 @@ template struct Reducer { } } - static inline void write(ReductionType REDUCE, scalar_t *address, - scalar_t val, int64_t *arg_address, int64_t arg, - int count) { + static inline void write(scalar_t *address, scalar_t val, + int64_t *arg_address, int64_t arg, int count) { if (REDUCE == SUM || REDUCE == MUL || REDUCE == DIV) *address = val; else if (REDUCE == MEAN) diff --git a/csrc/cpu/scatter_cpu.cpp b/csrc/cpu/scatter_cpu.cpp index 67516185..5f9da470 100644 --- a/csrc/cpu/scatter_cpu.cpp +++ b/csrc/cpu/scatter_cpu.cpp @@ -61,22 +61,22 @@ scatter_cpu(torch::Tensor src, torch::Tensor index, int64_t dim, int64_t i, idx; AT_DISPATCH_REDUCTION_TYPES(reduce, [&] { if (!optional_out.has_value()) - out.fill_(Reducer::init(REDUCE)); + out.fill_(Reducer::init()); for (auto b = 0; b < B; b++) { for (auto e = 0; e < E; e++) { for (auto k = 0; k < K; k++) { i = b * E * K + e * K + k; idx = index_info.data[IndexToOffset::get(i, index_info)]; - Reducer::update( - REDUCE, out_data + b * N * K + idx * K + k, src_data[i], + Reducer::update( + out_data + b * N * K + idx * K + k, src_data[i], arg_out_data + b * N * K + idx * K + k, e); } } } if (!optional_out.has_value() && (REDUCE == MIN || REDUCE == MAX)) - out.masked_fill_(out == Reducer::init(REDUCE), (scalar_t)0); + out.masked_fill_(out == Reducer::init(), (scalar_t)0); }); }); diff --git a/csrc/cpu/segment_coo_cpu.cpp b/csrc/cpu/segment_coo_cpu.cpp index 90d17752..c59afd2b 100644 --- a/csrc/cpu/segment_coo_cpu.cpp +++ b/csrc/cpu/segment_coo_cpu.cpp @@ -72,7 +72,7 @@ segment_coo_cpu(torch::Tensor src, torch::Tensor index, int64_t idx, next_idx, row_start; AT_DISPATCH_REDUCTION_TYPES(reduce, [&] { if (!optional_out.has_value()) - out.fill_(Reducer::init(REDUCE)); + out.fill_(Reducer::init()); if (REDUCE == MEAN) count_data = arg_out.value().data_ptr(); @@ -87,13 +87,13 @@ segment_coo_cpu(torch::Tensor src, torch::Tensor index, for (auto e = 0; e < E; e++) { for (auto k = 0; k < K; k++) - Reducer::update( - REDUCE, &vals[k], src_data[b * E * K + e * K + k], &args[k], e); + Reducer::update( + &vals[k], src_data[b * E * K + e * K + k], &args[k], e); if (e == E - 1) { for (auto k = 0; k < K; k++) - Reducer::write( - REDUCE, out_data + b * N * K + idx * K + k, vals[k], + Reducer::write( + out_data + b * N * K + idx * K + k, vals[k], arg_out_data + b * N * K + idx * K + k, args[k], e + 1 - row_start); if (REDUCE == MEAN) @@ -104,8 +104,8 @@ segment_coo_cpu(torch::Tensor src, torch::Tensor index, if (idx != next_idx) { for (auto k = 0; k < K; k++) { - Reducer::write( - REDUCE, out_data + b * N * K + idx * K + k, vals[k], + Reducer::write( + out_data + b * N * K + idx * K + k, vals[k], arg_out_data + b * N * K + idx * K + k, args[k], e + 1 - row_start); @@ -121,7 +121,7 @@ segment_coo_cpu(torch::Tensor src, torch::Tensor index, } } if (!optional_out.has_value() && (REDUCE == MIN || REDUCE == MAX)) - out.masked_fill_(out == Reducer::init(REDUCE), (scalar_t)0); + out.masked_fill_(out == Reducer::init(), (scalar_t)0); if (REDUCE == MEAN) arg_out.value().clamp_(1); diff --git a/csrc/cpu/segment_csr_cpu.cpp b/csrc/cpu/segment_csr_cpu.cpp index bfffa4ed..6dca23f6 100644 --- a/csrc/cpu/segment_csr_cpu.cpp +++ b/csrc/cpu/segment_csr_cpu.cpp @@ -68,17 +68,17 @@ segment_csr_cpu(torch::Tensor src, torch::Tensor indptr, offset = (n / (indptr.size(-1) - 1)) * E * K; for (auto k = 0; k < K; k++) - vals[k] = Reducer::init(REDUCE); + vals[k] = Reducer::init(); for (auto e = row_start; e < row_end; e++) for (auto k = 0; k < K; k++) - Reducer::update( - REDUCE, &vals[k], src_data[offset + e * K + k], &args[k], e); + Reducer::update( + &vals[k], src_data[offset + e * K + k], &args[k], e); for (auto k = 0; k < K; k++) - Reducer::write(REDUCE, out_data + n * K + k, vals[k], - arg_out_data + n * K + k, args[k], - row_end - row_start); + Reducer::write(out_data + n * K + k, vals[k], + arg_out_data + n * K + k, args[k], + row_end - row_start); } }); }); From 3b6c702f1e2dcf87bdf7ee52dbf452955e2fea6a Mon Sep 17 00:00:00 2001 From: Stefan Ivanov Date: Fri, 14 Feb 2020 02:57:34 +0000 Subject: [PATCH 2/2] Fix compile-time constant usage with MSVC --- csrc/cpu/reducer.h | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/csrc/cpu/reducer.h b/csrc/cpu/reducer.h index edb0ec5f..a07033aa 100644 --- a/csrc/cpu/reducer.h +++ b/csrc/cpu/reducer.h @@ -14,27 +14,27 @@ const std::map reduce2REDUCE = { [&] { \ switch (reduce2REDUCE.at(reduce)) { \ case SUM: { \ - const ReductionType REDUCE = SUM; \ + static constexpr ReductionType REDUCE = SUM; \ return __VA_ARGS__(); \ } \ case MEAN: { \ - const ReductionType REDUCE = MEAN; \ + static constexpr ReductionType REDUCE = MEAN; \ return __VA_ARGS__(); \ } \ case MUL: { \ - const ReductionType REDUCE = MUL; \ + static constexpr ReductionType REDUCE = MUL; \ return __VA_ARGS__(); \ } \ case DIV: { \ - const ReductionType REDUCE = DIV; \ + static constexpr ReductionType REDUCE = DIV; \ return __VA_ARGS__(); \ } \ case MIN: { \ - const ReductionType REDUCE = MIN; \ + static constexpr ReductionType REDUCE = MIN; \ return __VA_ARGS__(); \ } \ case MAX: { \ - const ReductionType REDUCE = MAX; \ + static constexpr ReductionType REDUCE = MAX; \ return __VA_ARGS__(); \ } \ } \