Skip to content

Commit

Permalink
Fix from review feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
Feynman Liang committed Jul 21, 2015
1 parent a6dcf70 commit 58f1d7b
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ final class EMLDAOptimizer extends LDAOptimizer {
* Compute bipartite term/doc graph.
*/
override private[clustering] def initialize(docs: RDD[(Long, Vector)], lda: LDA): LDAOptimizer = {
val docConcentration = breeze.stats.mean(lda.getDocConcentration.toBreeze)
val docConcentration = lda.getDocConcentration(0)
require({
lda.getDocConcentration.toArray.forall(_ == docConcentration)
}, "EMLDAOptimizer currently only supports symmetric document-topic priors")
Expand Down Expand Up @@ -349,13 +349,13 @@ final class OnlineLDAOptimizer extends LDAOptimizer {
this.alpha = if (lda.getDocConcentration.size == 1) {
if (lda.getDocConcentration(0) == -1) Vectors.dense(Array.fill(k)(1.0 / k))
else {
require(lda.getDocConcentration(0) >= 0, "all entries in alpha must be >=0")
require(lda.getDocConcentration(0) >= 0, s"all entries in alpha must be >=0, got: $alpha")
Vectors.dense(Array.fill(k)(lda.getDocConcentration(0)))
}
} else {
require(lda.getDocConcentration.size == k, "alpha must have length k")
require(lda.getDocConcentration.size == k, s"alpha must have length k, got: $alpha")
lda.getDocConcentration.foreachActive { case (_, x) =>
require(x >= 0, "all entries in alpha must be >= 0")
require(x >= 0, s"all entries in alpha must be >= 0, got: $alpha")
}
lda.getDocConcentration
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext {

test("initializing with elements in alpha < 0 fails") {
intercept[IllegalArgumentException] {
val lda = new LDA().setK(2).setAlpha(Vectors.dense(-1, 2, 3, 4))
val lda = new LDA().setK(4).setAlpha(Vectors.dense(-1, 2, 3, 4))
val corpus = sc.parallelize(tinyCorpus, 2)
lda.run(corpus)
}
Expand Down

0 comments on commit 58f1d7b

Please sign in to comment.