From 01baad70f44fa12ad37a743d5d0fba861d89f149 Mon Sep 17 00:00:00 2001 From: leahmcguire Date: Wed, 11 Mar 2015 15:44:22 -0700 Subject: [PATCH] made fixes from code review --- docs/mllib-naive-bayes.md | 4 ++-- .../mllib/classification/NaiveBayes.scala | 22 ++++++++----------- .../classification/NaiveBayesSuite.scala | 14 +++--------- 3 files changed, 14 insertions(+), 26 deletions(-) diff --git a/docs/mllib-naive-bayes.md b/docs/mllib-naive-bayes.md index 9e54fccad577a..d481eabe563bc 100644 --- a/docs/mllib-naive-bayes.md +++ b/docs/mllib-naive-bayes.md @@ -15,12 +15,12 @@ and use it for prediction. MLlib supports [multinomial naive Bayes](http://en.wikipedia.org/wiki/Naive_Bayes_classifier#Multinomial_naive_Bayes) and [Bernoulli naive Bayes] (http://nlp.stanford.edu/IR-book/html/htmledition/the-bernoulli-model-1.html). -Which are typically used for [document classification] +These models are typically used for [document classification] (http://nlp.stanford.edu/IR-book/html/htmledition/naive-bayes-text-classification-1.html). Within that context, each observation is a document and each feature represents a term whose value is the frequency of the term (in multinomial naive Bayes) or a zero or one indicating whether the term was found in the document (in Bernoulli naive Bayes). -Feature values must be nonnegative.The model type is selected with on optional parameter +Feature values must be nonnegative. The model type is selected with an optional parameter "Multinomial" or "Bernoulli" with "Multinomial" as the default. [Additive smoothing](http://en.wikipedia.org/wiki/Lidstone_smoothing) can be used by setting the parameter $\lambda$ (default to $1.0$). For document classification, the input feature diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala index 61085225d01c8..c7cb86fa30f0b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala @@ -49,15 +49,15 @@ class NaiveBayesModel private[mllib] ( val modelType: String) extends ClassificationModel with Serializable with Saveable { - def this(labels: Array[Double], pi: Array[Double], theta: Array[Array[Double]]) = + private[mllib] def this(labels: Array[Double], pi: Array[Double], theta: Array[Array[Double]]) = this(labels, pi, theta, NaiveBayes.Multinomial.toString) private val brzPi = new BDV[Double](pi) private val brzTheta = new BDM(theta(0).length, theta.length, theta.flatten).t - // Bernoulli scoring requires log(condprob) if 1 log(1-condprob) if 0 - // this precomputes log(1.0 - exp(theta)) and its sum for linear algebra application - // of this condition in predict function + // Bernoulli scoring requires log(condprob) if 1, log(1-condprob) if 0. + // This precomputes log(1.0 - exp(theta)) and its sum which are used for the linear algebra + // application of this condition (in predict function). private val (brzNegTheta, brzNegThetaSum) = NaiveBayes.ModelType.fromString(modelType) match { case NaiveBayes.Multinomial => (None, None) case NaiveBayes.Bernoulli => @@ -186,8 +186,6 @@ class NaiveBayes private ( private var lambda: Double, private var modelType: NaiveBayes.ModelType) extends Serializable with Logging { - def this(lambda: Double) = this(lambda, NaiveBayes.Multinomial) - def this() = this(1.0, NaiveBayes.Multinomial) /** Set the smoothing parameter. Default: 1.0. */ @@ -202,6 +200,7 @@ class NaiveBayes private ( this } + def getModelType(): NaiveBayes.ModelType = this.modelType /** * Run the algorithm with the configured parameters on an input RDD of LabeledPoint entries. @@ -301,10 +300,9 @@ object NaiveBayes { * @param lambda The smoothing parameter */ def train(input: RDD[LabeledPoint], lambda: Double): NaiveBayesModel = { - new NaiveBayes(lambda).run(input) + new NaiveBayes(lambda, NaiveBayes.Multinomial).run(input) } - /** * Trains a Naive Bayes model given an RDD of `(label, features)` pairs. * @@ -327,11 +325,7 @@ object NaiveBayes { new NaiveBayes(lambda, MODELTYPE.fromString(modelType)).run(input) } - - /** - * Model types supported in Naive Bayes: - * multinomial and Bernoulli currently supported - */ + /** Provides static methods for using ModelType. */ sealed abstract class ModelType object MODELTYPE { @@ -348,10 +342,12 @@ object NaiveBayes { final val ModelType = MODELTYPE + /** Constant for specifying ModelType parameter: multinomial model */ final val Multinomial: ModelType = new ModelType { override def toString: String = ModelType.MULTINOMIAL_STRING } + /** Constant for specifying ModelType parameter: bernoulli model */ final val Bernoulli: ModelType = new ModelType { override def toString: String = ModelType.BERNOULLI_STRING } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala index acc5b35c1bdb6..7ce9be4e3cdd4 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala @@ -58,7 +58,7 @@ object NaiveBayesSuite { for (i <- 0 until nPoints) yield { val y = calcLabel(rnd.nextDouble(), _pi) val xi = dataModel match { - case NaiveBayes.Bernoulli => Array.tabulate[Double] (D) {j => + case NaiveBayes.Bernoulli => Array.tabulate[Double] (D) { j => if (rnd.nextDouble () < _theta(y)(j) ) 1 else 0 } case NaiveBayes.Multinomial => @@ -118,11 +118,7 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext { ).map(_.map(math.log)) val testData = NaiveBayesSuite.generateNaiveBayesInput( - pi, - theta, - nPoints, - 42, - NaiveBayes.Multinomial) + pi, theta, nPoints, 42, NaiveBayes.Multinomial) val testRDD = sc.parallelize(testData, 2) testRDD.cache() @@ -130,11 +126,7 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext { validateModelFit(pi, theta, model) val validationData = NaiveBayesSuite.generateNaiveBayesInput( - pi, - theta, - nPoints, - 17, - NaiveBayes.Multinomial) + pi, theta, nPoints, 17, NaiveBayes.Multinomial) val validationRDD = sc.parallelize(validationData, 2) // Test prediction on RDD.