diff --git a/csrc/cpu/neighbor_sample_cpu.cpp b/csrc/cpu/neighbor_sample_cpu.cpp index 1b908bc0..189e734f 100644 --- a/csrc/cpu/neighbor_sample_cpu.cpp +++ b/csrc/cpu/neighbor_sample_cpu.cpp @@ -242,10 +242,11 @@ hetero_sample(const vector &node_types, // note that the sampling always needs to have directed=True // for temporal case // to_local_src_node is not used for temporal / directed case + const int64_t sample_idx = src_samples.size(); src_samples.push_back(v); src_root_time.push_back(dst_time); cols.push_back(i); - rows.push_back(src_samples.size() - 1); + rows.push_back(sample_idx); edges.push_back(offset); } else { const auto res = to_local_src_node.insert({v, src_samples.size()}); @@ -271,10 +272,11 @@ hetero_sample(const vector &node_types, // force disjoint of computation tree // note that the sampling always needs to have directed=True // for temporal case + const int64_t sample_idx = src_samples.size(); src_samples.push_back(v); src_root_time.push_back(dst_time); cols.push_back(i); - rows.push_back(src_samples.size() - 1); + rows.push_back(sample_idx); edges.push_back(offset); } else { const auto res = to_local_src_node.insert({v, src_samples.size()}); @@ -305,10 +307,11 @@ hetero_sample(const vector &node_types, // force disjoint of computation tree // note that the sampling always needs to have directed=True // for temporal case + const int64_t sample_idx = src_samples.size(); src_samples.push_back(v); src_root_time.push_back(dst_time); cols.push_back(i); - rows.push_back(src_samples.size() - 1); + rows.push_back(sample_idx); edges.push_back(offset); } else { const auto res = to_local_src_node.insert({v, src_samples.size()}); @@ -431,7 +434,7 @@ hetero_temporal_neighbor_sample_cpu( const c10::Dict> &num_neighbors_dict, const c10::Dict &node_time_dict, const int64_t num_hops, const bool replace, const bool directed) { - AT_ASSERTM(directed, "Temporal sampling requires 'directed' sampling") + AT_ASSERTM(directed, "Temporal sampling requires 'directed' sampling"); if (replace) { // We assume that directed = True for temporal sampling // The current implementation uses disjoint computation trees