diff --git a/csrc/cpu/neighbor_sample_cpu.cpp b/csrc/cpu/neighbor_sample_cpu.cpp index 23e3d498..9da7edff 100644 --- a/csrc/cpu/neighbor_sample_cpu.cpp +++ b/csrc/cpu/neighbor_sample_cpu.cpp @@ -153,7 +153,7 @@ hetero_sample(const vector &node_types, // Add the input nodes to the output nodes: for (const auto &kv : input_node_dict) { const auto &node_type = kv.key(); - const auto &input_node = kv.value(); + const torch::Tensor &input_node = kv.value(); const auto *input_node_data = input_node.data_ptr(); auto &samples = samples_dict.at(node_type); @@ -180,8 +180,8 @@ hetero_sample(const vector &node_types, auto &src_samples = samples_dict.at(src_node_type); auto &to_local_src_node = to_local_node_dict.at(src_node_type); - const auto *colptr_data = colptr_dict.at(rel_type).data_ptr(); - const auto *row_data = row_dict.at(rel_type).data_ptr(); + const auto *colptr_data = ((torch::Tensor)colptr_dict.at(rel_type)).data_ptr(); + const auto *row_data = ((torch::Tensor)row_dict.at(rel_type)).data_ptr(); auto &rows = rows_dict.at(rel_type); auto &cols = cols_dict.at(rel_type); @@ -261,8 +261,8 @@ hetero_sample(const vector &node_types, const auto &dst_samples = samples_dict.at(dst_node_type); auto &to_local_src_node = to_local_node_dict.at(src_node_type); - const auto *colptr_data = kv.value().data_ptr(); - const auto *row_data = row_dict.at(rel_type).data_ptr(); + const auto *colptr_data = ((torch::Tensor)kv.value()).data_ptr(); + const auto *row_data = ((torch::Tensor)row_dict.at(rel_type)).data_ptr(); auto &rows = rows_dict.at(rel_type); auto &cols = cols_dict.at(rel_type);