Skip to content

Commit

Permalink
Added Java-friendly run method to LDA.
Browse files Browse the repository at this point in the history
Added Java test suite for LDA.
Changed LDAModel.describeTopics to return Java-friendly type
  • Loading branch information
jkbradley committed Feb 2, 2015
1 parent b75472d commit 91aadfe
Show file tree
Hide file tree
Showing 6 changed files with 198 additions and 57 deletions.
Expand Up @@ -148,13 +148,13 @@ object LDAExample {

// Print the topics, showing the top-weighted terms for each topic.
val topicIndices = ldaModel.describeTopics(maxTermsPerTopic = 10)
val topics = topicIndices.map { topic =>
topic.map { case (weight, term) => (weight, vocabArray(term.toInt)) }
val topics = topicIndices.map { case (terms, termWeights) =>
terms.zip(termWeights).map { case (term, weight) => (vocabArray(term.toInt), weight) }
}
println(s"${params.k} topics:")
topics.zipWithIndex.foreach { case (topic, i) =>
println(s"TOPIC $i")
topic.foreach { case (weight, term) =>
topic.foreach { case (term, weight) =>
println(s"$term\t$weight")
}
println()
Expand Down
Expand Up @@ -167,13 +167,13 @@ object LDATiming {

// Print the topics, showing the top-weighted terms for each topic.
val topicIndices = ldaModel.describeTopics(maxTermsPerTopic = 10)
val topics = topicIndices.map { topic =>
topic.map { case (weight, term) => (weight, vocabArray(term.toInt))}
val topics = topicIndices.map { case (terms, termWeights) =>
terms.zip(termWeights).map { case (term, weight) => (vocabArray(term.toInt), weight) }
}
println(s"$k topics:")
topics.zipWithIndex.foreach { case (topic, i) =>
println(s"TOPIC $i")
topic.foreach { case (weight, term) =>
topic.foreach { case (term, weight) =>
println(s"$term\t$weight")
}
println()
Expand Down
24 changes: 15 additions & 9 deletions mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala
Expand Up @@ -23,6 +23,7 @@ import breeze.linalg.{DenseVector => BDV, normalize, axpy => brzAxpy}

import org.apache.spark.Logging
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.graphx._
import org.apache.spark.graphx.impl.GraphImpl
import org.apache.spark.mllib.impl.PeriodicGraphCheckpointer
Expand Down Expand Up @@ -233,10 +234,15 @@ class LDA private (
state.graphCheckpointer.deleteAllCheckpoints()
new DistributedLDAModel(state, iterationTimes)
}

/** Java-friendly version of [[run()]] */
def run(documents: JavaRDD[(java.lang.Long, Vector)]): DistributedLDAModel = {
run(documents.rdd.map(id_counts => (id_counts._1.asInstanceOf[Long], id_counts._2)))
}
}


object LDA {
private[clustering] object LDA {

/*
DEVELOPERS NOTE:
Expand Down Expand Up @@ -291,18 +297,18 @@ object LDA {
* Vector over topics (length k) of token counts.
* The meaning of these counts can vary, and it may or may not be normalized to be a distribution.
*/
private[clustering] type TopicCounts = BDV[Double]
type TopicCounts = BDV[Double]

private[clustering] type TokenCount = Double
type TokenCount = Double

/** Term vertex IDs are {-1, -2, ..., -vocabSize} */
private[clustering] def term2index(term: Int): Long = -(1 + term.toLong)
def term2index(term: Int): Long = -(1 + term.toLong)

private[clustering] def index2term(termIndex: Long): Int = -(1 + termIndex).toInt
def index2term(termIndex: Long): Int = -(1 + termIndex).toInt

private[clustering] def isDocumentVertex(v: (VertexId, _)): Boolean = v._1 >= 0
def isDocumentVertex(v: (VertexId, _)): Boolean = v._1 >= 0

private[clustering] def isTermVertex(v: (VertexId, _)): Boolean = v._1 < 0
def isTermVertex(v: (VertexId, _)): Boolean = v._1 < 0

/**
* State for EM algorithm: data + parameter graph, plus algorithm parameters.
Expand All @@ -314,7 +320,7 @@ object LDA {
* @param docConcentration "alpha"
* @param topicConcentration "beta" or "eta"
*/
private[clustering] class EMOptimizer(
class EMOptimizer(
var graph: Graph[TopicCounts, TokenCount],
val k: Int,
val vocabSize: Int,
Expand Down Expand Up @@ -374,7 +380,7 @@ object LDA {
*
* Note: This executes an action on the graph RDDs.
*/
private[clustering] var globalTopicTotals: TopicCounts = computeGlobalTopicTotals()
var globalTopicTotals: TopicCounts = computeGlobalTopicTotals()

private def computeGlobalTopicTotals(): TopicCounts = {
val numTopics = k
Expand Down
Expand Up @@ -60,21 +60,21 @@ abstract class LDAModel private[clustering] {
*
* @param maxTermsPerTopic Maximum number of terms to collect for each topic.
* @return Array over topics, where each element is a set of top terms represented
* as (term weight in topic, term index).
* as (term index, term weight in topic).
* Each topic's terms are sorted in order of decreasing weight.
*/
def describeTopics(maxTermsPerTopic: Int): Array[Array[(Double, Int)]]
def describeTopics(maxTermsPerTopic: Int): Array[(Array[Int], Array[Double])]

/**
* Return the topics described by weighted terms.
*
* WARNING: If vocabSize and k are large, this can return a large object!
*
* @return Array over topics, where each element is a set of top terms represented
* as (term weight in topic, term index).
* as (term index, term weight in topic).
* Each topic's terms are sorted in order of decreasing weight.
*/
def describeTopics(): Array[Array[(Double, Int)]] = describeTopics(vocabSize)
def describeTopics(): Array[(Array[Int], Array[Double])] = describeTopics(vocabSize)

/* TODO (once LDA can be trained with Strings or given a dictionary)
* Return the topics described by weighted terms.
Expand All @@ -89,11 +89,11 @@ abstract class LDAModel private[clustering] {
*
* @param maxTermsPerTopic Maximum number of terms to collect for each topic.
* @return Array over topics, where each element is a set of top terms represented
* as (term weight in topic, term), where "term" is either the actual term text
* as (term, term weight in topic), where "term" is either the actual term text
* (if available) or the term index.
* Each topic's terms are sorted in order of decreasing weight.
*/
//def describeTopicsAsStrings(maxTermsPerTopic: Int): Array[Array[(Double, String)]]
//def describeTopicsAsStrings(maxTermsPerTopic: Int): Array[(Array[Double], Array[String])]

/* TODO (once LDA can be trained with Strings or given a dictionary)
* Return the topics described by weighted terms.
Expand All @@ -105,11 +105,11 @@ abstract class LDAModel private[clustering] {
* WARNING: If vocabSize and k are large, this can return a large object!
*
* @return Array over topics, where each element is a set of top terms represented
* as (term weight in topic, term), where "term" is either the actual term text
* as (term, term weight in topic), where "term" is either the actual term text
* (if available) or the term index.
* Each topic's terms are sorted in order of decreasing weight.
*/
//def describeTopicsAsStrings(): Array[Array[(Double, String)]] =
//def describeTopicsAsStrings(): Array[(Array[Double], Array[String])] =
// describeTopicsAsStrings(vocabSize)

/* TODO
Expand Down Expand Up @@ -172,11 +172,13 @@ class LocalLDAModel private[clustering] (

override def topicsMatrix: Matrix = topics

override def describeTopics(maxTermsPerTopic: Int): Array[Array[(Double, Int)]] = {
override def describeTopics(maxTermsPerTopic: Int): Array[(Array[Int], Array[Double])] = {
val brzTopics = topics.toBreeze.toDenseMatrix
Range(0, k).map { topicIndex =>
val topic = normalize(brzTopics(::, topicIndex), 1.0)
topic.toArray.zipWithIndex.sortBy(-_._1).take(maxTermsPerTopic)
val (termWeights, terms) =
topic.toArray.zipWithIndex.sortBy(-_._1).take(maxTermsPerTopic).unzip
(terms.toArray, termWeights.toArray)
}.toArray
}

Expand Down Expand Up @@ -248,29 +250,35 @@ class DistributedLDAModel private (
Matrices.fromBreeze(brzTopics)
}

override def describeTopics(maxTermsPerTopic: Int): Array[Array[(Double, Int)]] = {
override def describeTopics(maxTermsPerTopic: Int): Array[(Array[Int], Array[Double])] = {
val numTopics = k
// Note: N_k is not needed to find the top terms, but it is needed to normalize weights
// to a distribution over terms.
val N_k: TopicCounts = globalTopicTotals
graph.vertices.filter(isTermVertex)
.mapPartitions { termVertices =>
// For this partition, collect the most common terms for each topic in queues:
// queues(topic) = queue of (term weight, term index).
// Term weights are N_{wk} / N_k.
val queues = Array.fill(numTopics)(new BoundedPriorityQueue[(Double, Int)](maxTermsPerTopic))
for ((termId, n_wk) <- termVertices) {
var topic = 0
while (topic < numTopics) {
queues(topic) += (n_wk(topic) / N_k(topic) -> index2term(termId.toInt))
topic += 1
val topicsInQueues: Array[BoundedPriorityQueue[(Double, Int)]] =
graph.vertices.filter(isTermVertex)
.mapPartitions { termVertices =>
// For this partition, collect the most common terms for each topic in queues:
// queues(topic) = queue of (term weight, term index).
// Term weights are N_{wk} / N_k.
val queues =
Array.fill(numTopics)(new BoundedPriorityQueue[(Double, Int)](maxTermsPerTopic))
for ((termId, n_wk) <- termVertices) {
var topic = 0
while (topic < numTopics) {
queues(topic) += (n_wk(topic) / N_k(topic) -> index2term(termId.toInt))
topic += 1
}
}
Iterator(queues)
}.reduce { (q1, q2) =>
q1.zip(q2).foreach { case (a, b) => a ++= b}
q1
}
Iterator(queues)
}.reduce { (q1, q2) =>
q1.zip(q2).foreach { case (a, b) => a ++= b}
q1
}.map(_.toArray.sortBy(-_._1))
topicsInQueues.map { q =>
val (termWeights, terms) = q.toArray.sortBy(-_._1).unzip
(terms.toArray, termWeights.toArray)
}
}

// TODO
Expand Down
@@ -0,0 +1,118 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.mllib.clustering;

import java.io.Serializable;
import java.util.ArrayList;

import scala.Tuple2;

import org.junit.After;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertArrayEquals;
import org.junit.Before;
import org.junit.Test;

import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.linalg.Matrix;
import org.apache.spark.mllib.linalg.Vector;


public class JavaLDASuite implements Serializable {
private transient JavaSparkContext sc;

@Before
public void setUp() {
sc = new JavaSparkContext("local", "JavaLDA");
tinyCorpus = new ArrayList<Tuple2<Long, Vector>>();
for (int i = 0; i < LDASuite$.MODULE$.tinyCorpus().length; i++) {
tinyCorpus.add(new Tuple2<Long, Vector>((Long)LDASuite$.MODULE$.tinyCorpus()[i]._1(),
LDASuite$.MODULE$.tinyCorpus()[i]._2()));
}
corpus = sc.parallelize(tinyCorpus, 2);
}

@After
public void tearDown() {
sc.stop();
sc = null;
}

@Test
public void localLDAModel() {
LocalLDAModel model = new LocalLDAModel(LDASuite$.MODULE$.tinyTopics());

// Check: basic parameters
assertEquals(model.k(), tinyK);
assertEquals(model.vocabSize(), tinyVocabSize);
assertEquals(model.topicsMatrix(), tinyTopics);

// Check: describeTopics() with all terms
Tuple2<int[], double[]>[] fullTopicSummary = model.describeTopics();
assertEquals(fullTopicSummary.length, tinyK);
for (int i = 0; i < fullTopicSummary.length; i++) {
assertArrayEquals(fullTopicSummary[i]._1(), tinyTopicDescription[i]._1());
assertArrayEquals(fullTopicSummary[i]._2(), tinyTopicDescription[i]._2(), 1e-5);
}
}

@Test
public void distributedLDAModel() {
int k = 3;
double topicSmoothing = 1.2;
double termSmoothing = 1.2;

// Train a model
LDA lda = new LDA();
lda.setK(k)
.setDocConcentration(topicSmoothing)
.setTopicConcentration(termSmoothing)
.setMaxIterations(5)
.setSeed(12345);

DistributedLDAModel model = lda.run(corpus);

// Check: basic parameters
LocalLDAModel localModel = model.toLocal();
assertEquals(model.k(), k);
assertEquals(localModel.k(), k);
assertEquals(model.vocabSize(), tinyVocabSize);
assertEquals(localModel.vocabSize(), tinyVocabSize);
assertEquals(model.topicsMatrix(), localModel.topicsMatrix());

// Check: topic summaries
Tuple2<int[], double[]>[] roundedTopicSummary = model.describeTopics();
assertEquals(roundedTopicSummary.length, k);
Tuple2<int[], double[]>[] roundedLocalTopicSummary = localModel.describeTopics();
assertEquals(roundedLocalTopicSummary.length, k);

// Check: log probabilities
assert(model.logLikelihood() < 0.0);
assert(model.logPrior() < 0.0);
}

private static int tinyK = LDASuite$.MODULE$.tinyK();
private static int tinyVocabSize = LDASuite$.MODULE$.tinyVocabSize();
private static Matrix tinyTopics = LDASuite$.MODULE$.tinyTopics();
private static Tuple2<int[], double[]>[] tinyTopicDescription =
LDASuite$.MODULE$.tinyTopicDescription();
private ArrayList<Tuple2<Long, Vector>> tinyCorpus;
JavaRDD<Tuple2<Long, Vector>> corpus;

}

0 comments on commit 91aadfe

Please sign in to comment.