Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 52 additions & 35 deletions csrc/cpu/neighbor_sample_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -238,17 +238,24 @@ hetero_sample(const vector<node_t> &node_types,
if (temporal) {
if (!satisfy_time(node_time_dict, src_node_type, dst_time, v))
continue;
}
const auto res = to_local_src_node.insert({v, src_samples.size()});
if (res.second) {
// force disjoint of computation tree
// note that the sampling always needs to have directed=True
// for temporal case
// to_local_src_node is not used for temporal / directed case
src_samples.push_back(v);
if (temporal)
src_root_time.push_back(dst_time);
}
if (directed) {
src_root_time.push_back(dst_time);
cols.push_back(i);
rows.push_back(res.first->second);
rows.push_back(src_samples.size() - 1);
edges.push_back(offset);
} else {
const auto res = to_local_src_node.insert({v, src_samples.size()});
if (res.second)
src_samples.push_back(v);
if (directed) {
cols.push_back(i);
rows.push_back(res.first->second);
edges.push_back(offset);
}
}
}
} else if (replace) {
Expand All @@ -261,17 +268,23 @@ hetero_sample(const vector<node_t> &node_types,
// TODO Infinity loop if no neighbor satisfies time constraint:
if (!satisfy_time(node_time_dict, src_node_type, dst_time, v))
continue;
}
const auto res = to_local_src_node.insert({v, src_samples.size()});
if (res.second) {
// force disjoint of computation tree
// note that the sampling always needs to have directed=True
// for temporal case
src_samples.push_back(v);
if (temporal)
src_root_time.push_back(dst_time);
}
if (directed) {
src_root_time.push_back(dst_time);
cols.push_back(i);
rows.push_back(res.first->second);
rows.push_back(src_samples.size() - 1);
edges.push_back(offset);
} else {
const auto res = to_local_src_node.insert({v, src_samples.size()});
if (res.second)
src_samples.push_back(v);
if (directed) {
cols.push_back(i);
rows.push_back(res.first->second);
edges.push_back(offset);
}
}
num_neighbors += 1;
}
Expand All @@ -289,17 +302,23 @@ hetero_sample(const vector<node_t> &node_types,
if (temporal) {
if (!satisfy_time(node_time_dict, src_node_type, dst_time, v))
continue;
}
const auto res = to_local_src_node.insert({v, src_samples.size()});
if (res.second) {
// force disjoint of computation tree
// note that the sampling always needs to have directed=True
// for temporal case
src_samples.push_back(v);
if (temporal)
src_root_time.push_back(dst_time);
}
if (directed) {
src_root_time.push_back(dst_time);
cols.push_back(i);
rows.push_back(res.first->second);
rows.push_back(src_samples.size() - 1);
edges.push_back(offset);
} else {
const auto res = to_local_src_node.insert({v, src_samples.size()});
if (res.second)
src_samples.push_back(v);
if (directed) {
cols.push_back(i);
rows.push_back(res.first->second);
edges.push_back(offset);
}
}
}
}
Expand Down Expand Up @@ -412,21 +431,19 @@ hetero_temporal_neighbor_sample_cpu(
const c10::Dict<rel_t, vector<int64_t>> &num_neighbors_dict,
const c10::Dict<node_t, torch::Tensor> &node_time_dict,
const int64_t num_hops, const bool replace, const bool directed) {

if (replace && directed) {
AT_ASSERTM(directed, "Temporal sampling requires 'directed' sampling")
if (replace) {
// We assume that directed = True for temporal sampling
// The current implementation uses disjoint computation trees
// to tackle the case of the same node sampled having different
// root time constraint.
// In future, we could extend to directed = False case,
// allowing additional edges within each computation tree.
return hetero_sample<true, true, true>(
node_types, edge_types, colptr_dict, row_dict, input_node_dict,
num_neighbors_dict, node_time_dict, num_hops);
} else if (replace && !directed) {
return hetero_sample<true, false, true>(
node_types, edge_types, colptr_dict, row_dict, input_node_dict,
num_neighbors_dict, node_time_dict, num_hops);
} else if (!replace && directed) {
return hetero_sample<false, true, true>(
node_types, edge_types, colptr_dict, row_dict, input_node_dict,
num_neighbors_dict, node_time_dict, num_hops);
} else {
return hetero_sample<false, false, true>(
return hetero_sample<false, true, true>(
node_types, edge_types, colptr_dict, row_dict, input_node_dict,
num_neighbors_dict, node_time_dict, num_hops);
}
Expand Down