From 58f1d7b6f2632e00451cf4becf2486cbc9650d03 Mon Sep 17 00:00:00 2001 From: Feynman Liang Date: Tue, 21 Jul 2015 11:16:31 -0700 Subject: [PATCH] Fix from review feedback --- .../org/apache/spark/mllib/clustering/LDAOptimizer.scala | 8 ++++---- .../org/apache/spark/mllib/clustering/LDASuite.scala | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala index bc20f75133509..b2f3a2538d0b0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala @@ -95,7 +95,7 @@ final class EMLDAOptimizer extends LDAOptimizer { * Compute bipartite term/doc graph. */ override private[clustering] def initialize(docs: RDD[(Long, Vector)], lda: LDA): LDAOptimizer = { - val docConcentration = breeze.stats.mean(lda.getDocConcentration.toBreeze) + val docConcentration = lda.getDocConcentration(0) require({ lda.getDocConcentration.toArray.forall(_ == docConcentration) }, "EMLDAOptimizer currently only supports symmetric document-topic priors") @@ -349,13 +349,13 @@ final class OnlineLDAOptimizer extends LDAOptimizer { this.alpha = if (lda.getDocConcentration.size == 1) { if (lda.getDocConcentration(0) == -1) Vectors.dense(Array.fill(k)(1.0 / k)) else { - require(lda.getDocConcentration(0) >= 0, "all entries in alpha must be >=0") + require(lda.getDocConcentration(0) >= 0, s"all entries in alpha must be >=0, got: $alpha") Vectors.dense(Array.fill(k)(lda.getDocConcentration(0))) } } else { - require(lda.getDocConcentration.size == k, "alpha must have length k") + require(lda.getDocConcentration.size == k, s"alpha must have length k, got: $alpha") lda.getDocConcentration.foreachActive { case (_, x) => - require(x >= 0, "all entries in alpha must be >= 0") + require(x >= 0, s"all entries in alpha must be >= 0, got: $alpha") } lda.getDocConcentration } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala index c6f6d2a73c4b0..da70d9bd7c790 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala @@ -148,7 +148,7 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext { test("initializing with elements in alpha < 0 fails") { intercept[IllegalArgumentException] { - val lda = new LDA().setK(2).setAlpha(Vectors.dense(-1, 2, 3, 4)) + val lda = new LDA().setK(4).setAlpha(Vectors.dense(-1, 2, 3, 4)) val corpus = sc.parallelize(tinyCorpus, 2) lda.run(corpus) }