Skip to content

Commit

Permalink
[SPARK-5757][MLLIB] replace SQL JSON usage in model import/export by …
Browse files Browse the repository at this point in the history
…json4s

This PR detaches MLlib model import/export code from SQL's JSON support, and hence unblocks apache#4544 . yhuai

Author: Xiangrui Meng <meng@databricks.com>

Closes apache#4555 from mengxr/SPARK-5757 and squashes the following commits:

b0415e8 [Xiangrui Meng] replace SQL JSON usage by json4s
  • Loading branch information
mengxr committed Feb 12, 2015
1 parent 466b1f6 commit 99bd500
Show file tree
Hide file tree
Showing 15 changed files with 92 additions and 127 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@

package org.apache.spark.mllib.classification

import org.json4s.{DefaultFormats, JValue}

import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.util.Loader
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Row}

/**
* :: Experimental ::
Expand Down Expand Up @@ -60,16 +60,10 @@ private[mllib] object ClassificationModel {

/**
* Helper method for loading GLM classification model metadata.
*
* @param modelClass String name for model class (used for error messages)
* @return (numFeatures, numClasses)
*/
def getNumFeaturesClasses(metadata: DataFrame, modelClass: String, path: String): (Int, Int) = {
metadata.select("numFeatures", "numClasses").take(1)(0) match {
case Row(nFeatures: Int, nClasses: Int) => (nFeatures, nClasses)
case _ => throw new Exception(s"$modelClass unable to load" +
s" numFeatures, numClasses from metadata: ${Loader.metadataPath(path)}")
}
def getNumFeaturesClasses(metadata: JValue): (Int, Int) = {
implicit val formats = DefaultFormats
((metadata \ "numFeatures").extract[Int], (metadata \ "numClasses").extract[Int])
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -173,8 +173,7 @@ object LogisticRegressionModel extends Loader[LogisticRegressionModel] {
val classNameV1_0 = "org.apache.spark.mllib.classification.LogisticRegressionModel"
(loadedClassName, version) match {
case (className, "1.0") if className == classNameV1_0 =>
val (numFeatures, numClasses) =
ClassificationModel.getNumFeaturesClasses(metadata, classNameV1_0, path)
val (numFeatures, numClasses) = ClassificationModel.getNumFeaturesClasses(metadata)
val data = GLMClassificationModel.SaveLoadV1_0.loadData(sc, path, classNameV1_0)
// numFeatures, numClasses, weights are checked in model initialization
val model =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,16 @@
package org.apache.spark.mllib.classification

import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, argmax => brzArgmax, sum => brzSum}
import org.json4s.JsonDSL._
import org.json4s.jackson.JsonMethods._

import org.apache.spark.{SparkContext, SparkException, Logging}
import org.apache.spark.{Logging, SparkContext, SparkException}
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.util.{Loader, Saveable}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, SQLContext}


/**
* Model for Naive Bayes Classifiers.
*
Expand Down Expand Up @@ -78,7 +79,7 @@ class NaiveBayesModel private[mllib] (

object NaiveBayesModel extends Loader[NaiveBayesModel] {

import Loader._
import org.apache.spark.mllib.util.Loader._

private object SaveLoadV1_0 {

Expand All @@ -95,10 +96,10 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] {
import sqlContext.implicits._

// Create JSON metadata.
val metadataRDD =
sc.parallelize(Seq((thisClassName, thisFormatVersion, data.theta(0).size, data.pi.size)), 1)
.toDataFrame("class", "version", "numFeatures", "numClasses")
metadataRDD.toJSON.saveAsTextFile(metadataPath(path))
val metadata = compact(render(
("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~
("numFeatures" -> data.theta(0).length) ~ ("numClasses" -> data.pi.length)))
sc.parallelize(Seq(metadata), 1).saveAsTextFile(metadataPath(path))

// Create Parquet data.
val dataRDD: DataFrame = sc.parallelize(Seq(data), 1)
Expand Down Expand Up @@ -126,8 +127,7 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] {
val classNameV1_0 = SaveLoadV1_0.thisClassName
(loadedClassName, version) match {
case (className, "1.0") if className == classNameV1_0 =>
val (numFeatures, numClasses) =
ClassificationModel.getNumFeaturesClasses(metadata, classNameV1_0, path)
val (numFeatures, numClasses) = ClassificationModel.getNumFeaturesClasses(metadata)
val model = SaveLoadV1_0.load(sc, path)
assert(model.pi.size == numClasses,
s"NaiveBayesModel.load expected $numClasses classes," +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,9 @@ import org.apache.spark.mllib.classification.impl.GLMClassificationModel
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.optimization._
import org.apache.spark.mllib.regression._
import org.apache.spark.mllib.util.{DataValidators, Saveable, Loader}
import org.apache.spark.mllib.util.{DataValidators, Loader, Saveable}
import org.apache.spark.rdd.RDD


/**
* Model for Support Vector Machines (SVMs).
*
Expand Down Expand Up @@ -97,8 +96,7 @@ object SVMModel extends Loader[SVMModel] {
val classNameV1_0 = "org.apache.spark.mllib.classification.SVMModel"
(loadedClassName, version) match {
case (className, "1.0") if className == classNameV1_0 =>
val (numFeatures, numClasses) =
ClassificationModel.getNumFeaturesClasses(metadata, classNameV1_0, path)
val (numFeatures, numClasses) = ClassificationModel.getNumFeaturesClasses(metadata)
val data = GLMClassificationModel.SaveLoadV1_0.loadData(sc, path, classNameV1_0)
val model = new SVMModel(data.weights, data.intercept)
assert(model.weights.size == numFeatures, s"SVMModel.load with numFeatures=$numFeatures" +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,13 @@

package org.apache.spark.mllib.classification.impl

import org.json4s.JsonDSL._
import org.json4s.jackson.JsonMethods._

import org.apache.spark.SparkContext
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.util.Loader
import org.apache.spark.sql.{DataFrame, Row, SQLContext}
import org.apache.spark.sql.{Row, SQLContext}

/**
* Helper class for import/export of GLM classification models.
Expand Down Expand Up @@ -52,16 +55,14 @@ private[classification] object GLMClassificationModel {
import sqlContext.implicits._

// Create JSON metadata.
val metadataRDD =
sc.parallelize(Seq((modelClass, thisFormatVersion, numFeatures, numClasses)), 1)
.toDataFrame("class", "version", "numFeatures", "numClasses")
metadataRDD.toJSON.saveAsTextFile(Loader.metadataPath(path))
val metadata = compact(render(
("class" -> modelClass) ~ ("version" -> thisFormatVersion) ~
("numFeatures" -> numFeatures) ~ ("numClasses" -> numClasses)))
sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path))

// Create Parquet data.
val data = Data(weights, intercept, threshold)
val dataRDD: DataFrame = sc.parallelize(Seq(data), 1)
// TODO: repartition with 1 partition after SPARK-5532 gets fixed
dataRDD.saveAsParquetFile(Loader.dataPath(path))
sc.parallelize(Seq(data), 1).saveAsParquetFile(Loader.dataPath(path))
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ import java.lang.{Integer => JavaInteger}

import org.apache.hadoop.fs.Path
import org.jblas.DoubleMatrix
import org.json4s._
import org.json4s.JsonDSL._
import org.json4s.jackson.JsonMethods._

import org.apache.spark.{Logging, SparkContext}
import org.apache.spark.api.java.{JavaPairRDD, JavaRDD}
Expand Down Expand Up @@ -153,7 +156,7 @@ object MatrixFactorizationModel extends Loader[MatrixFactorizationModel] {
import org.apache.spark.mllib.util.Loader._

override def load(sc: SparkContext, path: String): MatrixFactorizationModel = {
val (loadedClassName, formatVersion, metadata) = loadMetadata(sc, path)
val (loadedClassName, formatVersion, _) = loadMetadata(sc, path)
val classNameV1_0 = SaveLoadV1_0.thisClassName
(loadedClassName, formatVersion) match {
case (className, "1.0") if className == classNameV1_0 =>
Expand Down Expand Up @@ -181,19 +184,20 @@ object MatrixFactorizationModel extends Loader[MatrixFactorizationModel] {
val sc = model.userFeatures.sparkContext
val sqlContext = new SQLContext(sc)
import sqlContext.implicits._
val metadata = (thisClassName, thisFormatVersion, model.rank)
val metadataRDD = sc.parallelize(Seq(metadata), 1).toDataFrame("class", "version", "rank")
metadataRDD.toJSON.saveAsTextFile(metadataPath(path))
val metadata = compact(render(
("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ ("rank" -> model.rank)))
sc.parallelize(Seq(metadata), 1).saveAsTextFile(metadataPath(path))
model.userFeatures.toDataFrame("id", "features").saveAsParquetFile(userPath(path))
model.productFeatures.toDataFrame("id", "features").saveAsParquetFile(productPath(path))
}

def load(sc: SparkContext, path: String): MatrixFactorizationModel = {
implicit val formats = DefaultFormats
val sqlContext = new SQLContext(sc)
val (className, formatVersion, metadata) = loadMetadata(sc, path)
assert(className == thisClassName)
assert(formatVersion == thisFormatVersion)
val rank = metadata.select("rank").first().getInt(0)
val rank = (metadata \ "rank").extract[Int]
val userFeatures = sqlContext.parquetFile(userPath(path))
.map { case Row(id: Int, features: Seq[Double]) =>
(id, features.toArray)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ object LassoModel extends Loader[LassoModel] {
val classNameV1_0 = "org.apache.spark.mllib.regression.LassoModel"
(loadedClassName, version) match {
case (className, "1.0") if className == classNameV1_0 =>
val numFeatures = RegressionModel.getNumFeatures(metadata, classNameV1_0, path)
val numFeatures = RegressionModel.getNumFeatures(metadata)
val data = GLMRegressionModel.SaveLoadV1_0.loadData(sc, path, classNameV1_0, numFeatures)
new LassoModel(data.weights, data.intercept)
case _ => throw new Exception(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ object LinearRegressionModel extends Loader[LinearRegressionModel] {
val classNameV1_0 = "org.apache.spark.mllib.regression.LinearRegressionModel"
(loadedClassName, version) match {
case (className, "1.0") if className == classNameV1_0 =>
val numFeatures = RegressionModel.getNumFeatures(metadata, classNameV1_0, path)
val numFeatures = RegressionModel.getNumFeatures(metadata)
val data = GLMRegressionModel.SaveLoadV1_0.loadData(sc, path, classNameV1_0, numFeatures)
new LinearRegressionModel(data.weights, data.intercept)
case _ => throw new Exception(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@

package org.apache.spark.mllib.regression

import org.json4s.{DefaultFormats, JValue}

import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.util.Loader
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Row}

@Experimental
trait RegressionModel extends Serializable {
Expand Down Expand Up @@ -55,16 +55,10 @@ private[mllib] object RegressionModel {

/**
* Helper method for loading GLM regression model metadata.
*
* @param modelClass String name for model class (used for error messages)
* @return numFeatures
*/
def getNumFeatures(metadata: DataFrame, modelClass: String, path: String): Int = {
metadata.select("numFeatures").take(1)(0) match {
case Row(nFeatures: Int) => nFeatures
case _ => throw new Exception(s"$modelClass unable to load" +
s" numFeatures from metadata: ${Loader.metadataPath(path)}")
}
def getNumFeatures(metadata: JValue): Int = {
implicit val formats = DefaultFormats
(metadata \ "numFeatures").extract[Int]
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ object RidgeRegressionModel extends Loader[RidgeRegressionModel] {
val classNameV1_0 = "org.apache.spark.mllib.regression.RidgeRegressionModel"
(loadedClassName, version) match {
case (className, "1.0") if className == classNameV1_0 =>
val numFeatures = RegressionModel.getNumFeatures(metadata, classNameV1_0, path)
val numFeatures = RegressionModel.getNumFeatures(metadata)
val data = GLMRegressionModel.SaveLoadV1_0.loadData(sc, path, classNameV1_0, numFeatures)
new RidgeRegressionModel(data.weights, data.intercept)
case _ => throw new Exception(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@

package org.apache.spark.mllib.regression.impl

import org.json4s.JsonDSL._
import org.json4s.jackson.JsonMethods._

import org.apache.spark.SparkContext
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.util.Loader
Expand Down Expand Up @@ -48,10 +51,10 @@ private[regression] object GLMRegressionModel {
import sqlContext.implicits._

// Create JSON metadata.
val metadataRDD =
sc.parallelize(Seq((modelClass, thisFormatVersion, weights.size)), 1)
.toDataFrame("class", "version", "numFeatures")
metadataRDD.toJSON.saveAsTextFile(Loader.metadataPath(path))
val metadata = compact(render(
("class" -> modelClass) ~ ("version" -> thisFormatVersion) ~
("numFeatures" -> weights.size)))
sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path))

// Create Parquet data.
val data = Data(weights, intercept)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,28 +17,24 @@

package org.apache.spark.mllib.tree

import scala.collection.JavaConverters._
import scala.collection.mutable
import scala.collection.JavaConverters._
import scala.collection.mutable.ArrayBuffer


import org.apache.spark.Logging
import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.Logging
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.RandomForest.NodeIndexInfo
import org.apache.spark.mllib.tree.configuration.Strategy
import org.apache.spark.mllib.tree.configuration.Algo._
import org.apache.spark.mllib.tree.configuration.FeatureType._
import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
import org.apache.spark.mllib.tree.impl._
import org.apache.spark.mllib.tree.impurity.{Impurities, Impurity}
import org.apache.spark.mllib.tree.impurity._
import org.apache.spark.mllib.tree.model._
import org.apache.spark.rdd.RDD
import org.apache.spark.util.random.XORShiftRandom
import org.apache.spark.SparkContext._


/**
* :: Experimental ::
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ package org.apache.spark.mllib.tree.model

import scala.collection.mutable

import org.json4s._
import org.json4s.JsonDSL._
import org.json4s.jackson.JsonMethods._

import org.apache.spark.SparkContext
import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.JavaRDD
Expand Down Expand Up @@ -184,10 +188,10 @@ object DecisionTreeModel extends Loader[DecisionTreeModel] {
import sqlContext.implicits._

// Create JSON metadata.
val metadataRDD = sc.parallelize(
Seq((thisClassName, thisFormatVersion, model.algo.toString, model.numNodes)), 1)
.toDataFrame("class", "version", "algo", "numNodes")
metadataRDD.toJSON.saveAsTextFile(Loader.metadataPath(path))
val metadata = compact(render(
("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~
("algo" -> model.algo.toString) ~ ("numNodes" -> model.numNodes)))
sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path))

// Create Parquet data.
val nodes = model.topNode.subtreeIterator.toSeq
Expand Down Expand Up @@ -269,20 +273,10 @@ object DecisionTreeModel extends Loader[DecisionTreeModel] {
}

override def load(sc: SparkContext, path: String): DecisionTreeModel = {
implicit val formats = DefaultFormats
val (loadedClassName, version, metadata) = Loader.loadMetadata(sc, path)
val (algo: String, numNodes: Int) = try {
val algo_numNodes = metadata.select("algo", "numNodes").collect()
assert(algo_numNodes.length == 1)
algo_numNodes(0) match {
case Row(a: String, n: Int) => (a, n)
}
} catch {
// Catch both Error and Exception since the checks above can throw either.
case e: Throwable =>
throw new Exception(
s"Unable to load DecisionTreeModel metadata from: ${Loader.metadataPath(path)}."
+ s" Error message: ${e.getMessage}")
}
val algo = (metadata \ "algo").extract[String]
val numNodes = (metadata \ "numNodes").extract[Int]
val classNameV1_0 = SaveLoadV1_0.thisClassName
(loadedClassName, version) match {
case (className, "1.0") if className == classNameV1_0 =>
Expand Down
Loading

0 comments on commit 99bd500

Please sign in to comment.