Skip to content

Commit

Permalink
[SPARK-5989] Model save/load for LDA
Browse files Browse the repository at this point in the history
  • Loading branch information
MechCoder committed Jun 23, 2015
1 parent 41ab285 commit 2782326
Show file tree
Hide file tree
Showing 3 changed files with 269 additions and 5 deletions.
10 changes: 9 additions & 1 deletion docs/mllib-clustering.md
Original file line number Diff line number Diff line change
Expand Up @@ -425,7 +425,7 @@ to the algorithm. We then output the topics, represented as probability distribu
<div data-lang="scala" markdown="1">

{% highlight scala %}
import org.apache.spark.mllib.clustering.LDA
import org.apache.spark.mllib.clustering.{LDA, DistributedLDAModel}
import org.apache.spark.mllib.linalg.Vectors

// Load and parse the data
Expand All @@ -445,6 +445,11 @@ for (topic <- Range(0, 3)) {
for (word <- Range(0, ldaModel.vocabSize)) { print(" " + topics(word, topic)); }
println()
}

// Save and load model.
ldaModel.save(sc, "myLDAModel")
val sameModel = DistributedLDAModel.load(sc, "myLDAModel")

{% endhighlight %}
</div>

Expand Down Expand Up @@ -504,6 +509,9 @@ public class JavaLDAExample {
}
System.out.println();
}

ldaModel.save(sc.sc(), "myLDAModel");
DistributedLDAModel sameModel = DistributedLDAModel.load(sc.sc(), "myLDAModel");
}
}
{% endhighlight %}
Expand Down
224 changes: 220 additions & 4 deletions mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,25 @@

package org.apache.spark.mllib.clustering

import breeze.linalg.{DenseMatrix => BDM, normalize, sum => brzSum}
import breeze.linalg.{DenseMatrix => BDM, normalize, sum => brzSum, DenseVector => BDV}

import org.apache.hadoop.fs.Path

import org.json4s.DefaultFormats
import org.json4s.JsonDSL._
import org.json4s.jackson.JsonMethods._

import org.apache.spark.SparkContext
import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.JavaPairRDD
import org.apache.spark.graphx.{VertexId, EdgeContext, Graph}
import org.apache.spark.mllib.linalg.{Vectors, Vector, Matrices, Matrix}
import org.apache.spark.graphx.{VertexId, Edge, EdgeContext, Graph}
import org.apache.spark.mllib.linalg.{Vectors, Vector, Matrices, Matrix, DenseVector}
import org.apache.spark.mllib.util.{Saveable, Loader}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{SQLContext, Row}
import org.apache.spark.util.BoundedPriorityQueue


/**
* :: Experimental ::
*
Expand All @@ -35,7 +45,7 @@ import org.apache.spark.util.BoundedPriorityQueue
* including local and distributed data structures.
*/
@Experimental
abstract class LDAModel private[clustering] {
abstract class LDAModel private[clustering] extends Saveable {

/** Number of topics */
def k: Int
Expand Down Expand Up @@ -176,6 +186,11 @@ class LocalLDAModel private[clustering] (
}.toArray
}

override protected def formatVersion = "1.0"

override def save(sc: SparkContext, path: String): Unit = {
LocalLDAModel.SaveLoadV1_0.save(sc, path, topicsMatrix)
}
// TODO
// override def logLikelihood(documents: RDD[(Long, Vector)]): Double = ???

Expand All @@ -184,6 +199,82 @@ class LocalLDAModel private[clustering] (

}

@Experimental
object LocalLDAModel extends Loader[LocalLDAModel]{

private object SaveLoadV1_0 {

val formatVersionV1_0 = "1.0"

val classNameV1_0 = "org.apache.spark.mllib.clustering.LocalLDAModel"

// Store the distribution of terms of each topic as a Row in data.
case class Data(termDistributions: Vector)

def save(sc: SparkContext, path: String, topicsMatrix: Matrix): Unit = {

val sqlContext = new SQLContext(sc)
import sqlContext.implicits._

val k = topicsMatrix.numCols
val metadata = compact(render
(("class" -> classNameV1_0) ~ ("version" -> formatVersionV1_0) ~
("k" -> k) ~ ("vocabSize" -> topicsMatrix.numRows)))
sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path))

val topicsDenseMatrix = topicsMatrix.toBreeze.toDenseMatrix

val termDistributions = Range(0, k).map { topicInd =>
Data(Vectors.dense((topicsDenseMatrix(::, topicInd).toArray)))
}.toSeq
sc.parallelize(termDistributions, 1).toDF().write.parquet(Loader.dataPath(path))
}

def load(sc: SparkContext, path: String): Matrix = {

val dataPath = Loader.dataPath(path)
val sqlContext = new SQLContext(sc)
val dataFrame = sqlContext.read.parquet(dataPath)

Loader.checkSchema[Data](dataFrame.schema)
val termDistributions = dataFrame.collect()
val vocabSize = termDistributions(0)(0).asInstanceOf[Vector].size
val k = termDistributions.size

val brzTopics = BDM.zeros[Double](vocabSize, k)
termDistributions.zipWithIndex.foreach { case (Row(vec: Vector), ind: Int) =>
brzTopics(::, ind) := vec.toBreeze
}
Matrices.fromBreeze(brzTopics)
}
}

