Skip to content

Commit

Permalink
minor style fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
MechCoder committed Jul 16, 2015
1 parent cc14054 commit 49bcdce
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -208,12 +208,12 @@ object LocalLDAModel extends Loader[LocalLDAModel] {

val thisClassName = "org.apache.spark.mllib.clustering.LocalLDAModel"

// Store the distribution of terms of each topic as a Row in data.
case class Data(topic: Vector)
// Store the distribution of terms of each topic and the column index in topicsMatrix
// as a Row in data.
case class Data(topic: Vector, index: Int)

def save(sc: SparkContext, path: String, topicsMatrix: Matrix): Unit = {

val sqlContext = new SQLContext(sc)
val sqlContext = SQLContext.getOrCreate(sc)
import sqlContext.implicits._

val k = topicsMatrix.numCols
Expand All @@ -224,32 +224,30 @@ object LocalLDAModel extends Loader[LocalLDAModel] {

val topicsDenseMatrix = topicsMatrix.toBreeze.toDenseMatrix
val topics = Range(0, k).map { topicInd =>
Data(Vectors.dense((topicsDenseMatrix(::, topicInd).toArray)))
Data(Vectors.dense((topicsDenseMatrix(::, topicInd).toArray)), topicInd)
}.toSeq
sc.parallelize(topics, 1).toDF().write.parquet(Loader.dataPath(path))
}

def load(sc: SparkContext, path: String): LocalLDAModel = {

val dataPath = Loader.dataPath(path)
val sqlContext = SQLContext.getOrCreate(sc)
val dataFrame = sqlContext.read.parquet(dataPath)

Loader.checkSchema[Data](dataFrame.schema)
val topics = dataFrame.collect()
val vocabSize = topics(0)(0).asInstanceOf[Vector].size
val vocabSize = topics(0).getAs[Vector](0).size
val k = topics.size

val brzTopics = BDM.zeros[Double](vocabSize, k)
topics.zipWithIndex.foreach { case (Row(vec: Vector), ind: Int) =>
topics.foreach { case Row(vec: Vector, ind: Int) =>
brzTopics(::, ind) := vec.toBreeze
}
new LocalLDAModel(Matrices.fromBreeze(brzTopics))
}
}

override def load(sc: SparkContext, path: String): LocalLDAModel = {

val (loadedClassName, loadedVersion, metadata) = Loader.loadMetadata(sc, path)
implicit val formats = DefaultFormats
val expectedK = (metadata \ "k").extract[Int]
Expand All @@ -267,10 +265,10 @@ object LocalLDAModel extends Loader[LocalLDAModel] {

val topicsMatrix = model.topicsMatrix
require(expectedK == topicsMatrix.numCols,
s"LocalLDAModel requires $expectedK topics, got $topicsMatrix.numCols topics")
s"LocalLDAModel requires $expectedK topics, got ${topicsMatrix.numCols} topics")
require(expectedVocabSize == topicsMatrix.numRows,
s"LocalLDAModel requires $expectedVocabSize terms for each topic, " +
s"but got $topicsMatrix.numRows")
s"but got ${topicsMatrix.numRows}")
model
}
}
Expand Down Expand Up @@ -456,16 +454,15 @@ class DistributedLDAModel private (


@Experimental
object DistributedLDAModel extends Loader[DistributedLDAModel]{

object DistributedLDAModel extends Loader[DistributedLDAModel] {

object SaveLoadV1_0 {
private object SaveLoadV1_0 {

val thisFormatVersion = "1.0"

val classNameV1_0 = "org.apache.spark.mllib.clustering.DistributedLDAModel"

// Store the weight of each topic separately in a row.
// Store globalTopicTotals as a Vector.
case class Data(globalTopicTotals: Vector)

// Store each term and document vertex with an id and the topicWeights.
Expand All @@ -484,7 +481,6 @@ object DistributedLDAModel extends Loader[DistributedLDAModel]{
docConcentration: Double,
topicConcentration: Double,
iterationTimes: Array[Double]): Unit = {

val sqlContext = SQLContext.getOrCreate(sc)
import sqlContext.implicits._

Expand Down Expand Up @@ -517,11 +513,10 @@ object DistributedLDAModel extends Loader[DistributedLDAModel]{
docConcentration: Double,
topicConcentration: Double,
iterationTimes: Array[Double]): DistributedLDAModel = {

val dataPath = new Path(Loader.dataPath(path), "globalTopicTotals").toUri.toString
val vertexDataPath = new Path(Loader.dataPath(path), "topicCounts").toUri.toString
val edgeDataPath = new Path(Loader.dataPath(path), "tokenCounts").toUri.toString
val sqlContext = new SQLContext(sc)
val sqlContext = SQLContext.getOrCreate(sc)
val dataFrame = sqlContext.read.parquet(dataPath)
val vertexDataFrame = sqlContext.read.parquet(vertexDataPath)
val edgeDataFrame = sqlContext.read.parquet(edgeDataPath)
Expand All @@ -540,15 +535,13 @@ object DistributedLDAModel extends Loader[DistributedLDAModel]{
}
val graph: Graph[LDA.TopicCounts, LDA.TokenCount] = Graph(vertices, edges)

new DistributedLDAModel(
graph, globalTopicTotals, globalTopicTotals.length, vocabSize,
new DistributedLDAModel(graph, globalTopicTotals, globalTopicTotals.length, vocabSize,
docConcentration, topicConcentration, iterationTimes)
}

}

override def load(sc: SparkContext, path: String): DistributedLDAModel = {

val (loadedClassName, loadedVersion, metadata) = Loader.loadMetadata(sc, path)
implicit val formats = DefaultFormats
val expectedK = (metadata \ "k").extract[Int]
Expand All @@ -569,15 +562,15 @@ object DistributedLDAModel extends Loader[DistributedLDAModel]{
}

require(model.vocabSize == vocabSize,
s"DistributedLDAModel requires $vocabSize vocabSize, got $model.vocabSize vocabSize")
s"DistributedLDAModel requires $vocabSize vocabSize, got ${model.vocabSize} vocabSize")
require(model.docConcentration == docConcentration,
s"DistributedLDAModel requires $docConcentration docConcentration, " +
s"got $model.docConcentration docConcentration")
s"got ${model.docConcentration} docConcentration")
require(model.topicConcentration == topicConcentration,
s"DistributedLDAModel requires $topicConcentration docConcentration, " +
s"got $model.topicConcentration docConcentration")
s"got ${model.topicConcentration} docConcentration")
require(expectedK == model.k,
s"DistributedLDAModel requires $expectedK topics, got $model.k topics")
s"DistributedLDAModel requires $expectedK topics, got ${model.k} topics")
model
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -222,12 +222,12 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext {

// Test for DistributedLDAModel.
val k = 3
val topicSmoothing = 1.2
val termSmoothing = 1.2
val docConcentration = 1.2
val topicConcentration = 1.5
val lda = new LDA()
lda.setK(k)
.setDocConcentration(topicSmoothing)
.setTopicConcentration(termSmoothing)
.setDocConcentration(docConcentration)
.setTopicConcentration(topicConcentration)
.setMaxIterations(5)
.setSeed(12345)
val corpus = sc.parallelize(tinyCorpus, 2)
Expand Down

0 comments on commit 49bcdce

Please sign in to comment.