diff --git a/docs/mllib-clustering.md b/docs/mllib-clustering.md index dcaa3784be874..2b300439ce298 100644 --- a/docs/mllib-clustering.md +++ b/docs/mllib-clustering.md @@ -425,7 +425,7 @@ to the algorithm. We then output the topics, represented as probability distribu
{% highlight scala %} -import org.apache.spark.mllib.clustering.LDA +import org.apache.spark.mllib.clustering.{LDA, DistributedLDAModel} import org.apache.spark.mllib.linalg.Vectors // Load and parse the data @@ -445,6 +445,11 @@ for (topic <- Range(0, 3)) { for (word <- Range(0, ldaModel.vocabSize)) { print(" " + topics(word, topic)); } println() } + +// Save and load model. +ldaModel.save(sc, "myLDAModel") +val sameModel = DistributedLDAModel.load(sc, "myLDAModel") + {% endhighlight %}
@@ -504,6 +509,9 @@ public class JavaLDAExample { } System.out.println(); } + + ldaModel.save(sc.sc(), "myLDAModel"); + DistributedLDAModel sameModel = DistributedLDAModel.load(sc.sc(), "myLDAModel"); } } {% endhighlight %} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala index 974b26924dfb8..f6d3a76876912 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala @@ -17,15 +17,25 @@ package org.apache.spark.mllib.clustering -import breeze.linalg.{DenseMatrix => BDM, normalize, sum => brzSum} +import breeze.linalg.{DenseMatrix => BDM, normalize, sum => brzSum, DenseVector => BDV} +import org.apache.hadoop.fs.Path + +import org.json4s.DefaultFormats +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods._ + +import org.apache.spark.SparkContext import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.JavaPairRDD -import org.apache.spark.graphx.{VertexId, EdgeContext, Graph} -import org.apache.spark.mllib.linalg.{Vectors, Vector, Matrices, Matrix} +import org.apache.spark.graphx.{VertexId, Edge, EdgeContext, Graph} +import org.apache.spark.mllib.linalg.{Vectors, Vector, Matrices, Matrix, DenseVector} +import org.apache.spark.mllib.util.{Saveable, Loader} import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{SQLContext, Row} import org.apache.spark.util.BoundedPriorityQueue + /** * :: Experimental :: * @@ -35,7 +45,7 @@ import org.apache.spark.util.BoundedPriorityQueue * including local and distributed data structures. */ @Experimental -abstract class LDAModel private[clustering] { +abstract class LDAModel private[clustering] extends Saveable { /** Number of topics */ def k: Int @@ -176,6 +186,11 @@ class LocalLDAModel private[clustering] ( }.toArray } + override protected def formatVersion = "1.0" + + override def save(sc: SparkContext, path: String): Unit = { + LocalLDAModel.SaveLoadV1_0.save(sc, path, topicsMatrix) + } // TODO // override def logLikelihood(documents: RDD[(Long, Vector)]): Double = ??? @@ -184,6 +199,82 @@ class LocalLDAModel private[clustering] ( } +@Experimental +object LocalLDAModel extends Loader[LocalLDAModel]{ + + private object SaveLoadV1_0 { + + val formatVersionV1_0 = "1.0" + + val classNameV1_0 = "org.apache.spark.mllib.clustering.LocalLDAModel" + + // Store the distribution of terms of each topic as a Row in data. + case class Data(termDistributions: Vector) + + def save(sc: SparkContext, path: String, topicsMatrix: Matrix): Unit = { + + val sqlContext = new SQLContext(sc) + import sqlContext.implicits._ + + val k = topicsMatrix.numCols + val metadata = compact(render + (("class" -> classNameV1_0) ~ ("version" -> formatVersionV1_0) ~ + ("k" -> k) ~ ("vocabSize" -> topicsMatrix.numRows))) + sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path)) + + val topicsDenseMatrix = topicsMatrix.toBreeze.toDenseMatrix + + val termDistributions = Range(0, k).map { topicInd => + Data(Vectors.dense((topicsDenseMatrix(::, topicInd).toArray))) + }.toSeq + sc.parallelize(termDistributions, 1).toDF().write.parquet(Loader.dataPath(path)) + } + + def load(sc: SparkContext, path: String): Matrix = { + + val dataPath = Loader.dataPath(path) + val sqlContext = new SQLContext(sc) + val dataFrame = sqlContext.read.parquet(dataPath) + + Loader.checkSchema[Data](dataFrame.schema) + val termDistributions = dataFrame.collect() + val vocabSize = termDistributions(0)(0).asInstanceOf[Vector].size + val k = termDistributions.size + + val brzTopics = BDM.zeros[Double](vocabSize, k) + termDistributions.zipWithIndex.foreach { case (Row(vec: Vector), ind: Int) => + brzTopics(::, ind) := vec.toBreeze + } + Matrices.fromBreeze(brzTopics) + } + } + + override def load(sc: SparkContext, path: String): LocalLDAModel = { + + val (loadedClassName, loadedVersion, metadata) = Loader.loadMetadata(sc, path) + implicit val formats = DefaultFormats + val expectedK = (metadata \ "k").extract[Int] + val expectedVocabSize = (metadata \ "vocabSize").extract[Int] + val classNameV1_0 = SaveLoadV1_0.classNameV1_0 + + (loadedClassName, loadedVersion) match { + case (classNameV1_0, formatVersionV1_0) => { + val topicsMatrix = SaveLoadV1_0.load(sc, path) + require(expectedK == topicsMatrix.numCols, + s"LocalLDAModel requires $expectedK topics, got $topicsMatrix.numCols topics") + require(expectedVocabSize == topicsMatrix.numRows, + s"LocalLDAModel requires $expectedVocabSize terms for each topic, " + + s"but got $topicsMatrix.numRows") + new LocalLDAModel(topicsMatrix) + } + case _ => throw new Exception( + s"LocalLDAModel.load did not recognize model with (className, format version):" + + s"($loadedClassName, $loadedVersion). Supported:\n" + + s" ($classNameV1_0, 1.0)") + } + } +} + /** * :: Experimental :: * @@ -354,4 +445,129 @@ class DistributedLDAModel private ( // TODO: // override def topicDistributions(documents: RDD[(Long, Vector)]): RDD[(Long, Vector)] = ??? + override protected def formatVersion = "1.0" + + override def save(sc: SparkContext, path: String): Unit = { + DistributedLDAModel.SaveLoadV1_0.save( + sc, path, graph, globalTopicTotals, k, vocabSize, docConcentration, topicConcentration, + iterationTimes) + } +} + + +@Experimental +object DistributedLDAModel extends Loader[DistributedLDAModel]{ + + + object SaveLoadV1_0 { + + val formatVersionV1_0 = "1.0" + + val classNameV1_0 = "org.apache.spark.mllib.clustering.DistributedLDAModel" + + // Store the weight of each topic separately in a row. + case class Data(globalTopicTotals: Double) + + // Store each term and document vertex with an id and the topicWeights. + case class VertexData(id: Long, topicWeights: Vector) + + // Store each edge with the source id, destination id and tokenCounts. + case class EdgeData(srcId: Long, dstId: Long, tokenCounts: Double) + + def save( + sc: SparkContext, + path: String, + graph: Graph[LDA.TopicCounts, LDA.TokenCount], + globalTopicTotals: LDA.TopicCounts, + k: Int, + vocabSize: Int, + docConcentration: Double, + topicConcentration: Double, + iterationTimes: Array[Double]): Unit = { + + val sqlContext = new SQLContext(sc) + import sqlContext.implicits._ + + val metadata = compact(render + (("class" -> classNameV1_0) ~ ("version" -> formatVersionV1_0) ~ + ("k" -> k) ~ ("vocabSize" -> vocabSize) ~ ("docConcentration" -> docConcentration) ~ + ("topicConcentration" -> topicConcentration) ~ + ("iterationTimes" -> iterationTimes.toSeq))) + sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path)) + + val newPath = new Path(Loader.dataPath(path), "globalTopicTotals").toString + sc.parallelize(globalTopicTotals.toArray.toSeq.map(w => Data(w)), 1).toDF() + .write.parquet(newPath) + + val verticesPath = new Path(Loader.dataPath(path), "topicCounts").toString + graph.vertices.map { case (ind, vertex) => + VertexData(ind, Vectors.fromBreeze(vertex)) + }.toDF().write.parquet(verticesPath) + + val edgesPath = new Path(Loader.dataPath(path), "tokenCounts").toString + graph.edges.map { case Edge(srcId, dstId, prop) => + EdgeData(srcId, dstId, prop) + }.toDF().write.parquet(edgesPath) + } + + def load( + sc: SparkContext, + path: String): (Graph[LDA.TopicCounts, LDA.TokenCount], LDA.TopicCounts) = { + + val dataPath = new Path(Loader.dataPath(path), "globalTopicTotals").toString + val vertexDataPath = new Path(Loader.dataPath(path), "topicCounts").toString + val edgeDataPath = new Path(Loader.dataPath(path), "tokenCounts").toString + val sqlContext = new SQLContext(sc) + val dataFrame = sqlContext.read.parquet(dataPath) + val vertexDataFrame = sqlContext.read.parquet(vertexDataPath) + val edgeDataFrame = sqlContext.read.parquet(edgeDataPath) + + Loader.checkSchema[Data](dataFrame.schema) + Loader.checkSchema[VertexData](vertexDataFrame.schema) + Loader.checkSchema[EdgeData](edgeDataFrame.schema) + val globalTopicTotals: LDA.TopicCounts = BDV(dataFrame.collect().map(i => i.getDouble(0))) + val vertices: RDD[(VertexId, LDA.TopicCounts)] = vertexDataFrame.map { + case Row(ind: Long, vec: Vector) => (ind, vec.toBreeze.toDenseVector) + } + + val edges: RDD[Edge[LDA.TokenCount]] = edgeDataFrame.map { + case Row(srcId: Long, dstId: Long, prop: Double) => Edge(srcId, dstId, prop) + } + + val graph: Graph[LDA.TopicCounts, LDA.TokenCount] = Graph(vertices, edges) + + (graph, globalTopicTotals) + } + + } + + override def load(sc: SparkContext, path: String): DistributedLDAModel = { + + val (loadedClassName, loadedVersion, metadata) = Loader.loadMetadata(sc, path) + implicit val formats = DefaultFormats + val expectedK = (metadata \ "k").extract[Int] + val vocabSize = (metadata \ "vocabSize").extract[Int] + val docConcentration = (metadata \ "docConcentration").extract[Double] + val topicConcentration = (metadata \ "topicConcentration").extract[Double] + val iterationTimes = (metadata \ "iterationTimes").extract[Seq[Double]] + val classNameV1_0 = SaveLoadV1_0.classNameV1_0 + + (loadedClassName, loadedVersion) match { + case (classNameV1_0, formatVersionV1_0) => { + val (graph, globalTopicTotals) = DistributedLDAModel.SaveLoadV1_0.load(sc, path) + val model = new DistributedLDAModel( + graph, globalTopicTotals, expectedK, vocabSize, docConcentration, topicConcentration, + iterationTimes.toArray) + require(expectedK == globalTopicTotals.length, + s"LocalLDAModel requires $expectedK topics, got $globalTopicTotals.length topics") + model + } + + case _ => throw new Exception( + s"DistributedLDAModel.load did not recognize model with (className, format version):" + + s"($loadedClassName, $loadedVersion). Supported: ($classNameV1_0, 1.0)") + } + } + + } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala index 406affa25539d..db8e0ea2e3efd 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.mllib.clustering import breeze.linalg.{DenseMatrix => BDM} import org.apache.spark.SparkFunSuite +import org.apache.spark.util.Utils import org.apache.spark.mllib.linalg.{Vector, DenseMatrix, Matrix, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ @@ -213,6 +214,45 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext { } } + test("model save/load") { + // Test for LocalLDAModel. + val localModel = new LocalLDAModel(tinyTopics) + val tempDir1 = Utils.createTempDir() + val path1 = tempDir1.toURI.toString + + // Test for DistributedLDAModel. + val k = 3 + val topicSmoothing = 1.2 + val termSmoothing = 1.2 + val lda = new LDA() + lda.setK(k) + .setDocConcentration(topicSmoothing) + .setTopicConcentration(termSmoothing) + .setMaxIterations(5) + .setSeed(12345) + val corpus = sc.parallelize(tinyCorpus, 2) + val distributedModel: DistributedLDAModel = lda.run(corpus).asInstanceOf[DistributedLDAModel] + val tempDir2 = Utils.createTempDir() + val path2 = tempDir2.toURI.toString + + try { + localModel.save(sc, path1) + distributedModel.save(sc, path2) + val samelocalModel = LocalLDAModel.load(sc, path1) + assert(samelocalModel.topicsMatrix === localModel.topicsMatrix) + assert(samelocalModel.k === localModel.k) + assert(samelocalModel.vocabSize === localModel.vocabSize) + + val sameDistributedModel = DistributedLDAModel.load(sc, path2) + assert(distributedModel.topicsMatrix === sameDistributedModel.topicsMatrix) + assert(distributedModel.k === sameDistributedModel.k) + assert(distributedModel.vocabSize === sameDistributedModel.vocabSize) + } finally { + Utils.deleteRecursively(tempDir1) + Utils.deleteRecursively(tempDir2) + } + } + } private[clustering] object LDASuite {