-
Notifications
You must be signed in to change notification settings - Fork 140
/
sampler_cpu.cpp
46 lines (39 loc) · 1.57 KB
/
sampler_cpu.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
#include "sampler_cpu.h"
#include "utils.h"
torch::Tensor neighbor_sampler_cpu(torch::Tensor start, torch::Tensor rowptr,
int64_t count, double factor) {
auto start_data = start.data_ptr<int64_t>();
auto rowptr_data = rowptr.data_ptr<int64_t>();
std::vector<int64_t> e_ids;
for (auto i = 0; i < start.size(0); i++) {
auto row_start = rowptr_data[start_data[i]];
auto row_end = rowptr_data[start_data[i] + 1];
auto num_neighbors = row_end - row_start;
int64_t size = count;
if (count < 1)
size = int64_t(ceil(factor * float(num_neighbors)));
if (size > num_neighbors)
size = num_neighbors;
// If the number of neighbors is approximately equal to the number of
// neighbors which are requested, we use `randperm` to sample without
// replacement, otherwise we sample random numbers into a set as long
// as necessary.
std::unordered_set<int64_t> set;
if (size < 0.7 * float(num_neighbors)) {
while (int64_t(set.size()) < size) {
int64_t sample = rand() % num_neighbors;
set.insert(sample + row_start);
}
std::vector<int64_t> v(set.begin(), set.end());
e_ids.insert(e_ids.end(), v.begin(), v.end());
} else {
auto sample = torch::randperm(num_neighbors, start.options());
auto sample_data = sample.data_ptr<int64_t>();
for (auto j = 0; j < size; j++) {
e_ids.push_back(sample_data[j] + row_start);
}
}
}
int64_t length = e_ids.size();
return torch::from_blob(e_ids.data(), {length}, start.options()).clone();
}