Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 16 additions & 15 deletions cpu/segment.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,23 +11,24 @@ enum ReductionType { ADD, MEAN, MIN, MAX };

#define AT_DISPATCH_REDUCTION_TYPES(reduce, ...) \
[&] { \
ReductionType REDUCE = ADD; \
if (reduce == "add") { \
const ReductionType REDUCE = ADD; \
REDUCE = ADD; \
return __VA_ARGS__(); \
} else if (reduce == "mean") { \
const ReductionType REDUCE = MEAN; \
REDUCE = MEAN; \
return __VA_ARGS__(); \
} else if (reduce == "min") { \
const ReductionType REDUCE = MIN; \
REDUCE = MIN; \
return __VA_ARGS__(); \
} else if (reduce == "max") { \
const ReductionType REDUCE = MAX; \
REDUCE = MAX; \
return __VA_ARGS__(); \
} \
}()

template <typename scalar_t, ReductionType REDUCE> struct Reducer {
static inline scalar_t init() {
template <typename scalar_t> struct Reducer {
static inline scalar_t init(ReductionType REDUCE) {
if (REDUCE == MIN) {
return std::numeric_limits<scalar_t>::max();
} else if (REDUCE == MAX) {
Expand All @@ -37,7 +38,7 @@ template <typename scalar_t, ReductionType REDUCE> struct Reducer {
}
}

static inline void update(scalar_t *val, scalar_t new_val) {
static inline void update(ReductionType REDUCE, scalar_t *val, scalar_t new_val) {
if (REDUCE == ADD || REDUCE == MEAN) {
*val = *val + new_val;
} else if ((REDUCE == MIN && new_val < *val) ||
Expand All @@ -46,7 +47,7 @@ template <typename scalar_t, ReductionType REDUCE> struct Reducer {
}
}

static inline void update(scalar_t *val, scalar_t new_val, int64_t *arg,
static inline void update(ReductionType REDUCE, scalar_t *val, scalar_t new_val, int64_t *arg,
int64_t new_arg) {
if (REDUCE == ADD || REDUCE == MEAN) {
*val = *val + new_val;
Expand All @@ -57,7 +58,7 @@ template <typename scalar_t, ReductionType REDUCE> struct Reducer {
}
}

static inline void write(scalar_t *address, scalar_t val,
static inline void write(ReductionType REDUCE, scalar_t *address, scalar_t val,
int64_t *arg_address, int64_t arg, int count) {
if (REDUCE == ADD) {
*address = val;
Expand Down Expand Up @@ -136,16 +137,16 @@ segment_csr(at::Tensor src, at::Tensor indptr, at::optional<at::Tensor> out_opt,

offset = (n / (indptr.size(-1) - 1)) * E * K;
for (int k = 0; k < K; k++) {
vals[k] = Reducer<scalar_t, REDUCE>::init();
vals[k] = Reducer<scalar_t>::init(REDUCE);
}
for (int64_t e = row_start; e < row_end; e++) {
for (int k = 0; k < K; k++) {
Reducer<scalar_t, REDUCE>::update(
Reducer<scalar_t>::update(REDUCE,
&vals[k], src_data[offset + e * K + k], &args[k], e);
}
}
for (int k = 0; k < K; k++) {
Reducer<scalar_t, REDUCE>::write(out_data + n * K + k, vals[k],
Reducer<scalar_t>::write(REDUCE, out_data + n * K + k, vals[k],
arg_out_data + n * K + k, args[k],
row_end - row_start);
}
Expand Down Expand Up @@ -214,13 +215,13 @@ segment_coo(at::Tensor src, at::Tensor index, at::Tensor out,
for (int e_2 = 0; e_2 < E_2; e_2++) {

for (int k = 0; k < K; k++) {
Reducer<scalar_t, REDUCE>::update(
Reducer<scalar_t>::update(REDUCE,
&vals[k], src_data[e_1 * E_2 * K + e_2 * K + k], &args[k], e_2);
}

if (e_2 == E_2 - 1) {
for (int k = 0; k < K; k++) {
Reducer<scalar_t, REDUCE>::write(
Reducer<scalar_t>::write(REDUCE,
out_data + e_1 * N * K + idx * K + k, vals[k],
arg_out_data + e_1 * N * K + idx * K + k, args[k],
e_2 + 1 - row_start);
Expand All @@ -231,7 +232,7 @@ segment_coo(at::Tensor src, at::Tensor index, at::Tensor out,

if (idx != next_idx) {
for (int k = 0; k < K; k++) {
Reducer<scalar_t, REDUCE>::write(
Reducer<scalar_t>::write(REDUCE,
out_data + e_1 * N * K + idx * K + k, vals[k],
arg_out_data + e_1 * N * K + idx * K + k, args[k],
e_2 + 1 - row_start);
Expand Down