From cc140547696d84bd639d558feaa79739d19c2864 Mon Sep 17 00:00:00 2001 From: MechCoder Date: Thu, 16 Jul 2015 03:58:28 +0530 Subject: [PATCH] minor --- .../apache/spark/mllib/clustering/LDAModel.scala | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala index b063dc4529434..2436d1c2a6938 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala @@ -209,7 +209,7 @@ object LocalLDAModel extends Loader[LocalLDAModel] { val thisClassName = "org.apache.spark.mllib.clustering.LocalLDAModel" // Store the distribution of terms of each topic as a Row in data. - case class Data(topics: Vector) + case class Data(topic: Vector) def save(sc: SparkContext, path: String, topicsMatrix: Matrix): Unit = { @@ -257,7 +257,8 @@ object LocalLDAModel extends Loader[LocalLDAModel] { val classNameV1_0 = SaveLoadV1_0.thisClassName val model = (loadedClassName, loadedVersion) match { - case (classNameV1_0, "1.0") => SaveLoadV1_0.load(sc, path) + case (className, "1.0") if className == classNameV1_0 => + SaveLoadV1_0.load(sc, path) case _ => throw new Exception( s"LocalLDAModel.load did not recognize model with (className, format version):" + s"($loadedClassName, $loadedVersion). Supported:\n" + @@ -465,7 +466,7 @@ object DistributedLDAModel extends Loader[DistributedLDAModel]{ val classNameV1_0 = "org.apache.spark.mllib.clustering.DistributedLDAModel" // Store the weight of each topic separately in a row. - case class Data(globalTopicTotals: Double) + case class Data(globalTopicTotals: Vector) // Store each term and document vertex with an id and the topicWeights. case class VertexData(id: Long, topicWeights: Vector) @@ -495,7 +496,7 @@ object DistributedLDAModel extends Loader[DistributedLDAModel]{ sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path)) val newPath = new Path(Loader.dataPath(path), "globalTopicTotals").toUri.toString - sc.parallelize(globalTopicTotals.toArray.toSeq.map(w => Data(w)), 1).toDF() + sc.parallelize(Seq(Data(Vectors.fromBreeze(globalTopicTotals)))).toDF() .write.parquet(newPath) val verticesPath = new Path(Loader.dataPath(path), "topicCounts").toUri.toString @@ -528,7 +529,8 @@ object DistributedLDAModel extends Loader[DistributedLDAModel]{ Loader.checkSchema[Data](dataFrame.schema) Loader.checkSchema[VertexData](vertexDataFrame.schema) Loader.checkSchema[EdgeData](edgeDataFrame.schema) - val globalTopicTotals: LDA.TopicCounts = BDV(dataFrame.collect().map(i => i.getDouble(0))) + val globalTopicTotals: LDA.TopicCounts = + dataFrame.first().getAs[Vector](0).toBreeze.toDenseVector val vertices: RDD[(VertexId, LDA.TopicCounts)] = vertexDataFrame.map { case Row(ind: Long, vec: Vector) => (ind, vec.toBreeze.toDenseVector) } @@ -557,7 +559,7 @@ object DistributedLDAModel extends Loader[DistributedLDAModel]{ val classNameV1_0 = SaveLoadV1_0.classNameV1_0 val model = (loadedClassName, loadedVersion) match { - case (classNameV1_0, "1.0") => { + case (className, "1.0") if className == classNameV1_0 => { DistributedLDAModel.SaveLoadV1_0.load( sc, path, vocabSize, docConcentration, topicConcentration, iterationTimes.toArray) }