diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala index 2436d1c2a6938..920b57756b625 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala @@ -208,12 +208,12 @@ object LocalLDAModel extends Loader[LocalLDAModel] { val thisClassName = "org.apache.spark.mllib.clustering.LocalLDAModel" - // Store the distribution of terms of each topic as a Row in data. - case class Data(topic: Vector) + // Store the distribution of terms of each topic and the column index in topicsMatrix + // as a Row in data. + case class Data(topic: Vector, index: Int) def save(sc: SparkContext, path: String, topicsMatrix: Matrix): Unit = { - - val sqlContext = new SQLContext(sc) + val sqlContext = SQLContext.getOrCreate(sc) import sqlContext.implicits._ val k = topicsMatrix.numCols @@ -224,24 +224,23 @@ object LocalLDAModel extends Loader[LocalLDAModel] { val topicsDenseMatrix = topicsMatrix.toBreeze.toDenseMatrix val topics = Range(0, k).map { topicInd => - Data(Vectors.dense((topicsDenseMatrix(::, topicInd).toArray))) + Data(Vectors.dense((topicsDenseMatrix(::, topicInd).toArray)), topicInd) }.toSeq sc.parallelize(topics, 1).toDF().write.parquet(Loader.dataPath(path)) } def load(sc: SparkContext, path: String): LocalLDAModel = { - val dataPath = Loader.dataPath(path) val sqlContext = SQLContext.getOrCreate(sc) val dataFrame = sqlContext.read.parquet(dataPath) Loader.checkSchema[Data](dataFrame.schema) val topics = dataFrame.collect() - val vocabSize = topics(0)(0).asInstanceOf[Vector].size + val vocabSize = topics(0).getAs[Vector](0).size val k = topics.size val brzTopics = BDM.zeros[Double](vocabSize, k) - topics.zipWithIndex.foreach { case (Row(vec: Vector), ind: Int) => + topics.foreach { case Row(vec: Vector, ind: Int) => brzTopics(::, ind) := vec.toBreeze } new LocalLDAModel(Matrices.fromBreeze(brzTopics)) @@ -249,7 +248,6 @@ object LocalLDAModel extends Loader[LocalLDAModel] { } override def load(sc: SparkContext, path: String): LocalLDAModel = { - val (loadedClassName, loadedVersion, metadata) = Loader.loadMetadata(sc, path) implicit val formats = DefaultFormats val expectedK = (metadata \ "k").extract[Int] @@ -267,10 +265,10 @@ object LocalLDAModel extends Loader[LocalLDAModel] { val topicsMatrix = model.topicsMatrix require(expectedK == topicsMatrix.numCols, - s"LocalLDAModel requires $expectedK topics, got $topicsMatrix.numCols topics") + s"LocalLDAModel requires $expectedK topics, got ${topicsMatrix.numCols} topics") require(expectedVocabSize == topicsMatrix.numRows, s"LocalLDAModel requires $expectedVocabSize terms for each topic, " + - s"but got $topicsMatrix.numRows") + s"but got ${topicsMatrix.numRows}") model } } @@ -456,16 +454,15 @@ class DistributedLDAModel private ( @Experimental -object DistributedLDAModel extends Loader[DistributedLDAModel]{ - +object DistributedLDAModel extends Loader[DistributedLDAModel] { - object SaveLoadV1_0 { + private object SaveLoadV1_0 { val thisFormatVersion = "1.0" val classNameV1_0 = "org.apache.spark.mllib.clustering.DistributedLDAModel" - // Store the weight of each topic separately in a row. + // Store globalTopicTotals as a Vector. case class Data(globalTopicTotals: Vector) // Store each term and document vertex with an id and the topicWeights. @@ -484,7 +481,6 @@ object DistributedLDAModel extends Loader[DistributedLDAModel]{ docConcentration: Double, topicConcentration: Double, iterationTimes: Array[Double]): Unit = { - val sqlContext = SQLContext.getOrCreate(sc) import sqlContext.implicits._ @@ -517,11 +513,10 @@ object DistributedLDAModel extends Loader[DistributedLDAModel]{ docConcentration: Double, topicConcentration: Double, iterationTimes: Array[Double]): DistributedLDAModel = { - val dataPath = new Path(Loader.dataPath(path), "globalTopicTotals").toUri.toString val vertexDataPath = new Path(Loader.dataPath(path), "topicCounts").toUri.toString val edgeDataPath = new Path(Loader.dataPath(path), "tokenCounts").toUri.toString - val sqlContext = new SQLContext(sc) + val sqlContext = SQLContext.getOrCreate(sc) val dataFrame = sqlContext.read.parquet(dataPath) val vertexDataFrame = sqlContext.read.parquet(vertexDataPath) val edgeDataFrame = sqlContext.read.parquet(edgeDataPath) @@ -540,15 +535,13 @@ object DistributedLDAModel extends Loader[DistributedLDAModel]{ } val graph: Graph[LDA.TopicCounts, LDA.TokenCount] = Graph(vertices, edges) - new DistributedLDAModel( - graph, globalTopicTotals, globalTopicTotals.length, vocabSize, + new DistributedLDAModel(graph, globalTopicTotals, globalTopicTotals.length, vocabSize, docConcentration, topicConcentration, iterationTimes) } } override def load(sc: SparkContext, path: String): DistributedLDAModel = { - val (loadedClassName, loadedVersion, metadata) = Loader.loadMetadata(sc, path) implicit val formats = DefaultFormats val expectedK = (metadata \ "k").extract[Int] @@ -569,15 +562,15 @@ object DistributedLDAModel extends Loader[DistributedLDAModel]{ } require(model.vocabSize == vocabSize, - s"DistributedLDAModel requires $vocabSize vocabSize, got $model.vocabSize vocabSize") + s"DistributedLDAModel requires $vocabSize vocabSize, got ${model.vocabSize} vocabSize") require(model.docConcentration == docConcentration, s"DistributedLDAModel requires $docConcentration docConcentration, " + - s"got $model.docConcentration docConcentration") + s"got ${model.docConcentration} docConcentration") require(model.topicConcentration == topicConcentration, s"DistributedLDAModel requires $topicConcentration docConcentration, " + - s"got $model.topicConcentration docConcentration") + s"got ${model.topicConcentration} docConcentration") require(expectedK == model.k, - s"DistributedLDAModel requires $expectedK topics, got $model.k topics") + s"DistributedLDAModel requires $expectedK topics, got ${model.k} topics") model } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala index 86888f7dea50f..0b851e8cb03d3 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala @@ -222,12 +222,12 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext { // Test for DistributedLDAModel. val k = 3 - val topicSmoothing = 1.2 - val termSmoothing = 1.2 + val docConcentration = 1.2 + val topicConcentration = 1.5 val lda = new LDA() lda.setK(k) - .setDocConcentration(topicSmoothing) - .setTopicConcentration(termSmoothing) + .setDocConcentration(docConcentration) + .setTopicConcentration(topicConcentration) .setMaxIterations(5) .setSeed(12345) val corpus = sc.parallelize(tinyCorpus, 2)