Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 17 additions & 15 deletions csrc/cpu/segment_csr_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ segment_csr_cpu(torch::Tensor src, torch::Tensor indptr,
sizes[i] = src.size(i);
indptr = indptr.expand(sizes);

auto dim = indptr.dim() - 1;
const auto dim = indptr.dim() - 1;

src = src.contiguous();

Expand Down Expand Up @@ -50,38 +50,40 @@ segment_csr_cpu(torch::Tensor src, torch::Tensor indptr,
return std::make_tuple(out, arg_out);
}

auto N = out.size(dim) * (indptr.numel() / indptr.size(-1));
auto K = out.numel() / N;
auto E = src.size(dim);
const auto N = out.size(dim) * (indptr.numel() / indptr.size(-1));
const auto K = out.numel() / N;
const auto E = src.size(dim);

auto indptr_info = getTensorInfo<int64_t>(indptr);
auto stride = indptr_info.strides[indptr_info.dims - 1];
const auto indptr_info = getTensorInfo<int64_t>(indptr);
const auto stride = indptr_info.strides[indptr_info.dims - 1];
std::vector<int64_t> args(K);
AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, src.scalar_type(), "_", [&] {
auto src_data = src.data_ptr<scalar_t>();
const auto src_data = src.data_ptr<scalar_t>();
auto out_data = out.data_ptr<scalar_t>();

std::vector<scalar_t> vals(K);
int64_t row_start, row_end;
AT_DISPATCH_REDUCTION_TYPES(reduce, [&] {
for (auto n = 0; n < N; n++) {
auto offset = IndexPtrToOffset<int64_t>::get(n, indptr_info);
row_start = indptr_info.data[offset];
row_end = indptr_info.data[offset + stride];
const auto offset1 = IndexPtrToOffset<int64_t>::get(n, indptr_info);
const auto row_start = indptr_info.data[offset1];
const auto row_end = indptr_info.data[offset1 + stride];

offset = (n / (indptr.size(-1) - 1)) * E * K;
const auto offset2 = (n / (indptr.size(-1) - 1)) * E * K;
for (auto k = 0; k < K; k++)
vals[k] = Reducer<scalar_t, REDUCE>::init();

for (auto e = row_start; e < row_end; e++)
for (auto k = 0; k < K; k++)
Reducer<scalar_t, REDUCE>::update(
&vals[k], src_data[offset + e * K + k], &args[k], e);
&vals[k], src_data[offset2 + e * K + k], &args[k], e);

for (auto k = 0; k < K; k++)
Reducer<scalar_t, REDUCE>::write(out_data + n * K + k, vals[k],
if (reduce2REDUCE.at(reduce) == MIN || reduce2REDUCE.at(reduce) == MAX) {
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like we skip writing to the destination here for any other reduction (which is likely unintented - see test failures).

for (auto k = 0; k < K; k++){
Reducer<scalar_t, REDUCE>::write(out_data + n * K + k, vals[k],
arg_out_data + n * K + k, args[k],
row_end - row_start);
}
}
}
});
});
Expand Down