Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions cpu/gather.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
#include "compat.h"
#include "index_info.h"

#include <vector>

#define CHECK_CPU(x) AT_ASSERTM(!x.type().is_cuda(), #x " must be CPU tensor")

at::Tensor gather_csr(at::Tensor src, at::Tensor indptr,
Expand Down Expand Up @@ -43,7 +45,7 @@ at::Tensor gather_csr(at::Tensor src, at::Tensor indptr,
auto src_data = src.DATA_PTR<scalar_t>();
auto out_data = out.DATA_PTR<scalar_t>();

scalar_t vals[K];
std::vector<scalar_t> vals(K);
int64_t row_start, row_end;
for (int n = 0; n < N; n++) {
int offset = IndexPtrToOffset<int64_t>::get(n, indptr_info);
Expand Down Expand Up @@ -104,7 +106,7 @@ at::Tensor gather_coo(at::Tensor src, at::Tensor index,
auto src_data = src.DATA_PTR<scalar_t>();
auto out_data = out.DATA_PTR<scalar_t>();

scalar_t vals[K];
std::vector<scalar_t> vals(K);
int64_t idx, next_idx;
for (int e_1 = 0; e_1 < E_1; e_1++) {
int offset = IndexToOffset<int64_t>::get(e_1 * E_2, index_info);
Expand Down
12 changes: 8 additions & 4 deletions cpu/segment.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
#include "compat.h"
#include "index_info.h"

#include <vector>

#define CHECK_CPU(x) AT_ASSERTM(!x.type().is_cuda(), #x " must be CPU tensor")

enum ReductionType { ADD, MEAN, MIN, MAX };
Expand Down Expand Up @@ -123,8 +125,9 @@ segment_csr(at::Tensor src, at::Tensor indptr, at::optional<at::Tensor> out_opt,
auto src_data = src.DATA_PTR<scalar_t>();
auto out_data = out.DATA_PTR<scalar_t>();

scalar_t vals[K];
int64_t row_start, row_end, args[K];
std::vector<scalar_t> vals(K);
int64_t row_start, row_end;
std::vector<int64_t> args(K);
AT_DISPATCH_REDUCTION_TYPES(reduce, [&] {
for (int n = 0; n < N; n++) {
int offset = IndexPtrToOffset<int64_t>::get(n, indptr_info);
Expand Down Expand Up @@ -195,8 +198,9 @@ segment_coo(at::Tensor src, at::Tensor index, at::Tensor out,
auto src_data = src.DATA_PTR<scalar_t>();
auto out_data = out.DATA_PTR<scalar_t>();

scalar_t vals[K];
int64_t idx, next_idx, row_start, args[K];
std::vector<scalar_t> vals(K);
int64_t idx, next_idx, row_start;
std::vector<int64_t> args(K);
AT_DISPATCH_REDUCTION_TYPES(reduce, [&] {
for (int e_1 = 0; e_1 < E_1; e_1++) {
int offset = IndexToOffset<int64_t>::get(e_1 * E_2, index_info);
Expand Down