Skip to content

Commit

Permalink
Add tests referenced against gensim
Browse files Browse the repository at this point in the history
  • Loading branch information
Feynman Liang committed Jul 21, 2015
1 parent d4284fa commit 72038ff
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ class LDA private (
* [[https://github.com/Blei-Lab/onlineldavb]].
*/
def setDocConcentration(docConcentration: Vector): this.type = {
docConcentration.toArray.iterator
this.docConcentration = docConcentration
this
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,56 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext {
}
}

test("OnlineLDAOptimizer with asymmetric prior") {
def toydata: Array[(Long, Vector)] = Array(
Vectors.sparse(6, Array(0, 1), Array(1, 1)),
Vectors.sparse(6, Array(1, 2), Array(1, 1)),
Vectors.sparse(6, Array(0, 2), Array(1, 1)),
Vectors.sparse(6, Array(3, 4), Array(1, 1)),
Vectors.sparse(6, Array(3, 5), Array(1, 1)),
Vectors.sparse(6, Array(4, 5), Array(1, 1))
).zipWithIndex.map { case (wordCounts, docId) => (docId.toLong, wordCounts) }

val docs = sc.parallelize(toydata)
val op = new OnlineLDAOptimizer().setMiniBatchFraction(1).setTau0(1024).setKappa(0.51)
.setGammaShape(1e10)
val lda = new LDA().setK(2)
.setDocConcentration(Vectors.dense(0.00001, 0.1))
.setTopicConcentration(0.01)
.setMaxIterations(100)
.setOptimizer(op)
.setSeed(12345)

val ldaModel = lda.run(docs)
val topicIndices = ldaModel.describeTopics(maxTermsPerTopic = 10)
val topics = topicIndices.map { case (terms, termWeights) =>
terms.zip(termWeights)
}

/* Verify results with Python:
import numpy as np
from gensim import models
corpus = [
[(0, 1.0), (1, 1.0)],
[(1, 1.0), (2, 1.0)],
[(0, 1.0), (2, 1.0)],
[(3, 1.0), (4, 1.0)],
[(3, 1.0), (5, 1.0)],
[(4, 1.0), (5, 1.0)]]
np.random.seed(10)
lda = models.ldamodel.LdaModel(
corpus=corpus, alpha=np.array([0.00001, 0.1]), num_topics=2, update_every=0, passes=100)
lda.print_topics()
> ['0.167*0 + 0.167*1 + 0.167*2 + 0.167*3 + 0.167*4 + 0.167*5',
'0.167*0 + 0.167*1 + 0.167*2 + 0.167*4 + 0.167*3 + 0.167*5']
*/
topics.foreach { topic =>
assert(topic.forall { case (_, p) => p ~= 0.167 absTol 0.05 })
}
}

test("model save/load") {
// Test for LocalLDAModel.
val localModel = new LocalLDAModel(tinyTopics)
Expand Down

0 comments on commit 72038ff

Please sign in to comment.