-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathExample8.java
101 lines (74 loc) · 3.36 KB
/
Example8.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
package applications.ml;
import datastructs.maths.DenseMatrixSet;
import datastructs.maths.RowBuilder;
import datastructs.maths.Vector;
import datastructs.utils.RowType;
import maths.functions.distances.EuclideanVectorCalculator;
import ml.classifiers.ThreadedKNNClassifier;
import parallel.partitioners.MatrixRowPartitionPolicy;
import parallel.partitioners.RangePartitioner;
import tech.tablesaw.api.Table;
import tech.tablesaw.columns.Column;
import utils.ClassificationVoter;
import utils.Pair;
import utils.PairBuilder;
import utils.TableDataSetLoader;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.ExecutorService;
import static java.util.concurrent.Executors.newFixedThreadPool;
/** Category: Machine Learning
* ID: Example8
* Description: Classification with vanilla ParallelKNN algorithm
* Taken From:
* Details:
* TODO
*/
public class Example8 {
public static Pair<DenseMatrixSet<Double>, List<Integer>> createDataSet() throws IOException, IllegalArgumentException {
// load the data
Table dataSetTable = TableDataSetLoader.loadDataSet(new File("src/main/resources/datasets/iris_data.csv"));
List<Integer> labels = new ArrayList<>();
Column species = dataSetTable.column("species");
for (int i = 0; i < species.size(); i++) {
String label = (String) species.get(i);
if(label.equals("Iris-setosa")){
labels.add(0);
}
else if(label.equals("Iris-versicolor")){
labels.add(1);
}
else if(label.equals("Iris-virginica")){
labels.add(2);
}
else{
throw new IllegalArgumentException("Unknown class");
}
}
Table reducedDataSet = dataSetTable.removeColumns("species").first(dataSetTable.rowCount());
DenseMatrixSet<Double> dataSet = new DenseMatrixSet(RowType.Type.DOUBLE_VECTOR, new RowBuilder());
dataSet.initializeFrom(reducedDataSet);
// partition the data set
List<List<Integer>> partitions = RangePartitioner.partition(0, dataSet.m(), 4);
MatrixRowPartitionPolicy partitionPolicy = new MatrixRowPartitionPolicy(partitions);
dataSet.setPartitionPolicy(partitionPolicy);
return PairBuilder.makePair(dataSet, labels);
}
public static void main(String[] args) throws IOException, IllegalArgumentException{
Pair<DenseMatrixSet<Double>, List<Integer>> data = Example8.createDataSet();
ExecutorService executorService = newFixedThreadPool(4);
System.out.println("Number of rows: "+data.first.m());
System.out.println("Number of labels: "+data.second.size());
ThreadedKNNClassifier<Double, DenseMatrixSet<Double>, EuclideanVectorCalculator<Double>,
ClassificationVoter> classifier = new ThreadedKNNClassifier<>(3, false, executorService);
classifier.setDistanceCalculator(new EuclideanVectorCalculator<Double>());
classifier.setMajorityVoter(new ClassificationVoter());
classifier.train(data.first, data.second);
Vector point = new Vector(5.9,3.0,5.1,1.8);
Integer classIdx = classifier.predict(point);
System.out.println("Point "+ point +" has class index "+ classIdx);
executorService.shutdown();
}
}