diff --git a/knn/knn.go b/knn/knn.go index e16ed74e..80bc9719 100644 --- a/knn/knn.go +++ b/knn/knn.go @@ -44,27 +44,21 @@ func (KNN *KNNClassifier) PredictOne(vector []float64) string { convertedVector := util.FloatsToMatrix(vector) // Check what distance function we are using + var distanceFunc pairwiseMetrics.PairwiseDistanceFunc switch KNN.DistanceFunc { case "euclidean": - { - euclidean := pairwiseMetrics.NewEuclidean() - for i := 0; i < rows; i++ { - row := KNN.TrainingData.GetRowVectorWithoutClass(i) - rowMat := util.FloatsToMatrix(row) - distance := euclidean.Distance(rowMat, convertedVector) - rownumbers[i] = distance - } - } + distanceFunc = pairwiseMetrics.NewEuclidean() case "manhattan": - { - manhattan := pairwiseMetrics.NewEuclidean() - for i := 0; i < rows; i++ { - row := KNN.TrainingData.GetRowVectorWithoutClass(i) - rowMat := util.FloatsToMatrix(row) - distance := manhattan.Distance(rowMat, convertedVector) - rownumbers[i] = distance - } - } + distanceFunc = pairwiseMetrics.NewManhattan() + default: + panic("unsupported distance function") + } + + for i := 0; i < rows; i++ { + row := KNN.TrainingData.GetRowVectorWithoutClass(i) + rowMat := util.FloatsToMatrix(row) + distance := distanceFunc.Distance(rowMat, convertedVector) + rownumbers[i] = distance } sorted := util.SortIntMap(rownumbers) @@ -125,27 +119,21 @@ func (KNN *KNNRegressor) Predict(vector *mat64.Dense, K int) float64 { labels := make([]float64, 0) // Check what distance function we are using + var distanceFunc pairwiseMetrics.PairwiseDistanceFunc switch KNN.DistanceFunc { case "euclidean": - { - euclidean := pairwiseMetrics.NewEuclidean() - for i := 0; i < rows; i++ { - row := KNN.Data.RowView(i) - rowMat := util.FloatsToMatrix(row) - distance := euclidean.Distance(rowMat, vector) - rownumbers[i] = distance - } - } + distanceFunc = pairwiseMetrics.NewEuclidean() case "manhattan": - { - manhattan := pairwiseMetrics.NewEuclidean() - for i := 0; i < rows; i++ { - row := KNN.Data.RowView(i) - rowMat := util.FloatsToMatrix(row) - distance := manhattan.Distance(rowMat, vector) - rownumbers[i] = distance - } - } + distanceFunc = pairwiseMetrics.NewManhattan() + default: + panic("unsupported distance function") + } + + for i := 0; i < rows; i++ { + row := KNN.Data.RowView(i) + rowMat := util.FloatsToMatrix(row) + distance := distanceFunc.Distance(rowMat, vector) + rownumbers[i] = distance } sorted := util.SortIntMap(rownumbers) diff --git a/metrics/pairwise/pairwise.go b/metrics/pairwise/pairwise.go index 6c6a5dc1..3a88e679 100644 --- a/metrics/pairwise/pairwise.go +++ b/metrics/pairwise/pairwise.go @@ -1,2 +1,10 @@ // Package pairwise implements utilities to evaluate pairwise distances or inner product (via kernel). package pairwise + +import ( + "github.com/gonum/matrix/mat64" +) + +type PairwiseDistanceFunc interface { + Distance(vectorX *mat64.Dense, vectorY *mat64.Dense) float64 +}