diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala index b69af44a35b2a..0d2be054ac72a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala @@ -35,6 +35,8 @@ import org.apache.spark.mllib.util.{Loader, Saveable} import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, SQLContext} +import NaiveBayes.ModelType.{Bernoulli, Multinomial} + /** * Model for Naive Bayes Classifiers. @@ -54,7 +56,7 @@ class NaiveBayesModel private[mllib] ( extends ClassificationModel with Serializable with Saveable { private[mllib] def this(labels: Array[Double], pi: Array[Double], theta: Array[Array[Double]]) = - this(labels, pi, theta, NaiveBayes.Multinomial) + this(labels, pi, theta, Multinomial) /** A Java-friendly constructor that takes three Iterable parameters. */ private[mllib] def this( @@ -70,10 +72,13 @@ class NaiveBayesModel private[mllib] ( // This precomputes log(1.0 - exp(theta)) and its sum which are used for the linear algebra // application of this condition (in predict function). private val (brzNegTheta, brzNegThetaSum) = modelType match { - case NaiveBayes.Multinomial => (None, None) - case NaiveBayes.Bernoulli => + case Multinomial => (None, None) + case Bernoulli => val negTheta = brzLog((brzExp(brzTheta.copy) :*= (-1.0)) :+= 1.0) // log(1.0 - exp(x)) (Option(negTheta), Option(brzSum(negTheta, Axis._1))) + case _ => + // This should never happen. + throw new UnknownError(s"NaiveBayesModel was created with an unknown ModelType: $modelType") } override def predict(testData: RDD[Vector]): RDD[Double] = { @@ -86,29 +91,32 @@ class NaiveBayesModel private[mllib] ( override def predict(testData: Vector): Double = { modelType match { - case NaiveBayes.Multinomial => + case Multinomial => labels (brzArgmax (brzPi + brzTheta * testData.toBreeze) ) - case NaiveBayes.Bernoulli => + case Bernoulli => labels (brzArgmax (brzPi + (brzTheta - brzNegTheta.get) * testData.toBreeze + brzNegThetaSum.get)) + case _ => + // This should never happen. + throw new UnknownError(s"NaiveBayesModel was created with an unknown ModelType: $modelType") } } override def save(sc: SparkContext, path: String): Unit = { - val data = NaiveBayesModel.SaveLoadV1_0.Data(labels, pi, theta, modelType.toString) - NaiveBayesModel.SaveLoadV1_0.save(sc, path, data) + val data = NaiveBayesModel.SaveLoadV2_0.Data(labels, pi, theta, modelType.toString) + NaiveBayesModel.SaveLoadV2_0.save(sc, path, data) } - override protected def formatVersion: String = "1.0" + override protected def formatVersion: String = "2.0" } object NaiveBayesModel extends Loader[NaiveBayesModel] { import org.apache.spark.mllib.util.Loader._ - private object SaveLoadV1_0 { + private[mllib] object SaveLoadV2_0 { - def thisFormatVersion: String = "1.0" + def thisFormatVersion: String = "2.0" /** Hard-code class name string in case it changes in the future */ def thisClassName: String = "org.apache.spark.mllib.classification.NaiveBayesModel" @@ -127,8 +135,7 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] { // Create JSON metadata. val metadata = compact(render( ("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ - ("numFeatures" -> data.theta(0).length) ~ ("numClasses" -> data.pi.length) ~ - ("modelType" -> data.modelType))) + ("numFeatures" -> data.theta(0).length) ~ ("numClasses" -> data.pi.length))) sc.parallelize(Seq(metadata), 1).saveAsTextFile(metadataPath(path)) // Create Parquet data. @@ -151,36 +158,82 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] { val modelType = NaiveBayes.ModelType.fromString(data.getString(3)) new NaiveBayesModel(labels, pi, theta, modelType) } + } - override def load(sc: SparkContext, path: String): NaiveBayesModel = { - def getModelType(metadata: JValue): NaiveBayes.ModelType = { - implicit val formats = DefaultFormats - NaiveBayes.ModelType.fromString((metadata \ "modelType").extract[String]) + private[mllib] object SaveLoadV1_0 { + + def thisFormatVersion: String = "1.0" + + /** Hard-code class name string in case it changes in the future */ + def thisClassName: String = "org.apache.spark.mllib.classification.NaiveBayesModel" + + /** Model data for model import/export */ + case class Data( + labels: Array[Double], + pi: Array[Double], + theta: Array[Array[Double]]) + + def save(sc: SparkContext, path: String, data: Data): Unit = { + val sqlContext = new SQLContext(sc) + import sqlContext.implicits._ + + // Create JSON metadata. + val metadata = compact(render( + ("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ + ("numFeatures" -> data.theta(0).length) ~ ("numClasses" -> data.pi.length))) + sc.parallelize(Seq(metadata), 1).saveAsTextFile(metadataPath(path)) + + // Create Parquet data. + val dataRDD: DataFrame = sc.parallelize(Seq(data), 1).toDF() + dataRDD.saveAsParquetFile(dataPath(path)) + } + + def load(sc: SparkContext, path: String): NaiveBayesModel = { + val sqlContext = new SQLContext(sc) + // Load Parquet data. + val dataRDD = sqlContext.parquetFile(dataPath(path)) + // Check schema explicitly since erasure makes it hard to use match-case for checking. + checkSchema[Data](dataRDD.schema) + val dataArray = dataRDD.select("labels", "pi", "theta").take(1) + assert(dataArray.size == 1, s"Unable to load NaiveBayesModel data from: ${dataPath(path)}") + val data = dataArray(0) + val labels = data.getAs[Seq[Double]](0).toArray + val pi = data.getAs[Seq[Double]](1).toArray + val theta = data.getAs[Seq[Seq[Double]]](2).map(_.toArray).toArray + new NaiveBayesModel(labels, pi, theta) } + } + + override def load(sc: SparkContext, path: String): NaiveBayesModel = { val (loadedClassName, version, metadata) = loadMetadata(sc, path) val classNameV1_0 = SaveLoadV1_0.thisClassName - (loadedClassName, version) match { + val classNameV2_0 = SaveLoadV2_0.thisClassName + val (model, numFeatures, numClasses) = (loadedClassName, version) match { case (className, "1.0") if className == classNameV1_0 => val (numFeatures, numClasses) = ClassificationModel.getNumFeaturesClasses(metadata) val model = SaveLoadV1_0.load(sc, path) - assert(model.pi.size == numClasses, - s"NaiveBayesModel.load expected $numClasses classes," + - s" but class priors vector pi had ${model.pi.size} elements") - assert(model.theta.size == numClasses, - s"NaiveBayesModel.load expected $numClasses classes," + - s" but class conditionals array theta had ${model.theta.size} elements") - assert(model.theta.forall(_.size == numFeatures), - s"NaiveBayesModel.load expected $numFeatures features," + - s" but class conditionals array theta had elements of size:" + - s" ${model.theta.map(_.size).mkString(",")}") - assert(model.modelType == getModelType(metadata)) - model + (model, numFeatures, numClasses) + case (className, "2.0") if className == classNameV2_0 => + val (numFeatures, numClasses) = ClassificationModel.getNumFeaturesClasses(metadata) + val model = SaveLoadV2_0.load(sc, path) + (model, numFeatures, numClasses) case _ => throw new Exception( s"NaiveBayesModel.load did not recognize model with (className, format version):" + s"($loadedClassName, $version). Supported:\n" + s" ($classNameV1_0, 1.0)") } + assert(model.pi.size == numClasses, + s"NaiveBayesModel.load expected $numClasses classes," + + s" but class priors vector pi had ${model.pi.size} elements") + assert(model.theta.size == numClasses, + s"NaiveBayesModel.load expected $numClasses classes," + + s" but class conditionals array theta had ${model.theta.size} elements") + assert(model.theta.forall(_.size == numFeatures), + s"NaiveBayesModel.load expected $numFeatures features," + + s" but class conditionals array theta had elements of size:" + + s" ${model.theta.map(_.size).mkString(",")}") + model } } @@ -197,9 +250,9 @@ class NaiveBayes private ( private var lambda: Double, private var modelType: NaiveBayes.ModelType) extends Serializable with Logging { - def this(lambda: Double) = this(lambda, NaiveBayes.Multinomial) + def this(lambda: Double) = this(lambda, Multinomial) - def this() = this(1.0, NaiveBayes.Multinomial) + def this() = this(1.0, Multinomial) /** Set the smoothing parameter. Default: 1.0. */ def setLambda(lambda: Double): NaiveBayes = { @@ -210,9 +263,22 @@ class NaiveBayes private ( /** Get the smoothing parameter. */ def getLambda: Double = lambda - /** Set the model type. Default: Multinomial. */ - def setModelType(model: NaiveBayes.ModelType): NaiveBayes = { - this.modelType = model + /** + * Set the model type using a string (case-insensitive). + * Supported options: "multinomial" and "bernoulli". + * (default: multinomial) + */ + def setModelType(modelType: String): NaiveBayes = { + setModelType(NaiveBayes.ModelType.fromString(modelType)) + } + + /** + * Set the model type. + * Supported options: [[NaiveBayes.ModelType.Bernoulli]], [[NaiveBayes.ModelType.Multinomial]] + * (default: Multinomial) + */ + def setModelType(modelType: NaiveBayes.ModelType): NaiveBayes = { + this.modelType = modelType this } @@ -270,8 +336,11 @@ class NaiveBayes private ( labels(i) = label pi(i) = math.log(n + lambda) - piLogDenom val thetaLogDenom = modelType match { - case NaiveBayes.Multinomial => math.log(brzSum(sumTermFreqs) + numFeatures * lambda) - case NaiveBayes.Bernoulli => math.log(n + 2.0 * lambda) + case Multinomial => math.log(brzSum(sumTermFreqs) + numFeatures * lambda) + case Bernoulli => math.log(n + 2.0 * lambda) + case _ => + // This should never happen. + throw new UnknownError(s"NaiveBayes was created with an unknown ModelType: $modelType") } var j = 0 while (j < numFeatures) { @@ -317,7 +386,7 @@ object NaiveBayes { * @param lambda The smoothing parameter */ def train(input: RDD[LabeledPoint], lambda: Double): NaiveBayesModel = { - new NaiveBayes(lambda, NaiveBayes.Multinomial).run(input) + new NaiveBayes(lambda, NaiveBayes.ModelType.Multinomial).run(input) } /** @@ -339,12 +408,45 @@ object NaiveBayes { * multinomial or bernoulli */ def train(input: RDD[LabeledPoint], lambda: Double, modelType: String): NaiveBayesModel = { - new NaiveBayes(lambda, MODELTYPE.fromString(modelType)).run(input) + new NaiveBayes(lambda, ModelType.fromString(modelType)).run(input) } /** Provides static methods for using ModelType. */ sealed abstract class ModelType extends Serializable + object ModelType extends Serializable { + + /** + * Get the model type from a string. + * @param modelType Supported: "multinomial" or "bernoulli" (case-insensitive) + */ + def fromString(modelType: String): ModelType = modelType.toLowerCase match { + case "multinomial" => Multinomial + case "bernoulli" => Bernoulli + case _ => + throw new IllegalArgumentException( + s"NaiveBayes.ModelType.fromString did not recognize string: $modelType") + } + + final val Multinomial: ModelType = { + case object Multinomial extends ModelType with Serializable { + override def toString: String = "multinomial" + } + Multinomial + } + + final val Bernoulli: ModelType = { + case object Bernoulli extends ModelType with Serializable { + override def toString: String = "bernoulli" + } + Bernoulli + } + } + + /** Java-friendly accessor for supported ModelType options */ + final val modelTypes = ModelType + + /* object MODELTYPE extends Serializable{ final val MULTINOMIAL_STRING = "multinomial" final val BERNOULLI_STRING = "bernoulli" @@ -368,6 +470,6 @@ object NaiveBayes { final val Bernoulli: ModelType = new ModelType { override def toString: String = ModelType.BERNOULLI_STRING } - + */ } diff --git a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java index 1c90522a0714a..4d89c06b88c0e 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java @@ -17,20 +17,22 @@ package org.apache.spark.mllib.classification; +import java.io.Serializable; +import java.util.Arrays; +import java.util.List; + +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.Function; import org.apache.spark.mllib.linalg.Vector; import org.apache.spark.mllib.linalg.Vectors; import org.apache.spark.mllib.regression.LabeledPoint; -import org.junit.After; -import org.junit.Assert; -import org.junit.Before; -import org.junit.Test; -import java.io.Serializable; -import java.util.Arrays; -import java.util.List; public class JavaNaiveBayesSuite implements Serializable { private transient JavaSparkContext sc; @@ -102,4 +104,11 @@ public Vector call(LabeledPoint v) throws Exception { // Should be able to get the first prediction. predictions.first(); } + + @Test + public void testModelTypeSetters() { + NaiveBayes nb = new NaiveBayes() + .setModelType(NaiveBayes.modelTypes().Bernoulli()) + .setModelType(NaiveBayes.modelTypes().Multinomial()); + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala index 8b795e411817c..2d87d6893250b 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala @@ -17,14 +17,15 @@ package org.apache.spark.mllib.classification -import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, argmax => brzArgmax, sum => brzSum, Axis} -import breeze.stats.distributions.Multinomial - import scala.util.Random +import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, argmax => brzArgmax, sum => brzSum, Axis} +import breeze.stats.distributions.{Multinomial => BrzMultinomial} + import org.scalatest.FunSuite import org.apache.spark.SparkException +import org.apache.spark.mllib.classification.NaiveBayes.ModelType.{Bernoulli, Multinomial} import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext} @@ -48,7 +49,7 @@ object NaiveBayesSuite { theta: Array[Array[Double]], // CXD nPoints: Int, seed: Int, - dataModel: NaiveBayes.ModelType = NaiveBayes.Multinomial, + modelType: NaiveBayes.ModelType = Multinomial, sample: Int = 10): Seq[LabeledPoint] = { val D = theta(0).length val rnd = new Random(seed) @@ -57,26 +58,35 @@ object NaiveBayesSuite { for (i <- 0 until nPoints) yield { val y = calcLabel(rnd.nextDouble(), _pi) - val xi = dataModel match { - case NaiveBayes.Bernoulli => Array.tabulate[Double] (D) { j => + val xi = modelType match { + case Bernoulli => Array.tabulate[Double] (D) { j => if (rnd.nextDouble () < _theta(y)(j) ) 1 else 0 } - case NaiveBayes.Multinomial => - val mult = Multinomial(BDV(_theta(y))) + case Multinomial => + val mult = BrzMultinomial(BDV(_theta(y))) val emptyMap = (0 until D).map(x => (x, 0.0)).toMap val counts = emptyMap ++ mult.sample(sample).groupBy(x => x).map { case (index, reps) => (index, reps.size.toDouble) } counts.toArray.sortBy(_._1).map(_._2) + case _ => + // This should never happen. + throw new UnknownError(s"NaiveBayesSuite found unknown ModelType: $modelType") } LabeledPoint(y, Vectors.dense(xi)) } } - /** Binary labels, 3 features */ - private val binaryModel = new NaiveBayesModel(labels = Array(0.0, 1.0), pi = Array(0.2, 0.8), - theta = Array(Array(0.1, 0.3, 0.6), Array(0.2, 0.4, 0.4)), NaiveBayes.Bernoulli) + /** Bernoulli NaiveBayes with binary labels, 3 features */ + private val binaryBernoulliModel = new NaiveBayesModel(labels = Array(0.0, 1.0), + pi = Array(0.2, 0.8), theta = Array(Array(0.1, 0.3, 0.6), Array(0.2, 0.4, 0.4)), + Bernoulli) + + /** Multinomial NaiveBayes with binary labels, 3 features */ + private val binaryMultinomialModel = new NaiveBayesModel(labels = Array(0.0, 1.0), + pi = Array(0.2, 0.8), theta = Array(Array(0.1, 0.3, 0.6), Array(0.2, 0.4, 0.4)), + Multinomial) } class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext { @@ -126,7 +136,7 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext { ).map(_.map(math.log)) val testData = NaiveBayesSuite.generateNaiveBayesInput( - pi, theta, nPoints, 42, NaiveBayes.Multinomial) + pi, theta, nPoints, 42, Multinomial) val testRDD = sc.parallelize(testData, 2) testRDD.cache() @@ -134,7 +144,7 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext { validateModelFit(pi, theta, model) val validationData = NaiveBayesSuite.generateNaiveBayesInput( - pi, theta, nPoints, 17, NaiveBayes.Multinomial) + pi, theta, nPoints, 17, Multinomial) val validationRDD = sc.parallelize(validationData, 2) // Test prediction on RDD. @@ -154,7 +164,7 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext { ).map(_.map(math.log)) val testData = NaiveBayesSuite.generateNaiveBayesInput( - pi, theta, nPoints, 45, NaiveBayes.Bernoulli) + pi, theta, nPoints, 45, Bernoulli) val testRDD = sc.parallelize(testData, 2) testRDD.cache() @@ -162,7 +172,7 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext { validateModelFit(pi, theta, model) val validationData = NaiveBayesSuite.generateNaiveBayesInput( - pi, theta, nPoints, 20, NaiveBayes.Bernoulli) + pi, theta, nPoints, 20, Bernoulli) val validationRDD = sc.parallelize(validationData, 2) // Test prediction on RDD. @@ -199,19 +209,41 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext { } } - test("model save/load") { - val model = NaiveBayesSuite.binaryModel + test("model save/load: 2.0 to 2.0") { + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + + Seq(NaiveBayesSuite.binaryBernoulliModel, NaiveBayesSuite.binaryMultinomialModel).map { + model => + // Save model, load it back, and compare. + try { + model.save(sc, path) + val sameModel = NaiveBayesModel.load(sc, path) + assert(model.labels === sameModel.labels) + assert(model.pi === sameModel.pi) + assert(model.theta === sameModel.theta) + assert(model.modelType === sameModel.modelType) + } finally { + Utils.deleteRecursively(tempDir) + } + } + } + + test("model save/load: 1.0 to 2.0") { + val model = NaiveBayesSuite.binaryMultinomialModel val tempDir = Utils.createTempDir() val path = tempDir.toURI.toString - // Save model, load it back, and compare. + // Save model as version 1.0, load it back, and compare. try { - model.save(sc, path) + val data = NaiveBayesModel.SaveLoadV1_0.Data(model.labels, model.pi, model.theta) + NaiveBayesModel.SaveLoadV1_0.save(sc, path, data) val sameModel = NaiveBayesModel.load(sc, path) assert(model.labels === sameModel.labels) assert(model.pi === sameModel.pi) assert(model.theta === sameModel.theta) + assert(model.modelType === NaiveBayes.ModelType.Multinomial) } finally { Utils.deleteRecursively(tempDir) }