From b21460d3bb714b54090a136d84a5ca1a8e87c032 Mon Sep 17 00:00:00 2001 From: Richard Barnes Date: Tue, 3 May 2022 14:09:25 -0600 Subject: [PATCH] Fix nullptr math & clean segment_csr_cpu.cpp When I run this code with LLVM-12's undefined behaviour sanitizer enabled, I see: ``` pytorch/torch-scatter/csrc/cpu/segment_csr_cpu.cpp:60:3: runtime error: applying non-zero offset 8 to null pointer #0 0x7fa95da15396 in segment_csr_cpu(at::Tensor, at::Tensor, c10::optional, std::__cxx11::basic_string, std::allocator >)::$_0::operator()() const::'lambda2'()::operator()() const::'lambda'()::operator()() const::'lambda'()::operator()() const pytorch/torch-scatter/csrc/cpu/segment_csr_cpu.cpp:60 ``` This is because on Line 41 we have `int64_t *arg_out_data = nullptr;`. The value of `arg_out_data` is set conditionally, but `arg_out_data` is used unconditionally within the "segment_csr" kernel. Adding an if-statement gates that. I've also added `const` and de-shadowed variables in a few places to make the code more readable. --- csrc/cpu/segment_csr_cpu.cpp | 32 +++++++++++++++++--------------- 1 file changed, 17 insertions(+), 15 deletions(-) diff --git a/csrc/cpu/segment_csr_cpu.cpp b/csrc/cpu/segment_csr_cpu.cpp index a826192c..7f55419c 100644 --- a/csrc/cpu/segment_csr_cpu.cpp +++ b/csrc/cpu/segment_csr_cpu.cpp @@ -20,7 +20,7 @@ segment_csr_cpu(torch::Tensor src, torch::Tensor indptr, sizes[i] = src.size(i); indptr = indptr.expand(sizes); - auto dim = indptr.dim() - 1; + const auto dim = indptr.dim() - 1; src = src.contiguous(); @@ -50,38 +50,40 @@ segment_csr_cpu(torch::Tensor src, torch::Tensor indptr, return std::make_tuple(out, arg_out); } - auto N = out.size(dim) * (indptr.numel() / indptr.size(-1)); - auto K = out.numel() / N; - auto E = src.size(dim); + const auto N = out.size(dim) * (indptr.numel() / indptr.size(-1)); + const auto K = out.numel() / N; + const auto E = src.size(dim); - auto indptr_info = getTensorInfo(indptr); - auto stride = indptr_info.strides[indptr_info.dims - 1]; + const auto indptr_info = getTensorInfo(indptr); + const auto stride = indptr_info.strides[indptr_info.dims - 1]; std::vector args(K); AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, src.scalar_type(), "_", [&] { - auto src_data = src.data_ptr(); + const auto src_data = src.data_ptr(); auto out_data = out.data_ptr(); std::vector vals(K); - int64_t row_start, row_end; AT_DISPATCH_REDUCTION_TYPES(reduce, [&] { for (auto n = 0; n < N; n++) { - auto offset = IndexPtrToOffset::get(n, indptr_info); - row_start = indptr_info.data[offset]; - row_end = indptr_info.data[offset + stride]; + const auto offset1 = IndexPtrToOffset::get(n, indptr_info); + const auto row_start = indptr_info.data[offset1]; + const auto row_end = indptr_info.data[offset1 + stride]; - offset = (n / (indptr.size(-1) - 1)) * E * K; + const auto offset2 = (n / (indptr.size(-1) - 1)) * E * K; for (auto k = 0; k < K; k++) vals[k] = Reducer::init(); for (auto e = row_start; e < row_end; e++) for (auto k = 0; k < K; k++) Reducer::update( - &vals[k], src_data[offset + e * K + k], &args[k], e); + &vals[k], src_data[offset2 + e * K + k], &args[k], e); - for (auto k = 0; k < K; k++) - Reducer::write(out_data + n * K + k, vals[k], + if (reduce2REDUCE.at(reduce) == MIN || reduce2REDUCE.at(reduce) == MAX) { + for (auto k = 0; k < K; k++){ + Reducer::write(out_data + n * K + k, vals[k], arg_out_data + n * K + k, args[k], row_end - row_start); + } + } } }); });