Skip to content

Commit

Permalink
minor
Browse files Browse the repository at this point in the history
  • Loading branch information
MechCoder committed Jul 15, 2015
1 parent 4587d1d commit cc14054
Showing 1 changed file with 8 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ object LocalLDAModel extends Loader[LocalLDAModel] {
val thisClassName = "org.apache.spark.mllib.clustering.LocalLDAModel"

// Store the distribution of terms of each topic as a Row in data.
case class Data(topics: Vector)
case class Data(topic: Vector)

def save(sc: SparkContext, path: String, topicsMatrix: Matrix): Unit = {

Expand Down Expand Up @@ -257,7 +257,8 @@ object LocalLDAModel extends Loader[LocalLDAModel] {
val classNameV1_0 = SaveLoadV1_0.thisClassName

val model = (loadedClassName, loadedVersion) match {
case (classNameV1_0, "1.0") => SaveLoadV1_0.load(sc, path)
case (className, "1.0") if className == classNameV1_0 =>
SaveLoadV1_0.load(sc, path)
case _ => throw new Exception(
s"LocalLDAModel.load did not recognize model with (className, format version):" +
s"($loadedClassName, $loadedVersion). Supported:\n" +
Expand Down Expand Up @@ -465,7 +466,7 @@ object DistributedLDAModel extends Loader[DistributedLDAModel]{
val classNameV1_0 = "org.apache.spark.mllib.clustering.DistributedLDAModel"

// Store the weight of each topic separately in a row.
case class Data(globalTopicTotals: Double)
case class Data(globalTopicTotals: Vector)

// Store each term and document vertex with an id and the topicWeights.
case class VertexData(id: Long, topicWeights: Vector)
Expand Down Expand Up @@ -495,7 +496,7 @@ object DistributedLDAModel extends Loader[DistributedLDAModel]{
sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path))

val newPath = new Path(Loader.dataPath(path), "globalTopicTotals").toUri.toString
sc.parallelize(globalTopicTotals.toArray.toSeq.map(w => Data(w)), 1).toDF()
sc.parallelize(Seq(Data(Vectors.fromBreeze(globalTopicTotals)))).toDF()
.write.parquet(newPath)

val verticesPath = new Path(Loader.dataPath(path), "topicCounts").toUri.toString
Expand Down Expand Up @@ -528,7 +529,8 @@ object DistributedLDAModel extends Loader[DistributedLDAModel]{
Loader.checkSchema[Data](dataFrame.schema)
Loader.checkSchema[VertexData](vertexDataFrame.schema)
Loader.checkSchema[EdgeData](edgeDataFrame.schema)
val globalTopicTotals: LDA.TopicCounts = BDV(dataFrame.collect().map(i => i.getDouble(0)))
val globalTopicTotals: LDA.TopicCounts =
dataFrame.first().getAs[Vector](0).toBreeze.toDenseVector
val vertices: RDD[(VertexId, LDA.TopicCounts)] = vertexDataFrame.map {
case Row(ind: Long, vec: Vector) => (ind, vec.toBreeze.toDenseVector)
}
Expand Down Expand Up @@ -557,7 +559,7 @@ object DistributedLDAModel extends Loader[DistributedLDAModel]{
val classNameV1_0 = SaveLoadV1_0.classNameV1_0

val model = (loadedClassName, loadedVersion) match {
case (classNameV1_0, "1.0") => {
case (className, "1.0") if className == classNameV1_0 => {
DistributedLDAModel.SaveLoadV1_0.load(
sc, path, vocabSize, docConcentration, topicConcentration, iterationTimes.toArray)
}
Expand Down

0 comments on commit cc14054

Please sign in to comment.