1+ /* *
2+ * Example code using sampling to find KNN.
3+ *
4+ */
5+
6+ #include < algorithm>
7+ #include < fstream>
8+ #include < iostream>
9+ #include < numeric>
10+ #include < queue>
11+ #include " io.h"
12+
13+ using std::cout;
14+ using std::endl;
15+ using std::string;
16+ using std::vector;
17+
18+
19+ float compare_with_id (const std::vector<float >& a, const std::vector<float >& b) {
20+ float sum = 0.0 ;
21+ // Skip the first 2 dimensions
22+ for (size_t i = 2 ; i < a.size (); ++i) {
23+ float diff = a[i] - b[i];
24+ sum += diff * diff;
25+ }
26+ return sum;
27+ }
28+
29+
30+ int main (int argc, char **argv) {
31+ string source_path = " dummy-data.bin" ;
32+ string query_path = " dummy-queries.bin" ;
33+ string knn_save_path = " output.bin" ;
34+
35+ // Also accept other path for source data
36+ if (argc > 1 ) {
37+ source_path = string (argv[1 ]);
38+ }
39+
40+ uint32_t num_data_dimensions = 102 ;
41+ float sample_proportion = 0.001 ;
42+
43+ // Read data points
44+ vector <vector<float >> nodes;
45+ ReadBin (source_path, num_data_dimensions, nodes);
46+ cout<<nodes.size ()<<" \n " ;
47+ // Read queries
48+ uint32_t num_query_dimensions = num_data_dimensions + 2 ;
49+ vector <vector<float >> queries;
50+ ReadBin (query_path, num_query_dimensions, queries);
51+
52+ vector <vector<uint32_t >> knn_results; // for saving knn results
53+
54+ uint32_t n = nodes.size ();
55+ uint32_t d = nodes[0 ].size ();
56+ uint32_t nq = queries.size ();
57+ uint32_t sn = uint32_t (sample_proportion * n);
58+
59+ cout<<" # data points: " << n<<" \n " ;
60+ cout<<" # data point dim: " << d<<" \n " ;
61+ cout<<" # queries: " << nq<<" \n " ;
62+
63+ /* * A basic method to compute the KNN results using sampling **/
64+ const int K = 100 ; // To find 100-NN
65+
66+ for (uint i = 0 ; i < nq; i++){
67+ uint32_t query_type = queries[i][0 ];
68+ int32_t v = queries[i][1 ];
69+ float l = queries[i][2 ];
70+ float r = queries[i][3 ];
71+ vector<float > query_vec;
72+
73+ // first push_back 2 zeros for aligning with dataset
74+ query_vec.push_back (0 );
75+ query_vec.push_back (0 );
76+ for (uint j = 4 ; j < queries[i].size (); j++)
77+ query_vec.push_back (queries[i][j]);
78+
79+ vector<uint32_t > knn; // candidate knn
80+
81+ // Handling 4 types of queries
82+ if (query_type == 0 ){ // only ANN
83+ for (uint32_t j = 0 ; j < sn; j++){
84+ knn.push_back (j);
85+ }
86+ }
87+ else if (query_type == 1 ){ // equal + ANN
88+ for (uint32_t j = 0 ; j < sn; j++){
89+ if (nodes[j][0 ] == v){
90+ knn.push_back (j);
91+ }
92+ }
93+ }
94+ else if (query_type == 2 ){ // range + ANN
95+ for (uint32_t j = 0 ; j < sn; j++){
96+ if (nodes[j][1 ] >= l && nodes[j][1 ] <= r)
97+ knn.push_back (j);
98+ }
99+ }
100+ else if (query_type == 3 ){ // equal + range + ANN
101+ for (uint32_t j = 0 ; j < sn; j++){
102+ if (nodes[j][0 ] == v && nodes[j][1 ] >= l && nodes[j][1 ] <= r)
103+ knn.push_back (j);
104+ }
105+ }
106+
107+ // If the number of knn in the sampled data is less than K, then fill the rest with the last few nodes
108+ if (knn.size () < K){
109+ uint32_t s = 1 ;
110+ while (knn.size () < K) {
111+ knn.push_back (n - s);
112+ s = s + 1 ;
113+ }
114+ }
115+
116+ // build another vec to store the distance between knn[i] and query_vec
117+ vector<float > dists;
118+ dists.resize (knn.size ());
119+ for (uint32_t j = 0 ; j < knn.size (); j++)
120+ dists[j] = compare_with_id (nodes[knn[j]], query_vec);
121+
122+ vector<uint32_t > ids;
123+ ids.resize (knn.size ());
124+ std::iota (ids.begin (), ids.end (), 0 );
125+ // sort ids based on dists
126+ std::sort (ids.begin (), ids.end (), [&](uint32_t a, uint32_t b){
127+ return dists[a] < dists[b];
128+ });
129+ vector<uint32_t > knn_sorted;
130+ knn_sorted.resize (K);
131+ for (uint32_t j = 0 ; j < K; j++){
132+ knn_sorted[j] = knn[ids[j]];
133+ }
134+ knn_results.push_back (knn_sorted);
135+ }
136+
137+ // save the results
138+ SaveKNN (knn_results, knn_save_path);
139+ return 0 ;
140+ }
0 commit comments