Skip to content

Loading…

Knn refactor #54

Merged
merged 3 commits into from

3 participants

@njern
Collaborator

I noticed that whether the user picked "euclidean" or "manhattan", both switch statement were still calling - pairwiseMetrics.NewEuclidean()

Fixed the bug and refactored the code to make it a bit shorter at the same time.

@sjwhitworth
Owner

Thanks! Looks good. However, Travis doesn't seem to like the build for some reason - can you investigate?

@njern
Collaborator

I need to set up a pre-commit hook to run go test. I'd gone and done "just a small change" between the time I ran the tests and when I committed and of course it broke things :)

@lazywei
Collaborator

Looks good to me!
LGTM

@sjwhitworth sjwhitworth merged commit e6c28ef into sjwhitworth:master

1 check passed

Details continuous-integration/travis-ci The Travis CI build passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Commits on Jul 20, 2014
  1. @njern
  2. @njern

    When the user picks manhattan distance, we should actually use the Ma…

    njern committed
    …nhattan distance function. Also slight refactor to make the code more DRY.
  3. @njern
Showing with 32 additions and 36 deletions.
  1. +24 −36 knn/knn.go
  2. +8 −0 metrics/pairwise/pairwise.go
View
60 knn/knn.go
@@ -44,27 +44,21 @@ func (KNN *KNNClassifier) PredictOne(vector []float64) string {
convertedVector := util.FloatsToMatrix(vector)
// Check what distance function we are using
+ var distanceFunc pairwiseMetrics.PairwiseDistanceFunc
switch KNN.DistanceFunc {
case "euclidean":
- {
- euclidean := pairwiseMetrics.NewEuclidean()
- for i := 0; i < rows; i++ {
- row := KNN.TrainingData.GetRowVectorWithoutClass(i)
- rowMat := util.FloatsToMatrix(row)
- distance := euclidean.Distance(rowMat, convertedVector)
- rownumbers[i] = distance
- }
- }
+ distanceFunc = pairwiseMetrics.NewEuclidean()
case "manhattan":
- {
- manhattan := pairwiseMetrics.NewEuclidean()
- for i := 0; i < rows; i++ {
- row := KNN.TrainingData.GetRowVectorWithoutClass(i)
- rowMat := util.FloatsToMatrix(row)
- distance := manhattan.Distance(rowMat, convertedVector)
- rownumbers[i] = distance
- }
- }
+ distanceFunc = pairwiseMetrics.NewManhattan()
+ default:
+ panic("unsupported distance function")
+ }
+
+ for i := 0; i < rows; i++ {
+ row := KNN.TrainingData.GetRowVectorWithoutClass(i)
+ rowMat := util.FloatsToMatrix(row)
+ distance := distanceFunc.Distance(rowMat, convertedVector)
+ rownumbers[i] = distance
}
sorted := util.SortIntMap(rownumbers)
@@ -125,27 +119,21 @@ func (KNN *KNNRegressor) Predict(vector *mat64.Dense, K int) float64 {
labels := make([]float64, 0)
// Check what distance function we are using
+ var distanceFunc pairwiseMetrics.PairwiseDistanceFunc
switch KNN.DistanceFunc {
case "euclidean":
- {
- euclidean := pairwiseMetrics.NewEuclidean()
- for i := 0; i < rows; i++ {
- row := KNN.Data.RowView(i)
- rowMat := util.FloatsToMatrix(row)
- distance := euclidean.Distance(rowMat, vector)
- rownumbers[i] = distance
- }
- }
+ distanceFunc = pairwiseMetrics.NewEuclidean()
case "manhattan":
- {
- manhattan := pairwiseMetrics.NewEuclidean()
- for i := 0; i < rows; i++ {
- row := KNN.Data.RowView(i)
- rowMat := util.FloatsToMatrix(row)
- distance := manhattan.Distance(rowMat, vector)
- rownumbers[i] = distance
- }
- }
+ distanceFunc = pairwiseMetrics.NewManhattan()
+ default:
+ panic("unsupported distance function")
+ }
+
+ for i := 0; i < rows; i++ {
+ row := KNN.Data.RowView(i)
+ rowMat := util.FloatsToMatrix(row)
+ distance := distanceFunc.Distance(rowMat, vector)
+ rownumbers[i] = distance
}
sorted := util.SortIntMap(rownumbers)
View
8 metrics/pairwise/pairwise.go
@@ -1,2 +1,10 @@
// Package pairwise implements utilities to evaluate pairwise distances or inner product (via kernel).
package pairwise
+
+import (
+ "github.com/gonum/matrix/mat64"
+)
+
+type PairwiseDistanceFunc interface {
+ Distance(vectorX *mat64.Dense, vectorY *mat64.Dense) float64
+}
Something went wrong with that request. Please try again.