Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions csrc/cpu/neighbor_sample_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -242,10 +242,11 @@ hetero_sample(const vector<node_t> &node_types,
// note that the sampling always needs to have directed=True
// for temporal case
// to_local_src_node is not used for temporal / directed case
const int64_t sample_idx = src_samples.size();
src_samples.push_back(v);
src_root_time.push_back(dst_time);
cols.push_back(i);
rows.push_back(src_samples.size() - 1);
rows.push_back(sample_idx);
edges.push_back(offset);
} else {
const auto res = to_local_src_node.insert({v, src_samples.size()});
Expand All @@ -271,10 +272,11 @@ hetero_sample(const vector<node_t> &node_types,
// force disjoint of computation tree
// note that the sampling always needs to have directed=True
// for temporal case
const int64_t sample_idx = src_samples.size();
src_samples.push_back(v);
src_root_time.push_back(dst_time);
cols.push_back(i);
rows.push_back(src_samples.size() - 1);
rows.push_back(sample_idx);
edges.push_back(offset);
} else {
const auto res = to_local_src_node.insert({v, src_samples.size()});
Expand Down Expand Up @@ -305,10 +307,11 @@ hetero_sample(const vector<node_t> &node_types,
// force disjoint of computation tree
// note that the sampling always needs to have directed=True
// for temporal case
const int64_t sample_idx = src_samples.size();
src_samples.push_back(v);
src_root_time.push_back(dst_time);
cols.push_back(i);
rows.push_back(src_samples.size() - 1);
rows.push_back(sample_idx);
edges.push_back(offset);
} else {
const auto res = to_local_src_node.insert({v, src_samples.size()});
Expand Down Expand Up @@ -431,7 +434,7 @@ hetero_temporal_neighbor_sample_cpu(
const c10::Dict<rel_t, vector<int64_t>> &num_neighbors_dict,
const c10::Dict<node_t, torch::Tensor> &node_time_dict,
const int64_t num_hops, const bool replace, const bool directed) {
AT_ASSERTM(directed, "Temporal sampling requires 'directed' sampling")
AT_ASSERTM(directed, "Temporal sampling requires 'directed' sampling");
if (replace) {
// We assume that directed = True for temporal sampling
// The current implementation uses disjoint computation trees
Expand Down