-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathbfs_sampler.hpp
144 lines (112 loc) · 4.02 KB
/
bfs_sampler.hpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
#pragma once
// Code in this header is responsible of handling progressive
// sampling, so to answer queries on connection probabilities.
// In particular, this implements limited-depth sampling
#include "logging.hpp"
#include "prelude.hpp"
#include "rand.hpp"
#include "types.hpp"
#include "counts_cache.hpp"
template <typename T> class FixedCapacityQueue {
public:
FixedCapacityQueue(size_t capacity)
: m_capacity(capacity), m_begin(0), m_end(0),
m_storage(std::vector<T>(capacity)) {}
bool empty() const { return m_begin == m_end; }
void clear() {
m_begin = 0;
m_end = 0;
}
void push(T elem) {
size_t new_end = (m_end + 1) % m_capacity;
if (new_end == m_begin) {
throw std::logic_error("Queue capacity exceeded");
}
m_storage[m_end] = elem;
m_end = new_end;
}
T pop() {
assert(!empty());
T elem = m_storage[m_begin];
m_begin = (m_begin + 1) % m_capacity;
return elem;
}
private:
size_t m_capacity;
size_t m_begin;
size_t m_end;
std::vector<T> m_storage;
};
struct BfsSamplerThreadState {
typedef std::vector<size_t> distance_vector_t;
typedef std::vector<bool> edge_sample_t;
BfsSamplerThreadState(const ugraph_t &graph, Xorshift1024star randgen)
: queue(FixedCapacityQueue<ugraph_vertex_t>(boost::num_vertices(graph))),
connection_counts(std::vector<size_t>(boost::num_vertices(graph))),
distance_vector(std::vector<size_t>(boost::num_vertices(graph))),
rnd(randgen){};
FixedCapacityQueue<ugraph_vertex_t> queue;
std::vector<size_t> connection_counts;
std::vector<size_t> distance_vector;
// Random generators, one for each thread
Xorshift1024star rnd;
};
namespace std {
std::ostream &operator<<(std::ostream &os, BfsSamplerThreadState &tstate);
}
class BfsSampler {
public:
typedef std::vector<bool> edge_sample_t;
BfsSampler(const ugraph_t &graph,
const size_t max_depth,
std::function<size_t(double)> prob_to_samples, uint64_t seed,
size_t num_threads)
: prob_to_samples(prob_to_samples),
m_max_dist(max_depth),
m_samples(std::vector<edge_sample_t>()),
m_thread_states(std::vector<BfsSamplerThreadState>()),
m_used_samples(0)
{
Xorshift1024star rnd(seed);
for (size_t i = 0; i < num_threads; ++i) {
rnd.jump();
m_thread_states.emplace_back(graph, rnd);
}
}
void min_probability(const ugraph_t &graph, probability_t prob);
// Add samples, if needed
void sample_size(const ugraph_t &graph, size_t total_samples);
void log_states() {
for (auto &tstate : m_thread_states) {
LOG_INFO(tstate);
}
}
size_t connection_probabilities(const ugraph_t &graph,
const ugraph_vertex_t from,
std::vector<probability_t> &probabilities);
size_t connection_probabilities(const ugraph_t & graph,
const ugraph_vertex_t from,
const std::vector< ugraph_vertex_t > & targets,
std::vector< probability_t > & probabilities);
size_t connection_probabilities_cache(const ugraph_t & graph,
const ugraph_vertex_t from,
ConnectionCountsCache & cccache,
std::vector< probability_t > & probabilities){
// FIXME specialize
return connection_probabilities(graph, from, probabilities);
}
/** The probability that a given set of nodes is connected */
probability_t
connection_probability(const ugraph_t &graph,
const std::vector<ugraph_vertex_t> &vertices) {
throw std::logic_error("Not implemented");
}
private:
std::function<size_t(double)> prob_to_samples;
size_t m_max_dist;
std::vector<edge_sample_t> m_samples;
std::vector<BfsSamplerThreadState> m_thread_states;
// The minimum connection probability that is estimate reliably
probability_t m_min_probability = 1.0;
size_t m_used_samples;
};