-
Notifications
You must be signed in to change notification settings - Fork 0
/
knn.cpp
84 lines (65 loc) · 1.97 KB
/
knn.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
#include "knn.h"
#include "utils.h"
Knn::Knn(int n_neighbours){
k = n_neighbours;
}
int Knn::test(int inp){
return inp+1;
}
void Knn::fit(std::vector<std::vector<double> > samples,std::vector<int> labels){
//TO DO:check data first
X = samples;
y = labels;
d = samples[0].size();
}
int Knn::predict(std::vector<double> sample){
if(sample.size()!=d){
throw "dimensionality size not match";
}
/*
In a multimap, the key values are generally used to sort
and uniquely identify the elements, while the mapped values
store the content associated to this key.
*/
std::multimap<double, int> disMap;//disMap<distance,index>
for(int i=0;i<X.size();i++){
double distance = eucDist(sample,X[i]);
disMap.insert(std::pair<double,int>(distance,i));
}
std::multimap<double,int>::iterator iter = disMap.begin();
std::map<int,int> votes;//votes<label,cnts>
for(int i=0;i<k;i++,iter++){
votes[y[iter->second]]++;
}
int major_label = 0;
for(std::map<int,int>::iterator iter=votes.begin();iter!=votes.end();iter++){
if(iter->second > votes[major_label]){
major_label = iter->first;
}
}
return major_label;
}
std::vector<double> Knn::predict_prob(std::vector<double> sample){
if(sample.size()!=d){
throw "dimensionality size not match";
}
std::multimap<double, int> disMap;//disMap<distance,index>
for(int i=0;i<X.size();i++){
double distance = eucDist(sample,X[i]);
disMap.insert(std::pair<double,int>(distance,i));
}
std::multimap<double,int>::iterator iter = disMap.begin();
std::map<int,int> votes;//votes<label,cnts>
for(int i=0;i<k;i++,iter++){
votes[y[iter->second]]++;
}
std::multimap<int,double> labelMap;
for(std::map<int,int>::iterator iter=votes.begin();iter!=votes.end();iter++){
labelMap.insert(std::pair<int,double>(iter->first,iter->second/double(k)));
}
std::vector<double> probs;
for(std::multimap<int,double>::iterator iter=labelMap.begin();iter!=labelMap.end();iter++){
probs.push_back(iter->second);
}
return probs;
}