Skip to content
Go to file


Join the chat at

Build Status


k-Nearest Neighbors algorithm (k-NN) implemented on Apache Spark. This uses a hybrid spill tree approach to achieve high accuracy and search efficiency. The simplicity of k-NN and lack of tuning parameters makes k-NN a useful baseline model for many machine learning problems.

How to Use

This package is published using sbt-spark-package and linking information can be found at

k-NN can be used for both classification and regression, which are exposed using the new Spark ML API based on DataFrame. Both models accept a weight column so predictions can be optionally weighted.


//read in raw label and features
val training = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF()

val knn = new KNNClassifier()
  .setTopTreeSize(training.count().toInt / 500)

val knnModel =

val predicted = knnModel.transform(training)


//read in raw label and features
val training = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF()

val knn = new KNNRegression()
  .setTopTreeSize(training.count().toInt / 500)

val knnModel =

val predicted = knnModel.transform(training)

Furthermore, KNN itself is also exposed for advanced usage which returns arbitrary columns associated with found neighbors. For example, this can power clustering use case described in the reference Google paper.

When the model is trained, data points are repartitioned and within each partition a search tree is built to support efficient querying. When model is used in prediction, the prediction vectors are repartitioned, searched, collected and joined back to the search DataFrame. Assuming the training set is much larger, subsequent prediction can be much quicker than training. Overall the algorithm displays a O(m log n) runtime much better than the naive O(m n) runtime (for n training points, m prediction points and k = 1). See benchmark section for more details.

The number of neighbors can be set before and after training. Other parameters must be set before training and they control the number of partitions and trade off between accuracy and efficiency of individual search tree. Please refer to Scala doc for more information.

Using the Python interface with spark-submit

To run a Spark script in Python with spark-submit, use:

cd python
python bdist_egg
cd ..
sbt package

spark-submit --py-files python/dist/pyspark_knn-*.egg --driver-class-path spark-knn-core/target/scala-2.11/spark-knn_*.jar --jars spark-knn-core/target/scala-2.11/spark-knn_*.jar YOUR_SCRIPT


Preliminary benchmark results can be found at here.

We have benchmarked our implementation against MNIST dataset. For the canonical 60k training dataset, our implementation is able to get a reasonable cross validated F1 score of 0.97 comparing to brute force exact algorithm's to be computed.

While the implementation is approximate, it doesn't suffer much even when the dimension is high (as many as the full MNIST raw dimension: 784). This can be a huge advantage over other approximate implementation such as KD-tree and LSH. further benchmark is required

The implementation also exhibits sub-linear runtime which can lead to huge savings for large datasets.

Note: the duration in the above plot is total runtime thus brute-force exhibits polynomial runtime while SpillTree shows close to linearithmic runtime.

Finally the implementation scales horizontally and has been successfully applied on datasets with low hundreds millions of observations and low hundreds dimensions. We have no reason to say why it can't scale to billions of observations as described in the original Google paper.


  • implementation of MetricTree, SpillTree, HybridSpillTree

  • distributed KNN based on HybridSpillTree

  • (weighted) Classifier and Regression on ml API

  • benchmark against Brute Force and LSH based kNN in terms of model and runtime performance

  • benchmark against LSH based kNN

  • refactoring of Tree related code

    NB: currently tree are recursively constructed and contain some duplicated code. The data structure is also questionable.
    However preliminary empirical testing shows each tree can comfortably handle tens to hundreds thousands of high dimensional data points.
  • upgrade ml implementation to use DataSet API (pending Spark 1.6)

    NB: the largest cost of this implementation is disk I/O of repartition and distance calculation. While distance calculation
    has no good way to optimize, with DataSet API, we might be able to drastically reduce the shuffle size during training
    and prediction.
  • explore use of random projection for dimension reduction


  • Liu, Ting, et al. "An investigation of practical approximate nearest neighbor algorithms." Advances in neural information processing systems. 2004.

  • Liu, Ting, Charles Rosenberg, and Henry Rowley. "Clustering billions of images with large scale nearest neighbor search." Applications of Computer Vision, 2007. WACV'07. IEEE Workshop on. IEEE, 2007.

You can’t perform that action at this time.