From 52f9e9da8cee075ed23cb494d561f7500a00f305 Mon Sep 17 00:00:00 2001 From: RexYing Date: Fri, 1 Jul 2022 22:23:24 +0000 Subject: [PATCH 01/13] disable undirected for temporal sampling --- csrc/cpu/neighbor_sample_cpu.cpp | 57 ++++++++++++++++---------------- csrc/cpu/neighbor_sample_cpu.h | 2 +- csrc/neighbor_sample.cpp | 4 +-- 3 files changed, 31 insertions(+), 32 deletions(-) diff --git a/csrc/cpu/neighbor_sample_cpu.cpp b/csrc/cpu/neighbor_sample_cpu.cpp index 3aa287c8..3f517aa6 100644 --- a/csrc/cpu/neighbor_sample_cpu.cpp +++ b/csrc/cpu/neighbor_sample_cpu.cpp @@ -238,10 +238,16 @@ hetero_sample(const vector &node_types, continue; } const auto res = to_local_src_node.insert({v, src_samples.size()}); - if (res.second) { + if (temporal) { + // force disjoint of computation tree + // note that the sampling always needs to have directed=True + // for temporal case + // to_local_src_node is not used for temporal / directed case src_samples.push_back(v); - if (temporal) - src_root_time.push_back(dst_time); + src_root_time.push_back(dst_time); + } else { + if (res.second) + src_samples.push_back(v); } if (directed) { cols.push_back(i); @@ -261,15 +267,15 @@ hetero_sample(const vector &node_types, continue; } const auto res = to_local_src_node.insert({v, src_samples.size()}); - if (res.second) { + if (temporal) { + // force disjoint of computation tree + // note that the sampling always needs to have directed=True + // for temporal case src_samples.push_back(v); - if (temporal) - src_root_time.push_back(dst_time); - } - if (directed) { - cols.push_back(i); - rows.push_back(res.first->second); - edges.push_back(offset); + src_root_time.push_back(dst_time); + } else { + if (res.second) + src_samples.push_back(v); } num_neighbors += 1; } @@ -285,14 +291,14 @@ hetero_sample(const vector &node_types, const int64_t offset = col_start + rnd; const int64_t &v = row_data[offset]; if (temporal) { - if (!satisfy_time(node_time_dict, src_node_type, dst_time, v)) - continue; - } - const auto res = to_local_src_node.insert({v, src_samples.size()}); - if (res.second) { + // force disjoint of computation tree + // note that the sampling always needs to have directed=True + // for temporal case src_samples.push_back(v); - if (temporal) - src_root_time.push_back(dst_time); + src_root_time.push_back(dst_time); + } else { + if (res.second) + src_samples.push_back(v); } if (directed) { cols.push_back(i); @@ -409,22 +415,15 @@ hetero_temporal_neighbor_sample_cpu( const c10::Dict &input_node_dict, const c10::Dict> &num_neighbors_dict, const c10::Dict &node_time_dict, - const int64_t num_hops, const bool replace, const bool directed) { + const int64_t num_hops, const bool replace) { - if (replace && directed) { + if (replace) { + // directed = True for temporal sampling return hetero_sample( node_types, edge_types, colptr_dict, row_dict, input_node_dict, num_neighbors_dict, node_time_dict, num_hops); - } else if (replace && !directed) { - return hetero_sample( - node_types, edge_types, colptr_dict, row_dict, input_node_dict, - num_neighbors_dict, node_time_dict, num_hops); - } else if (!replace && directed) { - return hetero_sample( - node_types, edge_types, colptr_dict, row_dict, input_node_dict, - num_neighbors_dict, node_time_dict, num_hops); } else { - return hetero_sample( + return hetero_sample( node_types, edge_types, colptr_dict, row_dict, input_node_dict, num_neighbors_dict, node_time_dict, num_hops); } diff --git a/csrc/cpu/neighbor_sample_cpu.h b/csrc/cpu/neighbor_sample_cpu.h index 1759d456..04a9f744 100644 --- a/csrc/cpu/neighbor_sample_cpu.h +++ b/csrc/cpu/neighbor_sample_cpu.h @@ -33,4 +33,4 @@ hetero_temporal_neighbor_sample_cpu( const c10::Dict &input_node_dict, const c10::Dict> &num_neighbors_dict, const c10::Dict &node_time_dict, - const int64_t num_hops, const bool replace, const bool directed); + const int64_t num_hops, const bool replace); diff --git a/csrc/neighbor_sample.cpp b/csrc/neighbor_sample.cpp index d0f9e056..a322d171 100644 --- a/csrc/neighbor_sample.cpp +++ b/csrc/neighbor_sample.cpp @@ -52,10 +52,10 @@ hetero_temporal_neighbor_sample( const c10::Dict &input_node_dict, const c10::Dict> &num_neighbors_dict, const c10::Dict &node_time_dict, - const int64_t num_hops, const bool replace, const bool directed) { + const int64_t num_hops, const bool replace) { return hetero_temporal_neighbor_sample_cpu( node_types, edge_types, colptr_dict, row_dict, input_node_dict, - num_neighbors_dict, node_time_dict, num_hops, replace, directed); + num_neighbors_dict, node_time_dict, num_hops, replace); } static auto registry = From e66b65caefaea9cc804e73f092f4fe133ff97045 Mon Sep 17 00:00:00 2001 From: RexYing Date: Fri, 1 Jul 2022 22:25:42 +0000 Subject: [PATCH 02/13] disjoint sampling for temporal --- csrc/cpu/neighbor_sample_cpu.cpp | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/csrc/cpu/neighbor_sample_cpu.cpp b/csrc/cpu/neighbor_sample_cpu.cpp index 3f517aa6..e7b1bb42 100644 --- a/csrc/cpu/neighbor_sample_cpu.cpp +++ b/csrc/cpu/neighbor_sample_cpu.cpp @@ -277,6 +277,11 @@ hetero_sample(const vector &node_types, if (res.second) src_samples.push_back(v); } + if (directed) { + cols.push_back(i); + rows.push_back(res.first->second); + edges.push_back(offset); + } num_neighbors += 1; } } else { @@ -290,6 +295,11 @@ hetero_sample(const vector &node_types, } const int64_t offset = col_start + rnd; const int64_t &v = row_data[offset]; + if (temporal) { + if (!satisfy_time(node_time_dict, src_node_type, dst_time, v)) + continue; + } + const auto res = to_local_src_node.insert({v, src_samples.size()}); if (temporal) { // force disjoint of computation tree // note that the sampling always needs to have directed=True From 748f8169f1544875e065b973b61ae3b5413d3e4d Mon Sep 17 00:00:00 2001 From: RexYing Date: Fri, 1 Jul 2022 23:00:44 +0000 Subject: [PATCH 03/13] fix repeated node index --- csrc/cpu/neighbor_sample_cpu.cpp | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/csrc/cpu/neighbor_sample_cpu.cpp b/csrc/cpu/neighbor_sample_cpu.cpp index e7b1bb42..208c7a49 100644 --- a/csrc/cpu/neighbor_sample_cpu.cpp +++ b/csrc/cpu/neighbor_sample_cpu.cpp @@ -236,16 +236,17 @@ hetero_sample(const vector &node_types, if (temporal) { if (!satisfy_time(node_time_dict, src_node_type, dst_time, v)) continue; - } - const auto res = to_local_src_node.insert({v, src_samples.size()}); - if (temporal) { // force disjoint of computation tree // note that the sampling always needs to have directed=True // for temporal case // to_local_src_node is not used for temporal / directed case src_samples.push_back(v); src_root_time.push_back(dst_time); + cols.push_back(i); + rows.push_back(src_samples.size()); + edges.push_back(offset); } else { + const auto res = to_local_src_node.insert({v, src_samples.size()}); if (res.second) src_samples.push_back(v); } @@ -265,15 +266,16 @@ hetero_sample(const vector &node_types, // TODO Infinity loop if no neighbor satisfies time constraint: if (!satisfy_time(node_time_dict, src_node_type, dst_time, v)) continue; - } - const auto res = to_local_src_node.insert({v, src_samples.size()}); - if (temporal) { // force disjoint of computation tree // note that the sampling always needs to have directed=True // for temporal case src_samples.push_back(v); src_root_time.push_back(dst_time); + cols.push_back(i); + rows.push_back(src_samples.size()); + edges.push_back(offset); } else { + const auto res = to_local_src_node.insert({v, src_samples.size()}); if (res.second) src_samples.push_back(v); } @@ -298,15 +300,16 @@ hetero_sample(const vector &node_types, if (temporal) { if (!satisfy_time(node_time_dict, src_node_type, dst_time, v)) continue; - } - const auto res = to_local_src_node.insert({v, src_samples.size()}); - if (temporal) { // force disjoint of computation tree // note that the sampling always needs to have directed=True // for temporal case src_samples.push_back(v); src_root_time.push_back(dst_time); + cols.push_back(i); + rows.push_back(src_samples.size()); + edges.push_back(offset); } else { + const auto res = to_local_src_node.insert({v, src_samples.size()}); if (res.second) src_samples.push_back(v); } From 628f0f69a908c09d731774ee67171d680673ab38 Mon Sep 17 00:00:00 2001 From: RexYing Date: Fri, 1 Jul 2022 23:12:32 +0000 Subject: [PATCH 04/13] compile fix --- csrc/cpu/neighbor_sample_cpu.cpp | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/csrc/cpu/neighbor_sample_cpu.cpp b/csrc/cpu/neighbor_sample_cpu.cpp index 208c7a49..cfb57ad1 100644 --- a/csrc/cpu/neighbor_sample_cpu.cpp +++ b/csrc/cpu/neighbor_sample_cpu.cpp @@ -249,11 +249,11 @@ hetero_sample(const vector &node_types, const auto res = to_local_src_node.insert({v, src_samples.size()}); if (res.second) src_samples.push_back(v); - } - if (directed) { - cols.push_back(i); - rows.push_back(res.first->second); - edges.push_back(offset); + if (directed) { + cols.push_back(i); + rows.push_back(res.first->second); + edges.push_back(offset); + } } } } else if (replace) { @@ -278,11 +278,11 @@ hetero_sample(const vector &node_types, const auto res = to_local_src_node.insert({v, src_samples.size()}); if (res.second) src_samples.push_back(v); - } - if (directed) { - cols.push_back(i); - rows.push_back(res.first->second); - edges.push_back(offset); + if (directed) { + cols.push_back(i); + rows.push_back(res.first->second); + edges.push_back(offset); + } } num_neighbors += 1; } @@ -312,11 +312,11 @@ hetero_sample(const vector &node_types, const auto res = to_local_src_node.insert({v, src_samples.size()}); if (res.second) src_samples.push_back(v); - } - if (directed) { - cols.push_back(i); - rows.push_back(res.first->second); - edges.push_back(offset); + if (directed) { + cols.push_back(i); + rows.push_back(res.first->second); + edges.push_back(offset); + } } } } From 7fce422b6c3be1c0fcee2c2d9d3ddd1a2ec633b9 Mon Sep 17 00:00:00 2001 From: Rex Ying Date: Fri, 15 Jul 2022 09:43:03 -0700 Subject: [PATCH 05/13] Update csrc/cpu/neighbor_sample_cpu.cpp Co-authored-by: Zecheng Zhang --- csrc/cpu/neighbor_sample_cpu.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/cpu/neighbor_sample_cpu.cpp b/csrc/cpu/neighbor_sample_cpu.cpp index cfb57ad1..02f5c5b1 100644 --- a/csrc/cpu/neighbor_sample_cpu.cpp +++ b/csrc/cpu/neighbor_sample_cpu.cpp @@ -306,7 +306,7 @@ hetero_sample(const vector &node_types, src_samples.push_back(v); src_root_time.push_back(dst_time); cols.push_back(i); - rows.push_back(src_samples.size()); + rows.push_back(src_samples.size() - 1); edges.push_back(offset); } else { const auto res = to_local_src_node.insert({v, src_samples.size()}); From e0076a60e5e2831d7e82d8837ed121de4836b025 Mon Sep 17 00:00:00 2001 From: Rex Ying Date: Fri, 15 Jul 2022 09:43:10 -0700 Subject: [PATCH 06/13] Update csrc/cpu/neighbor_sample_cpu.cpp Co-authored-by: Zecheng Zhang --- csrc/cpu/neighbor_sample_cpu.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/cpu/neighbor_sample_cpu.cpp b/csrc/cpu/neighbor_sample_cpu.cpp index 02f5c5b1..c91fae43 100644 --- a/csrc/cpu/neighbor_sample_cpu.cpp +++ b/csrc/cpu/neighbor_sample_cpu.cpp @@ -272,7 +272,7 @@ hetero_sample(const vector &node_types, src_samples.push_back(v); src_root_time.push_back(dst_time); cols.push_back(i); - rows.push_back(src_samples.size()); + rows.push_back(src_samples.size() - 1); edges.push_back(offset); } else { const auto res = to_local_src_node.insert({v, src_samples.size()}); From 8f203eb8589a064339f9891f2c581da62ebc8afe Mon Sep 17 00:00:00 2001 From: Rex Ying Date: Fri, 15 Jul 2022 09:43:16 -0700 Subject: [PATCH 07/13] Update csrc/cpu/neighbor_sample_cpu.cpp Co-authored-by: Zecheng Zhang --- csrc/cpu/neighbor_sample_cpu.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/cpu/neighbor_sample_cpu.cpp b/csrc/cpu/neighbor_sample_cpu.cpp index c91fae43..1f8cad33 100644 --- a/csrc/cpu/neighbor_sample_cpu.cpp +++ b/csrc/cpu/neighbor_sample_cpu.cpp @@ -243,7 +243,7 @@ hetero_sample(const vector &node_types, src_samples.push_back(v); src_root_time.push_back(dst_time); cols.push_back(i); - rows.push_back(src_samples.size()); + rows.push_back(src_samples.size() - 1); edges.push_back(offset); } else { const auto res = to_local_src_node.insert({v, src_samples.size()}); From e98496f50749e4718214536ca97577a755c62ce5 Mon Sep 17 00:00:00 2001 From: RexYing Date: Fri, 15 Jul 2022 16:52:27 +0000 Subject: [PATCH 08/13] comments on directed to be true --- csrc/cpu/neighbor_sample_cpu.cpp | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/csrc/cpu/neighbor_sample_cpu.cpp b/csrc/cpu/neighbor_sample_cpu.cpp index 7e1515e0..177efff3 100644 --- a/csrc/cpu/neighbor_sample_cpu.cpp +++ b/csrc/cpu/neighbor_sample_cpu.cpp @@ -433,7 +433,12 @@ hetero_temporal_neighbor_sample_cpu( const int64_t num_hops, const bool replace) { if (replace) { - // directed = True for temporal sampling + // We assume that directed = True for temporal sampling + // The current implementatio uses disjoint computation tree + // to tackle the case of the same node sampled having different + // root time constraint. + // In future, we could extend to directed = False case, + // allowing additional edges within each computation tree. return hetero_sample( node_types, edge_types, colptr_dict, row_dict, input_node_dict, num_neighbors_dict, node_time_dict, num_hops); From 28a577ed740d082e10958d57748f9e32b4bb66f4 Mon Sep 17 00:00:00 2001 From: RexYing Date: Sat, 16 Jul 2022 06:31:56 +0000 Subject: [PATCH 09/13] add directed in API --- csrc/cpu/neighbor_sample_cpu.cpp | 5 +++-- csrc/cpu/neighbor_sample_cpu.h | 2 +- csrc/neighbor_sample.cpp | 2 +- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/csrc/cpu/neighbor_sample_cpu.cpp b/csrc/cpu/neighbor_sample_cpu.cpp index 177efff3..2b74d667 100644 --- a/csrc/cpu/neighbor_sample_cpu.cpp +++ b/csrc/cpu/neighbor_sample_cpu.cpp @@ -430,8 +430,9 @@ hetero_temporal_neighbor_sample_cpu( const c10::Dict &input_node_dict, const c10::Dict> &num_neighbors_dict, const c10::Dict &node_time_dict, - const int64_t num_hops, const bool replace) { - + const int64_t num_hops, const bool replace, const bool directed) { + AT_ASSERTM(directed, + "Currently, directed must be true for temporal sampling") if (replace) { // We assume that directed = True for temporal sampling // The current implementatio uses disjoint computation tree diff --git a/csrc/cpu/neighbor_sample_cpu.h b/csrc/cpu/neighbor_sample_cpu.h index 04a9f744..1759d456 100644 --- a/csrc/cpu/neighbor_sample_cpu.h +++ b/csrc/cpu/neighbor_sample_cpu.h @@ -33,4 +33,4 @@ hetero_temporal_neighbor_sample_cpu( const c10::Dict &input_node_dict, const c10::Dict> &num_neighbors_dict, const c10::Dict &node_time_dict, - const int64_t num_hops, const bool replace); + const int64_t num_hops, const bool replace, const bool directed); diff --git a/csrc/neighbor_sample.cpp b/csrc/neighbor_sample.cpp index a322d171..9883de32 100644 --- a/csrc/neighbor_sample.cpp +++ b/csrc/neighbor_sample.cpp @@ -52,7 +52,7 @@ hetero_temporal_neighbor_sample( const c10::Dict &input_node_dict, const c10::Dict> &num_neighbors_dict, const c10::Dict &node_time_dict, - const int64_t num_hops, const bool replace) { + const int64_t num_hops, const bool replace, const bool directed) { return hetero_temporal_neighbor_sample_cpu( node_types, edge_types, colptr_dict, row_dict, input_node_dict, num_neighbors_dict, node_time_dict, num_hops, replace); From 61e3167c73dcf2c4502484c2704a9ef2b7d441f3 Mon Sep 17 00:00:00 2001 From: RexYing Date: Sat, 16 Jul 2022 06:33:10 +0000 Subject: [PATCH 10/13] comments --- csrc/cpu/neighbor_sample_cpu.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/cpu/neighbor_sample_cpu.cpp b/csrc/cpu/neighbor_sample_cpu.cpp index 2b74d667..f01bad11 100644 --- a/csrc/cpu/neighbor_sample_cpu.cpp +++ b/csrc/cpu/neighbor_sample_cpu.cpp @@ -435,7 +435,7 @@ hetero_temporal_neighbor_sample_cpu( "Currently, directed must be true for temporal sampling") if (replace) { // We assume that directed = True for temporal sampling - // The current implementatio uses disjoint computation tree + // The current implementation uses disjoint computation trees // to tackle the case of the same node sampled having different // root time constraint. // In future, we could extend to directed = False case, From 63b16e0d0d874538fae7b29050fa68c0b519f25d Mon Sep 17 00:00:00 2001 From: RexYing Date: Sat, 16 Jul 2022 06:39:44 +0000 Subject: [PATCH 11/13] minor function signature fix --- csrc/neighbor_sample.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/neighbor_sample.cpp b/csrc/neighbor_sample.cpp index 9883de32..d0f9e056 100644 --- a/csrc/neighbor_sample.cpp +++ b/csrc/neighbor_sample.cpp @@ -55,7 +55,7 @@ hetero_temporal_neighbor_sample( const int64_t num_hops, const bool replace, const bool directed) { return hetero_temporal_neighbor_sample_cpu( node_types, edge_types, colptr_dict, row_dict, input_node_dict, - num_neighbors_dict, node_time_dict, num_hops, replace); + num_neighbors_dict, node_time_dict, num_hops, replace, directed); } static auto registry = From efb97769988f27869b014cd0a4aded24d417d750 Mon Sep 17 00:00:00 2001 From: Matthias Fey Date: Sat, 16 Jul 2022 08:39:49 +0200 Subject: [PATCH 12/13] Update csrc/cpu/neighbor_sample_cpu.cpp --- csrc/cpu/neighbor_sample_cpu.cpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/csrc/cpu/neighbor_sample_cpu.cpp b/csrc/cpu/neighbor_sample_cpu.cpp index f01bad11..1b908bc0 100644 --- a/csrc/cpu/neighbor_sample_cpu.cpp +++ b/csrc/cpu/neighbor_sample_cpu.cpp @@ -431,8 +431,7 @@ hetero_temporal_neighbor_sample_cpu( const c10::Dict> &num_neighbors_dict, const c10::Dict &node_time_dict, const int64_t num_hops, const bool replace, const bool directed) { - AT_ASSERTM(directed, - "Currently, directed must be true for temporal sampling") + AT_ASSERTM(directed, "Temporal sampling requires 'directed' sampling") if (replace) { // We assume that directed = True for temporal sampling // The current implementation uses disjoint computation trees From 9595ea3ccbec7b01771e53b3e10e076084a206a8 Mon Sep 17 00:00:00 2001 From: Matthias Fey Date: Sat, 16 Jul 2022 08:39:53 +0200 Subject: [PATCH 13/13] Update csrc/neighbor_sample.cpp --- csrc/neighbor_sample.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/neighbor_sample.cpp b/csrc/neighbor_sample.cpp index 9883de32..d0f9e056 100644 --- a/csrc/neighbor_sample.cpp +++ b/csrc/neighbor_sample.cpp @@ -55,7 +55,7 @@ hetero_temporal_neighbor_sample( const int64_t num_hops, const bool replace, const bool directed) { return hetero_temporal_neighbor_sample_cpu( node_types, edge_types, colptr_dict, row_dict, input_node_dict, - num_neighbors_dict, node_time_dict, num_hops, replace); + num_neighbors_dict, node_time_dict, num_hops, replace, directed); } static auto registry =