diff --git a/cpu/gather.cpp b/cpu/gather.cpp index db8aec8d..5c013edf 100644 --- a/cpu/gather.cpp +++ b/cpu/gather.cpp @@ -3,6 +3,8 @@ #include "compat.h" #include "index_info.h" +#include + #define CHECK_CPU(x) AT_ASSERTM(!x.type().is_cuda(), #x " must be CPU tensor") at::Tensor gather_csr(at::Tensor src, at::Tensor indptr, @@ -43,7 +45,7 @@ at::Tensor gather_csr(at::Tensor src, at::Tensor indptr, auto src_data = src.DATA_PTR(); auto out_data = out.DATA_PTR(); - scalar_t vals[K]; + std::vector vals(K); int64_t row_start, row_end; for (int n = 0; n < N; n++) { int offset = IndexPtrToOffset::get(n, indptr_info); @@ -104,7 +106,7 @@ at::Tensor gather_coo(at::Tensor src, at::Tensor index, auto src_data = src.DATA_PTR(); auto out_data = out.DATA_PTR(); - scalar_t vals[K]; + std::vector vals(K); int64_t idx, next_idx; for (int e_1 = 0; e_1 < E_1; e_1++) { int offset = IndexToOffset::get(e_1 * E_2, index_info); diff --git a/cpu/segment.cpp b/cpu/segment.cpp index dff7beca..f7c3767c 100644 --- a/cpu/segment.cpp +++ b/cpu/segment.cpp @@ -3,6 +3,8 @@ #include "compat.h" #include "index_info.h" +#include + #define CHECK_CPU(x) AT_ASSERTM(!x.type().is_cuda(), #x " must be CPU tensor") enum ReductionType { ADD, MEAN, MIN, MAX }; @@ -123,8 +125,9 @@ segment_csr(at::Tensor src, at::Tensor indptr, at::optional out_opt, auto src_data = src.DATA_PTR(); auto out_data = out.DATA_PTR(); - scalar_t vals[K]; - int64_t row_start, row_end, args[K]; + std::vector vals(K); + int64_t row_start, row_end; + std::vector args(K); AT_DISPATCH_REDUCTION_TYPES(reduce, [&] { for (int n = 0; n < N; n++) { int offset = IndexPtrToOffset::get(n, indptr_info); @@ -195,8 +198,9 @@ segment_coo(at::Tensor src, at::Tensor index, at::Tensor out, auto src_data = src.DATA_PTR(); auto out_data = out.DATA_PTR(); - scalar_t vals[K]; - int64_t idx, next_idx, row_start, args[K]; + std::vector vals(K); + int64_t idx, next_idx, row_start; + std::vector args(K); AT_DISPATCH_REDUCTION_TYPES(reduce, [&] { for (int e_1 = 0; e_1 < E_1; e_1++) { int offset = IndexToOffset::get(e_1 * E_2, index_info);