From 91aadfe8944e688990678285a62ea6cab4ad05c5 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Fri, 30 Jan 2015 12:23:57 -0800 Subject: [PATCH] Added Java-friendly run method to LDA. Added Java test suite for LDA. Changed LDAModel.describeTopics to return Java-friendly type --- .../spark/examples/mllib/LDAExample.scala | 6 +- .../spark/examples/mllib/LDATiming.scala | 6 +- .../apache/spark/mllib/clustering/LDA.scala | 24 ++-- .../spark/mllib/clustering/LDAModel.scala | 62 +++++---- .../spark/mllib/clustering/JavaLDASuite.java | 118 ++++++++++++++++++ .../spark/mllib/clustering/LDASuite.scala | 39 +++--- 6 files changed, 198 insertions(+), 57 deletions(-) create mode 100644 mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala index cf756e0dc05d1..8baf67da71414 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala @@ -148,13 +148,13 @@ object LDAExample { // Print the topics, showing the top-weighted terms for each topic. val topicIndices = ldaModel.describeTopics(maxTermsPerTopic = 10) - val topics = topicIndices.map { topic => - topic.map { case (weight, term) => (weight, vocabArray(term.toInt)) } + val topics = topicIndices.map { case (terms, termWeights) => + terms.zip(termWeights).map { case (term, weight) => (vocabArray(term.toInt), weight) } } println(s"${params.k} topics:") topics.zipWithIndex.foreach { case (topic, i) => println(s"TOPIC $i") - topic.foreach { case (weight, term) => + topic.foreach { case (term, weight) => println(s"$term\t$weight") } println() diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/LDATiming.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/LDATiming.scala index 75b17d35463d3..82c524c8fa16b 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/LDATiming.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/LDATiming.scala @@ -167,13 +167,13 @@ object LDATiming { // Print the topics, showing the top-weighted terms for each topic. val topicIndices = ldaModel.describeTopics(maxTermsPerTopic = 10) - val topics = topicIndices.map { topic => - topic.map { case (weight, term) => (weight, vocabArray(term.toInt))} + val topics = topicIndices.map { case (terms, termWeights) => + terms.zip(termWeights).map { case (term, weight) => (vocabArray(term.toInt), weight) } } println(s"$k topics:") topics.zipWithIndex.foreach { case (topic, i) => println(s"TOPIC $i") - topic.foreach { case (weight, term) => + topic.foreach { case (term, weight) => println(s"$term\t$weight") } println() diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala index 73c3b6d205602..1755b995d57e9 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala @@ -23,6 +23,7 @@ import breeze.linalg.{DenseVector => BDV, normalize, axpy => brzAxpy} import org.apache.spark.Logging import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.api.java.JavaRDD import org.apache.spark.graphx._ import org.apache.spark.graphx.impl.GraphImpl import org.apache.spark.mllib.impl.PeriodicGraphCheckpointer @@ -233,10 +234,15 @@ class LDA private ( state.graphCheckpointer.deleteAllCheckpoints() new DistributedLDAModel(state, iterationTimes) } + + /** Java-friendly version of [[run()]] */ + def run(documents: JavaRDD[(java.lang.Long, Vector)]): DistributedLDAModel = { + run(documents.rdd.map(id_counts => (id_counts._1.asInstanceOf[Long], id_counts._2))) + } } -object LDA { +private[clustering] object LDA { /* DEVELOPERS NOTE: @@ -291,18 +297,18 @@ object LDA { * Vector over topics (length k) of token counts. * The meaning of these counts can vary, and it may or may not be normalized to be a distribution. */ - private[clustering] type TopicCounts = BDV[Double] + type TopicCounts = BDV[Double] - private[clustering] type TokenCount = Double + type TokenCount = Double /** Term vertex IDs are {-1, -2, ..., -vocabSize} */ - private[clustering] def term2index(term: Int): Long = -(1 + term.toLong) + def term2index(term: Int): Long = -(1 + term.toLong) - private[clustering] def index2term(termIndex: Long): Int = -(1 + termIndex).toInt + def index2term(termIndex: Long): Int = -(1 + termIndex).toInt - private[clustering] def isDocumentVertex(v: (VertexId, _)): Boolean = v._1 >= 0 + def isDocumentVertex(v: (VertexId, _)): Boolean = v._1 >= 0 - private[clustering] def isTermVertex(v: (VertexId, _)): Boolean = v._1 < 0 + def isTermVertex(v: (VertexId, _)): Boolean = v._1 < 0 /** * State for EM algorithm: data + parameter graph, plus algorithm parameters. @@ -314,7 +320,7 @@ object LDA { * @param docConcentration "alpha" * @param topicConcentration "beta" or "eta" */ - private[clustering] class EMOptimizer( + class EMOptimizer( var graph: Graph[TopicCounts, TokenCount], val k: Int, val vocabSize: Int, @@ -374,7 +380,7 @@ object LDA { * * Note: This executes an action on the graph RDDs. */ - private[clustering] var globalTopicTotals: TopicCounts = computeGlobalTopicTotals() + var globalTopicTotals: TopicCounts = computeGlobalTopicTotals() private def computeGlobalTopicTotals(): TopicCounts = { val numTopics = k diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala index 6a1212a114e5c..4552d75acf849 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala @@ -60,10 +60,10 @@ abstract class LDAModel private[clustering] { * * @param maxTermsPerTopic Maximum number of terms to collect for each topic. * @return Array over topics, where each element is a set of top terms represented - * as (term weight in topic, term index). + * as (term index, term weight in topic). * Each topic's terms are sorted in order of decreasing weight. */ - def describeTopics(maxTermsPerTopic: Int): Array[Array[(Double, Int)]] + def describeTopics(maxTermsPerTopic: Int): Array[(Array[Int], Array[Double])] /** * Return the topics described by weighted terms. @@ -71,10 +71,10 @@ abstract class LDAModel private[clustering] { * WARNING: If vocabSize and k are large, this can return a large object! * * @return Array over topics, where each element is a set of top terms represented - * as (term weight in topic, term index). + * as (term index, term weight in topic). * Each topic's terms are sorted in order of decreasing weight. */ - def describeTopics(): Array[Array[(Double, Int)]] = describeTopics(vocabSize) + def describeTopics(): Array[(Array[Int], Array[Double])] = describeTopics(vocabSize) /* TODO (once LDA can be trained with Strings or given a dictionary) * Return the topics described by weighted terms. @@ -89,11 +89,11 @@ abstract class LDAModel private[clustering] { * * @param maxTermsPerTopic Maximum number of terms to collect for each topic. * @return Array over topics, where each element is a set of top terms represented - * as (term weight in topic, term), where "term" is either the actual term text + * as (term, term weight in topic), where "term" is either the actual term text * (if available) or the term index. * Each topic's terms are sorted in order of decreasing weight. */ - //def describeTopicsAsStrings(maxTermsPerTopic: Int): Array[Array[(Double, String)]] + //def describeTopicsAsStrings(maxTermsPerTopic: Int): Array[(Array[Double], Array[String])] /* TODO (once LDA can be trained with Strings or given a dictionary) * Return the topics described by weighted terms. @@ -105,11 +105,11 @@ abstract class LDAModel private[clustering] { * WARNING: If vocabSize and k are large, this can return a large object! * * @return Array over topics, where each element is a set of top terms represented - * as (term weight in topic, term), where "term" is either the actual term text + * as (term, term weight in topic), where "term" is either the actual term text * (if available) or the term index. * Each topic's terms are sorted in order of decreasing weight. */ - //def describeTopicsAsStrings(): Array[Array[(Double, String)]] = + //def describeTopicsAsStrings(): Array[(Array[Double], Array[String])] = // describeTopicsAsStrings(vocabSize) /* TODO @@ -172,11 +172,13 @@ class LocalLDAModel private[clustering] ( override def topicsMatrix: Matrix = topics - override def describeTopics(maxTermsPerTopic: Int): Array[Array[(Double, Int)]] = { + override def describeTopics(maxTermsPerTopic: Int): Array[(Array[Int], Array[Double])] = { val brzTopics = topics.toBreeze.toDenseMatrix Range(0, k).map { topicIndex => val topic = normalize(brzTopics(::, topicIndex), 1.0) - topic.toArray.zipWithIndex.sortBy(-_._1).take(maxTermsPerTopic) + val (termWeights, terms) = + topic.toArray.zipWithIndex.sortBy(-_._1).take(maxTermsPerTopic).unzip + (terms.toArray, termWeights.toArray) }.toArray } @@ -248,29 +250,35 @@ class DistributedLDAModel private ( Matrices.fromBreeze(brzTopics) } - override def describeTopics(maxTermsPerTopic: Int): Array[Array[(Double, Int)]] = { + override def describeTopics(maxTermsPerTopic: Int): Array[(Array[Int], Array[Double])] = { val numTopics = k // Note: N_k is not needed to find the top terms, but it is needed to normalize weights // to a distribution over terms. val N_k: TopicCounts = globalTopicTotals - graph.vertices.filter(isTermVertex) - .mapPartitions { termVertices => - // For this partition, collect the most common terms for each topic in queues: - // queues(topic) = queue of (term weight, term index). - // Term weights are N_{wk} / N_k. - val queues = Array.fill(numTopics)(new BoundedPriorityQueue[(Double, Int)](maxTermsPerTopic)) - for ((termId, n_wk) <- termVertices) { - var topic = 0 - while (topic < numTopics) { - queues(topic) += (n_wk(topic) / N_k(topic) -> index2term(termId.toInt)) - topic += 1 + val topicsInQueues: Array[BoundedPriorityQueue[(Double, Int)]] = + graph.vertices.filter(isTermVertex) + .mapPartitions { termVertices => + // For this partition, collect the most common terms for each topic in queues: + // queues(topic) = queue of (term weight, term index). + // Term weights are N_{wk} / N_k. + val queues = + Array.fill(numTopics)(new BoundedPriorityQueue[(Double, Int)](maxTermsPerTopic)) + for ((termId, n_wk) <- termVertices) { + var topic = 0 + while (topic < numTopics) { + queues(topic) += (n_wk(topic) / N_k(topic) -> index2term(termId.toInt)) + topic += 1 + } } + Iterator(queues) + }.reduce { (q1, q2) => + q1.zip(q2).foreach { case (a, b) => a ++= b} + q1 } - Iterator(queues) - }.reduce { (q1, q2) => - q1.zip(q2).foreach { case (a, b) => a ++= b} - q1 - }.map(_.toArray.sortBy(-_._1)) + topicsInQueues.map { q => + val (termWeights, terms) = q.toArray.sortBy(-_._1).unzip + (terms.toArray, termWeights.toArray) + } } // TODO diff --git a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java new file mode 100644 index 0000000000000..9fc561bb3b332 --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java @@ -0,0 +1,118 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.clustering; + +import java.io.Serializable; +import java.util.ArrayList; + +import scala.Tuple2; + +import org.junit.After; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertArrayEquals; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.mllib.linalg.Matrix; +import org.apache.spark.mllib.linalg.Vector; + + +public class JavaLDASuite implements Serializable { + private transient JavaSparkContext sc; + + @Before + public void setUp() { + sc = new JavaSparkContext("local", "JavaLDA"); + tinyCorpus = new ArrayList>(); + for (int i = 0; i < LDASuite$.MODULE$.tinyCorpus().length; i++) { + tinyCorpus.add(new Tuple2((Long)LDASuite$.MODULE$.tinyCorpus()[i]._1(), + LDASuite$.MODULE$.tinyCorpus()[i]._2())); + } + corpus = sc.parallelize(tinyCorpus, 2); + } + + @After + public void tearDown() { + sc.stop(); + sc = null; + } + + @Test + public void localLDAModel() { + LocalLDAModel model = new LocalLDAModel(LDASuite$.MODULE$.tinyTopics()); + + // Check: basic parameters + assertEquals(model.k(), tinyK); + assertEquals(model.vocabSize(), tinyVocabSize); + assertEquals(model.topicsMatrix(), tinyTopics); + + // Check: describeTopics() with all terms + Tuple2[] fullTopicSummary = model.describeTopics(); + assertEquals(fullTopicSummary.length, tinyK); + for (int i = 0; i < fullTopicSummary.length; i++) { + assertArrayEquals(fullTopicSummary[i]._1(), tinyTopicDescription[i]._1()); + assertArrayEquals(fullTopicSummary[i]._2(), tinyTopicDescription[i]._2(), 1e-5); + } + } + + @Test + public void distributedLDAModel() { + int k = 3; + double topicSmoothing = 1.2; + double termSmoothing = 1.2; + + // Train a model + LDA lda = new LDA(); + lda.setK(k) + .setDocConcentration(topicSmoothing) + .setTopicConcentration(termSmoothing) + .setMaxIterations(5) + .setSeed(12345); + + DistributedLDAModel model = lda.run(corpus); + + // Check: basic parameters + LocalLDAModel localModel = model.toLocal(); + assertEquals(model.k(), k); + assertEquals(localModel.k(), k); + assertEquals(model.vocabSize(), tinyVocabSize); + assertEquals(localModel.vocabSize(), tinyVocabSize); + assertEquals(model.topicsMatrix(), localModel.topicsMatrix()); + + // Check: topic summaries + Tuple2[] roundedTopicSummary = model.describeTopics(); + assertEquals(roundedTopicSummary.length, k); + Tuple2[] roundedLocalTopicSummary = localModel.describeTopics(); + assertEquals(roundedLocalTopicSummary.length, k); + + // Check: log probabilities + assert(model.logLikelihood() < 0.0); + assert(model.logPrior() < 0.0); + } + + private static int tinyK = LDASuite$.MODULE$.tinyK(); + private static int tinyVocabSize = LDASuite$.MODULE$.tinyVocabSize(); + private static Matrix tinyTopics = LDASuite$.MODULE$.tinyTopics(); + private static Tuple2[] tinyTopicDescription = + LDASuite$.MODULE$.tinyTopicDescription(); + private ArrayList> tinyCorpus; + JavaRDD> corpus; + +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala index a1aa64dbe3d3a..302d751eb8a94 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala @@ -38,15 +38,19 @@ class LDASuite extends FunSuite with MLlibTestSparkContext { // Check: describeTopics() with all terms val fullTopicSummary = model.describeTopics() assert(fullTopicSummary.size === tinyK) - fullTopicSummary.zip(tinyTopicDescription).foreach { case (algSummary, tinySummary) => - assert(algSummary === tinySummary) + fullTopicSummary.zip(tinyTopicDescription).foreach { + case ((algTerms, algTermWeights), (terms, termWeights)) => + assert(algTerms === terms) + assert(algTermWeights === termWeights) } // Check: describeTopics() with some terms val smallNumTerms = 3 val smallTopicSummary = model.describeTopics(maxTermsPerTopic = smallNumTerms) - smallTopicSummary.zip(tinyTopicDescription).foreach { case (algSummary, tinySummary) => - assert(algSummary === tinySummary.slice(0, smallNumTerms)) + smallTopicSummary.zip(tinyTopicDescription).foreach { + case ((algTerms, algTermWeights), (terms, termWeights)) => + assert(algTerms === terms.slice(0, smallNumTerms)) + assert(algTermWeights === termWeights.slice(0, smallNumTerms)) } } @@ -58,10 +62,10 @@ class LDASuite extends FunSuite with MLlibTestSparkContext { // Train a model val lda = new LDA() lda.setK(k) - lda.setDocConcentration(topicSmoothing) - lda.setTopicConcentration(termSmoothing) - lda.setMaxIterations(5) - lda.setSeed(12345) + .setDocConcentration(topicSmoothing) + .setTopicConcentration(termSmoothing) + .setMaxIterations(5) + .setSeed(12345) val corpus = sc.parallelize(tinyCorpus, 2) val model: DistributedLDAModel = lda.run(corpus) @@ -76,13 +80,17 @@ class LDASuite extends FunSuite with MLlibTestSparkContext { // Check: topic summaries // The odd decimal formatting and sorting is a hack to do a robust comparison. - val roundedTopicSummary = model.describeTopics().map { case topic => + val roundedTopicSummary = model.describeTopics().map { case (terms, termWeights) => // cut values to 3 digits after the decimal place - topic.map { case (weight, term) => ("%.3f".format(weight).toDouble, term.toInt)} + terms.zip(termWeights).map { case (term, weight) => + ("%.3f".format(weight).toDouble, term.toInt) + } }.sortBy(_.mkString("")) - val roundedLocalTopicSummary = localModel.describeTopics().map { case topic => + val roundedLocalTopicSummary = localModel.describeTopics().map { case (terms, termWeights) => // cut values to 3 digits after the decimal place - topic.map { case (weight, term) => ("%.3f".format(weight).toDouble, term.toInt)} + terms.zip(termWeights).map { case (term, weight) => + ("%.3f".format(weight).toDouble, term.toInt) + } }.sortBy(_.mkString("")) roundedTopicSummary.zip(roundedLocalTopicSummary).foreach { case (t1, t2) => assert(t1 === t2) @@ -117,7 +125,7 @@ class LDASuite extends FunSuite with MLlibTestSparkContext { } } -private object LDASuite { +private[clustering] object LDASuite { def tinyK: Int = 3 def tinyVocabSize: Int = 5 @@ -128,8 +136,9 @@ private object LDASuite { ) def tinyTopics: Matrix = new DenseMatrix(numRows = tinyVocabSize, numCols = tinyK, values = tinyTopicsAsArray.fold(Array.empty[Double])(_ ++ _)) - def tinyTopicDescription: Array[Array[(Double, Int)]] = tinyTopicsAsArray.map { topic => - topic.zipWithIndex.sortBy(-_._1) + def tinyTopicDescription: Array[(Array[Int], Array[Double])] = tinyTopicsAsArray.map { topic => + val (termWeights, terms) = topic.zipWithIndex.sortBy(-_._1).unzip + (terms.toArray, termWeights.toArray) } def tinyCorpus = Array(