From cd84568bf4e21e782c032a43953652d717573180 Mon Sep 17 00:00:00 2001 From: Koch Date: Wed, 15 Jan 2020 15:08:48 +0100 Subject: [PATCH] fix: fix errors regarding Reducer functionalities in segment.cpp --- cpu/segment.cpp | 31 ++++++++++++++++--------------- 1 file changed, 16 insertions(+), 15 deletions(-) diff --git a/cpu/segment.cpp b/cpu/segment.cpp index f7c3767c..a33c5346 100644 --- a/cpu/segment.cpp +++ b/cpu/segment.cpp @@ -11,23 +11,24 @@ enum ReductionType { ADD, MEAN, MIN, MAX }; #define AT_DISPATCH_REDUCTION_TYPES(reduce, ...) \ [&] { \ + ReductionType REDUCE = ADD; \ if (reduce == "add") { \ - const ReductionType REDUCE = ADD; \ + REDUCE = ADD; \ return __VA_ARGS__(); \ } else if (reduce == "mean") { \ - const ReductionType REDUCE = MEAN; \ + REDUCE = MEAN; \ return __VA_ARGS__(); \ } else if (reduce == "min") { \ - const ReductionType REDUCE = MIN; \ + REDUCE = MIN; \ return __VA_ARGS__(); \ } else if (reduce == "max") { \ - const ReductionType REDUCE = MAX; \ + REDUCE = MAX; \ return __VA_ARGS__(); \ } \ }() -template struct Reducer { - static inline scalar_t init() { +template struct Reducer { + static inline scalar_t init(ReductionType REDUCE) { if (REDUCE == MIN) { return std::numeric_limits::max(); } else if (REDUCE == MAX) { @@ -37,7 +38,7 @@ template struct Reducer { } } - static inline void update(scalar_t *val, scalar_t new_val) { + static inline void update(ReductionType REDUCE, scalar_t *val, scalar_t new_val) { if (REDUCE == ADD || REDUCE == MEAN) { *val = *val + new_val; } else if ((REDUCE == MIN && new_val < *val) || @@ -46,7 +47,7 @@ template struct Reducer { } } - static inline void update(scalar_t *val, scalar_t new_val, int64_t *arg, + static inline void update(ReductionType REDUCE, scalar_t *val, scalar_t new_val, int64_t *arg, int64_t new_arg) { if (REDUCE == ADD || REDUCE == MEAN) { *val = *val + new_val; @@ -57,7 +58,7 @@ template struct Reducer { } } - static inline void write(scalar_t *address, scalar_t val, + static inline void write(ReductionType REDUCE, scalar_t *address, scalar_t val, int64_t *arg_address, int64_t arg, int count) { if (REDUCE == ADD) { *address = val; @@ -136,16 +137,16 @@ segment_csr(at::Tensor src, at::Tensor indptr, at::optional out_opt, offset = (n / (indptr.size(-1) - 1)) * E * K; for (int k = 0; k < K; k++) { - vals[k] = Reducer::init(); + vals[k] = Reducer::init(REDUCE); } for (int64_t e = row_start; e < row_end; e++) { for (int k = 0; k < K; k++) { - Reducer::update( + Reducer::update(REDUCE, &vals[k], src_data[offset + e * K + k], &args[k], e); } } for (int k = 0; k < K; k++) { - Reducer::write(out_data + n * K + k, vals[k], + Reducer::write(REDUCE, out_data + n * K + k, vals[k], arg_out_data + n * K + k, args[k], row_end - row_start); } @@ -214,13 +215,13 @@ segment_coo(at::Tensor src, at::Tensor index, at::Tensor out, for (int e_2 = 0; e_2 < E_2; e_2++) { for (int k = 0; k < K; k++) { - Reducer::update( + Reducer::update(REDUCE, &vals[k], src_data[e_1 * E_2 * K + e_2 * K + k], &args[k], e_2); } if (e_2 == E_2 - 1) { for (int k = 0; k < K; k++) { - Reducer::write( + Reducer::write(REDUCE, out_data + e_1 * N * K + idx * K + k, vals[k], arg_out_data + e_1 * N * K + idx * K + k, args[k], e_2 + 1 - row_start); @@ -231,7 +232,7 @@ segment_coo(at::Tensor src, at::Tensor index, at::Tensor out, if (idx != next_idx) { for (int k = 0; k < K; k++) { - Reducer::write( + Reducer::write(REDUCE, out_data + e_1 * N * K + idx * K + k, vals[k], arg_out_data + e_1 * N * K + idx * K + k, args[k], e_2 + 1 - row_start);