override def load(sc: SparkContext, path: String): LocalLDAModel = {

val (loadedClassName, loadedVersion, metadata) = Loader.loadMetadata(sc, path)
implicit val formats = DefaultFormats
val expectedK = (metadata \ "k").extract[Int]
val expectedVocabSize = (metadata \ "vocabSize").extract[Int]
val classNameV1_0 = SaveLoadV1_0.classNameV1_0

(loadedClassName, loadedVersion) match {
case (classNameV1_0, formatVersionV1_0) => {
val topicsMatrix = SaveLoadV1_0.load(sc, path)
require(expectedK == topicsMatrix.numCols,
s"LocalLDAModel requires $expectedK topics, got $topicsMatrix.numCols topics")
require(expectedVocabSize == topicsMatrix.numRows,
s"LocalLDAModel requires $expectedVocabSize terms for each topic, " +
s"but got $topicsMatrix.numRows")
new LocalLDAModel(topicsMatrix)
}
case _ => throw new Exception(
s"LocalLDAModel.load did not recognize model with (className, format version):" +
s"($loadedClassName, $loadedVersion). Supported:\n" +
s" ($classNameV1_0, 1.0)")
}
}
}

/**
* :: Experimental ::
*
Expand Down Expand Up @@ -354,4 +445,129 @@ class DistributedLDAModel private (
// TODO:
// override def topicDistributions(documents: RDD[(Long, Vector)]): RDD[(Long, Vector)] = ???

override protected def formatVersion = "1.0"

override def save(sc: SparkContext, path: String): Unit = {
DistributedLDAModel.SaveLoadV1_0.save(
sc, path, graph, globalTopicTotals, k, vocabSize, docConcentration, topicConcentration,
iterationTimes)
}
}


@Experimental
object DistributedLDAModel extends Loader[DistributedLDAModel]{


object SaveLoadV1_0 {

val formatVersionV1_0 = "1.0"

val classNameV1_0 = "org.apache.spark.mllib.clustering.DistributedLDAModel"

// Store the weight of each topic separately in a row.
case class Data(globalTopicTotals: Double)

// Store each term and document vertex with an id and the topicWeights.
case class VertexData(id: Long, topicWeights: Vector)

// Store each edge with the source id, destination id and tokenCounts.
case class EdgeData(srcId: Long, dstId: Long, tokenCounts: Double)

def save(
sc: SparkContext,
path: String,
graph: Graph[LDA.TopicCounts, LDA.TokenCount],
globalTopicTotals: LDA.TopicCounts,
k: Int,
vocabSize: Int,
docConcentration: Double,
topicConcentration: Double,
iterationTimes: Array[Double]): Unit = {

val sqlContext = new SQLContext(sc)
import sqlContext.implicits._

val metadata = compact(render
(("class" -> classNameV1_0) ~ ("version" -> formatVersionV1_0) ~
("k" -> k) ~ ("vocabSize" -> vocabSize) ~ ("docConcentration" -> docConcentration) ~
("topicConcentration" -> topicConcentration) ~
("iterationTimes" -> iterationTimes.toSeq)))
sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path))

val newPath = new Path(Loader.dataPath(path), "globalTopicTotals").toString
sc.parallelize(globalTopicTotals.toArray.toSeq.map(w => Data(w)), 1).toDF()
.write.parquet(newPath)

val verticesPath = new Path(Loader.dataPath(path), "topicCounts").toString
graph.vertices.map { case (ind, vertex) =>
VertexData(ind, Vectors.fromBreeze(vertex))
}.toDF().write.parquet(verticesPath)

val edgesPath = new Path(Loader.dataPath(path), "tokenCounts").toString
graph.edges.map { case Edge(srcId, dstId, prop) =>
EdgeData(srcId, dstId, prop)
}.toDF().write.parquet(edgesPath)
}

def load(
sc: SparkContext,
path: String): (Graph[LDA.TopicCounts, LDA.TokenCount], LDA.TopicCounts) = {

val dataPath = new Path(Loader.dataPath(path), "globalTopicTotals").toString
val vertexDataPath = new Path(Loader.dataPath(path), "topicCounts").toString
val edgeDataPath = new Path(Loader.dataPath(path), "tokenCounts").toString
val sqlContext = new SQLContext(sc)
val dataFrame = sqlContext.read.parquet(dataPath)
val vertexDataFrame = sqlContext.read.parquet(vertexDataPath)
val edgeDataFrame = sqlContext.read.parquet(edgeDataPath)

Loader.checkSchema[Data](dataFrame.schema)
Loader.checkSchema[VertexData](vertexDataFrame.schema)
Loader.checkSchema[EdgeData](edgeDataFrame.schema)
val globalTopicTotals: LDA.TopicCounts = BDV(dataFrame.collect().map(i => i.getDouble(0)))
val vertices: RDD[(VertexId, LDA.TopicCounts)] = vertexDataFrame.map {
case Row(ind: Long, vec: Vector) => (ind, vec.toBreeze.toDenseVector)
}

val edges: RDD[Edge[LDA.TokenCount]] = edgeDataFrame.map {
case Row(srcId: Long, dstId: Long, prop: Double) => Edge(srcId, dstId, prop)
}

val graph: Graph[LDA.TopicCounts, LDA.TokenCount] = Graph(vertices, edges)

(graph, globalTopicTotals)
}

}

override def load(sc: SparkContext, path: String): DistributedLDAModel = {

val (loadedClassName, loadedVersion, metadata) = Loader.loadMetadata(sc, path)
implicit val formats = DefaultFormats
val expectedK = (metadata \ "k").extract[Int]
val vocabSize = (metadata \ "vocabSize").extract[Int]
val docConcentration = (metadata \ "docConcentration").extract[Double]
val topicConcentration = (metadata \ "topicConcentration").extract[Double]
val iterationTimes = (metadata \ "iterationTimes").extract[Seq[Double]]
val classNameV1_0 = SaveLoadV1_0.classNameV1_0

(loadedClassName, loadedVersion) match {
case (classNameV1_0, formatVersionV1_0) => {
val (graph, globalTopicTotals) = DistributedLDAModel.SaveLoadV1_0.load(sc, path)
val model = new DistributedLDAModel(
graph, globalTopicTotals, expectedK, vocabSize, docConcentration, topicConcentration,
iterationTimes.toArray)
require(expectedK == globalTopicTotals.length,
s"LocalLDAModel requires $expectedK topics, got $globalTopicTotals.length topics")
model
}

case _ => throw new Exception(
s"DistributedLDAModel.load did not recognize model with (className, format version):" +
s"($loadedClassName, $loadedVersion). Supported: ($classNameV1_0, 1.0)")
}
}


}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.mllib.clustering
import breeze.linalg.{DenseMatrix => BDM}

import org.apache.spark.SparkFunSuite
import org.apache.spark.util.Utils
import org.apache.spark.mllib.linalg.{Vector, DenseMatrix, Matrix, Vectors}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
Expand Down Expand Up @@ -213,6 +214,45 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext {
}
}

test("model save/load") {
// Test for LocalLDAModel.
val localModel = new LocalLDAModel(tinyTopics)
val tempDir1 = Utils.createTempDir()
val path1 = tempDir1.toURI.toString

// Test for DistributedLDAModel.
val k = 3
val topicSmoothing = 1.2
val termSmoothing = 1.2
val lda = new LDA()
lda.setK(k)
.setDocConcentration(topicSmoothing)
.setTopicConcentration(termSmoothing)
.setMaxIterations(5)
.setSeed(12345)
val corpus = sc.parallelize(tinyCorpus, 2)
val distributedModel: DistributedLDAModel = lda.run(corpus).asInstanceOf[DistributedLDAModel]
val tempDir2 = Utils.createTempDir()
val path2 = tempDir2.toURI.toString

try {
localModel.save(sc, path1)
distributedModel.save(sc, path2)
val samelocalModel = LocalLDAModel.load(sc, path1)
assert(samelocalModel.topicsMatrix === localModel.topicsMatrix)
assert(samelocalModel.k === localModel.k)
assert(samelocalModel.vocabSize === localModel.vocabSize)

val sameDistributedModel = DistributedLDAModel.load(sc, path2)
assert(distributedModel.topicsMatrix === sameDistributedModel.topicsMatrix)
assert(distributedModel.k === sameDistributedModel.k)
assert(distributedModel.vocabSize === sameDistributedModel.vocabSize)
} finally {
Utils.deleteRecursively(tempDir1)
Utils.deleteRecursively(tempDir2)
}
}

}

private[clustering] object LDASuite {
Expand Down

0 comments on commit 2782326

Please sign in to comment.