Skip to content

Commit

Permalink
changed NaiveBayesModel modelType parameter back to NaiveBayes.ModelT…
Browse files Browse the repository at this point in the history
…ype, made NaiveBayes.ModelType serializable, fixed getter method in NavieBayes
  • Loading branch information
leahmcguire committed Mar 17, 2015
1 parent 18f3219 commit a22d670
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -46,19 +46,19 @@ class NaiveBayesModel private[mllib] (
val labels: Array[Double],
val pi: Array[Double],
val theta: Array[Array[Double]],
val modelType: String)
val modelType: NaiveBayes.ModelType)
extends ClassificationModel with Serializable with Saveable {

private[mllib] def this(labels: Array[Double], pi: Array[Double], theta: Array[Array[Double]]) =
this(labels, pi, theta, NaiveBayes.Multinomial.toString)
this(labels, pi, theta, NaiveBayes.Multinomial)

private val brzPi = new BDV[Double](pi)
private val brzTheta = new BDM(theta(0).length, theta.length, theta.flatten).t

// Bernoulli scoring requires log(condprob) if 1, log(1-condprob) if 0.
// This precomputes log(1.0 - exp(theta)) and its sum which are used for the linear algebra
// application of this condition (in predict function).
private val (brzNegTheta, brzNegThetaSum) = NaiveBayes.ModelType.fromString(modelType) match {
private val (brzNegTheta, brzNegThetaSum) = modelType match {
case NaiveBayes.Multinomial => (None, None)
case NaiveBayes.Bernoulli =>
val negTheta = brzLog((brzExp(brzTheta.copy) :*= (-1.0)) :+= 1.0) // log(1.0 - exp(x))
Expand All @@ -74,7 +74,7 @@ class NaiveBayesModel private[mllib] (
}

override def predict(testData: Vector): Double = {
NaiveBayes.ModelType.fromString(modelType) match {
modelType match {
case NaiveBayes.Multinomial =>
labels (brzArgmax (brzPi + brzTheta * testData.toBreeze) )
case NaiveBayes.Bernoulli =>
Expand All @@ -84,7 +84,7 @@ class NaiveBayesModel private[mllib] (
}

override def save(sc: SparkContext, path: String): Unit = {
val data = NaiveBayesModel.SaveLoadV1_0.Data(labels, pi, theta, modelType)
val data = NaiveBayesModel.SaveLoadV1_0.Data(labels, pi, theta, modelType.toString)
NaiveBayesModel.SaveLoadV1_0.save(sc, path, data)
}

Expand Down Expand Up @@ -137,15 +137,15 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] {
val labels = data.getAs[Seq[Double]](0).toArray
val pi = data.getAs[Seq[Double]](1).toArray
val theta = data.getAs[Seq[Seq[Double]]](2).map(_.toArray).toArray
val modelType = NaiveBayes.ModelType.fromString(data.getString(3)).toString
val modelType = NaiveBayes.ModelType.fromString(data.getString(3))
new NaiveBayesModel(labels, pi, theta, modelType)
}
}

override def load(sc: SparkContext, path: String): NaiveBayesModel = {
def getModelType(metadata: JValue): String = {
def getModelType(metadata: JValue): NaiveBayes.ModelType = {
implicit val formats = DefaultFormats
NaiveBayes.ModelType.fromString((metadata \ "modelType").extract[String]).toString
NaiveBayes.ModelType.fromString((metadata \ "modelType").extract[String])
}
val (loadedClassName, version, metadata) = loadMetadata(sc, path)
val classNameV1_0 = SaveLoadV1_0.thisClassName
Expand Down Expand Up @@ -202,7 +202,7 @@ class NaiveBayes private (
this
}

def getModelType(): NaiveBayes.ModelType = this.modelType
def getModelType: NaiveBayes.ModelType = this.modelType

/**
* Run the algorithm with the configured parameters on an input RDD of LabeledPoint entries.
Expand Down Expand Up @@ -266,7 +266,7 @@ class NaiveBayes private (
i += 1
}

new NaiveBayesModel(labels, pi, theta, modelType.toString)
new NaiveBayesModel(labels, pi, theta, modelType)
}
}

Expand Down Expand Up @@ -328,9 +328,9 @@ object NaiveBayes {
}

/** Provides static methods for using ModelType. */
sealed abstract class ModelType
sealed abstract class ModelType extends Serializable

object MODELTYPE {
object MODELTYPE extends Serializable{
final val MULTINOMIAL_STRING = "multinomial"
final val BERNOULLI_STRING = "bernoulli"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ object NaiveBayesSuite {

/** Binary labels, 3 features */
private val binaryModel = new NaiveBayesModel(labels = Array(0.0, 1.0), pi = Array(0.2, 0.8),
theta = Array(Array(0.1, 0.3, 0.6), Array(0.2, 0.4, 0.4)), NaiveBayes.Bernoulli.toString)
theta = Array(Array(0.1, 0.3, 0.6), Array(0.2, 0.4, 0.4)), NaiveBayes.Bernoulli)
}

class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext {
Expand Down

0 comments on commit a22d670

Please sign in to comment.