diff --git a/csrc/cpu/neighbor_sample_cpu.cpp b/csrc/cpu/neighbor_sample_cpu.cpp index 0b51bf59..1b908bc0 100644 --- a/csrc/cpu/neighbor_sample_cpu.cpp +++ b/csrc/cpu/neighbor_sample_cpu.cpp @@ -238,17 +238,24 @@ hetero_sample(const vector &node_types, if (temporal) { if (!satisfy_time(node_time_dict, src_node_type, dst_time, v)) continue; - } - const auto res = to_local_src_node.insert({v, src_samples.size()}); - if (res.second) { + // force disjoint of computation tree + // note that the sampling always needs to have directed=True + // for temporal case + // to_local_src_node is not used for temporal / directed case src_samples.push_back(v); - if (temporal) - src_root_time.push_back(dst_time); - } - if (directed) { + src_root_time.push_back(dst_time); cols.push_back(i); - rows.push_back(res.first->second); + rows.push_back(src_samples.size() - 1); edges.push_back(offset); + } else { + const auto res = to_local_src_node.insert({v, src_samples.size()}); + if (res.second) + src_samples.push_back(v); + if (directed) { + cols.push_back(i); + rows.push_back(res.first->second); + edges.push_back(offset); + } } } } else if (replace) { @@ -261,17 +268,23 @@ hetero_sample(const vector &node_types, // TODO Infinity loop if no neighbor satisfies time constraint: if (!satisfy_time(node_time_dict, src_node_type, dst_time, v)) continue; - } - const auto res = to_local_src_node.insert({v, src_samples.size()}); - if (res.second) { + // force disjoint of computation tree + // note that the sampling always needs to have directed=True + // for temporal case src_samples.push_back(v); - if (temporal) - src_root_time.push_back(dst_time); - } - if (directed) { + src_root_time.push_back(dst_time); cols.push_back(i); - rows.push_back(res.first->second); + rows.push_back(src_samples.size() - 1); edges.push_back(offset); + } else { + const auto res = to_local_src_node.insert({v, src_samples.size()}); + if (res.second) + src_samples.push_back(v); + if (directed) { + cols.push_back(i); + rows.push_back(res.first->second); + edges.push_back(offset); + } } num_neighbors += 1; } @@ -289,17 +302,23 @@ hetero_sample(const vector &node_types, if (temporal) { if (!satisfy_time(node_time_dict, src_node_type, dst_time, v)) continue; - } - const auto res = to_local_src_node.insert({v, src_samples.size()}); - if (res.second) { + // force disjoint of computation tree + // note that the sampling always needs to have directed=True + // for temporal case src_samples.push_back(v); - if (temporal) - src_root_time.push_back(dst_time); - } - if (directed) { + src_root_time.push_back(dst_time); cols.push_back(i); - rows.push_back(res.first->second); + rows.push_back(src_samples.size() - 1); edges.push_back(offset); + } else { + const auto res = to_local_src_node.insert({v, src_samples.size()}); + if (res.second) + src_samples.push_back(v); + if (directed) { + cols.push_back(i); + rows.push_back(res.first->second); + edges.push_back(offset); + } } } } @@ -412,21 +431,19 @@ hetero_temporal_neighbor_sample_cpu( const c10::Dict> &num_neighbors_dict, const c10::Dict &node_time_dict, const int64_t num_hops, const bool replace, const bool directed) { - - if (replace && directed) { + AT_ASSERTM(directed, "Temporal sampling requires 'directed' sampling") + if (replace) { + // We assume that directed = True for temporal sampling + // The current implementation uses disjoint computation trees + // to tackle the case of the same node sampled having different + // root time constraint. + // In future, we could extend to directed = False case, + // allowing additional edges within each computation tree. return hetero_sample( node_types, edge_types, colptr_dict, row_dict, input_node_dict, num_neighbors_dict, node_time_dict, num_hops); - } else if (replace && !directed) { - return hetero_sample( - node_types, edge_types, colptr_dict, row_dict, input_node_dict, - num_neighbors_dict, node_time_dict, num_hops); - } else if (!replace && directed) { - return hetero_sample( - node_types, edge_types, colptr_dict, row_dict, input_node_dict, - num_neighbors_dict, node_time_dict, num_hops); } else { - return hetero_sample( + return hetero_sample( node_types, edge_types, colptr_dict, row_dict, input_node_dict, num_neighbors_dict, node_time_dict, num_hops); }