-
Notifications
You must be signed in to change notification settings - Fork 488
/
AdaBoostClassifier.java
220 lines (197 loc) · 6.8 KB
/
AdaBoostClassifier.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
package func;
import dist.*;
import dist.Distribution;
import dist.DiscreteDistribution;
import shared.DataSet;
import shared.DataSetDescription;
import shared.Instance;
/**
* A class for constructing a ensemble of classifiers
* @author Andrew Guillory gtg008g@mail.gatech.edu
* @version 1.0
*/
public class AdaBoostClassifier extends AbstractConditionalDistribution implements FunctionApproximater {
/**
* The supplier of classifiers to use
*/
private final FunctionApproximaterSupplier classifierSupplier;
/**
* The classifiers themselves
*/
private FunctionApproximater[] classifiers;
/**
* The weights for each of the classifiers
*/
private double[] weights;
/**
* The range of the class
*/
private int classRange;
/**
* The size of the ensemble
*/
private int size;
/**
* Create a new ensemble given a {@link FunctionApproximaterSupplier} of classifiers.
* @param size the size of the ensemble
* @param classifierSupplier the {@link FunctionApproximaterSupplier} that supplies classifiers to use
*/
public AdaBoostClassifier(int size, FunctionApproximaterSupplier classifierSupplier) {
this.size = size;
this.classifierSupplier = classifierSupplier;
}
/**
* Create a new ensemble with a given classifier type
* @param size the size of the ensemble
* @param classifier the classifier class to use
*/
public AdaBoostClassifier(int size, final Class classifier) {
this(size, new FunctionApproximaterSupplier() {
@Override public FunctionApproximater get() {
try {
return (FunctionApproximater)
classifier.getConstructor(new Class[0]).newInstance(new Object[0]);
} catch (Exception e) {
throw new UnsupportedOperationException("Could not create " + classifier);
}
}
});
}
/**
* Create a new decision stump ensemble
* @param size the size of the ensemble
*/
public AdaBoostClassifier(int size) {
this(size, DecisionStumpClassifier.class);
}
/**
* Create a new default ensemble
*/
public AdaBoostClassifier() {
this(100);
}
/**
* Build the ensemble
* @param instances the instances to train with
*/
public void estimate(DataSet instances) {
classifiers = new FunctionApproximater[size];
weights = new double[size];
// initialize the weights of the instances
for (int i = 0; i < instances.size(); i++) {
instances.get(i).setWeight(1.0 / instances.size());
}
// getting some info
if (instances.getDescription() == null) {
DataSetDescription desc = new DataSetDescription();
desc.induceFrom(instances);
instances.setDescription(desc);
}
classRange = instances.getDescription().getLabelDescription().getDiscreteRange();
for (int i = 0; i < classifiers.length; i++) {
// make a new classifier
classifiers[i] = classifierSupplier.get();
classifiers[i].estimate(instances);
// find the error for the newest classifier
double error = 0;
for (int j = 0; j < instances.size(); j++) {
if (classifiers[i].value(instances.get(j)).getDiscrete()
!= instances.get(j).getLabel().getDiscrete()) {
error += instances.get(j).getWeight();
}
}
double beta = error / (1 - error);
// set the weight of the classifier
weights[i] = Math.log(1 / beta);
// the classifier didn't do any good
if (error == .5) {
classifiers[i] = null;
break;
} else if (error == 0) {
break;
}
// readjust the weights of the instances
// and calculate the sum of the weights
double weightSum = 0;
for (int j = 0; j < instances.size(); j++) {
if (classifiers[i].value(instances.get(j)).getDiscrete()
== instances.get(j).getLabel().getDiscrete()) {
instances.get(j).setWeight(instances.get(j).getWeight()
* beta);
weightSum += instances.get(j).getWeight();
} else {
weightSum += instances.get(j).getWeight();
}
}
// normalize the weights
for (int j = 0; j < instances.size(); j++) {
instances.get(j).setWeight(instances.get(j).getWeight() / weightSum);
}
}
}
/**
* Get the classification for an instances
* @param data the data to classify
* @return the class distribution
*/
public Instance value(Instance data) {
double[] votes = new double[classRange];
for (int i = 0; i < classifiers.length && classifiers[i] != null; i++) {
votes[classifiers[i].value(data).getDiscrete()] +=
weights[i];
}
int classification = 0;
for (int i = 1; i < votes.length; i++) {
if (votes[i] > votes[classification]) {
classification = i;
}
}
return new Instance(classification);
}
/**
* @see func.Classifier#classDistribution(shared.Instance)
*/
public Distribution distributionFor(Instance data) {
Instance v = value(data);
double[] p = new double[classRange];
p[v.getDiscrete()] = 1;
return new DiscreteDistribution(p);
}
/**
* Get the size of the ensemble
*/
public int getSize() {
return size;
}
/**
* Set the size of the ensemble
*/
public void setSize(int i) {
size = i;
}
/**
* Get the classifiers
* @return the classfiers
*/
public FunctionApproximater[] getClassifiers() {
return classifiers;
}
/**
* Get the weights of the classifiers
* @return the weights of classifiers
*/
public double[] getWeights() {
return weights;
}
/**
* @see java.lang.Object#toString()
*/
public String toString() {
String ret = "";
for (int i = 0; i < classifiers.length && classifiers[i] != null; i++) {
ret += "weight " + weights[i] + "\n";
ret += classifiers[i] + "\n\n";
}
return ret;
}
}