Skip to content

Commit

Permalink
removed enum type and replaces all modelType parameters with strings
Browse files Browse the repository at this point in the history
  • Loading branch information
leahmcguire committed Mar 28, 2015
1 parent 2224b15 commit acb69af
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 83 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,6 @@ import org.apache.spark.mllib.util.{Loader, Saveable}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, SQLContext}

import NaiveBayes.ModelType.{Bernoulli, Multinomial}


/**
* Model for Naive Bayes Classifiers.
Expand All @@ -45,18 +43,17 @@ import NaiveBayes.ModelType.{Bernoulli, Multinomial}
* @param pi log of class priors, whose dimension is C, number of labels
* @param theta log of class conditional probabilities, whose dimension is C-by-D,
* where D is number of features
* @param modelType The type of NB model to fit from the enumeration NaiveBayesModels, can be
* Multinomial or Bernoulli
* @param modelType The type of NB model to fit can be "Multinomial" or "Bernoulli"
*/
class NaiveBayesModel private[mllib] (
val labels: Array[Double],
val pi: Array[Double],
val theta: Array[Array[Double]],
val modelType: NaiveBayes.ModelType)
val modelType: String)
extends ClassificationModel with Serializable with Saveable {

private[mllib] def this(labels: Array[Double], pi: Array[Double], theta: Array[Array[Double]]) =
this(labels, pi, theta, Multinomial)
this(labels, pi, theta, "Multinomial")

/** A Java-friendly constructor that takes three Iterable parameters. */
private[mllib] def this(
Expand All @@ -72,8 +69,8 @@ class NaiveBayesModel private[mllib] (
// This precomputes log(1.0 - exp(theta)) and its sum which are used for the linear algebra
// application of this condition (in predict function).
private val (brzNegTheta, brzNegThetaSum) = modelType match {
case Multinomial => (None, None)
case Bernoulli =>
case "Multinomial" => (None, None)
case "Bernoulli" =>
val negTheta = brzLog((brzExp(brzTheta.copy) :*= (-1.0)) :+= 1.0) // log(1.0 - exp(x))
(Option(negTheta), Option(brzSum(negTheta, Axis._1)))
case _ =>
Expand All @@ -91,9 +88,9 @@ class NaiveBayesModel private[mllib] (

override def predict(testData: Vector): Double = {
modelType match {
case Multinomial =>
case "Multinomial" =>
labels (brzArgmax (brzPi + brzTheta * testData.toBreeze) )
case Bernoulli =>
case "Bernoulli" =>
labels (brzArgmax (brzPi +
(brzTheta - brzNegTheta.get) * testData.toBreeze + brzNegThetaSum.get))
case _ =>
Expand All @@ -103,7 +100,7 @@ class NaiveBayesModel private[mllib] (
}

override def save(sc: SparkContext, path: String): Unit = {
val data = NaiveBayesModel.SaveLoadV2_0.Data(labels, pi, theta, modelType.toString)
val data = NaiveBayesModel.SaveLoadV2_0.Data(labels, pi, theta, modelType)
NaiveBayesModel.SaveLoadV2_0.save(sc, path, data)
}

Expand Down Expand Up @@ -155,7 +152,7 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] {
val labels = data.getAs[Seq[Double]](0).toArray
val pi = data.getAs[Seq[Double]](1).toArray
val theta = data.getAs[Seq[Seq[Double]]](2).map(_.toArray).toArray
val modelType = NaiveBayes.ModelType.fromString(data.getString(3))
val modelType = data.getString(3)
new NaiveBayesModel(labels, pi, theta, modelType)
}

Expand Down Expand Up @@ -248,11 +245,11 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] {

class NaiveBayes private (
private var lambda: Double,
private var modelType: NaiveBayes.ModelType) extends Serializable with Logging {
private var modelType: String) extends Serializable with Logging {

def this(lambda: Double) = this(lambda, Multinomial)
def this(lambda: Double) = this(lambda, "Multinomial")

def this() = this(1.0, Multinomial)
def this() = this(1.0, "Multinomial")

/** Set the smoothing parameter. Default: 1.0. */
def setLambda(lambda: Double): NaiveBayes = {
Expand All @@ -264,26 +261,21 @@ class NaiveBayes private (
def getLambda: Double = lambda

/**
* Set the model type using a string (case-insensitive).
* Supported options: "multinomial" and "bernoulli".
* (default: multinomial)
*/
def setModelType(modelType: String): NaiveBayes = {
setModelType(NaiveBayes.ModelType.fromString(modelType))
}

/**
* Set the model type.
* Supported options: [[NaiveBayes.ModelType.Bernoulli]], [[NaiveBayes.ModelType.Multinomial]]
* Set the model type using a string (case-sensitive).
* Supported options: "Multinomial" and "Bernoulli".
* (default: Multinomial)
*/
def setModelType(modelType: NaiveBayes.ModelType): NaiveBayes = {
this.modelType = modelType
this
def setModelType(modelType:String): NaiveBayes = {
if (NaiveBayes.supportedModelTypes.contains(modelType)) {
this.modelType = modelType
this
} else {
throw new UnknownError(s"NaiveBayesModel does not support ModelType: $modelType")
}
}

/** Get the model type. */
def getModelType: NaiveBayes.ModelType = this.modelType
def getModelType: String = this.modelType

/**
* Run the algorithm with the configured parameters on an input RDD of LabeledPoint entries.
Expand Down Expand Up @@ -336,8 +328,8 @@ class NaiveBayes private (
labels(i) = label
pi(i) = math.log(n + lambda) - piLogDenom
val thetaLogDenom = modelType match {
case Multinomial => math.log(brzSum(sumTermFreqs) + numFeatures * lambda)
case Bernoulli => math.log(n + 2.0 * lambda)
case "Multinomial" => math.log(brzSum(sumTermFreqs) + numFeatures * lambda)
case "Bernoulli" => math.log(n + 2.0 * lambda)
case _ =>
// This should never happen.
throw new UnknownError(s"NaiveBayes was created with an unknown ModelType: $modelType")
Expand All @@ -358,6 +350,10 @@ class NaiveBayes private (
* Top-level methods for calling naive Bayes.
*/
object NaiveBayes {

/* Set of modelTypes that NaiveBayes supports */
private[mllib] val supportedModelTypes = Set("Multinomial", "Bernoulli")

/**
* Trains a Naive Bayes model given an RDD of `(label, features)` pairs.
*
Expand Down Expand Up @@ -386,7 +382,7 @@ object NaiveBayes {
* @param lambda The smoothing parameter
*/
def train(input: RDD[LabeledPoint], lambda: Double): NaiveBayesModel = {
new NaiveBayes(lambda, NaiveBayes.ModelType.Multinomial).run(input)
new NaiveBayes(lambda, "Multinomial").run(input)
}

/**
Expand All @@ -408,42 +404,11 @@ object NaiveBayes {
* multinomial or bernoulli
*/
def train(input: RDD[LabeledPoint], lambda: Double, modelType: String): NaiveBayesModel = {
new NaiveBayes(lambda, ModelType.fromString(modelType)).run(input)
}

/** Provides static methods for using ModelType. */
sealed abstract class ModelType extends Serializable

object ModelType extends Serializable {

/**
* Get the model type from a string.
* @param modelType Supported: "multinomial" or "bernoulli" (case-insensitive)
*/
def fromString(modelType: String): ModelType = modelType.toLowerCase match {
case "multinomial" => Multinomial
case "bernoulli" => Bernoulli
case _ =>
throw new IllegalArgumentException(
s"NaiveBayes.ModelType.fromString did not recognize string: $modelType")
}

final val Multinomial: ModelType = {
case object Multinomial extends ModelType with Serializable {
override def toString: String = "multinomial"
}
Multinomial
}

final val Bernoulli: ModelType = {
case object Bernoulli extends ModelType with Serializable {
override def toString: String = "bernoulli"
}
Bernoulli
if (supportedModelTypes.contains(modelType)) {
new NaiveBayes(lambda, modelType).run(input)
} else {
throw new UnknownError(s"NaiveBayes was created with an unknown ModelType: $modelType")
}
}

/** Java-friendly accessor for supported ModelType options */
final val modelTypes = ModelType

}
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ public Vector call(LabeledPoint v) throws Exception {
@Test
public void testModelTypeSetters() {
NaiveBayes nb = new NaiveBayes()
.setModelType(NaiveBayes.modelTypes().Bernoulli())
.setModelType(NaiveBayes.modelTypes().Multinomial());
.setModelType("Bernoulli")
.setModelType("Multinomial");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ import breeze.stats.distributions.{Multinomial => BrzMultinomial}
import org.scalatest.FunSuite

import org.apache.spark.SparkException
import org.apache.spark.mllib.classification.NaiveBayes.ModelType.{Bernoulli, Multinomial}
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext}
Expand All @@ -49,7 +48,7 @@ object NaiveBayesSuite {
theta: Array[Array[Double]], // CXD
nPoints: Int,
seed: Int,
modelType: NaiveBayes.ModelType = Multinomial,
modelType: String = "Multinomial",
sample: Int = 10): Seq[LabeledPoint] = {
val D = theta(0).length
val rnd = new Random(seed)
Expand All @@ -59,10 +58,10 @@ object NaiveBayesSuite {
for (i <- 0 until nPoints) yield {
val y = calcLabel(rnd.nextDouble(), _pi)
val xi = modelType match {
case Bernoulli => Array.tabulate[Double] (D) { j =>
case "Bernoulli" => Array.tabulate[Double] (D) { j =>
if (rnd.nextDouble () < _theta(y)(j) ) 1 else 0
}
case Multinomial =>
case "Multinomial" =>
val mult = BrzMultinomial(BDV(_theta(y)))
val emptyMap = (0 until D).map(x => (x, 0.0)).toMap
val counts = emptyMap ++ mult.sample(sample).groupBy(x => x).map {
Expand All @@ -81,12 +80,12 @@ object NaiveBayesSuite {
/** Bernoulli NaiveBayes with binary labels, 3 features */
private val binaryBernoulliModel = new NaiveBayesModel(labels = Array(0.0, 1.0),
pi = Array(0.2, 0.8), theta = Array(Array(0.1, 0.3, 0.6), Array(0.2, 0.4, 0.4)),
Bernoulli)
"Bernoulli")

/** Multinomial NaiveBayes with binary labels, 3 features */
private val binaryMultinomialModel = new NaiveBayesModel(labels = Array(0.0, 1.0),
pi = Array(0.2, 0.8), theta = Array(Array(0.1, 0.3, 0.6), Array(0.2, 0.4, 0.4)),
Multinomial)
"Multinomial")
}

class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext {
Expand Down Expand Up @@ -136,15 +135,15 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext {
).map(_.map(math.log))

val testData = NaiveBayesSuite.generateNaiveBayesInput(
pi, theta, nPoints, 42, Multinomial)
pi, theta, nPoints, 42, "Multinomial")
val testRDD = sc.parallelize(testData, 2)
testRDD.cache()

val model = NaiveBayes.train(testRDD, 1.0, "multinomial")
val model = NaiveBayes.train(testRDD, 1.0, "Multinomial")
validateModelFit(pi, theta, model)

val validationData = NaiveBayesSuite.generateNaiveBayesInput(
pi, theta, nPoints, 17, Multinomial)
pi, theta, nPoints, 17, "Multinomial")
val validationRDD = sc.parallelize(validationData, 2)

// Test prediction on RDD.
Expand All @@ -164,15 +163,15 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext {
).map(_.map(math.log))

val testData = NaiveBayesSuite.generateNaiveBayesInput(
pi, theta, nPoints, 45, Bernoulli)
pi, theta, nPoints, 45, "Bernoulli")
val testRDD = sc.parallelize(testData, 2)
testRDD.cache()

val model = NaiveBayes.train(testRDD, 1.0, "bernoulli")
val model = NaiveBayes.train(testRDD, 1.0, "Bernoulli")
validateModelFit(pi, theta, model)

val validationData = NaiveBayesSuite.generateNaiveBayesInput(
pi, theta, nPoints, 20, Bernoulli)
pi, theta, nPoints, 20, "Bernoulli")
val validationRDD = sc.parallelize(validationData, 2)

// Test prediction on RDD.
Expand Down Expand Up @@ -243,7 +242,7 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext {
assert(model.labels === sameModel.labels)
assert(model.pi === sameModel.pi)
assert(model.theta === sameModel.theta)
assert(model.modelType === NaiveBayes.ModelType.Multinomial)
assert(model.modelType === "Multinomial")
} finally {
Utils.deleteRecursively(tempDir)
}
Expand Down

0 comments on commit acb69af

Please sign in to comment.