Skip to content

Commit

Permalink
SPARK-1215 [MLLIB]: Clustering: Index out of bounds error (2)
Browse files Browse the repository at this point in the history
Added check to LocalKMeans.scala: kMeansPlusPlus initialization to handle case with fewer distinct data points than clusters k.  Added two related unit tests to KMeansSuite.  (Re-submitting PR after tangling commits in PR 1407 apache#1407 )

Author: Joseph K. Bradley <joseph.kurata.bradley@gmail.com>

Closes apache#1468 from jkbradley/kmeans-fix and squashes the following commits:

4e9bd1e [Joseph K. Bradley] Updated PR per comments from mengxr
6c7a2ec [Joseph K. Bradley] Added check to LocalKMeans.scala: kMeansPlusPlus initialization to handle case with fewer distinct data points than clusters k.  Added two related unit tests to KMeansSuite.
  • Loading branch information
jkbradley authored and conviva-zz committed Sep 4, 2014
1 parent 22c7ce5 commit f24062e
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,13 @@ private[mllib] object LocalKMeans extends Logging {
cumulativeScore += weights(j) * KMeans.pointCost(curCenters, points(j))
j += 1
}
centers(i) = points(j-1).toDense
if (j == 0) {
logWarning("kMeansPlusPlus initialization ran out of distinct points for centers." +
s" Using duplicate point for center k = $i.")
centers(i) = points(0).toDense
} else {
centers(i) = points(j - 1).toDense
}
}

// Run up to maxIterations iterations of Lloyd's algorithm
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,32 @@ class KMeansSuite extends FunSuite with LocalSparkContext {
assert(model.clusterCenters.head === center)
}

test("no distinct points") {
val data = sc.parallelize(
Array(
Vectors.dense(1.0, 2.0, 3.0),
Vectors.dense(1.0, 2.0, 3.0),
Vectors.dense(1.0, 2.0, 3.0)),
2)
val center = Vectors.dense(1.0, 2.0, 3.0)

// Make sure code runs.
var model = KMeans.train(data, k=2, maxIterations=1)
assert(model.clusterCenters.size === 2)
}

test("more clusters than points") {
val data = sc.parallelize(
Array(
Vectors.dense(1.0, 2.0, 3.0),
Vectors.dense(1.0, 3.0, 4.0)),
2)

// Make sure code runs.
var model = KMeans.train(data, k=3, maxIterations=1)
assert(model.clusterCenters.size === 3)
}

test("single cluster with big dataset") {
val smallData = Array(
Vectors.dense(1.0, 2.0, 6.0),
Expand Down

0 comments on commit f24062e

Please sign in to comment.