From f893955b9cc6ea456fc5845890893c08d8878481 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 16 Jul 2015 21:41:36 -0700 Subject: [PATCH 01/58] [SPARK-8899] [SQL] remove duplicated equals method for Row Author: Wenchen Fan Closes #7291 from cloud-fan/row and squashes the following commits: a11addf [Wenchen Fan] move hashCode back to internal row 2de6180 [Wenchen Fan] making apply() call to get() fbe1b24 [Wenchen Fan] add null check ebdf148 [Wenchen Fan] address comments 25ef087 [Wenchen Fan] remove duplicated equals method for Row --- .../sql/catalyst/expressions/UnsafeRow.java | 5 --- .../main/scala/org/apache/spark/sql/Row.scala | 44 +++++++++++++++++-- .../spark/sql/catalyst/InternalRow.scala | 37 +--------------- .../spark/sql/catalyst/expressions/Cast.scala | 1 - .../sql/catalyst/expressions/Projection.scala | 12 ++--- .../expressions/SpecificMutableRow.scala | 2 +- .../spark/sql/catalyst/expressions/rows.scala | 23 +--------- 7 files changed, 50 insertions(+), 74 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index 4b99030d1046f..87294a0e21441 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -246,11 +246,6 @@ public void setFloat(int ordinal, float value) { PlatformDependent.UNSAFE.putFloat(baseObject, getFieldOffset(ordinal), value); } - @Override - public int size() { - return numFields; - } - /** * Returns the object for column `i`, which should not be primitive type. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala index 0f2fd6a86d177..5f0592dc1d77b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.GenericRow import org.apache.spark.sql.types.StructType @@ -151,7 +152,7 @@ trait Row extends Serializable { * StructType -> org.apache.spark.sql.Row * }}} */ - def apply(i: Int): Any + def apply(i: Int): Any = get(i) /** * Returns the value at position i. If the value is null, null is returned. The following @@ -176,10 +177,10 @@ trait Row extends Serializable { * StructType -> org.apache.spark.sql.Row * }}} */ - def get(i: Int): Any = apply(i) + def get(i: Int): Any /** Checks whether the value at position i is null. */ - def isNullAt(i: Int): Boolean = apply(i) == null + def isNullAt(i: Int): Boolean = get(i) == null /** * Returns the value at position i as a primitive boolean. @@ -311,7 +312,7 @@ trait Row extends Serializable { * * @throws ClassCastException when data type does not match. */ - def getAs[T](i: Int): T = apply(i).asInstanceOf[T] + def getAs[T](i: Int): T = get(i).asInstanceOf[T] /** * Returns the value of a given fieldName. @@ -363,6 +364,41 @@ trait Row extends Serializable { false } + protected def canEqual(other: Any) = + other.isInstanceOf[Row] && !other.isInstanceOf[InternalRow] + + override def equals(o: Any): Boolean = { + if (o == null || !canEqual(o)) return false + + val other = o.asInstanceOf[Row] + if (length != other.length) { + return false + } + + var i = 0 + while (i < length) { + if (isNullAt(i) != other.isNullAt(i)) { + return false + } + if (!isNullAt(i)) { + val o1 = get(i) + val o2 = other.get(i) + if (o1.isInstanceOf[Array[Byte]]) { + // handle equality of Array[Byte] + val b1 = o1.asInstanceOf[Array[Byte]] + if (!o2.isInstanceOf[Array[Byte]] || + !java.util.Arrays.equals(b1, o2.asInstanceOf[Array[Byte]])) { + return false + } + } else if (o1 != o2) { + return false + } + } + i += 1 + } + return true + } + /* ---------------------- utility methods for Scala ---------------------- */ /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala index 57de0f26a9720..e2fafb88ee43e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala @@ -53,41 +53,8 @@ abstract class InternalRow extends Row { // A default implementation to change the return type override def copy(): InternalRow = this - override def apply(i: Int): Any = get(i) - override def equals(o: Any): Boolean = { - if (!o.isInstanceOf[Row]) { - return false - } - - val other = o.asInstanceOf[Row] - if (length != other.length) { - return false - } - - var i = 0 - while (i < length) { - if (isNullAt(i) != other.isNullAt(i)) { - return false - } - if (!isNullAt(i)) { - val o1 = apply(i) - val o2 = other.apply(i) - if (o1.isInstanceOf[Array[Byte]]) { - // handle equality of Array[Byte] - val b1 = o1.asInstanceOf[Array[Byte]] - if (!o2.isInstanceOf[Array[Byte]] || - !java.util.Arrays.equals(b1, o2.asInstanceOf[Array[Byte]])) { - return false - } - } else if (o1 != o2) { - return false - } - } - i += 1 - } - true - } + protected override def canEqual(other: Any) = other.isInstanceOf[InternalRow] // Custom hashCode function that matches the efficient code generated version. override def hashCode: Int = { @@ -98,7 +65,7 @@ abstract class InternalRow extends Row { if (isNullAt(i)) { 0 } else { - apply(i) match { + get(i) match { case b: Boolean => if (b) 0 else 1 case b: Byte => b.toInt case s: Short => s.toInt diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 83d5b3b76b0a3..65ae87fe6d166 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -56,7 +56,6 @@ object Cast { case (_, DateType) => true case (StringType, IntervalType) => true - case (IntervalType, StringType) => true case (StringType, _: NumericType) => true case (BooleanType, _: NumericType) => true diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index 886a486bf5ee0..bf47a6c75b809 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -110,7 +110,7 @@ class JoinedRow extends InternalRow { override def length: Int = row1.length + row2.length - override def apply(i: Int): Any = + override def get(i: Int): Any = if (i < row1.length) row1(i) else row2(i - row1.length) override def isNullAt(i: Int): Boolean = @@ -204,7 +204,7 @@ class JoinedRow2 extends InternalRow { override def length: Int = row1.length + row2.length - override def apply(i: Int): Any = + override def get(i: Int): Any = if (i < row1.length) row1(i) else row2(i - row1.length) override def isNullAt(i: Int): Boolean = @@ -292,7 +292,7 @@ class JoinedRow3 extends InternalRow { override def length: Int = row1.length + row2.length - override def apply(i: Int): Any = + override def get(i: Int): Any = if (i < row1.length) row1(i) else row2(i - row1.length) override def isNullAt(i: Int): Boolean = @@ -380,7 +380,7 @@ class JoinedRow4 extends InternalRow { override def length: Int = row1.length + row2.length - override def apply(i: Int): Any = + override def get(i: Int): Any = if (i < row1.length) row1(i) else row2(i - row1.length) override def isNullAt(i: Int): Boolean = @@ -468,7 +468,7 @@ class JoinedRow5 extends InternalRow { override def length: Int = row1.length + row2.length - override def apply(i: Int): Any = + override def get(i: Int): Any = if (i < row1.length) row1(i) else row2(i - row1.length) override def isNullAt(i: Int): Boolean = @@ -556,7 +556,7 @@ class JoinedRow6 extends InternalRow { override def length: Int = row1.length + row2.length - override def apply(i: Int): Any = + override def get(i: Int): Any = if (i < row1.length) row1(i) else row2(i - row1.length) override def isNullAt(i: Int): Boolean = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala index efa24710a5a67..6f291d2c86c1e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala @@ -219,7 +219,7 @@ final class SpecificMutableRow(val values: Array[MutableValue]) extends MutableR values(i).isNull = true } - override def apply(i: Int): Any = values(i).boxed + override def get(i: Int): Any = values(i).boxed override def isNullAt(i: Int): Boolean = values(i).isNull diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala index 094904bbf9c15..d78be5a5958f9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala @@ -66,7 +66,7 @@ trait ArrayBackedRow { def length: Int = values.length - override def apply(i: Int): Any = values(i) + override def get(i: Int): Any = values(i) def setNullAt(i: Int): Unit = { values(i) = null} @@ -84,27 +84,6 @@ class GenericRow(protected[sql] val values: Array[Any]) extends Row with ArrayBa def this(size: Int) = this(new Array[Any](size)) - // This is used by test or outside - override def equals(o: Any): Boolean = o match { - case other: Row if other.length == length => - var i = 0 - while (i < length) { - if (isNullAt(i) != other.isNullAt(i)) { - return false - } - val equal = (apply(i), other.apply(i)) match { - case (a: Array[Byte], b: Array[Byte]) => java.util.Arrays.equals(a, b) - case (a, b) => a == b - } - if (!equal) { - return false - } - i += 1 - } - true - case _ => false - } - override def copy(): Row = this } From 322d286bb7773389ed07df96290e427b21c775bd Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Thu, 16 Jul 2015 22:26:59 -0700 Subject: [PATCH 02/58] [SPARK-7131] [ML] Copy Decision Tree, Random Forest impl to spark.ml This PR copies the RandomForest implementation from spark.mllib to spark.ml. Note that this includes the DecisionTree implementation, but not the GradientBoostedTrees one (which will come later). I essentially copied a minimal amount of code to spark.ml, removed the use of bins (and only used splits), and modified code only as much as necessary to get it to compile. The spark.ml implementation still uses some spark.mllib classes (privately), which can be moved in future PRs. This refactoring will be helpful in extending the node representation to include more information, such as class probabilities. Specifically: * Copied code from spark.mllib to spark.ml: * mllib.tree.DecisionTree, mllib.tree.RandomForest copied to ml.tree.impl.RandomForest (main implementation) * NodeIdCache (needed to use splits instead of bins) * TreePoint (use splits instead of bins) * Added ml.tree.LearningNode used in RandomForest training (needed vars) * Removed bins from implementation, and only used splits * Small fix in JavaDecisionTreeRegressorSuite CC: mengxr manishamde codedeft chouqin Author: Joseph K. Bradley Closes #7294 from jkbradley/dt-move-impl and squashes the following commits: 48749be [Joseph K. Bradley] cleanups based on code review, mostly style bea9703 [Joseph K. Bradley] scala style fixes. added some scala doc 4e6d2a4 [Joseph K. Bradley] removed unnecessary use of copyValues, setParent for trees 9a4d721 [Joseph K. Bradley] cleanups. removed InfoGainStats from ml, using old one for now. 836e7d4 [Joseph K. Bradley] Fixed test suite failures bd5e063 [Joseph K. Bradley] fixed bucketizing issue 0df3759 [Joseph K. Bradley] Need to remove use of Bucketizer d5224a9 [Joseph K. Bradley] modified tree and forest to use moved impl cc01823 [Joseph K. Bradley] still editing RF to get it to work 19143fb [Joseph K. Bradley] More progress, but not done yet. Rebased with master after 1.4 release. --- .../DecisionTreeClassifier.scala | 13 +- .../RandomForestClassifier.scala | 16 +- .../ml/regression/DecisionTreeRegressor.scala | 13 +- .../ml/regression/RandomForestRegressor.scala | 15 +- .../scala/org/apache/spark/ml/tree/Node.scala | 129 ++ .../org/apache/spark/ml/tree/Split.scala | 30 +- .../spark/ml/tree/impl/NodeIdCache.scala | 194 +++ .../spark/ml/tree/impl/RandomForest.scala | 1132 +++++++++++++++++ .../apache/spark/ml/tree/impl/TreePoint.scala | 134 ++ .../spark/mllib/tree/impl/BaggedPoint.scala | 10 +- .../mllib/tree/impl/DTStatsAggregator.scala | 2 +- .../tree/impl/DecisionTreeMetadata.scala | 4 +- .../spark/mllib/tree/impl/NodeIdCache.scala | 4 +- .../spark/mllib/tree/impl/TimeTracker.scala | 2 +- .../spark/mllib/tree/impl/TreePoint.scala | 4 +- .../spark/mllib/tree/impurity/Impurity.scala | 4 +- .../tree/model/InformationGainStats.scala | 2 +- .../JavaDecisionTreeRegressorSuite.java | 2 +- 18 files changed, 1678 insertions(+), 32 deletions(-) create mode 100644 mllib/src/main/scala/org/apache/spark/ml/tree/impl/NodeIdCache.scala create mode 100644 mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala create mode 100644 mllib/src/main/scala/org/apache/spark/ml/tree/impl/TreePoint.scala diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala index 2dc1824964a42..36fe1bd40469c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala @@ -21,10 +21,10 @@ import org.apache.spark.annotation.Experimental import org.apache.spark.ml.{PredictionModel, Predictor} import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.tree.{DecisionTreeModel, DecisionTreeParams, Node, TreeClassifierParams} +import org.apache.spark.ml.tree.impl.RandomForest import org.apache.spark.ml.util.{Identifiable, MetadataUtils} import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree} import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy} import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeModel} import org.apache.spark.rdd.RDD @@ -75,8 +75,9 @@ final class DecisionTreeClassifier(override val uid: String) } val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset) val strategy = getOldStrategy(categoricalFeatures, numClasses) - val oldModel = OldDecisionTree.train(oldDataset, strategy) - DecisionTreeClassificationModel.fromOld(oldModel, this, categoricalFeatures) + val trees = RandomForest.run(oldDataset, strategy, numTrees = 1, featureSubsetStrategy = "all", + seed = 0L, parentUID = Some(uid)) + trees.head.asInstanceOf[DecisionTreeClassificationModel] } /** (private[ml]) Create a Strategy instance to use with the old API. */ @@ -112,6 +113,12 @@ final class DecisionTreeClassificationModel private[ml] ( require(rootNode != null, "DecisionTreeClassificationModel given null rootNode, but it requires a non-null rootNode.") + /** + * Construct a decision tree classification model. + * @param rootNode Root node of tree, with other nodes attached. + */ + def this(rootNode: Node) = this(Identifiable.randomUID("dtc"), rootNode) + override protected def predict(features: Vector): Double = { rootNode.predict(features) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala index d3c67494a31e4..490f04c7c7172 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala @@ -20,13 +20,13 @@ package org.apache.spark.ml.classification import scala.collection.mutable import org.apache.spark.annotation.Experimental +import org.apache.spark.ml.tree.impl.RandomForest import org.apache.spark.ml.{PredictionModel, Predictor} import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.tree.{DecisionTreeModel, RandomForestParams, TreeClassifierParams, TreeEnsembleModel} import org.apache.spark.ml.util.{Identifiable, MetadataUtils} import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.tree.{RandomForest => OldRandomForest} import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} import org.apache.spark.mllib.tree.model.{RandomForestModel => OldRandomForestModel} import org.apache.spark.rdd.RDD @@ -93,9 +93,10 @@ final class RandomForestClassifier(override val uid: String) val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset) val strategy = super.getOldStrategy(categoricalFeatures, numClasses, OldAlgo.Classification, getOldImpurity) - val oldModel = OldRandomForest.trainClassifier( - oldDataset, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed.toInt) - RandomForestClassificationModel.fromOld(oldModel, this, categoricalFeatures) + val trees = + RandomForest.run(oldDataset, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed) + .map(_.asInstanceOf[DecisionTreeClassificationModel]) + new RandomForestClassificationModel(trees) } override def copy(extra: ParamMap): RandomForestClassifier = defaultCopy(extra) @@ -128,6 +129,13 @@ final class RandomForestClassificationModel private[ml] ( require(numTrees > 0, "RandomForestClassificationModel requires at least 1 tree.") + /** + * Construct a random forest classification model, with all trees weighted equally. + * @param trees Component trees + */ + def this(trees: Array[DecisionTreeClassificationModel]) = + this(Identifiable.randomUID("rfc"), trees) + override def trees: Array[DecisionTreeModel] = _trees.asInstanceOf[Array[DecisionTreeModel]] // Note: We may add support for weights (based on tree performance) later on. diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala index be1f8063d41d8..6f3340c2f02be 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala @@ -21,10 +21,10 @@ import org.apache.spark.annotation.Experimental import org.apache.spark.ml.{PredictionModel, Predictor} import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.tree.{DecisionTreeModel, DecisionTreeParams, Node, TreeRegressorParams} +import org.apache.spark.ml.tree.impl.RandomForest import org.apache.spark.ml.util.{Identifiable, MetadataUtils} import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree} import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy} import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeModel} import org.apache.spark.rdd.RDD @@ -67,8 +67,9 @@ final class DecisionTreeRegressor(override val uid: String) MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol))) val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset) val strategy = getOldStrategy(categoricalFeatures) - val oldModel = OldDecisionTree.train(oldDataset, strategy) - DecisionTreeRegressionModel.fromOld(oldModel, this, categoricalFeatures) + val trees = RandomForest.run(oldDataset, strategy, numTrees = 1, featureSubsetStrategy = "all", + seed = 0L, parentUID = Some(uid)) + trees.head.asInstanceOf[DecisionTreeRegressionModel] } /** (private[ml]) Create a Strategy instance to use with the old API. */ @@ -102,6 +103,12 @@ final class DecisionTreeRegressionModel private[ml] ( require(rootNode != null, "DecisionTreeClassificationModel given null rootNode, but it requires a non-null rootNode.") + /** + * Construct a decision tree regression model. + * @param rootNode Root node of tree, with other nodes attached. + */ + def this(rootNode: Node) = this(Identifiable.randomUID("dtr"), rootNode) + override protected def predict(features: Vector): Double = { rootNode.predict(features) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala index 21c59061a02fa..5fd5c7c7bd3fc 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala @@ -21,10 +21,10 @@ import org.apache.spark.annotation.Experimental import org.apache.spark.ml.{PredictionModel, Predictor} import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.tree.{DecisionTreeModel, RandomForestParams, TreeEnsembleModel, TreeRegressorParams} +import org.apache.spark.ml.tree.impl.RandomForest import org.apache.spark.ml.util.{Identifiable, MetadataUtils} import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.tree.{RandomForest => OldRandomForest} import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} import org.apache.spark.mllib.tree.model.{RandomForestModel => OldRandomForestModel} import org.apache.spark.rdd.RDD @@ -82,9 +82,10 @@ final class RandomForestRegressor(override val uid: String) val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset) val strategy = super.getOldStrategy(categoricalFeatures, numClasses = 0, OldAlgo.Regression, getOldImpurity) - val oldModel = OldRandomForest.trainRegressor( - oldDataset, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed.toInt) - RandomForestRegressionModel.fromOld(oldModel, this, categoricalFeatures) + val trees = + RandomForest.run(oldDataset, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed) + .map(_.asInstanceOf[DecisionTreeRegressionModel]) + new RandomForestRegressionModel(trees) } override def copy(extra: ParamMap): RandomForestRegressor = defaultCopy(extra) @@ -115,6 +116,12 @@ final class RandomForestRegressionModel private[ml] ( require(numTrees > 0, "RandomForestRegressionModel requires at least 1 tree.") + /** + * Construct a random forest regression model, with all trees weighted equally. + * @param trees Component trees + */ + def this(trees: Array[DecisionTreeRegressionModel]) = this(Identifiable.randomUID("rfr"), trees) + override def trees: Array[DecisionTreeModel] = _trees.asInstanceOf[Array[DecisionTreeModel]] // Note: We may add support for weights (based on tree performance) later on. diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala index 4242154be14ce..bbc2427ca7d3d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala @@ -209,3 +209,132 @@ private object InternalNode { } } } + +/** + * Version of a node used in learning. This uses vars so that we can modify nodes as we split the + * tree by adding children, etc. + * + * For now, we use node IDs. These will be kept internal since we hope to remove node IDs + * in the future, or at least change the indexing (so that we can support much deeper trees). + * + * This node can either be: + * - a leaf node, with leftChild, rightChild, split set to null, or + * - an internal node, with all values set + * + * @param id We currently use the same indexing as the old implementation in + * [[org.apache.spark.mllib.tree.model.Node]], but this will change later. + * @param predictionStats Predicted label + class probability (for classification). + * We will later modify this to store aggregate statistics for labels + * to provide all class probabilities (for classification) and maybe a + * distribution (for regression). + * @param isLeaf Indicates whether this node will definitely be a leaf in the learned tree, + * so that we do not need to consider splitting it further. + * @param stats Old structure for storing stats about information gain, prediction, etc. + * This is legacy and will be modified in the future. + */ +private[tree] class LearningNode( + var id: Int, + var predictionStats: OldPredict, + var impurity: Double, + var leftChild: Option[LearningNode], + var rightChild: Option[LearningNode], + var split: Option[Split], + var isLeaf: Boolean, + var stats: Option[OldInformationGainStats]) extends Serializable { + + /** + * Convert this [[LearningNode]] to a regular [[Node]], and recurse on any children. + */ + def toNode: Node = { + if (leftChild.nonEmpty) { + assert(rightChild.nonEmpty && split.nonEmpty && stats.nonEmpty, + "Unknown error during Decision Tree learning. Could not convert LearningNode to Node.") + new InternalNode(predictionStats.predict, impurity, stats.get.gain, + leftChild.get.toNode, rightChild.get.toNode, split.get) + } else { + new LeafNode(predictionStats.predict, impurity) + } + } + +} + +private[tree] object LearningNode { + + /** Create a node with some of its fields set. */ + def apply( + id: Int, + predictionStats: OldPredict, + impurity: Double, + isLeaf: Boolean): LearningNode = { + new LearningNode(id, predictionStats, impurity, None, None, None, false, None) + } + + /** Create an empty node with the given node index. Values must be set later on. */ + def emptyNode(nodeIndex: Int): LearningNode = { + new LearningNode(nodeIndex, new OldPredict(Double.NaN, Double.NaN), Double.NaN, + None, None, None, false, None) + } + + // The below indexing methods were copied from spark.mllib.tree.model.Node + + /** + * Return the index of the left child of this node. + */ + def leftChildIndex(nodeIndex: Int): Int = nodeIndex << 1 + + /** + * Return the index of the right child of this node. + */ + def rightChildIndex(nodeIndex: Int): Int = (nodeIndex << 1) + 1 + + /** + * Get the parent index of the given node, or 0 if it is the root. + */ + def parentIndex(nodeIndex: Int): Int = nodeIndex >> 1 + + /** + * Return the level of a tree which the given node is in. + */ + def indexToLevel(nodeIndex: Int): Int = if (nodeIndex == 0) { + throw new IllegalArgumentException(s"0 is not a valid node index.") + } else { + java.lang.Integer.numberOfTrailingZeros(java.lang.Integer.highestOneBit(nodeIndex)) + } + + /** + * Returns true if this is a left child. + * Note: Returns false for the root. + */ + def isLeftChild(nodeIndex: Int): Boolean = nodeIndex > 1 && nodeIndex % 2 == 0 + + /** + * Return the maximum number of nodes which can be in the given level of the tree. + * @param level Level of tree (0 = root). + */ + def maxNodesInLevel(level: Int): Int = 1 << level + + /** + * Return the index of the first node in the given level. + * @param level Level of tree (0 = root). + */ + def startIndexInLevel(level: Int): Int = 1 << level + + /** + * Traces down from a root node to get the node with the given node index. + * This assumes the node exists. + */ + def getNode(nodeIndex: Int, rootNode: LearningNode): LearningNode = { + var tmpNode: LearningNode = rootNode + var levelsToGo = indexToLevel(nodeIndex) + while (levelsToGo > 0) { + if ((nodeIndex & (1 << levelsToGo - 1)) == 0) { + tmpNode = tmpNode.leftChild.asInstanceOf[LearningNode] + } else { + tmpNode = tmpNode.rightChild.asInstanceOf[LearningNode] + } + levelsToGo -= 1 + } + tmpNode + } + +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala index 7acdeeee72d23..78199cc2df582 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala @@ -34,9 +34,19 @@ sealed trait Split extends Serializable { /** Index of feature which this split tests */ def featureIndex: Int - /** Return true (split to left) or false (split to right) */ + /** + * Return true (split to left) or false (split to right). + * @param features Vector of features (original values, not binned). + */ private[ml] def shouldGoLeft(features: Vector): Boolean + /** + * Return true (split to left) or false (split to right). + * @param binnedFeature Binned feature value. + * @param splits All splits for the given feature. + */ + private[tree] def shouldGoLeft(binnedFeature: Int, splits: Array[Split]): Boolean + /** Convert to old Split format */ private[tree] def toOld: OldSplit } @@ -94,6 +104,14 @@ final class CategoricalSplit private[ml] ( } } + override private[tree] def shouldGoLeft(binnedFeature: Int, splits: Array[Split]): Boolean = { + if (isLeft) { + categories.contains(binnedFeature.toDouble) + } else { + !categories.contains(binnedFeature.toDouble) + } + } + override def equals(o: Any): Boolean = { o match { case other: CategoricalSplit => featureIndex == other.featureIndex && @@ -144,6 +162,16 @@ final class ContinuousSplit private[ml] (override val featureIndex: Int, val thr features(featureIndex) <= threshold } + override private[tree] def shouldGoLeft(binnedFeature: Int, splits: Array[Split]): Boolean = { + if (binnedFeature == splits.length) { + // > last split, so split right + false + } else { + val featureValueUpperBound = splits(binnedFeature).asInstanceOf[ContinuousSplit].threshold + featureValueUpperBound <= threshold + } + } + override def equals(o: Any): Boolean = { o match { case other: ContinuousSplit => diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/NodeIdCache.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/NodeIdCache.scala new file mode 100644 index 0000000000000..488e8e4fb5dcd --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/NodeIdCache.scala @@ -0,0 +1,194 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.tree.impl + +import java.io.IOException + +import scala.collection.mutable + +import org.apache.hadoop.fs.{Path, FileSystem} + +import org.apache.spark.Logging +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.ml.tree.{LearningNode, Split} +import org.apache.spark.mllib.tree.impl.BaggedPoint +import org.apache.spark.rdd.RDD +import org.apache.spark.storage.StorageLevel + + +/** + * This is used by the node id cache to find the child id that a data point would belong to. + * @param split Split information. + * @param nodeIndex The current node index of a data point that this will update. + */ +private[tree] case class NodeIndexUpdater(split: Split, nodeIndex: Int) { + + /** + * Determine a child node index based on the feature value and the split. + * @param binnedFeature Binned feature value. + * @param splits Split information to convert the bin indices to approximate feature values. + * @return Child node index to update to. + */ + def updateNodeIndex(binnedFeature: Int, splits: Array[Split]): Int = { + if (split.shouldGoLeft(binnedFeature, splits)) { + LearningNode.leftChildIndex(nodeIndex) + } else { + LearningNode.rightChildIndex(nodeIndex) + } + } +} + +/** + * Each TreePoint belongs to a particular node per tree. + * Each row in the nodeIdsForInstances RDD is an array over trees of the node index + * in each tree. Initially, values should all be 1 for root node. + * The nodeIdsForInstances RDD needs to be updated at each iteration. + * @param nodeIdsForInstances The initial values in the cache + * (should be an Array of all 1's (meaning the root nodes)). + * @param checkpointInterval The checkpointing interval + * (how often should the cache be checkpointed.). + */ +private[spark] class NodeIdCache( + var nodeIdsForInstances: RDD[Array[Int]], + val checkpointInterval: Int) extends Logging { + + // Keep a reference to a previous node Ids for instances. + // Because we will keep on re-persisting updated node Ids, + // we want to unpersist the previous RDD. + private var prevNodeIdsForInstances: RDD[Array[Int]] = null + + // To keep track of the past checkpointed RDDs. + private val checkpointQueue = mutable.Queue[RDD[Array[Int]]]() + private var rddUpdateCount = 0 + + // Indicates whether we can checkpoint + private val canCheckpoint = nodeIdsForInstances.sparkContext.getCheckpointDir.nonEmpty + + // FileSystem instance for deleting checkpoints as needed + private val fs = FileSystem.get(nodeIdsForInstances.sparkContext.hadoopConfiguration) + + /** + * Update the node index values in the cache. + * This updates the RDD and its lineage. + * TODO: Passing bin information to executors seems unnecessary and costly. + * @param data The RDD of training rows. + * @param nodeIdUpdaters A map of node index updaters. + * The key is the indices of nodes that we want to update. + * @param splits Split information needed to find child node indices. + */ + def updateNodeIndices( + data: RDD[BaggedPoint[TreePoint]], + nodeIdUpdaters: Array[mutable.Map[Int, NodeIndexUpdater]], + splits: Array[Array[Split]]): Unit = { + if (prevNodeIdsForInstances != null) { + // Unpersist the previous one if one exists. + prevNodeIdsForInstances.unpersist() + } + + prevNodeIdsForInstances = nodeIdsForInstances + nodeIdsForInstances = data.zip(nodeIdsForInstances).map { case (point, ids) => + var treeId = 0 + while (treeId < nodeIdUpdaters.length) { + val nodeIdUpdater = nodeIdUpdaters(treeId).getOrElse(ids(treeId), null) + if (nodeIdUpdater != null) { + val featureIndex = nodeIdUpdater.split.featureIndex + val newNodeIndex = nodeIdUpdater.updateNodeIndex( + binnedFeature = point.datum.binnedFeatures(featureIndex), + splits = splits(featureIndex)) + ids(treeId) = newNodeIndex + } + treeId += 1 + } + ids + } + + // Keep on persisting new ones. + nodeIdsForInstances.persist(StorageLevel.MEMORY_AND_DISK) + rddUpdateCount += 1 + + // Handle checkpointing if the directory is not None. + if (canCheckpoint && (rddUpdateCount % checkpointInterval) == 0) { + // Let's see if we can delete previous checkpoints. + var canDelete = true + while (checkpointQueue.size > 1 && canDelete) { + // We can delete the oldest checkpoint iff + // the next checkpoint actually exists in the file system. + if (checkpointQueue(1).getCheckpointFile.isDefined) { + val old = checkpointQueue.dequeue() + // Since the old checkpoint is not deleted by Spark, we'll manually delete it here. + try { + fs.delete(new Path(old.getCheckpointFile.get), true) + } catch { + case e: IOException => + logError("Decision Tree learning using cacheNodeIds failed to remove checkpoint" + + s" file: ${old.getCheckpointFile.get}") + } + } else { + canDelete = false + } + } + + nodeIdsForInstances.checkpoint() + checkpointQueue.enqueue(nodeIdsForInstances) + } + } + + /** + * Call this after training is finished to delete any remaining checkpoints. + */ + def deleteAllCheckpoints(): Unit = { + while (checkpointQueue.nonEmpty) { + val old = checkpointQueue.dequeue() + if (old.getCheckpointFile.isDefined) { + try { + fs.delete(new Path(old.getCheckpointFile.get), true) + } catch { + case e: IOException => + logError("Decision Tree learning using cacheNodeIds failed to remove checkpoint" + + s" file: ${old.getCheckpointFile.get}") + } + } + } + } + if (prevNodeIdsForInstances != null) { + // Unpersist the previous one if one exists. + prevNodeIdsForInstances.unpersist() + } +} + +@DeveloperApi +private[spark] object NodeIdCache { + /** + * Initialize the node Id cache with initial node Id values. + * @param data The RDD of training rows. + * @param numTrees The number of trees that we want to create cache for. + * @param checkpointInterval The checkpointing interval + * (how often should the cache be checkpointed.). + * @param initVal The initial values in the cache. + * @return A node Id cache containing an RDD of initial root node Indices. + */ + def init( + data: RDD[BaggedPoint[TreePoint]], + numTrees: Int, + checkpointInterval: Int, + initVal: Int = 1): NodeIdCache = { + new NodeIdCache( + data.map(_ => Array.fill[Int](numTrees)(initVal)), + checkpointInterval) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala new file mode 100644 index 0000000000000..15b56bd844bad --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala @@ -0,0 +1,1132 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.tree.impl + +import java.io.IOException + +import scala.collection.mutable +import scala.util.Random + +import org.apache.spark.Logging +import org.apache.spark.ml.classification.DecisionTreeClassificationModel +import org.apache.spark.ml.regression.DecisionTreeRegressionModel +import org.apache.spark.ml.tree._ +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy} +import org.apache.spark.mllib.tree.impl.{BaggedPoint, DTStatsAggregator, DecisionTreeMetadata, + TimeTracker} +import org.apache.spark.mllib.tree.impurity.ImpurityCalculator +import org.apache.spark.mllib.tree.model.{InformationGainStats, Predict} +import org.apache.spark.rdd.RDD +import org.apache.spark.storage.StorageLevel +import org.apache.spark.util.random.{SamplingUtils, XORShiftRandom} + + +private[ml] object RandomForest extends Logging { + + /** + * Train a random forest. + * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] + * @return an unweighted set of trees + */ + def run( + input: RDD[LabeledPoint], + strategy: OldStrategy, + numTrees: Int, + featureSubsetStrategy: String, + seed: Long, + parentUID: Option[String] = None): Array[DecisionTreeModel] = { + + val timer = new TimeTracker() + + timer.start("total") + + timer.start("init") + + val retaggedInput = input.retag(classOf[LabeledPoint]) + val metadata = + DecisionTreeMetadata.buildMetadata(retaggedInput, strategy, numTrees, featureSubsetStrategy) + logDebug("algo = " + strategy.algo) + logDebug("numTrees = " + numTrees) + logDebug("seed = " + seed) + logDebug("maxBins = " + metadata.maxBins) + logDebug("featureSubsetStrategy = " + featureSubsetStrategy) + logDebug("numFeaturesPerNode = " + metadata.numFeaturesPerNode) + logDebug("subsamplingRate = " + strategy.subsamplingRate) + + // Find the splits and the corresponding bins (interval between the splits) using a sample + // of the input data. + timer.start("findSplitsBins") + val splits = findSplits(retaggedInput, metadata) + timer.stop("findSplitsBins") + logDebug("numBins: feature: number of bins") + logDebug(Range(0, metadata.numFeatures).map { featureIndex => + s"\t$featureIndex\t${metadata.numBins(featureIndex)}" + }.mkString("\n")) + + // Bin feature values (TreePoint representation). + // Cache input RDD for speedup during multiple passes. + val treeInput = TreePoint.convertToTreeRDD(retaggedInput, splits, metadata) + + val withReplacement = numTrees > 1 + + val baggedInput = BaggedPoint + .convertToBaggedRDD(treeInput, strategy.subsamplingRate, numTrees, withReplacement, seed) + .persist(StorageLevel.MEMORY_AND_DISK) + + // depth of the decision tree + val maxDepth = strategy.maxDepth + require(maxDepth <= 30, + s"DecisionTree currently only supports maxDepth <= 30, but was given maxDepth = $maxDepth.") + + // Max memory usage for aggregates + // TODO: Calculate memory usage more precisely. + val maxMemoryUsage: Long = strategy.maxMemoryInMB * 1024L * 1024L + logDebug("max memory usage for aggregates = " + maxMemoryUsage + " bytes.") + val maxMemoryPerNode = { + val featureSubset: Option[Array[Int]] = if (metadata.subsamplingFeatures) { + // Find numFeaturesPerNode largest bins to get an upper bound on memory usage. + Some(metadata.numBins.zipWithIndex.sortBy(- _._1) + .take(metadata.numFeaturesPerNode).map(_._2)) + } else { + None + } + RandomForest.aggregateSizeForNode(metadata, featureSubset) * 8L + } + require(maxMemoryPerNode <= maxMemoryUsage, + s"RandomForest/DecisionTree given maxMemoryInMB = ${strategy.maxMemoryInMB}," + + " which is too small for the given features." + + s" Minimum value = ${maxMemoryPerNode / (1024L * 1024L)}") + + timer.stop("init") + + /* + * The main idea here is to perform group-wise training of the decision tree nodes thus + * reducing the passes over the data from (# nodes) to (# nodes / maxNumberOfNodesPerGroup). + * Each data sample is handled by a particular node (or it reaches a leaf and is not used + * in lower levels). + */ + + // Create an RDD of node Id cache. + // At first, all the rows belong to the root nodes (node Id == 1). + val nodeIdCache = if (strategy.useNodeIdCache) { + Some(NodeIdCache.init( + data = baggedInput, + numTrees = numTrees, + checkpointInterval = strategy.checkpointInterval, + initVal = 1)) + } else { + None + } + + // FIFO queue of nodes to train: (treeIndex, node) + val nodeQueue = new mutable.Queue[(Int, LearningNode)]() + + val rng = new Random() + rng.setSeed(seed) + + // Allocate and queue root nodes. + val topNodes = Array.fill[LearningNode](numTrees)(LearningNode.emptyNode(nodeIndex = 1)) + Range(0, numTrees).foreach(treeIndex => nodeQueue.enqueue((treeIndex, topNodes(treeIndex)))) + + while (nodeQueue.nonEmpty) { + // Collect some nodes to split, and choose features for each node (if subsampling). + // Each group of nodes may come from one or multiple trees, and at multiple levels. + val (nodesForGroup, treeToNodeToIndexInfo) = + RandomForest.selectNodesToSplit(nodeQueue, maxMemoryUsage, metadata, rng) + // Sanity check (should never occur): + assert(nodesForGroup.nonEmpty, + s"RandomForest selected empty nodesForGroup. Error for unknown reason.") + + // Choose node splits, and enqueue new nodes as needed. + timer.start("findBestSplits") + RandomForest.findBestSplits(baggedInput, metadata, topNodes, nodesForGroup, + treeToNodeToIndexInfo, splits, nodeQueue, timer, nodeIdCache) + timer.stop("findBestSplits") + } + + baggedInput.unpersist() + + timer.stop("total") + + logInfo("Internal timing for DecisionTree:") + logInfo(s"$timer") + + // Delete any remaining checkpoints used for node Id cache. + if (nodeIdCache.nonEmpty) { + try { + nodeIdCache.get.deleteAllCheckpoints() + } catch { + case e: IOException => + logWarning(s"delete all checkpoints failed. Error reason: ${e.getMessage}") + } + } + + parentUID match { + case Some(uid) => + if (strategy.algo == OldAlgo.Classification) { + topNodes.map(rootNode => new DecisionTreeClassificationModel(uid, rootNode.toNode)) + } else { + topNodes.map(rootNode => new DecisionTreeRegressionModel(uid, rootNode.toNode)) + } + case None => + if (strategy.algo == OldAlgo.Classification) { + topNodes.map(rootNode => new DecisionTreeClassificationModel(rootNode.toNode)) + } else { + topNodes.map(rootNode => new DecisionTreeRegressionModel(rootNode.toNode)) + } + } + } + + /** + * Get the node index corresponding to this data point. + * This function mimics prediction, passing an example from the root node down to a leaf + * or unsplit node; that node's index is returned. + * + * @param node Node in tree from which to classify the given data point. + * @param binnedFeatures Binned feature vector for data point. + * @param splits possible splits for all features, indexed (numFeatures)(numSplits) + * @return Leaf index if the data point reaches a leaf. + * Otherwise, last node reachable in tree matching this example. + * Note: This is the global node index, i.e., the index used in the tree. + * This index is different from the index used during training a particular + * group of nodes on one call to [[findBestSplits()]]. + */ + private def predictNodeIndex( + node: LearningNode, + binnedFeatures: Array[Int], + splits: Array[Array[Split]]): Int = { + if (node.isLeaf || node.split.isEmpty) { + node.id + } else { + val split = node.split.get + val featureIndex = split.featureIndex + val splitLeft = split.shouldGoLeft(binnedFeatures(featureIndex), splits(featureIndex)) + if (node.leftChild.isEmpty) { + // Not yet split. Return index from next layer of nodes to train + if (splitLeft) { + LearningNode.leftChildIndex(node.id) + } else { + LearningNode.rightChildIndex(node.id) + } + } else { + if (splitLeft) { + predictNodeIndex(node.leftChild.get, binnedFeatures, splits) + } else { + predictNodeIndex(node.rightChild.get, binnedFeatures, splits) + } + } + } + } + + /** + * Helper for binSeqOp, for data which can contain a mix of ordered and unordered features. + * + * For ordered features, a single bin is updated. + * For unordered features, bins correspond to subsets of categories; either the left or right bin + * for each subset is updated. + * + * @param agg Array storing aggregate calculation, with a set of sufficient statistics for + * each (feature, bin). + * @param treePoint Data point being aggregated. + * @param splits possible splits indexed (numFeatures)(numSplits) + * @param unorderedFeatures Set of indices of unordered features. + * @param instanceWeight Weight (importance) of instance in dataset. + */ + private def mixedBinSeqOp( + agg: DTStatsAggregator, + treePoint: TreePoint, + splits: Array[Array[Split]], + unorderedFeatures: Set[Int], + instanceWeight: Double, + featuresForNode: Option[Array[Int]]): Unit = { + val numFeaturesPerNode = if (featuresForNode.nonEmpty) { + // Use subsampled features + featuresForNode.get.length + } else { + // Use all features + agg.metadata.numFeatures + } + // Iterate over features. + var featureIndexIdx = 0 + while (featureIndexIdx < numFeaturesPerNode) { + val featureIndex = if (featuresForNode.nonEmpty) { + featuresForNode.get.apply(featureIndexIdx) + } else { + featureIndexIdx + } + if (unorderedFeatures.contains(featureIndex)) { + // Unordered feature + val featureValue = treePoint.binnedFeatures(featureIndex) + val (leftNodeFeatureOffset, rightNodeFeatureOffset) = + agg.getLeftRightFeatureOffsets(featureIndexIdx) + // Update the left or right bin for each split. + val numSplits = agg.metadata.numSplits(featureIndex) + val featureSplits = splits(featureIndex) + var splitIndex = 0 + while (splitIndex < numSplits) { + if (featureSplits(splitIndex).shouldGoLeft(featureValue, featureSplits)) { + agg.featureUpdate(leftNodeFeatureOffset, splitIndex, treePoint.label, instanceWeight) + } else { + agg.featureUpdate(rightNodeFeatureOffset, splitIndex, treePoint.label, instanceWeight) + } + splitIndex += 1 + } + } else { + // Ordered feature + val binIndex = treePoint.binnedFeatures(featureIndex) + agg.update(featureIndexIdx, binIndex, treePoint.label, instanceWeight) + } + featureIndexIdx += 1 + } + } + + /** + * Helper for binSeqOp, for regression and for classification with only ordered features. + * + * For each feature, the sufficient statistics of one bin are updated. + * + * @param agg Array storing aggregate calculation, with a set of sufficient statistics for + * each (feature, bin). + * @param treePoint Data point being aggregated. + * @param instanceWeight Weight (importance) of instance in dataset. + */ + private def orderedBinSeqOp( + agg: DTStatsAggregator, + treePoint: TreePoint, + instanceWeight: Double, + featuresForNode: Option[Array[Int]]): Unit = { + val label = treePoint.label + + // Iterate over features. + if (featuresForNode.nonEmpty) { + // Use subsampled features + var featureIndexIdx = 0 + while (featureIndexIdx < featuresForNode.get.length) { + val binIndex = treePoint.binnedFeatures(featuresForNode.get.apply(featureIndexIdx)) + agg.update(featureIndexIdx, binIndex, label, instanceWeight) + featureIndexIdx += 1 + } + } else { + // Use all features + val numFeatures = agg.metadata.numFeatures + var featureIndex = 0 + while (featureIndex < numFeatures) { + val binIndex = treePoint.binnedFeatures(featureIndex) + agg.update(featureIndex, binIndex, label, instanceWeight) + featureIndex += 1 + } + } + } + + /** + * Given a group of nodes, this finds the best split for each node. + * + * @param input Training data: RDD of [[org.apache.spark.mllib.tree.impl.TreePoint]] + * @param metadata Learning and dataset metadata + * @param topNodes Root node for each tree. Used for matching instances with nodes. + * @param nodesForGroup Mapping: treeIndex --> nodes to be split in tree + * @param treeToNodeToIndexInfo Mapping: treeIndex --> nodeIndex --> nodeIndexInfo, + * where nodeIndexInfo stores the index in the group and the + * feature subsets (if using feature subsets). + * @param splits possible splits for all features, indexed (numFeatures)(numSplits) + * @param nodeQueue Queue of nodes to split, with values (treeIndex, node). + * Updated with new non-leaf nodes which are created. + * @param nodeIdCache Node Id cache containing an RDD of Array[Int] where + * each value in the array is the data point's node Id + * for a corresponding tree. This is used to prevent the need + * to pass the entire tree to the executors during + * the node stat aggregation phase. + */ + private[tree] def findBestSplits( + input: RDD[BaggedPoint[TreePoint]], + metadata: DecisionTreeMetadata, + topNodes: Array[LearningNode], + nodesForGroup: Map[Int, Array[LearningNode]], + treeToNodeToIndexInfo: Map[Int, Map[Int, NodeIndexInfo]], + splits: Array[Array[Split]], + nodeQueue: mutable.Queue[(Int, LearningNode)], + timer: TimeTracker = new TimeTracker, + nodeIdCache: Option[NodeIdCache] = None): Unit = { + + /* + * The high-level descriptions of the best split optimizations are noted here. + * + * *Group-wise training* + * We perform bin calculations for groups of nodes to reduce the number of + * passes over the data. Each iteration requires more computation and storage, + * but saves several iterations over the data. + * + * *Bin-wise computation* + * We use a bin-wise best split computation strategy instead of a straightforward best split + * computation strategy. Instead of analyzing each sample for contribution to the left/right + * child node impurity of every split, we first categorize each feature of a sample into a + * bin. We exploit this structure to calculate aggregates for bins and then use these aggregates + * to calculate information gain for each split. + * + * *Aggregation over partitions* + * Instead of performing a flatMap/reduceByKey operation, we exploit the fact that we know + * the number of splits in advance. Thus, we store the aggregates (at the appropriate + * indices) in a single array for all bins and rely upon the RDD aggregate method to + * drastically reduce the communication overhead. + */ + + // numNodes: Number of nodes in this group + val numNodes = nodesForGroup.values.map(_.length).sum + logDebug("numNodes = " + numNodes) + logDebug("numFeatures = " + metadata.numFeatures) + logDebug("numClasses = " + metadata.numClasses) + logDebug("isMulticlass = " + metadata.isMulticlass) + logDebug("isMulticlassWithCategoricalFeatures = " + + metadata.isMulticlassWithCategoricalFeatures) + logDebug("using nodeIdCache = " + nodeIdCache.nonEmpty.toString) + + /** + * Performs a sequential aggregation over a partition for a particular tree and node. + * + * For each feature, the aggregate sufficient statistics are updated for the relevant + * bins. + * + * @param treeIndex Index of the tree that we want to perform aggregation for. + * @param nodeInfo The node info for the tree node. + * @param agg Array storing aggregate calculation, with a set of sufficient statistics + * for each (node, feature, bin). + * @param baggedPoint Data point being aggregated. + */ + def nodeBinSeqOp( + treeIndex: Int, + nodeInfo: NodeIndexInfo, + agg: Array[DTStatsAggregator], + baggedPoint: BaggedPoint[TreePoint]): Unit = { + if (nodeInfo != null) { + val aggNodeIndex = nodeInfo.nodeIndexInGroup + val featuresForNode = nodeInfo.featureSubset + val instanceWeight = baggedPoint.subsampleWeights(treeIndex) + if (metadata.unorderedFeatures.isEmpty) { + orderedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, instanceWeight, featuresForNode) + } else { + mixedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, splits, + metadata.unorderedFeatures, instanceWeight, featuresForNode) + } + } + } + + /** + * Performs a sequential aggregation over a partition. + * + * Each data point contributes to one node. For each feature, + * the aggregate sufficient statistics are updated for the relevant bins. + * + * @param agg Array storing aggregate calculation, with a set of sufficient statistics for + * each (node, feature, bin). + * @param baggedPoint Data point being aggregated. + * @return agg + */ + def binSeqOp( + agg: Array[DTStatsAggregator], + baggedPoint: BaggedPoint[TreePoint]): Array[DTStatsAggregator] = { + treeToNodeToIndexInfo.foreach { case (treeIndex, nodeIndexToInfo) => + val nodeIndex = + predictNodeIndex(topNodes(treeIndex), baggedPoint.datum.binnedFeatures, splits) + nodeBinSeqOp(treeIndex, nodeIndexToInfo.getOrElse(nodeIndex, null), agg, baggedPoint) + } + agg + } + + /** + * Do the same thing as binSeqOp, but with nodeIdCache. + */ + def binSeqOpWithNodeIdCache( + agg: Array[DTStatsAggregator], + dataPoint: (BaggedPoint[TreePoint], Array[Int])): Array[DTStatsAggregator] = { + treeToNodeToIndexInfo.foreach { case (treeIndex, nodeIndexToInfo) => + val baggedPoint = dataPoint._1 + val nodeIdCache = dataPoint._2 + val nodeIndex = nodeIdCache(treeIndex) + nodeBinSeqOp(treeIndex, nodeIndexToInfo.getOrElse(nodeIndex, null), agg, baggedPoint) + } + + agg + } + + /** + * Get node index in group --> features indices map, + * which is a short cut to find feature indices for a node given node index in group. + */ + def getNodeToFeatures( + treeToNodeToIndexInfo: Map[Int, Map[Int, NodeIndexInfo]]): Option[Map[Int, Array[Int]]] = { + if (!metadata.subsamplingFeatures) { + None + } else { + val mutableNodeToFeatures = new mutable.HashMap[Int, Array[Int]]() + treeToNodeToIndexInfo.values.foreach { nodeIdToNodeInfo => + nodeIdToNodeInfo.values.foreach { nodeIndexInfo => + assert(nodeIndexInfo.featureSubset.isDefined) + mutableNodeToFeatures(nodeIndexInfo.nodeIndexInGroup) = nodeIndexInfo.featureSubset.get + } + } + Some(mutableNodeToFeatures.toMap) + } + } + + // array of nodes to train indexed by node index in group + val nodes = new Array[LearningNode](numNodes) + nodesForGroup.foreach { case (treeIndex, nodesForTree) => + nodesForTree.foreach { node => + nodes(treeToNodeToIndexInfo(treeIndex)(node.id).nodeIndexInGroup) = node + } + } + + // Calculate best splits for all nodes in the group + timer.start("chooseSplits") + + // In each partition, iterate all instances and compute aggregate stats for each node, + // yield an (nodeIndex, nodeAggregateStats) pair for each node. + // After a `reduceByKey` operation, + // stats of a node will be shuffled to a particular partition and be combined together, + // then best splits for nodes are found there. + // Finally, only best Splits for nodes are collected to driver to construct decision tree. + val nodeToFeatures = getNodeToFeatures(treeToNodeToIndexInfo) + val nodeToFeaturesBc = input.sparkContext.broadcast(nodeToFeatures) + + val partitionAggregates : RDD[(Int, DTStatsAggregator)] = if (nodeIdCache.nonEmpty) { + input.zip(nodeIdCache.get.nodeIdsForInstances).mapPartitions { points => + // Construct a nodeStatsAggregators array to hold node aggregate stats, + // each node will have a nodeStatsAggregator + val nodeStatsAggregators = Array.tabulate(numNodes) { nodeIndex => + val featuresForNode = nodeToFeaturesBc.value.flatMap { nodeToFeatures => + Some(nodeToFeatures(nodeIndex)) + } + new DTStatsAggregator(metadata, featuresForNode) + } + + // iterator all instances in current partition and update aggregate stats + points.foreach(binSeqOpWithNodeIdCache(nodeStatsAggregators, _)) + + // transform nodeStatsAggregators array to (nodeIndex, nodeAggregateStats) pairs, + // which can be combined with other partition using `reduceByKey` + nodeStatsAggregators.view.zipWithIndex.map(_.swap).iterator + } + } else { + input.mapPartitions { points => + // Construct a nodeStatsAggregators array to hold node aggregate stats, + // each node will have a nodeStatsAggregator + val nodeStatsAggregators = Array.tabulate(numNodes) { nodeIndex => + val featuresForNode = nodeToFeaturesBc.value.flatMap { nodeToFeatures => + Some(nodeToFeatures(nodeIndex)) + } + new DTStatsAggregator(metadata, featuresForNode) + } + + // iterator all instances in current partition and update aggregate stats + points.foreach(binSeqOp(nodeStatsAggregators, _)) + + // transform nodeStatsAggregators array to (nodeIndex, nodeAggregateStats) pairs, + // which can be combined with other partition using `reduceByKey` + nodeStatsAggregators.view.zipWithIndex.map(_.swap).iterator + } + } + + val nodeToBestSplits = partitionAggregates.reduceByKey((a, b) => a.merge(b)).map { + case (nodeIndex, aggStats) => + val featuresForNode = nodeToFeaturesBc.value.flatMap { nodeToFeatures => + Some(nodeToFeatures(nodeIndex)) + } + + // find best split for each node + val (split: Split, stats: InformationGainStats, predict: Predict) = + binsToBestSplit(aggStats, splits, featuresForNode, nodes(nodeIndex)) + (nodeIndex, (split, stats, predict)) + }.collectAsMap() + + timer.stop("chooseSplits") + + val nodeIdUpdaters = if (nodeIdCache.nonEmpty) { + Array.fill[mutable.Map[Int, NodeIndexUpdater]]( + metadata.numTrees)(mutable.Map[Int, NodeIndexUpdater]()) + } else { + null + } + // Iterate over all nodes in this group. + nodesForGroup.foreach { case (treeIndex, nodesForTree) => + nodesForTree.foreach { node => + val nodeIndex = node.id + val nodeInfo = treeToNodeToIndexInfo(treeIndex)(nodeIndex) + val aggNodeIndex = nodeInfo.nodeIndexInGroup + val (split: Split, stats: InformationGainStats, predict: Predict) = + nodeToBestSplits(aggNodeIndex) + logDebug("best split = " + split) + + // Extract info for this node. Create children if not leaf. + val isLeaf = + (stats.gain <= 0) || (LearningNode.indexToLevel(nodeIndex) == metadata.maxDepth) + node.predictionStats = predict + node.isLeaf = isLeaf + node.stats = Some(stats) + node.impurity = stats.impurity + logDebug("Node = " + node) + + if (!isLeaf) { + node.split = Some(split) + val childIsLeaf = (LearningNode.indexToLevel(nodeIndex) + 1) == metadata.maxDepth + val leftChildIsLeaf = childIsLeaf || (stats.leftImpurity == 0.0) + val rightChildIsLeaf = childIsLeaf || (stats.rightImpurity == 0.0) + node.leftChild = Some(LearningNode(LearningNode.leftChildIndex(nodeIndex), + stats.leftPredict, stats.leftImpurity, leftChildIsLeaf)) + node.rightChild = Some(LearningNode(LearningNode.rightChildIndex(nodeIndex), + stats.rightPredict, stats.rightImpurity, rightChildIsLeaf)) + + if (nodeIdCache.nonEmpty) { + val nodeIndexUpdater = NodeIndexUpdater( + split = split, + nodeIndex = nodeIndex) + nodeIdUpdaters(treeIndex).put(nodeIndex, nodeIndexUpdater) + } + + // enqueue left child and right child if they are not leaves + if (!leftChildIsLeaf) { + nodeQueue.enqueue((treeIndex, node.leftChild.get)) + } + if (!rightChildIsLeaf) { + nodeQueue.enqueue((treeIndex, node.rightChild.get)) + } + + logDebug("leftChildIndex = " + node.leftChild.get.id + + ", impurity = " + stats.leftImpurity) + logDebug("rightChildIndex = " + node.rightChild.get.id + + ", impurity = " + stats.rightImpurity) + } + } + } + + if (nodeIdCache.nonEmpty) { + // Update the cache if needed. + nodeIdCache.get.updateNodeIndices(input, nodeIdUpdaters, splits) + } + } + + /** + * Calculate the information gain for a given (feature, split) based upon left/right aggregates. + * @param leftImpurityCalculator left node aggregates for this (feature, split) + * @param rightImpurityCalculator right node aggregate for this (feature, split) + * @return information gain and statistics for split + */ + private def calculateGainForSplit( + leftImpurityCalculator: ImpurityCalculator, + rightImpurityCalculator: ImpurityCalculator, + metadata: DecisionTreeMetadata, + impurity: Double): InformationGainStats = { + val leftCount = leftImpurityCalculator.count + val rightCount = rightImpurityCalculator.count + + // If left child or right child doesn't satisfy minimum instances per node, + // then this split is invalid, return invalid information gain stats. + if ((leftCount < metadata.minInstancesPerNode) || + (rightCount < metadata.minInstancesPerNode)) { + return InformationGainStats.invalidInformationGainStats + } + + val totalCount = leftCount + rightCount + + val leftImpurity = leftImpurityCalculator.calculate() // Note: This equals 0 if count = 0 + val rightImpurity = rightImpurityCalculator.calculate() + + val leftWeight = leftCount / totalCount.toDouble + val rightWeight = rightCount / totalCount.toDouble + + val gain = impurity - leftWeight * leftImpurity - rightWeight * rightImpurity + + // if information gain doesn't satisfy minimum information gain, + // then this split is invalid, return invalid information gain stats. + if (gain < metadata.minInfoGain) { + return InformationGainStats.invalidInformationGainStats + } + + // calculate left and right predict + val leftPredict = calculatePredict(leftImpurityCalculator) + val rightPredict = calculatePredict(rightImpurityCalculator) + + new InformationGainStats(gain, impurity, leftImpurity, rightImpurity, + leftPredict, rightPredict) + } + + private def calculatePredict(impurityCalculator: ImpurityCalculator): Predict = { + val predict = impurityCalculator.predict + val prob = impurityCalculator.prob(predict) + new Predict(predict, prob) + } + + /** + * Calculate predict value for current node, given stats of any split. + * Note that this function is called only once for each node. + * @param leftImpurityCalculator left node aggregates for a split + * @param rightImpurityCalculator right node aggregates for a split + * @return predict value and impurity for current node + */ + private def calculatePredictImpurity( + leftImpurityCalculator: ImpurityCalculator, + rightImpurityCalculator: ImpurityCalculator): (Predict, Double) = { + val parentNodeAgg = leftImpurityCalculator.copy + parentNodeAgg.add(rightImpurityCalculator) + val predict = calculatePredict(parentNodeAgg) + val impurity = parentNodeAgg.calculate() + + (predict, impurity) + } + + /** + * Find the best split for a node. + * @param binAggregates Bin statistics. + * @return tuple for best split: (Split, information gain, prediction at node) + */ + private def binsToBestSplit( + binAggregates: DTStatsAggregator, + splits: Array[Array[Split]], + featuresForNode: Option[Array[Int]], + node: LearningNode): (Split, InformationGainStats, Predict) = { + + // Calculate prediction and impurity if current node is top node + val level = LearningNode.indexToLevel(node.id) + var predictionAndImpurity: Option[(Predict, Double)] = if (level == 0) { + None + } else { + Some((node.predictionStats, node.impurity)) + } + + // For each (feature, split), calculate the gain, and select the best (feature, split). + val (bestSplit, bestSplitStats) = + Range(0, binAggregates.metadata.numFeaturesPerNode).map { featureIndexIdx => + val featureIndex = if (featuresForNode.nonEmpty) { + featuresForNode.get.apply(featureIndexIdx) + } else { + featureIndexIdx + } + val numSplits = binAggregates.metadata.numSplits(featureIndex) + if (binAggregates.metadata.isContinuous(featureIndex)) { + // Cumulative sum (scanLeft) of bin statistics. + // Afterwards, binAggregates for a bin is the sum of aggregates for + // that bin + all preceding bins. + val nodeFeatureOffset = binAggregates.getFeatureOffset(featureIndexIdx) + var splitIndex = 0 + while (splitIndex < numSplits) { + binAggregates.mergeForFeature(nodeFeatureOffset, splitIndex + 1, splitIndex) + splitIndex += 1 + } + // Find best split. + val (bestFeatureSplitIndex, bestFeatureGainStats) = + Range(0, numSplits).map { case splitIdx => + val leftChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, splitIdx) + val rightChildStats = + binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits) + rightChildStats.subtract(leftChildStats) + predictionAndImpurity = Some(predictionAndImpurity.getOrElse( + calculatePredictImpurity(leftChildStats, rightChildStats))) + val gainStats = calculateGainForSplit(leftChildStats, + rightChildStats, binAggregates.metadata, predictionAndImpurity.get._2) + (splitIdx, gainStats) + }.maxBy(_._2.gain) + (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats) + } else if (binAggregates.metadata.isUnordered(featureIndex)) { + // Unordered categorical feature + val (leftChildOffset, rightChildOffset) = + binAggregates.getLeftRightFeatureOffsets(featureIndexIdx) + val (bestFeatureSplitIndex, bestFeatureGainStats) = + Range(0, numSplits).map { splitIndex => + val leftChildStats = binAggregates.getImpurityCalculator(leftChildOffset, splitIndex) + val rightChildStats = + binAggregates.getImpurityCalculator(rightChildOffset, splitIndex) + predictionAndImpurity = Some(predictionAndImpurity.getOrElse( + calculatePredictImpurity(leftChildStats, rightChildStats))) + val gainStats = calculateGainForSplit(leftChildStats, + rightChildStats, binAggregates.metadata, predictionAndImpurity.get._2) + (splitIndex, gainStats) + }.maxBy(_._2.gain) + (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats) + } else { + // Ordered categorical feature + val nodeFeatureOffset = binAggregates.getFeatureOffset(featureIndexIdx) + val numCategories = binAggregates.metadata.numBins(featureIndex) + + /* Each bin is one category (feature value). + * The bins are ordered based on centroidForCategories, and this ordering determines which + * splits are considered. (With K categories, we consider K - 1 possible splits.) + * + * centroidForCategories is a list: (category, centroid) + */ + val centroidForCategories = if (binAggregates.metadata.isMulticlass) { + // For categorical variables in multiclass classification, + // the bins are ordered by the impurity of their corresponding labels. + Range(0, numCategories).map { case featureValue => + val categoryStats = + binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue) + val centroid = if (categoryStats.count != 0) { + categoryStats.calculate() + } else { + Double.MaxValue + } + (featureValue, centroid) + } + } else { // regression or binary classification + // For categorical variables in regression and binary classification, + // the bins are ordered by the centroid of their corresponding labels. + Range(0, numCategories).map { case featureValue => + val categoryStats = + binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue) + val centroid = if (categoryStats.count != 0) { + categoryStats.predict + } else { + Double.MaxValue + } + (featureValue, centroid) + } + } + + logDebug("Centroids for categorical variable: " + centroidForCategories.mkString(",")) + + // bins sorted by centroids + val categoriesSortedByCentroid = centroidForCategories.toList.sortBy(_._2) + + logDebug("Sorted centroids for categorical variable = " + + categoriesSortedByCentroid.mkString(",")) + + // Cumulative sum (scanLeft) of bin statistics. + // Afterwards, binAggregates for a bin is the sum of aggregates for + // that bin + all preceding bins. + var splitIndex = 0 + while (splitIndex < numSplits) { + val currentCategory = categoriesSortedByCentroid(splitIndex)._1 + val nextCategory = categoriesSortedByCentroid(splitIndex + 1)._1 + binAggregates.mergeForFeature(nodeFeatureOffset, nextCategory, currentCategory) + splitIndex += 1 + } + // lastCategory = index of bin with total aggregates for this (node, feature) + val lastCategory = categoriesSortedByCentroid.last._1 + // Find best split. + val (bestFeatureSplitIndex, bestFeatureGainStats) = + Range(0, numSplits).map { splitIndex => + val featureValue = categoriesSortedByCentroid(splitIndex)._1 + val leftChildStats = + binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue) + val rightChildStats = + binAggregates.getImpurityCalculator(nodeFeatureOffset, lastCategory) + rightChildStats.subtract(leftChildStats) + predictionAndImpurity = Some(predictionAndImpurity.getOrElse( + calculatePredictImpurity(leftChildStats, rightChildStats))) + val gainStats = calculateGainForSplit(leftChildStats, + rightChildStats, binAggregates.metadata, predictionAndImpurity.get._2) + (splitIndex, gainStats) + }.maxBy(_._2.gain) + val categoriesForSplit = + categoriesSortedByCentroid.map(_._1.toDouble).slice(0, bestFeatureSplitIndex + 1) + val bestFeatureSplit = + new CategoricalSplit(featureIndex, categoriesForSplit.toArray, numCategories) + (bestFeatureSplit, bestFeatureGainStats) + } + }.maxBy(_._2.gain) + + (bestSplit, bestSplitStats, predictionAndImpurity.get._1) + } + + /** + * Returns splits and bins for decision tree calculation. + * Continuous and categorical features are handled differently. + * + * Continuous features: + * For each feature, there are numBins - 1 possible splits representing the possible binary + * decisions at each node in the tree. + * This finds locations (feature values) for splits using a subsample of the data. + * + * Categorical features: + * For each feature, there is 1 bin per split. + * Splits and bins are handled in 2 ways: + * (a) "unordered features" + * For multiclass classification with a low-arity feature + * (i.e., if isMulticlass && isSpaceSufficientForAllCategoricalSplits), + * the feature is split based on subsets of categories. + * (b) "ordered features" + * For regression and binary classification, + * and for multiclass classification with a high-arity feature, + * there is one bin per category. + * + * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] + * @param metadata Learning and dataset metadata + * @return A tuple of (splits, bins). + * Splits is an Array of [[org.apache.spark.mllib.tree.model.Split]] + * of size (numFeatures, numSplits). + * Bins is an Array of [[org.apache.spark.mllib.tree.model.Bin]] + * of size (numFeatures, numBins). + */ + protected[tree] def findSplits( + input: RDD[LabeledPoint], + metadata: DecisionTreeMetadata): Array[Array[Split]] = { + + logDebug("isMulticlass = " + metadata.isMulticlass) + + val numFeatures = metadata.numFeatures + + // Sample the input only if there are continuous features. + val hasContinuousFeatures = Range(0, numFeatures).exists(metadata.isContinuous) + val sampledInput = if (hasContinuousFeatures) { + // Calculate the number of samples for approximate quantile calculation. + val requiredSamples = math.max(metadata.maxBins * metadata.maxBins, 10000) + val fraction = if (requiredSamples < metadata.numExamples) { + requiredSamples.toDouble / metadata.numExamples + } else { + 1.0 + } + logDebug("fraction of data used for calculating quantiles = " + fraction) + input.sample(withReplacement = false, fraction, new XORShiftRandom(1).nextInt()).collect() + } else { + new Array[LabeledPoint](0) + } + + val splits = new Array[Array[Split]](numFeatures) + + // Find all splits. + // Iterate over all features. + var featureIndex = 0 + while (featureIndex < numFeatures) { + if (metadata.isContinuous(featureIndex)) { + val featureSamples = sampledInput.map(_.features(featureIndex)) + val featureSplits = findSplitsForContinuousFeature(featureSamples, metadata, featureIndex) + + val numSplits = featureSplits.length + logDebug(s"featureIndex = $featureIndex, numSplits = $numSplits") + splits(featureIndex) = new Array[Split](numSplits) + + var splitIndex = 0 + while (splitIndex < numSplits) { + val threshold = featureSplits(splitIndex) + splits(featureIndex)(splitIndex) = new ContinuousSplit(featureIndex, threshold) + splitIndex += 1 + } + } else { + // Categorical feature + if (metadata.isUnordered(featureIndex)) { + val numSplits = metadata.numSplits(featureIndex) + val featureArity = metadata.featureArity(featureIndex) + // TODO: Use an implicit representation mapping each category to a subset of indices. + // I.e., track indices such that we can calculate the set of bins for which + // feature value x splits to the left. + // Unordered features + // 2^(maxFeatureValue - 1) - 1 combinations + splits(featureIndex) = new Array[Split](numSplits) + var splitIndex = 0 + while (splitIndex < numSplits) { + val categories: List[Double] = + extractMultiClassCategories(splitIndex + 1, featureArity) + splits(featureIndex)(splitIndex) = + new CategoricalSplit(featureIndex, categories.toArray, featureArity) + splitIndex += 1 + } + } else { + // Ordered features + // Bins correspond to feature values, so we do not need to compute splits or bins + // beforehand. Splits are constructed as needed during training. + splits(featureIndex) = new Array[Split](0) + } + } + featureIndex += 1 + } + splits + } + + /** + * Nested method to extract list of eligible categories given an index. It extracts the + * position of ones in a binary representation of the input. If binary + * representation of an number is 01101 (13), the output list should (3.0, 2.0, + * 0.0). The maxFeatureValue depict the number of rightmost digits that will be tested for ones. + */ + private[tree] def extractMultiClassCategories( + input: Int, + maxFeatureValue: Int): List[Double] = { + var categories = List[Double]() + var j = 0 + var bitShiftedInput = input + while (j < maxFeatureValue) { + if (bitShiftedInput % 2 != 0) { + // updating the list of categories. + categories = j.toDouble :: categories + } + // Right shift by one + bitShiftedInput = bitShiftedInput >> 1 + j += 1 + } + categories + } + + /** + * Find splits for a continuous feature + * NOTE: Returned number of splits is set based on `featureSamples` and + * could be different from the specified `numSplits`. + * The `numSplits` attribute in the `DecisionTreeMetadata` class will be set accordingly. + * @param featureSamples feature values of each sample + * @param metadata decision tree metadata + * NOTE: `metadata.numbins` will be changed accordingly + * if there are not enough splits to be found + * @param featureIndex feature index to find splits + * @return array of splits + */ + private[tree] def findSplitsForContinuousFeature( + featureSamples: Array[Double], + metadata: DecisionTreeMetadata, + featureIndex: Int): Array[Double] = { + require(metadata.isContinuous(featureIndex), + "findSplitsForContinuousFeature can only be used to find splits for a continuous feature.") + + val splits = { + val numSplits = metadata.numSplits(featureIndex) + + // get count for each distinct value + val valueCountMap = featureSamples.foldLeft(Map.empty[Double, Int]) { (m, x) => + m + ((x, m.getOrElse(x, 0) + 1)) + } + // sort distinct values + val valueCounts = valueCountMap.toSeq.sortBy(_._1).toArray + + // if possible splits is not enough or just enough, just return all possible splits + val possibleSplits = valueCounts.length + if (possibleSplits <= numSplits) { + valueCounts.map(_._1) + } else { + // stride between splits + val stride: Double = featureSamples.length.toDouble / (numSplits + 1) + logDebug("stride = " + stride) + + // iterate `valueCount` to find splits + val splitsBuilder = mutable.ArrayBuilder.make[Double] + var index = 1 + // currentCount: sum of counts of values that have been visited + var currentCount = valueCounts(0)._2 + // targetCount: target value for `currentCount`. + // If `currentCount` is closest value to `targetCount`, + // then current value is a split threshold. + // After finding a split threshold, `targetCount` is added by stride. + var targetCount = stride + while (index < valueCounts.length) { + val previousCount = currentCount + currentCount += valueCounts(index)._2 + val previousGap = math.abs(previousCount - targetCount) + val currentGap = math.abs(currentCount - targetCount) + // If adding count of current value to currentCount + // makes the gap between currentCount and targetCount smaller, + // previous value is a split threshold. + if (previousGap < currentGap) { + splitsBuilder += valueCounts(index - 1)._1 + targetCount += stride + } + index += 1 + } + + splitsBuilder.result() + } + } + + // TODO: Do not fail; just ignore the useless feature. + assert(splits.length > 0, + s"DecisionTree could not handle feature $featureIndex since it had only 1 unique value." + + " Please remove this feature and then try again.") + // set number of splits accordingly + metadata.setNumSplits(featureIndex, splits.length) + + splits + } + + private[tree] class NodeIndexInfo( + val nodeIndexInGroup: Int, + val featureSubset: Option[Array[Int]]) extends Serializable + + /** + * Pull nodes off of the queue, and collect a group of nodes to be split on this iteration. + * This tracks the memory usage for aggregates and stops adding nodes when too much memory + * will be needed; this allows an adaptive number of nodes since different nodes may require + * different amounts of memory (if featureSubsetStrategy is not "all"). + * + * @param nodeQueue Queue of nodes to split. + * @param maxMemoryUsage Bound on size of aggregate statistics. + * @return (nodesForGroup, treeToNodeToIndexInfo). + * nodesForGroup holds the nodes to split: treeIndex --> nodes in tree. + * + * treeToNodeToIndexInfo holds indices selected features for each node: + * treeIndex --> (global) node index --> (node index in group, feature indices). + * The (global) node index is the index in the tree; the node index in group is the + * index in [0, numNodesInGroup) of the node in this group. + * The feature indices are None if not subsampling features. + */ + private[tree] def selectNodesToSplit( + nodeQueue: mutable.Queue[(Int, LearningNode)], + maxMemoryUsage: Long, + metadata: DecisionTreeMetadata, + rng: Random): (Map[Int, Array[LearningNode]], Map[Int, Map[Int, NodeIndexInfo]]) = { + // Collect some nodes to split: + // nodesForGroup(treeIndex) = nodes to split + val mutableNodesForGroup = new mutable.HashMap[Int, mutable.ArrayBuffer[LearningNode]]() + val mutableTreeToNodeToIndexInfo = + new mutable.HashMap[Int, mutable.HashMap[Int, NodeIndexInfo]]() + var memUsage: Long = 0L + var numNodesInGroup = 0 + while (nodeQueue.nonEmpty && memUsage < maxMemoryUsage) { + val (treeIndex, node) = nodeQueue.head + // Choose subset of features for node (if subsampling). + val featureSubset: Option[Array[Int]] = if (metadata.subsamplingFeatures) { + Some(SamplingUtils.reservoirSampleAndCount(Range(0, + metadata.numFeatures).iterator, metadata.numFeaturesPerNode, rng.nextLong())._1) + } else { + None + } + // Check if enough memory remains to add this node to the group. + val nodeMemUsage = RandomForest.aggregateSizeForNode(metadata, featureSubset) * 8L + if (memUsage + nodeMemUsage <= maxMemoryUsage) { + nodeQueue.dequeue() + mutableNodesForGroup.getOrElseUpdate(treeIndex, new mutable.ArrayBuffer[LearningNode]()) += + node + mutableTreeToNodeToIndexInfo + .getOrElseUpdate(treeIndex, new mutable.HashMap[Int, NodeIndexInfo]())(node.id) + = new NodeIndexInfo(numNodesInGroup, featureSubset) + } + numNodesInGroup += 1 + memUsage += nodeMemUsage + } + // Convert mutable maps to immutable ones. + val nodesForGroup: Map[Int, Array[LearningNode]] = + mutableNodesForGroup.mapValues(_.toArray).toMap + val treeToNodeToIndexInfo = mutableTreeToNodeToIndexInfo.mapValues(_.toMap).toMap + (nodesForGroup, treeToNodeToIndexInfo) + } + + /** + * Get the number of values to be stored for this node in the bin aggregates. + * @param featureSubset Indices of features which may be split at this node. + * If None, then use all features. + */ + private def aggregateSizeForNode( + metadata: DecisionTreeMetadata, + featureSubset: Option[Array[Int]]): Long = { + val totalBins = if (featureSubset.nonEmpty) { + featureSubset.get.map(featureIndex => metadata.numBins(featureIndex).toLong).sum + } else { + metadata.numBins.map(_.toLong).sum + } + if (metadata.isClassification) { + metadata.numClasses * totalBins + } else { + 3 * totalBins + } + } + +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/TreePoint.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/TreePoint.scala new file mode 100644 index 0000000000000..9fa27e5e1f721 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/TreePoint.scala @@ -0,0 +1,134 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.tree.impl + +import org.apache.spark.ml.tree.{ContinuousSplit, Split} +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.tree.impl.DecisionTreeMetadata +import org.apache.spark.rdd.RDD + + +/** + * Internal representation of LabeledPoint for DecisionTree. + * This bins feature values based on a subsampled of data as follows: + * (a) Continuous features are binned into ranges. + * (b) Unordered categorical features are binned based on subsets of feature values. + * "Unordered categorical features" are categorical features with low arity used in + * multiclass classification. + * (c) Ordered categorical features are binned based on feature values. + * "Ordered categorical features" are categorical features with high arity, + * or any categorical feature used in regression or binary classification. + * + * @param label Label from LabeledPoint + * @param binnedFeatures Binned feature values. + * Same length as LabeledPoint.features, but values are bin indices. + */ +private[spark] class TreePoint(val label: Double, val binnedFeatures: Array[Int]) + extends Serializable { +} + +private[spark] object TreePoint { + + /** + * Convert an input dataset into its TreePoint representation, + * binning feature values in preparation for DecisionTree training. + * @param input Input dataset. + * @param splits Splits for features, of size (numFeatures, numSplits). + * @param metadata Learning and dataset metadata + * @return TreePoint dataset representation + */ + def convertToTreeRDD( + input: RDD[LabeledPoint], + splits: Array[Array[Split]], + metadata: DecisionTreeMetadata): RDD[TreePoint] = { + // Construct arrays for featureArity for efficiency in the inner loop. + val featureArity: Array[Int] = new Array[Int](metadata.numFeatures) + var featureIndex = 0 + while (featureIndex < metadata.numFeatures) { + featureArity(featureIndex) = metadata.featureArity.getOrElse(featureIndex, 0) + featureIndex += 1 + } + val thresholds: Array[Array[Double]] = featureArity.zipWithIndex.map { case (arity, idx) => + if (arity == 0) { + splits(idx).map(_.asInstanceOf[ContinuousSplit].threshold) + } else { + Array.empty[Double] + } + } + input.map { x => + TreePoint.labeledPointToTreePoint(x, thresholds, featureArity) + } + } + + /** + * Convert one LabeledPoint into its TreePoint representation. + * @param thresholds For each feature, split thresholds for continuous features, + * empty for categorical features. + * @param featureArity Array indexed by feature, with value 0 for continuous and numCategories + * for categorical features. + */ + private def labeledPointToTreePoint( + labeledPoint: LabeledPoint, + thresholds: Array[Array[Double]], + featureArity: Array[Int]): TreePoint = { + val numFeatures = labeledPoint.features.size + val arr = new Array[Int](numFeatures) + var featureIndex = 0 + while (featureIndex < numFeatures) { + arr(featureIndex) = + findBin(featureIndex, labeledPoint, featureArity(featureIndex), thresholds(featureIndex)) + featureIndex += 1 + } + new TreePoint(labeledPoint.label, arr) + } + + /** + * Find discretized value for one (labeledPoint, feature). + * + * NOTE: We cannot use Bucketizer since it handles split thresholds differently than the old + * (mllib) tree API. We want to maintain the same behavior as the old tree API. + * + * @param featureArity 0 for continuous features; number of categories for categorical features. + */ + private def findBin( + featureIndex: Int, + labeledPoint: LabeledPoint, + featureArity: Int, + thresholds: Array[Double]): Int = { + val featureValue = labeledPoint.features(featureIndex) + + if (featureArity == 0) { + val idx = java.util.Arrays.binarySearch(thresholds, featureValue) + if (idx >= 0) { + idx + } else { + -idx - 1 + } + } else { + // Categorical feature bins are indexed by feature values. + if (featureValue < 0 || featureValue >= featureArity) { + throw new IllegalArgumentException( + s"DecisionTree given invalid data:" + + s" Feature $featureIndex is categorical with values in {0,...,${featureArity - 1}," + + s" but a data point gives it value $featureValue.\n" + + " Bad data point: " + labeledPoint.toString) + } + featureValue.toInt + } + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/BaggedPoint.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/BaggedPoint.scala index 089010c81ffb6..572815df0bc4a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/BaggedPoint.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/BaggedPoint.scala @@ -38,10 +38,10 @@ import org.apache.spark.util.random.XORShiftRandom * TODO: This does not currently support (Double) weighted instances. Once MLlib has weighted * dataset support, update. (We store subsampleWeights as Double for this future extension.) */ -private[tree] class BaggedPoint[Datum](val datum: Datum, val subsampleWeights: Array[Double]) +private[spark] class BaggedPoint[Datum](val datum: Datum, val subsampleWeights: Array[Double]) extends Serializable -private[tree] object BaggedPoint { +private[spark] object BaggedPoint { /** * Convert an input dataset into its BaggedPoint representation, @@ -60,7 +60,7 @@ private[tree] object BaggedPoint { subsamplingRate: Double, numSubsamples: Int, withReplacement: Boolean, - seed: Int = Utils.random.nextInt()): RDD[BaggedPoint[Datum]] = { + seed: Long = Utils.random.nextLong()): RDD[BaggedPoint[Datum]] = { if (withReplacement) { convertToBaggedRDDSamplingWithReplacement(input, subsamplingRate, numSubsamples, seed) } else { @@ -76,7 +76,7 @@ private[tree] object BaggedPoint { input: RDD[Datum], subsamplingRate: Double, numSubsamples: Int, - seed: Int): RDD[BaggedPoint[Datum]] = { + seed: Long): RDD[BaggedPoint[Datum]] = { input.mapPartitionsWithIndex { (partitionIndex, instances) => // Use random seed = seed + partitionIndex + 1 to make generation reproducible. val rng = new XORShiftRandom @@ -100,7 +100,7 @@ private[tree] object BaggedPoint { input: RDD[Datum], subsample: Double, numSubsamples: Int, - seed: Int): RDD[BaggedPoint[Datum]] = { + seed: Long): RDD[BaggedPoint[Datum]] = { input.mapPartitionsWithIndex { (partitionIndex, instances) => // Use random seed = seed + partitionIndex + 1 to make generation reproducible. val poisson = new PoissonDistribution(subsample) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala index ce8825cc03229..7985ed4b4c0fa 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala @@ -27,7 +27,7 @@ import org.apache.spark.mllib.tree.impurity._ * and helps with indexing. * This class is abstract to support learning with and without feature subsampling. */ -private[tree] class DTStatsAggregator( +private[spark] class DTStatsAggregator( val metadata: DecisionTreeMetadata, featureSubset: Option[Array[Int]]) extends Serializable { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala index f73896e37c05e..380291ac22bd3 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala @@ -37,7 +37,7 @@ import org.apache.spark.rdd.RDD * I.e., the feature takes values in {0, ..., arity - 1}. * @param numBins Number of bins for each feature. */ -private[tree] class DecisionTreeMetadata( +private[spark] class DecisionTreeMetadata( val numFeatures: Int, val numExamples: Long, val numClasses: Int, @@ -94,7 +94,7 @@ private[tree] class DecisionTreeMetadata( } -private[tree] object DecisionTreeMetadata extends Logging { +private[spark] object DecisionTreeMetadata extends Logging { /** * Construct a [[DecisionTreeMetadata]] instance for this dataset and parameters. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/NodeIdCache.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/NodeIdCache.scala index bdd0f576b048d..8f9eb24b57b55 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/NodeIdCache.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/NodeIdCache.scala @@ -75,7 +75,7 @@ private[tree] case class NodeIndexUpdater( * (how often should the cache be checkpointed.). */ @DeveloperApi -private[tree] class NodeIdCache( +private[spark] class NodeIdCache( var nodeIdsForInstances: RDD[Array[Int]], val checkpointInterval: Int) { @@ -170,7 +170,7 @@ private[tree] class NodeIdCache( } @DeveloperApi -private[tree] object NodeIdCache { +private[spark] object NodeIdCache { /** * Initialize the node Id cache with initial node Id values. * @param data The RDD of training rows. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TimeTracker.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TimeTracker.scala index d215d68c4279e..aac84243d5ce1 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TimeTracker.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TimeTracker.scala @@ -25,7 +25,7 @@ import org.apache.spark.annotation.Experimental * Time tracker implementation which holds labeled timers. */ @Experimental -private[tree] class TimeTracker extends Serializable { +private[spark] class TimeTracker extends Serializable { private val starts: MutableHashMap[String, Long] = new MutableHashMap[String, Long]() diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala index 50b292e71b067..21919d69a38a3 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala @@ -37,11 +37,11 @@ import org.apache.spark.rdd.RDD * @param binnedFeatures Binned feature values. * Same length as LabeledPoint.features, but values are bin indices. */ -private[tree] class TreePoint(val label: Double, val binnedFeatures: Array[Int]) +private[spark] class TreePoint(val label: Double, val binnedFeatures: Array[Int]) extends Serializable { } -private[tree] object TreePoint { +private[spark] object TreePoint { /** * Convert an input dataset into its TreePoint representation, diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala index 72eb24c49264a..578749d85a4e6 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala @@ -57,7 +57,7 @@ trait Impurity extends Serializable { * Note: Instances of this class do not hold the data; they operate on views of the data. * @param statsSize Length of the vector of sufficient statistics for one bin. */ -private[tree] abstract class ImpurityAggregator(val statsSize: Int) extends Serializable { +private[spark] abstract class ImpurityAggregator(val statsSize: Int) extends Serializable { /** * Merge the stats from one bin into another. @@ -95,7 +95,7 @@ private[tree] abstract class ImpurityAggregator(val statsSize: Int) extends Seri * (node, feature, bin). * @param stats Array of sufficient statistics for a (node, feature, bin). */ -private[tree] abstract class ImpurityCalculator(val stats: Array[Double]) { +private[spark] abstract class ImpurityCalculator(val stats: Array[Double]) { /** * Make a deep copy of this [[ImpurityCalculator]]. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala index 2d087c967f679..dc9e0f9f51ffb 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala @@ -67,7 +67,7 @@ class InformationGainStats( } -private[tree] object InformationGainStats { +private[spark] object InformationGainStats { /** * An [[org.apache.spark.mllib.tree.model.InformationGainStats]] object to * denote that current split doesn't satisfies minimum info gain or diff --git a/mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java b/mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java index 71b041818d7ee..ebe800e749e05 100644 --- a/mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java @@ -57,7 +57,7 @@ public void runDT() { JavaRDD data = sc.parallelize( LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache(); Map categoricalFeatures = new HashMap(); - DataFrame dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 2); + DataFrame dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 0); // This tests setters. Training with various options is tested in Scala. DecisionTreeRegressor dt = new DecisionTreeRegressor() From 358e7bf652d6fedd9377593025cd661c142efeca Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Thu, 16 Jul 2015 23:02:06 -0700 Subject: [PATCH 03/58] [SPARK-9126] [MLLIB] do not assert on time taken by Thread.sleep() Measure lower and upper bounds for task time and use them for validation. This PR also implements `Stopwatch.toString`. This suite should finish in less than 1 second. jkbradley pwendell Author: Xiangrui Meng Closes #7457 from mengxr/SPARK-9126 and squashes the following commits: 4b40faa [Xiangrui Meng] simplify tests 739f5bd [Xiangrui Meng] do not assert on time taken by Thread.sleep() --- .../apache/spark/ml/util/stopwatches.scala | 4 +- .../apache/spark/ml/util/StopwatchSuite.scala | 64 ++++++++++++------- 2 files changed, 43 insertions(+), 25 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/stopwatches.scala b/mllib/src/main/scala/org/apache/spark/ml/util/stopwatches.scala index 5fdf878a3df72..8d4174124b5c4 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/stopwatches.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/stopwatches.scala @@ -67,6 +67,8 @@ private[spark] abstract class Stopwatch extends Serializable { */ def elapsed(): Long + override def toString: String = s"$name: ${elapsed()}ms" + /** * Gets the current time in milliseconds. */ @@ -145,7 +147,7 @@ private[spark] class MultiStopwatch(@transient private val sc: SparkContext) ext override def toString: String = { stopwatches.values.toArray.sortBy(_.name) - .map(c => s" ${c.name}: ${c.elapsed()}ms") + .map(c => s" $c") .mkString("{\n", ",\n", "\n}") } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/StopwatchSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/util/StopwatchSuite.scala index 8df6617fe0228..9e6bc7193c13b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/util/StopwatchSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/util/StopwatchSuite.scala @@ -17,11 +17,15 @@ package org.apache.spark.ml.util +import java.util.Random + import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.util.MLlibTestSparkContext class StopwatchSuite extends SparkFunSuite with MLlibTestSparkContext { + import StopwatchSuite._ + private def testStopwatchOnDriver(sw: Stopwatch): Unit = { assert(sw.name === "sw") assert(sw.elapsed() === 0L) @@ -29,18 +33,13 @@ class StopwatchSuite extends SparkFunSuite with MLlibTestSparkContext { intercept[AssertionError] { sw.stop() } - sw.start() - Thread.sleep(50) - val duration = sw.stop() - assert(duration >= 50 && duration < 100) // using a loose upper bound + val duration = checkStopwatch(sw) val elapsed = sw.elapsed() assert(elapsed === duration) - sw.start() - Thread.sleep(50) - val duration2 = sw.stop() - assert(duration2 >= 50 && duration2 < 100) + val duration2 = checkStopwatch(sw) val elapsed2 = sw.elapsed() assert(elapsed2 === duration + duration2) + assert(sw.toString === s"sw: ${elapsed2}ms") sw.start() assert(sw.isRunning) intercept[AssertionError] { @@ -61,14 +60,13 @@ class StopwatchSuite extends SparkFunSuite with MLlibTestSparkContext { test("DistributedStopwatch on executors") { val sw = new DistributedStopwatch(sc, "sw") val rdd = sc.parallelize(0 until 4, 4) + val acc = sc.accumulator(0L) rdd.foreach { i => - sw.start() - Thread.sleep(50) - sw.stop() + acc += checkStopwatch(sw) } assert(!sw.isRunning) val elapsed = sw.elapsed() - assert(elapsed >= 200 && elapsed < 400) // using a loose upper bound + assert(elapsed === acc.value) } test("MultiStopwatch") { @@ -81,29 +79,47 @@ class StopwatchSuite extends SparkFunSuite with MLlibTestSparkContext { sw("some") } assert(sw.toString === "{\n local: 0ms,\n spark: 0ms\n}") - sw("local").start() - sw("spark").start() - Thread.sleep(50) - sw("local").stop() - Thread.sleep(50) - sw("spark").stop() + val localDuration = checkStopwatch(sw("local")) + val sparkDuration = checkStopwatch(sw("spark")) val localElapsed = sw("local").elapsed() val sparkElapsed = sw("spark").elapsed() - assert(localElapsed >= 50 && localElapsed < 100) - assert(sparkElapsed >= 100 && sparkElapsed < 200) + assert(localElapsed === localDuration) + assert(sparkElapsed === sparkDuration) assert(sw.toString === s"{\n local: ${localElapsed}ms,\n spark: ${sparkElapsed}ms\n}") val rdd = sc.parallelize(0 until 4, 4) + val acc = sc.accumulator(0L) rdd.foreach { i => sw("local").start() - sw("spark").start() - Thread.sleep(50) - sw("spark").stop() + val duration = checkStopwatch(sw("spark")) sw("local").stop() + acc += duration } val localElapsed2 = sw("local").elapsed() assert(localElapsed2 === localElapsed) val sparkElapsed2 = sw("spark").elapsed() - assert(sparkElapsed2 >= 300 && sparkElapsed2 < 600) + assert(sparkElapsed2 === sparkElapsed + acc.value) } } + +private object StopwatchSuite extends SparkFunSuite { + + /** + * Checks the input stopwatch on a task that takes a random time (<10ms) to finish. Validates and + * returns the duration reported by the stopwatch. + */ + def checkStopwatch(sw: Stopwatch): Long = { + val ubStart = now + sw.start() + val lbStart = now + Thread.sleep(new Random().nextInt(10)) + val lb = now - lbStart + val duration = sw.stop() + val ub = now - ubStart + assert(duration >= lb && duration <= ub) + duration + } + + /** The current time in milliseconds. */ + private def now: Long = System.currentTimeMillis() +} From 111c05538d9dcee06e918dcd4481104ace712dc3 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 16 Jul 2015 23:13:06 -0700 Subject: [PATCH 04/58] Added inline comment for the canEqual PR by @cloud-fan. --- sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala index 5f0592dc1d77b..3623fefbf2604 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala @@ -364,8 +364,13 @@ trait Row extends Serializable { false } - protected def canEqual(other: Any) = + protected def canEqual(other: Any) = { + // Note that InternalRow overrides canEqual. These two canEqual's together makes sure that + // comparing the external Row and InternalRow will always yield false. + // In the future, InternalRow should not extend Row. In that case, we can remove these + // canEqual methods. other.isInstanceOf[Row] && !other.isInstanceOf[InternalRow] + } override def equals(o: Any): Boolean = { if (o == null || !canEqual(o)) return false From 3f6d28a5ca98cf7d20c2c029094350cc4f9545a0 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 17 Jul 2015 00:59:15 -0700 Subject: [PATCH 05/58] [SPARK-9102] [SQL] Improve project collapse with nondeterministic expressions Currently we will stop project collapse when the lower projection has nondeterministic expressions. However it's overkill sometimes, we should be able to optimize `df.select(Rand(10)).select('a)` to `df.select('a)` Author: Wenchen Fan Closes #7445 from cloud-fan/non-deterministic and squashes the following commits: 0deaef6 [Wenchen Fan] Improve project collapse with nondeterministic expressions --- .../sql/catalyst/optimizer/Optimizer.scala | 38 ++++++++++--------- .../optimizer/ProjectCollapsingSuite.scala | 26 +++++++++++++ .../org/apache/spark/sql/DataFrameSuite.scala | 10 ++--- 3 files changed, 51 insertions(+), 23 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 2f94b457f4cdc..d5beeec0ffac1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -206,31 +206,33 @@ object ColumnPruning extends Rule[LogicalPlan] { */ object ProjectCollapsing extends Rule[LogicalPlan] { - /** Returns true if any expression in projectList is non-deterministic. */ - private def hasNondeterministic(projectList: Seq[NamedExpression]): Boolean = { - projectList.exists(expr => expr.find(!_.deterministic).isDefined) - } - def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { - // We only collapse these two Projects if the child Project's expressions are all - // deterministic. - case Project(projectList1, Project(projectList2, child)) - if !hasNondeterministic(projectList2) => + case p @ Project(projectList1, Project(projectList2, child)) => // Create a map of Aliases to their values from the child projection. // e.g., 'SELECT ... FROM (SELECT a + b AS c, d ...)' produces Map(c -> Alias(a + b, c)). val aliasMap = AttributeMap(projectList2.collect { - case a @ Alias(e, _) => (a.toAttribute, a) + case a: Alias => (a.toAttribute, a) }) - // Substitute any attributes that are produced by the child projection, so that we safely - // eliminate it. - // e.g., 'SELECT c + 1 FROM (SELECT a + b AS C ...' produces 'SELECT a + b + 1 ...' - // TODO: Fix TransformBase to avoid the cast below. - val substitutedProjection = projectList1.map(_.transform { - case a: Attribute if aliasMap.contains(a) => aliasMap(a) - }).asInstanceOf[Seq[NamedExpression]] + // We only collapse these two Projects if their overlapped expressions are all + // deterministic. + val hasNondeterministic = projectList1.flatMap(_.collect { + case a: Attribute if aliasMap.contains(a) => aliasMap(a).child + }).exists(_.find(!_.deterministic).isDefined) - Project(substitutedProjection, child) + if (hasNondeterministic) { + p + } else { + // Substitute any attributes that are produced by the child projection, so that we safely + // eliminate it. + // e.g., 'SELECT c + 1 FROM (SELECT a + b AS C ...' produces 'SELECT a + b + 1 ...' + // TODO: Fix TransformBase to avoid the cast below. + val substitutedProjection = projectList1.map(_.transform { + case a: Attribute => aliasMap.getOrElse(a, a) + }).asInstanceOf[Seq[NamedExpression]] + + Project(substitutedProjection, child) + } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ProjectCollapsingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ProjectCollapsingSuite.scala index 151654bffbd66..1aa89991cc698 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ProjectCollapsingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ProjectCollapsingSuite.scala @@ -70,4 +70,30 @@ class ProjectCollapsingSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + + test("collapse two nondeterministic, independent projects into one") { + val query = testRelation + .select(Rand(10).as('rand)) + .select(Rand(20).as('rand2)) + + val optimized = Optimize.execute(query.analyze) + + val correctAnswer = testRelation + .select(Rand(20).as('rand2)).analyze + + comparePlans(optimized, correctAnswer) + } + + test("collapse one nondeterministic, one deterministic, independent projects into one") { + val query = testRelation + .select(Rand(10).as('rand), 'a) + .select(('a + 1).as('a_plus_1)) + + val optimized = Optimize.execute(query.analyze) + + val correctAnswer = testRelation + .select(('a + 1).as('a_plus_1)).analyze + + comparePlans(optimized, correctAnswer) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 23244fd310d0f..192cc0a6e5d7c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -745,8 +745,8 @@ class DataFrameSuite extends QueryTest with SQLTestUtils { test("SPARK-8072: Better Exception for Duplicate Columns") { // only one duplicate column present val e = intercept[org.apache.spark.sql.AnalysisException] { - val df1 = Seq((1, 2, 3), (2, 3, 4), (3, 4, 5)).toDF("column1", "column2", "column1") - .write.format("parquet").save("temp") + Seq((1, 2, 3), (2, 3, 4), (3, 4, 5)).toDF("column1", "column2", "column1") + .write.format("parquet").save("temp") } assert(e.getMessage.contains("Duplicate column(s)")) assert(e.getMessage.contains("parquet")) @@ -755,9 +755,9 @@ class DataFrameSuite extends QueryTest with SQLTestUtils { // multiple duplicate columns present val f = intercept[org.apache.spark.sql.AnalysisException] { - val df2 = Seq((1, 2, 3, 4, 5), (2, 3, 4, 5, 6), (3, 4, 5, 6, 7)) - .toDF("column1", "column2", "column3", "column1", "column3") - .write.format("json").save("temp") + Seq((1, 2, 3, 4, 5), (2, 3, 4, 5, 6), (3, 4, 5, 6, 7)) + .toDF("column1", "column2", "column3", "column1", "column3") + .write.format("json").save("temp") } assert(f.getMessage.contains("Duplicate column(s)")) assert(f.getMessage.contains("JSON")) From 5a3c1ad087cb645a9496349ca021168e479ffae9 Mon Sep 17 00:00:00 2001 From: Yu ISHIKAWA Date: Fri, 17 Jul 2015 17:00:50 +0900 Subject: [PATCH 06/58] [SPARK-9093] [SPARKR] Fix single-quotes strings in SparkR [[SPARK-9093] Fix single-quotes strings in SparkR - ASF JIRA](https://issues.apache.org/jira/browse/SPARK-9093) This is the result of lintr at the rivision:011551620faa87107a787530f074af3d9be7e695 [[SPARK-9093] The result of lintr at 011551620faa87107a787530f074af3d9be7e695](https://gist.github.com/yu-iskw/8c47acf3202796da4d01) Author: Yu ISHIKAWA Closes #7439 from yu-iskw/SPARK-9093 and squashes the following commits: 61c391e [Yu ISHIKAWA] [SPARK-9093][SparkR] Fix single-quotes strings in SparkR --- R/pkg/R/DataFrame.R | 10 +++++----- R/pkg/R/SQLContext.R | 4 ++-- R/pkg/R/serialize.R | 4 ++-- R/pkg/R/sparkR.R | 2 +- R/pkg/inst/tests/test_sparkSQL.R | 4 ++-- 5 files changed, 12 insertions(+), 12 deletions(-) diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 208813768e264..a58433df3c8c1 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -1314,7 +1314,7 @@ setMethod("except", #' write.df(df, "myfile", "parquet", "overwrite") #' } setMethod("write.df", - signature(df = "DataFrame", path = 'character'), + signature(df = "DataFrame", path = "character"), function(df, path, source = NULL, mode = "append", ...){ if (is.null(source)) { sqlContext <- get(".sparkRSQLsc", envir = .sparkREnv) @@ -1328,7 +1328,7 @@ setMethod("write.df", jmode <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "saveMode", mode) options <- varargsToEnv(...) if (!is.null(path)) { - options[['path']] <- path + options[["path"]] <- path } callJMethod(df@sdf, "save", source, jmode, options) }) @@ -1337,7 +1337,7 @@ setMethod("write.df", #' @aliases saveDF #' @export setMethod("saveDF", - signature(df = "DataFrame", path = 'character'), + signature(df = "DataFrame", path = "character"), function(df, path, source = NULL, mode = "append", ...){ write.df(df, path, source, mode, ...) }) @@ -1375,8 +1375,8 @@ setMethod("saveDF", #' saveAsTable(df, "myfile") #' } setMethod("saveAsTable", - signature(df = "DataFrame", tableName = 'character', source = 'character', - mode = 'character'), + signature(df = "DataFrame", tableName = "character", source = "character", + mode = "character"), function(df, tableName, source = NULL, mode="append", ...){ if (is.null(source)) { sqlContext <- get(".sparkRSQLsc", envir = .sparkREnv) diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R index 30978bb50d339..110117a18ccbc 100644 --- a/R/pkg/R/SQLContext.R +++ b/R/pkg/R/SQLContext.R @@ -457,7 +457,7 @@ dropTempTable <- function(sqlContext, tableName) { read.df <- function(sqlContext, path = NULL, source = NULL, schema = NULL, ...) { options <- varargsToEnv(...) if (!is.null(path)) { - options[['path']] <- path + options[["path"]] <- path } if (is.null(source)) { sqlContext <- get(".sparkRSQLsc", envir = .sparkREnv) @@ -506,7 +506,7 @@ loadDF <- function(sqlContext, path = NULL, source = NULL, schema = NULL, ...) { createExternalTable <- function(sqlContext, tableName, path = NULL, source = NULL, ...) { options <- varargsToEnv(...) if (!is.null(path)) { - options[['path']] <- path + options[["path"]] <- path } sdf <- callJMethod(sqlContext, "createExternalTable", tableName, source, options) dataFrame(sdf) diff --git a/R/pkg/R/serialize.R b/R/pkg/R/serialize.R index 78535eff0d2f6..311021e5d8473 100644 --- a/R/pkg/R/serialize.R +++ b/R/pkg/R/serialize.R @@ -140,8 +140,8 @@ writeType <- function(con, class) { jobj = "j", environment = "e", Date = "D", - POSIXlt = 't', - POSIXct = 't', + POSIXlt = "t", + POSIXct = "t", stop(paste("Unsupported type for serialization", class))) writeBin(charToRaw(type), con) } diff --git a/R/pkg/R/sparkR.R b/R/pkg/R/sparkR.R index 172335809dec2..79b79d70943cb 100644 --- a/R/pkg/R/sparkR.R +++ b/R/pkg/R/sparkR.R @@ -140,7 +140,7 @@ sparkR.init <- function( if (!file.exists(path)) { stop("JVM is not ready after 10 seconds") } - f <- file(path, open='rb') + f <- file(path, open="rb") backendPort <- readInt(f) monitorPort <- readInt(f) close(f) diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R index cdfe6481f60ea..a3039d36c9402 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -57,9 +57,9 @@ test_that("infer types", { expect_equal(infer_type(as.Date("2015-03-11")), "date") expect_equal(infer_type(as.POSIXlt("2015-03-11 12:13:04.043")), "timestamp") expect_equal(infer_type(c(1L, 2L)), - list(type = 'array', elementType = "integer", containsNull = TRUE)) + list(type = "array", elementType = "integer", containsNull = TRUE)) expect_equal(infer_type(list(1L, 2L)), - list(type = 'array', elementType = "integer", containsNull = TRUE)) + list(type = "array", elementType = "integer", containsNull = TRUE)) testStruct <- infer_type(list(a = 1L, b = "2")) expect_equal(class(testStruct), "structType") checkStructField(testStruct$fields()[[1]], "a", "IntegerType", TRUE) From ec8973d1245d4a99edeb7365d7f4b0063ac31ddf Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Fri, 17 Jul 2015 01:27:14 -0700 Subject: [PATCH 07/58] [SPARK-9022] [SQL] Generated projections for UnsafeRow Added two projections: GenerateUnsafeProjection and FromUnsafeProjection, which could be used to convert UnsafeRow from/to GenericInternalRow. They will re-use the buffer during projection, similar to MutableProjection (without all the interface MutableProjection has). cc rxin JoshRosen Author: Davies Liu Closes #7437 from davies/unsafe_proj2 and squashes the following commits: dbf538e [Davies Liu] test with all the expression (only for supported types) dc737b2 [Davies Liu] address comment e424520 [Davies Liu] fix scala style 70e231c [Davies Liu] address comments 729138d [Davies Liu] Merge branch 'master' of github.com:apache/spark into unsafe_proj2 5a26373 [Davies Liu] unsafe projections --- .../execution/UnsafeExternalRowSorter.java | 27 ++-- .../spark/sql/catalyst/expressions/Cast.scala | 8 +- .../sql/catalyst/expressions/Projection.scala | 35 +++++ .../expressions/UnsafeRowConverter.scala | 69 +++++----- .../expressions/codegen/CodeGenerator.scala | 15 ++- .../codegen/GenerateProjection.scala | 4 +- .../codegen/GenerateUnsafeProjection.scala | 125 ++++++++++++++++++ .../expressions/decimalFunctions.scala | 2 +- .../spark/sql/catalyst/expressions/math.scala | 2 +- .../spark/sql/catalyst/expressions/misc.scala | 17 ++- .../expressions/ExpressionEvalHelper.scala | 34 ++++- 11 files changed, 266 insertions(+), 72 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java index b94601cf6d818..d1d81c87bb052 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java @@ -28,13 +28,11 @@ import org.apache.spark.TaskContext; import org.apache.spark.sql.AbstractScalaRowIterator; import org.apache.spark.sql.catalyst.InternalRow; -import org.apache.spark.sql.catalyst.expressions.ObjectUnsafeColumnWriter; import org.apache.spark.sql.catalyst.expressions.UnsafeColumnWriter; +import org.apache.spark.sql.catalyst.expressions.UnsafeProjection; import org.apache.spark.sql.catalyst.expressions.UnsafeRow; -import org.apache.spark.sql.catalyst.expressions.UnsafeRowConverter; import org.apache.spark.sql.catalyst.util.ObjectPool; -import org.apache.spark.sql.types.StructField; -import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.types.*; import org.apache.spark.unsafe.PlatformDependent; import org.apache.spark.util.collection.unsafe.sort.PrefixComparator; import org.apache.spark.util.collection.unsafe.sort.RecordComparator; @@ -52,10 +50,9 @@ final class UnsafeExternalRowSorter { private long numRowsInserted = 0; private final StructType schema; - private final UnsafeRowConverter rowConverter; + private final UnsafeProjection unsafeProjection; private final PrefixComputer prefixComputer; private final UnsafeExternalSorter sorter; - private byte[] rowConversionBuffer = new byte[1024 * 8]; public static abstract class PrefixComputer { abstract long computePrefix(InternalRow row); @@ -67,7 +64,7 @@ public UnsafeExternalRowSorter( PrefixComparator prefixComparator, PrefixComputer prefixComputer) throws IOException { this.schema = schema; - this.rowConverter = new UnsafeRowConverter(schema); + this.unsafeProjection = UnsafeProjection.create(schema); this.prefixComputer = prefixComputer; final SparkEnv sparkEnv = SparkEnv.get(); final TaskContext taskContext = TaskContext.get(); @@ -94,18 +91,12 @@ void setTestSpillFrequency(int frequency) { @VisibleForTesting void insertRow(InternalRow row) throws IOException { - final int sizeRequirement = rowConverter.getSizeRequirement(row); - if (sizeRequirement > rowConversionBuffer.length) { - rowConversionBuffer = new byte[sizeRequirement]; - } - final int bytesWritten = rowConverter.writeRow( - row, rowConversionBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, sizeRequirement, null); - assert (bytesWritten == sizeRequirement); + UnsafeRow unsafeRow = unsafeProjection.apply(row); final long prefix = prefixComputer.computePrefix(row); sorter.insertRecord( - rowConversionBuffer, - PlatformDependent.BYTE_ARRAY_OFFSET, - sizeRequirement, + unsafeRow.getBaseObject(), + unsafeRow.getBaseOffset(), + unsafeRow.getSizeInBytes(), prefix ); numRowsInserted++; @@ -186,7 +177,7 @@ public Iterator sort(Iterator inputIterator) throws IO public static boolean supportsSchema(StructType schema) { // TODO: add spilling note to explain why we do this for now: for (StructField field : schema.fields()) { - if (UnsafeColumnWriter.forType(field.dataType()) instanceof ObjectUnsafeColumnWriter) { + if (!UnsafeColumnWriter.canEmbed(field.dataType())) { return false; } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 65ae87fe6d166..692b9fddbb041 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -424,20 +424,20 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w case (BinaryType, StringType) => defineCodeGen (ctx, ev, c => - s"${ctx.stringType}.fromBytes($c)") + s"UTF8String.fromBytes($c)") case (DateType, StringType) => defineCodeGen(ctx, ev, c => - s"""${ctx.stringType}.fromString( + s"""UTF8String.fromString( org.apache.spark.sql.catalyst.util.DateTimeUtils.dateToString($c))""") case (TimestampType, StringType) => defineCodeGen(ctx, ev, c => - s"""${ctx.stringType}.fromString( + s"""UTF8String.fromString( org.apache.spark.sql.catalyst.util.DateTimeUtils.timestampToString($c))""") case (_, StringType) => - defineCodeGen(ctx, ev, c => s"${ctx.stringType}.fromString(String.valueOf($c))") + defineCodeGen(ctx, ev, c => s"UTF8String.fromString(String.valueOf($c))") case (StringType, IntervalType) => defineCodeGen(ctx, ev, c => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index bf47a6c75b809..24b01ea55110e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -18,6 +18,8 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateUnsafeProjection, GenerateMutableProjection} +import org.apache.spark.sql.types.{StructType, DataType} /** * A [[Projection]] that is calculated by calling the `eval` of each of the specified expressions. @@ -73,6 +75,39 @@ case class InterpretedMutableProjection(expressions: Seq[Expression]) extends Mu } } +/** + * A projection that returns UnsafeRow. + */ +abstract class UnsafeProjection extends Projection { + override def apply(row: InternalRow): UnsafeRow +} + +object UnsafeProjection { + def create(schema: StructType): UnsafeProjection = create(schema.fields.map(_.dataType)) + + def create(fields: Seq[DataType]): UnsafeProjection = { + val exprs = fields.zipWithIndex.map(x => new BoundReference(x._2, x._1, true)) + GenerateUnsafeProjection.generate(exprs) + } +} + +/** + * A projection that could turn UnsafeRow into GenericInternalRow + */ +case class FromUnsafeProjection(fields: Seq[DataType]) extends Projection { + + private[this] val expressions = fields.zipWithIndex.map { case (dt, idx) => + new BoundReference(idx, dt, true) + } + + @transient private[this] lazy val generatedProj = + GenerateMutableProjection.generate(expressions)() + + override def apply(input: InternalRow): InternalRow = { + generatedProj(input) + } +} + /** * A mutable wrapper that makes two rows appear as a single concatenated row. Designed to * be instantiated once per thread and reused. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala index 6af5e6200e57b..885ab091fcdf5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala @@ -147,77 +147,73 @@ private object UnsafeColumnWriter { case t => ObjectUnsafeColumnWriter } } + + /** + * Returns whether the dataType can be embedded into UnsafeRow (not using ObjectPool). + */ + def canEmbed(dataType: DataType): Boolean = { + forType(dataType) != ObjectUnsafeColumnWriter + } } // ------------------------------------------------------------------------------------------------ -private object NullUnsafeColumnWriter extends NullUnsafeColumnWriter -private object BooleanUnsafeColumnWriter extends BooleanUnsafeColumnWriter -private object ByteUnsafeColumnWriter extends ByteUnsafeColumnWriter -private object ShortUnsafeColumnWriter extends ShortUnsafeColumnWriter -private object IntUnsafeColumnWriter extends IntUnsafeColumnWriter -private object LongUnsafeColumnWriter extends LongUnsafeColumnWriter -private object FloatUnsafeColumnWriter extends FloatUnsafeColumnWriter -private object DoubleUnsafeColumnWriter extends DoubleUnsafeColumnWriter -private object StringUnsafeColumnWriter extends StringUnsafeColumnWriter -private object BinaryUnsafeColumnWriter extends BinaryUnsafeColumnWriter -private object ObjectUnsafeColumnWriter extends ObjectUnsafeColumnWriter private abstract class PrimitiveUnsafeColumnWriter extends UnsafeColumnWriter { // Primitives don't write to the variable-length region: def getSize(sourceRow: InternalRow, column: Int): Int = 0 } -private class NullUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter { +private object NullUnsafeColumnWriter extends PrimitiveUnsafeColumnWriter { override def write(source: InternalRow, target: UnsafeRow, column: Int, cursor: Int): Int = { target.setNullAt(column) 0 } } -private class BooleanUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter { +private object BooleanUnsafeColumnWriter extends PrimitiveUnsafeColumnWriter { override def write(source: InternalRow, target: UnsafeRow, column: Int, cursor: Int): Int = { target.setBoolean(column, source.getBoolean(column)) 0 } } -private class ByteUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter { +private object ByteUnsafeColumnWriter extends PrimitiveUnsafeColumnWriter { override def write(source: InternalRow, target: UnsafeRow, column: Int, cursor: Int): Int = { target.setByte(column, source.getByte(column)) 0 } } -private class ShortUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter { +private object ShortUnsafeColumnWriter extends PrimitiveUnsafeColumnWriter { override def write(source: InternalRow, target: UnsafeRow, column: Int, cursor: Int): Int = { target.setShort(column, source.getShort(column)) 0 } } -private class IntUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter { +private object IntUnsafeColumnWriter extends PrimitiveUnsafeColumnWriter { override def write(source: InternalRow, target: UnsafeRow, column: Int, cursor: Int): Int = { target.setInt(column, source.getInt(column)) 0 } } -private class LongUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter { +private object LongUnsafeColumnWriter extends PrimitiveUnsafeColumnWriter { override def write(source: InternalRow, target: UnsafeRow, column: Int, cursor: Int): Int = { target.setLong(column, source.getLong(column)) 0 } } -private class FloatUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter { +private object FloatUnsafeColumnWriter extends PrimitiveUnsafeColumnWriter { override def write(source: InternalRow, target: UnsafeRow, column: Int, cursor: Int): Int = { target.setFloat(column, source.getFloat(column)) 0 } } -private class DoubleUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter { +private object DoubleUnsafeColumnWriter extends PrimitiveUnsafeColumnWriter { override def write(source: InternalRow, target: UnsafeRow, column: Int, cursor: Int): Int = { target.setDouble(column, source.getDouble(column)) 0 @@ -226,18 +222,21 @@ private class DoubleUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWr private abstract class BytesUnsafeColumnWriter extends UnsafeColumnWriter { - def getBytes(source: InternalRow, column: Int): Array[Byte] + protected[this] def isString: Boolean + protected[this] def getBytes(source: InternalRow, column: Int): Array[Byte] - def getSize(source: InternalRow, column: Int): Int = { + override def getSize(source: InternalRow, column: Int): Int = { val numBytes = getBytes(source, column).length ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes) } - protected[this] def isString: Boolean - override def write(source: InternalRow, target: UnsafeRow, column: Int, cursor: Int): Int = { - val offset = target.getBaseOffset + cursor val bytes = getBytes(source, column) + write(target, bytes, column, cursor) + } + + def write(target: UnsafeRow, bytes: Array[Byte], column: Int, cursor: Int): Int = { + val offset = target.getBaseOffset + cursor val numBytes = bytes.length if ((numBytes & 0x07) > 0) { // zero-out the padding bytes @@ -256,22 +255,32 @@ private abstract class BytesUnsafeColumnWriter extends UnsafeColumnWriter { } } -private class StringUnsafeColumnWriter private() extends BytesUnsafeColumnWriter { +private object StringUnsafeColumnWriter extends BytesUnsafeColumnWriter { protected[this] def isString: Boolean = true def getBytes(source: InternalRow, column: Int): Array[Byte] = { source.getAs[UTF8String](column).getBytes } + // TODO(davies): refactor this + // specialized for codegen + def getSize(value: UTF8String): Int = + ByteArrayMethods.roundNumberOfBytesToNearestWord(value.numBytes()) + def write(target: UnsafeRow, value: UTF8String, column: Int, cursor: Int): Int = { + write(target, value.getBytes, column, cursor) + } } -private class BinaryUnsafeColumnWriter private() extends BytesUnsafeColumnWriter { - protected[this] def isString: Boolean = false - def getBytes(source: InternalRow, column: Int): Array[Byte] = { +private object BinaryUnsafeColumnWriter extends BytesUnsafeColumnWriter { + protected[this] override def isString: Boolean = false + override def getBytes(source: InternalRow, column: Int): Array[Byte] = { source.getAs[Array[Byte]](column) } + // specialized for codegen + def getSize(value: Array[Byte]): Int = + ByteArrayMethods.roundNumberOfBytesToNearestWord(value.length) } -private class ObjectUnsafeColumnWriter private() extends UnsafeColumnWriter { - def getSize(sourceRow: InternalRow, column: Int): Int = 0 +private object ObjectUnsafeColumnWriter extends UnsafeColumnWriter { + override def getSize(sourceRow: InternalRow, column: Int): Int = 0 override def write(source: InternalRow, target: UnsafeRow, column: Int, cursor: Int): Int = { val obj = source.get(column) val idx = target.getPool.put(obj) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 328d635de8743..45dc146488e12 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -24,6 +24,7 @@ import com.google.common.cache.{CacheBuilder, CacheLoader} import org.codehaus.janino.ClassBodyEvaluator import org.apache.spark.Logging +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -68,9 +69,6 @@ class CodeGenContext { mutableStates += ((javaType, variableName, initialValue)) } - val stringType: String = classOf[UTF8String].getName - val decimalType: String = classOf[Decimal].getName - final val JAVA_BOOLEAN = "boolean" final val JAVA_BYTE = "byte" final val JAVA_SHORT = "short" @@ -136,9 +134,9 @@ class CodeGenContext { case LongType | TimestampType => JAVA_LONG case FloatType => JAVA_FLOAT case DoubleType => JAVA_DOUBLE - case dt: DecimalType => decimalType + case dt: DecimalType => "Decimal" case BinaryType => "byte[]" - case StringType => stringType + case StringType => "UTF8String" case _: StructType => "InternalRow" case _: ArrayType => s"scala.collection.Seq" case _: MapType => s"scala.collection.Map" @@ -262,7 +260,12 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin private[this] def doCompile(code: String): GeneratedClass = { val evaluator = new ClassBodyEvaluator() evaluator.setParentClassLoader(getClass.getClassLoader) - evaluator.setDefaultImports(Array("org.apache.spark.sql.catalyst.InternalRow")) + evaluator.setDefaultImports(Array( + classOf[InternalRow].getName, + classOf[UnsafeRow].getName, + classOf[UTF8String].getName, + classOf[Decimal].getName + )) evaluator.setExtendedClass(classOf[GeneratedClass]) try { evaluator.cook(code) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala index 3e5ca308dc31d..8f9fcbf810554 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.types._ /** * Java can not access Projection (in package object) */ -abstract class BaseProject extends Projection {} +abstract class BaseProjection extends Projection {} /** * Generates bytecode that produces a new [[InternalRow]] object based on a fixed set of input @@ -160,7 +160,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { return new SpecificProjection(expr); } - class SpecificProjection extends ${classOf[BaseProject].getName} { + class SpecificProjection extends ${classOf[BaseProjection].getName} { private $exprType[] expressions = null; $mutableStates diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala new file mode 100644 index 0000000000000..a81d545a8ec63 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.codegen + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.types.{NullType, BinaryType, StringType} + + +/** + * Generates a [[Projection]] that returns an [[UnsafeRow]]. + * + * It generates the code for all the expressions, compute the total length for all the columns + * (can be accessed via variables), and then copy the data into a scratch buffer space in the + * form of UnsafeRow (the scratch buffer will grow as needed). + * + * Note: The returned UnsafeRow will be pointed to a scratch buffer inside the projection. + */ +object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafeProjection] { + + protected def canonicalize(in: Seq[Expression]): Seq[Expression] = + in.map(ExpressionCanonicalizer.execute) + + protected def bind(in: Seq[Expression], inputSchema: Seq[Attribute]): Seq[Expression] = + in.map(BindReferences.bindReference(_, inputSchema)) + + protected def create(expressions: Seq[Expression]): UnsafeProjection = { + val ctx = newCodeGenContext() + val exprs = expressions.map(_.gen(ctx)) + val allExprs = exprs.map(_.code).mkString("\n") + val fixedSize = 8 * exprs.length + UnsafeRow.calculateBitSetWidthInBytes(exprs.length) + val stringWriter = "org.apache.spark.sql.catalyst.expressions.StringUnsafeColumnWriter" + val binaryWriter = "org.apache.spark.sql.catalyst.expressions.BinaryUnsafeColumnWriter" + val additionalSize = expressions.zipWithIndex.map { case (e, i) => + e.dataType match { + case StringType => + s" + (${exprs(i).isNull} ? 0 : $stringWriter.getSize(${exprs(i).primitive}))" + case BinaryType => + s" + (${exprs(i).isNull} ? 0 : $binaryWriter.getSize(${exprs(i).primitive}))" + case _ => "" + } + }.mkString("") + + val writers = expressions.zipWithIndex.map { case (e, i) => + val update = e.dataType match { + case dt if ctx.isPrimitiveType(dt) => + s"${ctx.setColumn("target", dt, i, exprs(i).primitive)}" + case StringType => + s"cursor += $stringWriter.write(target, ${exprs(i).primitive}, $i, cursor)" + case BinaryType => + s"cursor += $binaryWriter.write(target, ${exprs(i).primitive}, $i, cursor)" + case NullType => "" + case _ => + throw new UnsupportedOperationException(s"Not supported DataType: ${e.dataType}") + } + s"""if (${exprs(i).isNull}) { + target.setNullAt($i); + } else { + $update; + }""" + }.mkString("\n ") + + val mutableStates = ctx.mutableStates.map { case (javaType, variableName, initialValue) => + s"private $javaType $variableName = $initialValue;" + }.mkString("\n ") + + val code = s""" + private $exprType[] expressions; + + public Object generate($exprType[] expr) { + this.expressions = expr; + return new SpecificProjection(); + } + + class SpecificProjection extends ${classOf[UnsafeProjection].getName} { + + private UnsafeRow target = new UnsafeRow(); + private byte[] buffer = new byte[64]; + + $mutableStates + + public SpecificProjection() {} + + // Scala.Function1 need this + public Object apply(Object row) { + return apply((InternalRow) row); + } + + public UnsafeRow apply(InternalRow i) { + ${allExprs} + + // additionalSize had '+' in the beginning + int numBytes = $fixedSize $additionalSize; + if (numBytes > buffer.length) { + buffer = new byte[numBytes]; + } + target.pointTo(buffer, org.apache.spark.unsafe.PlatformDependent.BYTE_ARRAY_OFFSET, + ${expressions.size}, numBytes, null); + int cursor = $fixedSize; + $writers + return target; + } + } + """ + + logDebug(s"code for ${expressions.mkString(",")}:\n$code") + + val c = compile(code) + c.generate(ctx.references.toArray).asInstanceOf[UnsafeProjection] + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala index 2fa74b4ffc5da..b9d4736a65e26 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala @@ -54,7 +54,7 @@ case class MakeDecimal(child: Expression, precision: Int, scale: Int) extends Un override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { nullSafeCodeGen(ctx, ev, eval => { s""" - ${ev.primitive} = (new ${ctx.decimalType}()).setOrNull($eval, $precision, $scale); + ${ev.primitive} = (new Decimal()).setOrNull($eval, $precision, $scale); ${ev.isNull} = ${ev.primitive} == null; """ }) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala index a7ad452ef4943..84b289c4d1a68 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala @@ -263,7 +263,7 @@ case class Bin(child: Expression) override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { defineCodeGen(ctx, ev, (c) => - s"${ctx.stringType}.fromString(java.lang.Long.toBinaryString($c))") + s"UTF8String.fromString(java.lang.Long.toBinaryString($c))") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index a269ec4a1e6dc..8d8d66ddeb341 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -17,12 +17,11 @@ package org.apache.spark.sql.catalyst.expressions -import java.security.MessageDigest -import java.security.NoSuchAlgorithmException +import java.security.{MessageDigest, NoSuchAlgorithmException} import java.util.zip.CRC32 import org.apache.commons.codec.digest.DigestUtils -import org.apache.spark.sql.catalyst.analysis.TypeCheckResult + import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -42,7 +41,7 @@ case class Md5(child: Expression) extends UnaryExpression with ImplicitCastInput override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { defineCodeGen(ctx, ev, c => - s"${ctx.stringType}.fromString(org.apache.commons.codec.digest.DigestUtils.md5Hex($c))") + s"UTF8String.fromString(org.apache.commons.codec.digest.DigestUtils.md5Hex($c))") } } @@ -93,19 +92,19 @@ case class Sha2(left: Expression, right: Expression) try { java.security.MessageDigest md = java.security.MessageDigest.getInstance("SHA-224"); md.update($eval1); - ${ev.primitive} = ${ctx.stringType}.fromBytes(md.digest()); + ${ev.primitive} = UTF8String.fromBytes(md.digest()); } catch (java.security.NoSuchAlgorithmException e) { ${ev.isNull} = true; } } else if ($eval2 == 256 || $eval2 == 0) { ${ev.primitive} = - ${ctx.stringType}.fromString($digestUtils.sha256Hex($eval1)); + UTF8String.fromString($digestUtils.sha256Hex($eval1)); } else if ($eval2 == 384) { ${ev.primitive} = - ${ctx.stringType}.fromString($digestUtils.sha384Hex($eval1)); + UTF8String.fromString($digestUtils.sha384Hex($eval1)); } else if ($eval2 == 512) { ${ev.primitive} = - ${ctx.stringType}.fromString($digestUtils.sha512Hex($eval1)); + UTF8String.fromString($digestUtils.sha512Hex($eval1)); } else { ${ev.isNull} = true; } @@ -129,7 +128,7 @@ case class Sha1(child: Expression) extends UnaryExpression with ImplicitCastInpu override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { defineCodeGen(ctx, ev, c => - s"${ctx.stringType}.fromString(org.apache.commons.codec.digest.DigestUtils.shaHex($c))" + s"UTF8String.fromString(org.apache.commons.codec.digest.DigestUtils.shaHex($c))" ) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index 43392df4bec2e..c43486b3ddcf5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -23,7 +23,7 @@ import org.scalatest.Matchers._ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.CatalystTypeConverters -import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateProjection, GenerateMutableProjection} +import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateUnsafeProjection, GenerateProjection, GenerateMutableProjection} import org.apache.spark.sql.catalyst.optimizer.DefaultOptimizer import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Project} @@ -43,6 +43,9 @@ trait ExpressionEvalHelper { checkEvaluationWithoutCodegen(expression, catalystValue, inputRow) checkEvaluationWithGeneratedMutableProjection(expression, catalystValue, inputRow) checkEvaluationWithGeneratedProjection(expression, catalystValue, inputRow) + if (UnsafeColumnWriter.canEmbed(expression.dataType)) { + checkEvalutionWithUnsafeProjection(expression, catalystValue, inputRow) + } checkEvaluationWithOptimization(expression, catalystValue, inputRow) } @@ -142,6 +145,35 @@ trait ExpressionEvalHelper { } } + protected def checkEvalutionWithUnsafeProjection( + expression: Expression, + expected: Any, + inputRow: InternalRow = EmptyRow): Unit = { + val ctx = GenerateUnsafeProjection.newCodeGenContext() + lazy val evaluated = expression.gen(ctx) + + val plan = try { + GenerateUnsafeProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil) + } catch { + case e: Throwable => + fail( + s""" + |Code generation of $expression failed: + |${evaluated.code} + |$e + """.stripMargin) + } + + val unsafeRow = plan(inputRow) + // UnsafeRow cannot be compared with GenericInternalRow directly + val actual = FromUnsafeProjection(expression.dataType :: Nil)(unsafeRow) + val expectedRow = InternalRow(expected) + if (actual != expectedRow) { + val input = if (inputRow == EmptyRow) "" else s", input: $inputRow" + fail(s"Incorrect Evaluation: $expression, actual: $actual, expected: $expected$input") + } + } + protected def checkEvaluationWithOptimization( expression: Expression, expected: Any, From c043a3e9df55721f21332f7c44ff351832d20324 Mon Sep 17 00:00:00 2001 From: Hari Shreedharan Date: Fri, 17 Jul 2015 09:38:08 -0500 Subject: [PATCH 08/58] [SPARK-8851] [YARN] In Client mode, make sure the client logs in and updates tokens In client side, the flow is SparkSubmit -> SparkContext -> yarn/Client. Since the yarn client only gets a cloned config and the staging dir is set here, it is not really possible to do re-logins in the SparkContext. So, do the initial logins in Spark Submit and do re-logins as we do now in the AM, but the Client behaves like an executor in this specific context and reads the credentials file to update the tokens. This way, even if the streaming context is started up from checkpoint - it is fine since we have logged in from SparkSubmit itself itself. Author: Hari Shreedharan Closes #7394 from harishreedharan/yarn-client-login and squashes the following commits: 9a2166f [Hari Shreedharan] make it possible to use command line args and config parameters together. de08f57 [Hari Shreedharan] Fix import order. 5c4fa63 [Hari Shreedharan] Add a comment explaining what is being done in YarnClientSchedulerBackend. c872caa [Hari Shreedharan] Fix typo in log message. 2c80540 [Hari Shreedharan] Move token renewal to YarnClientSchedulerBackend. 0c48ac2 [Hari Shreedharan] Remove direct use of ExecutorDelegationTokenUpdater in Client. 26f8bfa [Hari Shreedharan] [SPARK-8851][YARN] In Client mode, make sure the client logs in and updates tokens. 58b1969 [Hari Shreedharan] Simple attempt 1. --- .../apache/spark/deploy/SparkHadoopUtil.scala | 29 ++++++++++------- .../org/apache/spark/deploy/SparkSubmit.scala | 10 ++++-- .../org/apache/spark/deploy/yarn/Client.scala | 32 ++++++++++++------- .../cluster/YarnClientSchedulerBackend.scala | 11 +++++-- 4 files changed, 56 insertions(+), 26 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala index 9f94118829ff1..6b14d407a6380 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala @@ -25,6 +25,7 @@ import java.util.{Arrays, Comparator} import scala.collection.JavaConversions._ import scala.concurrent.duration._ import scala.language.postfixOps +import scala.util.control.NonFatal import com.google.common.primitives.Longs import org.apache.hadoop.conf.Configuration @@ -248,19 +249,25 @@ class SparkHadoopUtil extends Logging { dir: Path, prefix: String, exclusionSuffix: String): Array[FileStatus] = { - val fileStatuses = remoteFs.listStatus(dir, - new PathFilter { - override def accept(path: Path): Boolean = { - val name = path.getName - name.startsWith(prefix) && !name.endsWith(exclusionSuffix) + try { + val fileStatuses = remoteFs.listStatus(dir, + new PathFilter { + override def accept(path: Path): Boolean = { + val name = path.getName + name.startsWith(prefix) && !name.endsWith(exclusionSuffix) + } + }) + Arrays.sort(fileStatuses, new Comparator[FileStatus] { + override def compare(o1: FileStatus, o2: FileStatus): Int = { + Longs.compare(o1.getModificationTime, o2.getModificationTime) } }) - Arrays.sort(fileStatuses, new Comparator[FileStatus] { - override def compare(o1: FileStatus, o2: FileStatus): Int = { - Longs.compare(o1.getModificationTime, o2.getModificationTime) - } - }) - fileStatuses + fileStatuses + } catch { + case NonFatal(e) => + logWarning("Error while attempting to list files from application staging dir", e) + Array.empty + } } /** diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 036cb6e054791..0b39ee8fe3ba0 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -508,8 +508,14 @@ object SparkSubmit { } // Let YARN know it's a pyspark app, so it distributes needed libraries. - if (clusterManager == YARN && args.isPython) { - sysProps.put("spark.yarn.isPython", "true") + if (clusterManager == YARN) { + if (args.isPython) { + sysProps.put("spark.yarn.isPython", "true") + } + if (args.principal != null) { + require(args.keytab != null, "Keytab must be specified when the keytab is specified") + UserGroupInformation.loginUserFromKeytab(args.principal, args.keytab) + } } // In yarn-cluster mode, use yarn.Client as a wrapper around the user class diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index b74ea9a10afb2..bc28ce5eeae72 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -80,10 +80,12 @@ private[spark] class Client( private val isClusterMode = args.isClusterMode private var loginFromKeytab = false + private var principal: String = null + private var keytab: String = null + private val fireAndForget = isClusterMode && !sparkConf.getBoolean("spark.yarn.submit.waitAppCompletion", true) - def stop(): Unit = yarnClient.stop() /** @@ -339,7 +341,7 @@ private[spark] class Client( if (loginFromKeytab) { logInfo("To enable the AM to login from keytab, credentials are being copied over to the AM" + " via the YARN Secure Distributed Cache.") - val (_, localizedPath) = distribute(args.keytab, + val (_, localizedPath) = distribute(keytab, destName = Some(sparkConf.get("spark.yarn.keytab")), appMasterOnly = true) require(localizedPath != null, "Keytab file already distributed.") @@ -785,19 +787,27 @@ private[spark] class Client( } def setupCredentials(): Unit = { - if (args.principal != null) { - require(args.keytab != null, "Keytab must be specified when principal is specified.") + loginFromKeytab = args.principal != null || sparkConf.contains("spark.yarn.principal") + if (loginFromKeytab) { + principal = + if (args.principal != null) args.principal else sparkConf.get("spark.yarn.principal") + keytab = { + if (args.keytab != null) { + args.keytab + } else { + sparkConf.getOption("spark.yarn.keytab").orNull + } + } + + require(keytab != null, "Keytab must be specified when principal is specified.") logInfo("Attempting to login to the Kerberos" + - s" using principal: ${args.principal} and keytab: ${args.keytab}") - val f = new File(args.keytab) + s" using principal: $principal and keytab: $keytab") + val f = new File(keytab) // Generate a file name that can be used for the keytab file, that does not conflict // with any user file. val keytabFileName = f.getName + "-" + UUID.randomUUID().toString - UserGroupInformation.loginUserFromKeytab(args.principal, args.keytab) - loginFromKeytab = true sparkConf.set("spark.yarn.keytab", keytabFileName) - sparkConf.set("spark.yarn.principal", args.principal) - logInfo("Successfully logged into the KDC.") + sparkConf.set("spark.yarn.principal", principal) } credentials = UserGroupInformation.getCurrentUser.getCredentials } @@ -1162,7 +1172,7 @@ object Client extends Logging { * * If not a "local:" file and no alternate name, the environment is not modified. * - * @parma conf Spark configuration. + * @param conf Spark configuration. * @param uri URI to add to classpath (optional). * @param fileName Alternate name for the file (optional). * @param env Map holding the environment variables. diff --git a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala index 3a0b9443d2d7b..d97fa2e2151bc 100644 --- a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala +++ b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala @@ -20,10 +20,9 @@ package org.apache.spark.scheduler.cluster import scala.collection.mutable.ArrayBuffer import org.apache.hadoop.yarn.api.records.{ApplicationId, YarnApplicationState} -import org.apache.hadoop.yarn.exceptions.ApplicationNotFoundException import org.apache.spark.{SparkException, Logging, SparkContext} -import org.apache.spark.deploy.yarn.{Client, ClientArguments} +import org.apache.spark.deploy.yarn.{Client, ClientArguments, YarnSparkHadoopUtil} import org.apache.spark.scheduler.TaskSchedulerImpl private[spark] class YarnClientSchedulerBackend( @@ -62,6 +61,13 @@ private[spark] class YarnClientSchedulerBackend( super.start() waitForApplication() + + // SPARK-8851: In yarn-client mode, the AM still does the credentials refresh. The driver + // reads the credentials from HDFS, just like the executors and updates its own credentials + // cache. + if (conf.contains("spark.yarn.credentials.file")) { + YarnSparkHadoopUtil.get.startExecutorDelegationTokenRenewer(conf) + } monitorThread = asyncMonitorApplication() monitorThread.start() } @@ -158,6 +164,7 @@ private[spark] class YarnClientSchedulerBackend( } super.stop() client.stop() + YarnSparkHadoopUtil.get.stopExecutorDelegationTokenRenewer() logInfo("Stopped") } From 441e072a227378cae31afc45a608318b58ce2ac4 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Fri, 17 Jul 2015 09:00:41 -0700 Subject: [PATCH 09/58] [MINOR] [ML] fix wrong annotation of RFormula.formula fix wrong annotation of RFormula.formula Author: Yanbo Liang Closes #7470 from yanboliang/RFormula and squashes the following commits: 61f1919 [Yanbo Liang] fix wrong annotation --- mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala index d9a36bda386b3..56169f2a01fc9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala @@ -42,7 +42,7 @@ class RFormula(override val uid: String) /** * R formula parameter. The formula is provided in string form. - * @group setParam + * @group param */ val formula: Param[String] = new Param(this, "formula", "R model formula") From 59d24c226a441db5f08c58ec407ba5873bd3b954 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 17 Jul 2015 09:31:13 -0700 Subject: [PATCH 10/58] [SPARK-9130][SQL] throw exception when check equality between external and internal row instead of return false, throw exception when check equality between external and internal row is better. Author: Wenchen Fan Closes #7460 from cloud-fan/row-compare and squashes the following commits: 8a20911 [Wenchen Fan] improve equals 402daa8 [Wenchen Fan] throw exception when check equality between external and internal row --- .../main/scala/org/apache/spark/sql/Row.scala | 27 ++++++++++++++----- .../spark/sql/catalyst/InternalRow.scala | 7 ++++- .../scala/org/apache/spark/sql/RowTest.scala | 26 ++++++++++++++++++ 3 files changed, 53 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala index 3623fefbf2604..2cb64d00935de 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala @@ -364,18 +364,33 @@ trait Row extends Serializable { false } - protected def canEqual(other: Any) = { - // Note that InternalRow overrides canEqual. These two canEqual's together makes sure that - // comparing the external Row and InternalRow will always yield false. + /** + * Returns true if we can check equality for these 2 rows. + * Equality check between external row and internal row is not allowed. + * Here we do this check to prevent call `equals` on external row with internal row. + */ + protected def canEqual(other: Row) = { + // Note that `Row` is not only the interface of external row but also the parent + // of `InternalRow`, so we have to ensure `other` is not a internal row here to prevent + // call `equals` on external row with internal row. + // `InternalRow` overrides canEqual, and these two canEquals together makes sure that + // equality check between external Row and InternalRow will always fail. // In the future, InternalRow should not extend Row. In that case, we can remove these // canEqual methods. - other.isInstanceOf[Row] && !other.isInstanceOf[InternalRow] + !other.isInstanceOf[InternalRow] } override def equals(o: Any): Boolean = { - if (o == null || !canEqual(o)) return false - + if (!o.isInstanceOf[Row]) return false val other = o.asInstanceOf[Row] + + if (!canEqual(other)) { + throw new UnsupportedOperationException( + "cannot check equality between external and internal rows") + } + + if (other eq null) return false + if (length != other.length) { return false } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala index e2fafb88ee43e..024973a6b9fcd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala @@ -54,7 +54,12 @@ abstract class InternalRow extends Row { // A default implementation to change the return type override def copy(): InternalRow = this - protected override def canEqual(other: Any) = other.isInstanceOf[InternalRow] + /** + * Returns true if we can check equality for these 2 rows. + * Equality check between external row and internal row is not allowed. + * Here we do this check to prevent call `equals` on internal row with external row. + */ + protected override def canEqual(other: Row) = other.isInstanceOf[InternalRow] // Custom hashCode function that matches the efficient code generated version. override def hashCode: Int = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/RowTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/RowTest.scala index bbb9739e9cc76..878a1bb9b7e6d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/RowTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/RowTest.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{GenericRow, GenericRowWithSchema} import org.apache.spark.sql.types._ import org.scalatest.{Matchers, FunSpec} @@ -68,4 +69,29 @@ class RowTest extends FunSpec with Matchers { sampleRow.getValuesMap(List("col1", "col2")) shouldBe expected } } + + describe("row equals") { + val externalRow = Row(1, 2) + val externalRow2 = Row(1, 2) + val internalRow = InternalRow(1, 2) + val internalRow2 = InternalRow(1, 2) + + it("equality check for external rows") { + externalRow shouldEqual externalRow2 + } + + it("equality check for internal rows") { + internalRow shouldEqual internalRow2 + } + + it("throws an exception when check equality between external and internal rows") { + def assertError(f: => Unit): Unit = { + val e = intercept[UnsupportedOperationException](f) + e.getMessage.contains("cannot check equality between external and internal rows") + } + + assertError(internalRow.equals(externalRow)) + assertError(externalRow.equals(internalRow)) + } + } } From 305e77cd83f3dbe680a920d5329c2e8c58452d5b Mon Sep 17 00:00:00 2001 From: "zhichao.li" Date: Fri, 17 Jul 2015 09:32:27 -0700 Subject: [PATCH 11/58] [SPARK-8209[SQL]Add function conv cc chenghao-intel adrian-wang Author: zhichao.li Closes #6872 from zhichao-li/conv and squashes the following commits: 6ef3b37 [zhichao.li] add unittest and comments 78d9836 [zhichao.li] polish dataframe api and add unittest e2bace3 [zhichao.li] update to use ImplicitCastInputTypes cbcad3f [zhichao.li] add function conv --- .../catalyst/analysis/FunctionRegistry.scala | 1 + .../spark/sql/catalyst/expressions/math.scala | 191 ++++++++++++++++++ .../expressions/MathFunctionsSuite.scala | 21 +- .../org/apache/spark/sql/functions.scala | 18 ++ .../spark/sql/MathExpressionsSuite.scala | 13 ++ 5 files changed, 242 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index e0beafe710079..a45181712dbdf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -99,6 +99,7 @@ object FunctionRegistry { expression[Ceil]("ceil"), expression[Ceil]("ceiling"), expression[Cos]("cos"), + expression[Conv]("conv"), expression[EulerNumber]("e"), expression[Exp]("exp"), expression[Expm1]("expm1"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala index 84b289c4d1a68..7a543ff36afd1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import java.{lang => jl} +import java.util.Arrays import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckSuccess, TypeCheckFailure} @@ -139,6 +140,196 @@ case class Cos(child: Expression) extends UnaryMathExpression(math.cos, "COS") case class Cosh(child: Expression) extends UnaryMathExpression(math.cosh, "COSH") +/** + * Convert a num from one base to another + * @param numExpr the number to be converted + * @param fromBaseExpr from which base + * @param toBaseExpr to which base + */ +case class Conv(numExpr: Expression, fromBaseExpr: Expression, toBaseExpr: Expression) + extends Expression with ImplicitCastInputTypes{ + + override def foldable: Boolean = numExpr.foldable && fromBaseExpr.foldable && toBaseExpr.foldable + + override def nullable: Boolean = numExpr.nullable || fromBaseExpr.nullable || toBaseExpr.nullable + + override def children: Seq[Expression] = Seq(numExpr, fromBaseExpr, toBaseExpr) + + override def inputTypes: Seq[AbstractDataType] = Seq(StringType, IntegerType, IntegerType) + + /** Returns the result of evaluating this expression on a given input Row */ + override def eval(input: InternalRow): Any = { + val num = numExpr.eval(input) + val fromBase = fromBaseExpr.eval(input) + val toBase = toBaseExpr.eval(input) + if (num == null || fromBase == null || toBase == null) { + null + } else { + conv(num.asInstanceOf[UTF8String].getBytes, + fromBase.asInstanceOf[Int], toBase.asInstanceOf[Int]) + } + } + + /** + * Returns the [[DataType]] of the result of evaluating this expression. It is + * invalid to query the dataType of an unresolved expression (i.e., when `resolved` == false). + */ + override def dataType: DataType = StringType + + private val value = new Array[Byte](64) + + /** + * Divide x by m as if x is an unsigned 64-bit integer. Examples: + * unsignedLongDiv(-1, 2) == Long.MAX_VALUE unsignedLongDiv(6, 3) == 2 + * unsignedLongDiv(0, 5) == 0 + * + * @param x is treated as unsigned + * @param m is treated as signed + */ + private def unsignedLongDiv(x: Long, m: Int): Long = { + if (x >= 0) { + x / m + } else { + // Let uval be the value of the unsigned long with the same bits as x + // Two's complement => x = uval - 2*MAX - 2 + // => uval = x + 2*MAX + 2 + // Now, use the fact: (a+b)/c = a/c + b/c + (a%c+b%c)/c + (x / m + 2 * (Long.MaxValue / m) + 2 / m + (x % m + 2 * (Long.MaxValue % m) + 2 % m) / m) + } + } + + /** + * Decode v into value[]. + * + * @param v is treated as an unsigned 64-bit integer + * @param radix must be between MIN_RADIX and MAX_RADIX + */ + private def decode(v: Long, radix: Int): Unit = { + var tmpV = v + Arrays.fill(value, 0.asInstanceOf[Byte]) + var i = value.length - 1 + while (tmpV != 0) { + val q = unsignedLongDiv(tmpV, radix) + value(i) = (tmpV - q * radix).asInstanceOf[Byte] + tmpV = q + i -= 1 + } + } + + /** + * Convert value[] into a long. On overflow, return -1 (as mySQL does). If a + * negative digit is found, ignore the suffix starting there. + * + * @param radix must be between MIN_RADIX and MAX_RADIX + * @param fromPos is the first element that should be conisdered + * @return the result should be treated as an unsigned 64-bit integer. + */ + private def encode(radix: Int, fromPos: Int): Long = { + var v: Long = 0L + val bound = unsignedLongDiv(-1 - radix, radix) // Possible overflow once + // val + // exceeds this value + var i = fromPos + while (i < value.length && value(i) >= 0) { + if (v >= bound) { + // Check for overflow + if (unsignedLongDiv(-1 - value(i), radix) < v) { + return -1 + } + } + v = v * radix + value(i) + i += 1 + } + return v + } + + /** + * Convert the bytes in value[] to the corresponding chars. + * + * @param radix must be between MIN_RADIX and MAX_RADIX + * @param fromPos is the first nonzero element + */ + private def byte2char(radix: Int, fromPos: Int): Unit = { + var i = fromPos + while (i < value.length) { + value(i) = Character.toUpperCase(Character.forDigit(value(i), radix)).asInstanceOf[Byte] + i += 1 + } + } + + /** + * Convert the chars in value[] to the corresponding integers. Convert invalid + * characters to -1. + * + * @param radix must be between MIN_RADIX and MAX_RADIX + * @param fromPos is the first nonzero element + */ + private def char2byte(radix: Int, fromPos: Int): Unit = { + var i = fromPos + while ( i < value.length) { + value(i) = Character.digit(value(i), radix).asInstanceOf[Byte] + i += 1 + } + } + + /** + * Convert numbers between different number bases. If toBase>0 the result is + * unsigned, otherwise it is signed. + * NB: This logic is borrowed from org.apache.hadoop.hive.ql.ud.UDFConv + */ + private def conv(n: Array[Byte] , fromBase: Int, toBase: Int ): UTF8String = { + if (n == null || fromBase == null || toBase == null || n.isEmpty) { + return null + } + + if (fromBase < Character.MIN_RADIX || fromBase > Character.MAX_RADIX + || Math.abs(toBase) < Character.MIN_RADIX + || Math.abs(toBase) > Character.MAX_RADIX) { + return null + } + + var (negative, first) = if (n(0) == '-') (true, 1) else (false, 0) + + // Copy the digits in the right side of the array + var i = 1 + while (i <= n.length - first) { + value(value.length - i) = n(n.length - i) + i += 1 + } + char2byte(fromBase, value.length - n.length + first) + + // Do the conversion by going through a 64 bit integer + var v = encode(fromBase, value.length - n.length + first) + if (negative && toBase > 0) { + if (v < 0) { + v = -1 + } else { + v = -v + } + } + if (toBase < 0 && v < 0) { + v = -v + negative = true + } + decode(v, Math.abs(toBase)) + + // Find the first non-zero digit or the last digits if all are zero. + val firstNonZeroPos = { + val firstNonZero = value.indexWhere( _ != 0) + if (firstNonZero != -1) firstNonZero else value.length - 1 + } + + byte2char(Math.abs(toBase), firstNonZeroPos) + + var resultStartPos = firstNonZeroPos + if (negative && toBase < 0) { + resultStartPos = firstNonZeroPos - 1 + value(resultStartPos) = '-' + } + UTF8String.fromBytes( Arrays.copyOfRange(value, resultStartPos, value.length)) + } +} + case class Exp(child: Expression) extends UnaryMathExpression(math.exp, "EXP") case class Expm1(child: Expression) extends UnaryMathExpression(math.expm1, "EXPM1") diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala index 52a874a9d89ef..ca35c7ef8ae5d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala @@ -17,14 +17,13 @@ package org.apache.spark.sql.catalyst.expressions -import scala.math.BigDecimal.RoundingMode - import com.google.common.math.LongMath import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.types._ + class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { /** @@ -95,6 +94,24 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(c(Literal(1.0), Literal.create(null, DoubleType)), null, create_row(null)) } + test("conv") { + checkEvaluation(Conv(Literal("3"), Literal(10), Literal(2)), "11") + checkEvaluation(Conv(Literal("-15"), Literal(10), Literal(-16)), "-F") + checkEvaluation(Conv(Literal("-15"), Literal(10), Literal(16)), "FFFFFFFFFFFFFFF1") + checkEvaluation(Conv(Literal("big"), Literal(36), Literal(16)), "3A48") + checkEvaluation(Conv(Literal(null), Literal(36), Literal(16)), null) + checkEvaluation(Conv(Literal("3"), Literal(null), Literal(16)), null) + checkEvaluation( + Conv(Literal("1234"), Literal(10), Literal(37)), null) + checkEvaluation( + Conv(Literal(""), Literal(10), Literal(16)), null) + checkEvaluation( + Conv(Literal("9223372036854775807"), Literal(36), Literal(16)), "FFFFFFFFFFFFFFFF") + // If there is an invalid digit in the number, the longest valid prefix should be converted. + checkEvaluation( + Conv(Literal("11abc"), Literal(10), Literal(16)), "B") + } + test("e") { testLeaf(EulerNumber, math.E) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index d6da284a4c788..fe511c296cfd2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -68,6 +68,24 @@ object functions { */ def column(colName: String): Column = Column(colName) + /** + * Convert a number from one base to another for the specified expressions + * + * @group math_funcs + * @since 1.5.0 + */ + def conv(num: Column, fromBase: Int, toBase: Int): Column = + Conv(num.expr, lit(fromBase).expr, lit(toBase).expr) + + /** + * Convert a number from one base to another for the specified expressions + * + * @group math_funcs + * @since 1.5.0 + */ + def conv(numColName: String, fromBase: Int, toBase: Int): Column = + conv(Column(numColName), fromBase, toBase) + /** * Creates a [[Column]] of literal value. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala index 087126bb2e513..8eb3fec756b4c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala @@ -178,6 +178,19 @@ class MathExpressionsSuite extends QueryTest { Row(0.0, 1.0, 2.0)) } + test("conv") { + val df = Seq(("333", 10, 2)).toDF("num", "fromBase", "toBase") + checkAnswer(df.select(conv('num, 10, 16)), Row("14D")) + checkAnswer(df.select(conv("num", 10, 16)), Row("14D")) + checkAnswer(df.select(conv(lit(100), 2, 16)), Row("4")) + checkAnswer(df.select(conv(lit(3122234455L), 10, 16)), Row("BA198457")) + checkAnswer(df.selectExpr("conv(num, fromBase, toBase)"), Row("101001101")) + checkAnswer(df.selectExpr("""conv("100", 2, 10)"""), Row("4")) + checkAnswer(df.selectExpr("""conv("-10", 16, -10)"""), Row("-16")) + checkAnswer( + df.selectExpr("""conv("9223372036854775807", 36, -16)"""), Row("-1")) // for overflow + } + test("floor") { testOneToOneMathFunction(floor, math.floor) } From eba6a1af4c8ffb21934a59a61a419d625f37cceb Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 17 Jul 2015 09:38:08 -0700 Subject: [PATCH 12/58] [SPARK-8945][SQL] Add add and subtract expressions for IntervalType JIRA: https://issues.apache.org/jira/browse/SPARK-8945 Add add and subtract expressions for IntervalType. Author: Liang-Chi Hsieh This patch had conflicts when merged, resolved by Committer: Reynold Xin Closes #7398 from viirya/interval_add_subtract and squashes the following commits: acd1f1e [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into interval_add_subtract 5abae28 [Liang-Chi Hsieh] For comments. 6f5b72e [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into interval_add_subtract dbe3906 [Liang-Chi Hsieh] For comments. 13a2fc5 [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into interval_add_subtract 83ec129 [Liang-Chi Hsieh] Remove intervalMethod. acfe1ab [Liang-Chi Hsieh] Fix scala style. d3e9d0e [Liang-Chi Hsieh] Add add and subtract expressions for IntervalType. --- .../sql/catalyst/expressions/arithmetic.scala | 60 ++++++++++++++++--- .../expressions/codegen/CodeGenerator.scala | 4 +- .../sql/catalyst/expressions/literals.scala | 3 +- .../spark/sql/types/AbstractDataType.scala | 6 ++ .../ExpressionTypeCheckingSuite.scala | 6 +- .../org/apache/spark/sql/SQLQuerySuite.scala | 17 ++++++ .../apache/spark/unsafe/types/Interval.java | 16 +++++ .../spark/unsafe/types/IntervalSuite.java | 38 ++++++++++++ 8 files changed, 136 insertions(+), 14 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 382cbe3b84a07..1616d1bc0aed5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -21,11 +21,12 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode} import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.Interval case class UnaryMinus(child: Expression) extends UnaryExpression with ExpectsInputTypes { - override def inputTypes: Seq[AbstractDataType] = Seq(NumericType) + override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection.NumericAndInterval) override def dataType: DataType = child.dataType @@ -36,15 +37,22 @@ case class UnaryMinus(child: Expression) extends UnaryExpression with ExpectsInp override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = dataType match { case dt: DecimalType => defineCodeGen(ctx, ev, c => s"$c.unary_$$minus()") case dt: NumericType => defineCodeGen(ctx, ev, c => s"(${ctx.javaType(dt)})(-($c))") + case dt: IntervalType => defineCodeGen(ctx, ev, c => s"$c.negate()") } - protected override def nullSafeEval(input: Any): Any = numeric.negate(input) + protected override def nullSafeEval(input: Any): Any = { + if (dataType.isInstanceOf[IntervalType]) { + input.asInstanceOf[Interval].negate() + } else { + numeric.negate(input) + } + } } case class UnaryPositive(child: Expression) extends UnaryExpression with ExpectsInputTypes { override def prettyName: String = "positive" - override def inputTypes: Seq[AbstractDataType] = Seq(NumericType) + override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection.NumericAndInterval) override def dataType: DataType = child.dataType @@ -95,32 +103,66 @@ private[sql] object BinaryArithmetic { case class Add(left: Expression, right: Expression) extends BinaryArithmetic { - override def inputType: AbstractDataType = NumericType + override def inputType: AbstractDataType = TypeCollection.NumericAndInterval override def symbol: String = "+" - override def decimalMethod: String = "$plus" override lazy val resolved = childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType) private lazy val numeric = TypeUtils.getNumeric(dataType) - protected override def nullSafeEval(input1: Any, input2: Any): Any = numeric.plus(input1, input2) + protected override def nullSafeEval(input1: Any, input2: Any): Any = { + if (dataType.isInstanceOf[IntervalType]) { + input1.asInstanceOf[Interval].add(input2.asInstanceOf[Interval]) + } else { + numeric.plus(input1, input2) + } + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = dataType match { + case dt: DecimalType => + defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.$$plus($eval2)") + case ByteType | ShortType => + defineCodeGen(ctx, ev, + (eval1, eval2) => s"(${ctx.javaType(dataType)})($eval1 $symbol $eval2)") + case IntervalType => + defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.add($eval2)") + case _ => + defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1 $symbol $eval2") + } } case class Subtract(left: Expression, right: Expression) extends BinaryArithmetic { - override def inputType: AbstractDataType = NumericType + override def inputType: AbstractDataType = TypeCollection.NumericAndInterval override def symbol: String = "-" - override def decimalMethod: String = "$minus" override lazy val resolved = childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType) private lazy val numeric = TypeUtils.getNumeric(dataType) - protected override def nullSafeEval(input1: Any, input2: Any): Any = numeric.minus(input1, input2) + protected override def nullSafeEval(input1: Any, input2: Any): Any = { + if (dataType.isInstanceOf[IntervalType]) { + input1.asInstanceOf[Interval].subtract(input2.asInstanceOf[Interval]) + } else { + numeric.minus(input1, input2) + } + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = dataType match { + case dt: DecimalType => + defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.$$minus($eval2)") + case ByteType | ShortType => + defineCodeGen(ctx, ev, + (eval1, eval2) => s"(${ctx.javaType(dataType)})($eval1 $symbol $eval2)") + case IntervalType => + defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.subtract($eval2)") + case _ => + defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1 $symbol $eval2") + } } case class Multiply(left: Expression, right: Expression) extends BinaryArithmetic { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 45dc146488e12..7c388bc346306 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -27,7 +27,7 @@ import org.apache.spark.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.unsafe.types._ // These classes are here to avoid issues with serialization and integration with quasiquotes. @@ -69,6 +69,7 @@ class CodeGenContext { mutableStates += ((javaType, variableName, initialValue)) } + final val intervalType: String = classOf[Interval].getName final val JAVA_BOOLEAN = "boolean" final val JAVA_BYTE = "byte" final val JAVA_SHORT = "short" @@ -137,6 +138,7 @@ class CodeGenContext { case dt: DecimalType => "Decimal" case BinaryType => "byte[]" case StringType => "UTF8String" + case IntervalType => intervalType case _: StructType => "InternalRow" case _: ArrayType => s"scala.collection.Seq" case _: MapType => s"scala.collection.Map" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index 3a7a7ae440036..e1fdb29541fa8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.CatalystTypeConverters import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode} import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.unsafe.types._ object Literal { def apply(v: Any): Literal = v match { @@ -42,6 +42,7 @@ object Literal { case t: Timestamp => Literal(DateTimeUtils.fromJavaTimestamp(t), TimestampType) case d: Date => Literal(DateTimeUtils.fromJavaDate(d), DateType) case a: Array[Byte] => Literal(a, BinaryType) + case i: Interval => Literal(i, IntervalType) case null => Literal(null, NullType) case _ => throw new RuntimeException("Unsupported literal type " + v.getClass + " " + v) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala index 076d7b5a5118d..40bf4b299c990 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala @@ -91,6 +91,12 @@ private[sql] object TypeCollection { TimestampType, DateType, StringType, BinaryType) + /** + * Types that include numeric types and interval type. They are only used in unary_minus, + * unary_positive, add and subtract operations. + */ + val NumericAndInterval = TypeCollection(NumericType, IntervalType) + def apply(types: AbstractDataType*): TypeCollection = new TypeCollection(types) def unapply(typ: AbstractDataType): Option[Seq[AbstractDataType]] = typ match { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala index ed0d20e7de80e..ad15136ee9a2f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala @@ -53,7 +53,7 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { } test("check types for unary arithmetic") { - assertError(UnaryMinus('stringField), "expected to be of type numeric") + assertError(UnaryMinus('stringField), "type (numeric or interval)") assertError(Abs('stringField), "expected to be of type numeric") assertError(BitwiseNot('stringField), "expected to be of type integral") } @@ -78,8 +78,8 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { assertErrorForDifferingTypes(MaxOf('intField, 'booleanField)) assertErrorForDifferingTypes(MinOf('intField, 'booleanField)) - assertError(Add('booleanField, 'booleanField), "accepts numeric type") - assertError(Subtract('booleanField, 'booleanField), "accepts numeric type") + assertError(Add('booleanField, 'booleanField), "accepts (numeric or interval) type") + assertError(Subtract('booleanField, 'booleanField), "accepts (numeric or interval) type") assertError(Multiply('booleanField, 'booleanField), "accepts numeric type") assertError(Divide('booleanField, 'booleanField), "accepts numeric type") assertError(Remainder('booleanField, 'booleanField), "accepts numeric type") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 231440892bf0b..5b8b70ed5ae11 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -1492,4 +1492,21 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { // Currently we don't yet support nanosecond checkIntervalParseError("select interval 23 nanosecond") } + + test("SPARK-8945: add and subtract expressions for interval type") { + import org.apache.spark.unsafe.types.Interval + + val df = sql("select interval 3 years -3 month 7 week 123 microseconds as i") + checkAnswer(df, Row(new Interval(12 * 3 - 3, 7L * 1000 * 1000 * 3600 * 24 * 7 + 123))) + + checkAnswer(df.select(df("i") + new Interval(2, 123)), + Row(new Interval(12 * 3 - 3 + 2, 7L * 1000 * 1000 * 3600 * 24 * 7 + 123 + 123))) + + checkAnswer(df.select(df("i") - new Interval(2, 123)), + Row(new Interval(12 * 3 - 3 - 2, 7L * 1000 * 1000 * 3600 * 24 * 7 + 123 - 123))) + + // unary minus + checkAnswer(df.select(-df("i")), + Row(new Interval(-(12 * 3 - 3), -(7L * 1000 * 1000 * 3600 * 24 * 7 + 123)))) + } } diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/Interval.java b/unsafe/src/main/java/org/apache/spark/unsafe/types/Interval.java index 905ea0b7b878c..71b1a85a818ea 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/types/Interval.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/types/Interval.java @@ -87,6 +87,22 @@ public Interval(int months, long microseconds) { this.microseconds = microseconds; } + public Interval add(Interval that) { + int months = this.months + that.months; + long microseconds = this.microseconds + that.microseconds; + return new Interval(months, microseconds); + } + + public Interval subtract(Interval that) { + int months = this.months - that.months; + long microseconds = this.microseconds - that.microseconds; + return new Interval(months, microseconds); + } + + public Interval negate() { + return new Interval(-this.months, -this.microseconds); + } + @Override public boolean equals(Object other) { if (this == other) return true; diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/types/IntervalSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/types/IntervalSuite.java index 1832d0bc65551..d29517cda66a3 100644 --- a/unsafe/src/test/java/org/apache/spark/unsafe/types/IntervalSuite.java +++ b/unsafe/src/test/java/org/apache/spark/unsafe/types/IntervalSuite.java @@ -101,6 +101,44 @@ public void fromStringTest() { assertEquals(Interval.fromString(input), null); } + @Test + public void addTest() { + String input = "interval 3 month 1 hour"; + String input2 = "interval 2 month 100 hour"; + + Interval interval = Interval.fromString(input); + Interval interval2 = Interval.fromString(input2); + + assertEquals(interval.add(interval2), new Interval(5, 101 * MICROS_PER_HOUR)); + + input = "interval -10 month -81 hour"; + input2 = "interval 75 month 200 hour"; + + interval = Interval.fromString(input); + interval2 = Interval.fromString(input2); + + assertEquals(interval.add(interval2), new Interval(65, 119 * MICROS_PER_HOUR)); + } + + @Test + public void subtractTest() { + String input = "interval 3 month 1 hour"; + String input2 = "interval 2 month 100 hour"; + + Interval interval = Interval.fromString(input); + Interval interval2 = Interval.fromString(input2); + + assertEquals(interval.subtract(interval2), new Interval(1, -99 * MICROS_PER_HOUR)); + + input = "interval -10 month -81 hour"; + input2 = "interval 75 month 200 hour"; + + interval = Interval.fromString(input); + interval2 = Interval.fromString(input2); + + assertEquals(interval.subtract(interval2), new Interval(-85, -281 * MICROS_PER_HOUR)); + } + private void testSingleUnit(String unit, int number, int months, long microseconds) { String input1 = "interval " + number + " " + unit; String input2 = "interval " + number + " " + unit + "s"; From 587c315b204f1439f696620543c38166d95f8a3d Mon Sep 17 00:00:00 2001 From: tien-dungle Date: Fri, 17 Jul 2015 12:11:32 -0700 Subject: [PATCH 13/58] [SPARK-9109] [GRAPHX] Keep the cached edge in the graph The change here is to keep the cached RDDs in the graph object so that when the graph.unpersist() is called these RDDs are correctly unpersisted. ```java import org.apache.spark.graphx._ import org.apache.spark.rdd.RDD import org.slf4j.LoggerFactory import org.apache.spark.graphx.util.GraphGenerators // Create an RDD for the vertices val users: RDD[(VertexId, (String, String))] = sc.parallelize(Array((3L, ("rxin", "student")), (7L, ("jgonzal", "postdoc")), (5L, ("franklin", "prof")), (2L, ("istoica", "prof")))) // Create an RDD for edges val relationships: RDD[Edge[String]] = sc.parallelize(Array(Edge(3L, 7L, "collab"), Edge(5L, 3L, "advisor"), Edge(2L, 5L, "colleague"), Edge(5L, 7L, "pi"))) // Define a default user in case there are relationship with missing user val defaultUser = ("John Doe", "Missing") // Build the initial Graph val graph = Graph(users, relationships, defaultUser) graph.cache().numEdges graph.unpersist() sc.getPersistentRDDs.foreach( r => println( r._2.toString)) ``` Author: tien-dungle Closes #7469 from tien-dungle/SPARK-9109_Graphx-unpersist and squashes the following commits: 8d87997 [tien-dungle] Keep the cached edge in the graph --- .../scala/org/apache/spark/graphx/impl/GraphImpl.scala | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala index 90a74d23a26cc..da95314440d86 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala @@ -332,9 +332,9 @@ object GraphImpl { edgeStorageLevel: StorageLevel, vertexStorageLevel: StorageLevel): GraphImpl[VD, ED] = { val edgeRDD = EdgeRDD.fromEdges(edges)(classTag[ED], classTag[VD]) - .withTargetStorageLevel(edgeStorageLevel).cache() + .withTargetStorageLevel(edgeStorageLevel) val vertexRDD = VertexRDD(vertices, edgeRDD, defaultVertexAttr) - .withTargetStorageLevel(vertexStorageLevel).cache() + .withTargetStorageLevel(vertexStorageLevel) GraphImpl(vertexRDD, edgeRDD) } @@ -346,9 +346,14 @@ object GraphImpl { def apply[VD: ClassTag, ED: ClassTag]( vertices: VertexRDD[VD], edges: EdgeRDD[ED]): GraphImpl[VD, ED] = { + + vertices.cache() + // Convert the vertex partitions in edges to the correct type val newEdges = edges.asInstanceOf[EdgeRDDImpl[ED, _]] .mapEdgePartitions((pid, part) => part.withoutVertexAttributes[VD]) + .cache() + GraphImpl.fromExistingRDDs(vertices, newEdges) } From f9a82a884e7cb2a466a33ab64912924ce7ee30c1 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Fri, 17 Jul 2015 12:43:58 -0700 Subject: [PATCH 14/58] [SPARK-9138] [MLLIB] fix Vectors.dense Vectors.dense() should accept numbers directly, like the one in Scala. We already use it in doctests, it worked by luck. cc mengxr jkbradley Author: Davies Liu Closes #7476 from davies/fix_vectors_dense and squashes the following commits: e0fd292 [Davies Liu] fix Vectors.dense --- python/pyspark/mllib/linalg.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/python/pyspark/mllib/linalg.py b/python/pyspark/mllib/linalg.py index 040886f71775b..529bd75894c96 100644 --- a/python/pyspark/mllib/linalg.py +++ b/python/pyspark/mllib/linalg.py @@ -30,6 +30,7 @@ basestring = str xrange = range import copyreg as copy_reg + long = int else: from itertools import izip as zip import copy_reg @@ -770,14 +771,18 @@ def sparse(size, *args): return SparseVector(size, *args) @staticmethod - def dense(elements): + def dense(*elements): """ - Create a dense vector of 64-bit floats from a Python list. Always - returns a NumPy array. + Create a dense vector of 64-bit floats from a Python list or numbers. >>> Vectors.dense([1, 2, 3]) DenseVector([1.0, 2.0, 3.0]) + >>> Vectors.dense(1.0, 2.0) + DenseVector([1.0, 2.0]) """ + if len(elements) == 1 and not isinstance(elements[0], (float, int, long)): + # it's list, numpy.array or other iterable object. + elements = elements[0] return DenseVector(elements) @staticmethod From 806c579f43ce66ac1398200cbc773fa3b69b5cb6 Mon Sep 17 00:00:00 2001 From: Yuhao Yang Date: Fri, 17 Jul 2015 13:43:19 -0700 Subject: [PATCH 15/58] [SPARK-9062] [ML] Change output type of Tokenizer to Array(String, true) jira: https://issues.apache.org/jira/browse/SPARK-9062 Currently output type of Tokenizer is Array(String, false), which is not compatible with Word2Vec and Other transformers since their input type is Array(String, true). Seq[String] in udf will be treated as Array(String, true) by default. I'm not sure what's the recommended way for Tokenizer to handle the null value in the input. Any suggestion will be welcome. Author: Yuhao Yang Closes #7414 from hhbyyh/tokenizer and squashes the following commits: c01bd7a [Yuhao Yang] change output type of tokenizer --- .../main/scala/org/apache/spark/ml/feature/Tokenizer.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala index 5f9f57a2ebcfa..0b3af4747e693 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala @@ -42,7 +42,7 @@ class Tokenizer(override val uid: String) extends UnaryTransformer[String, Seq[S require(inputType == StringType, s"Input type must be string type but got $inputType.") } - override protected def outputDataType: DataType = new ArrayType(StringType, false) + override protected def outputDataType: DataType = new ArrayType(StringType, true) override def copy(extra: ParamMap): Tokenizer = defaultCopy(extra) } @@ -113,7 +113,7 @@ class RegexTokenizer(override val uid: String) require(inputType == StringType, s"Input type must be string type but got $inputType.") } - override protected def outputDataType: DataType = new ArrayType(StringType, false) + override protected def outputDataType: DataType = new ArrayType(StringType, true) override def copy(extra: ParamMap): RegexTokenizer = defaultCopy(extra) } From 9974642870404381fa425fadb966c6dd3ac4a94f Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Fri, 17 Jul 2015 13:55:17 -0700 Subject: [PATCH 16/58] [SPARK-8600] [ML] Naive Bayes API for spark.ml Pipelines Naive Bayes API for spark.ml Pipelines Author: Yanbo Liang Closes #7284 from yanboliang/spark-8600 and squashes the following commits: bc890f7 [Yanbo Liang] remove labels valid check c3de687 [Yanbo Liang] remove labels from ml.NaiveBayesModel a2b3088 [Yanbo Liang] address comments 3220b82 [Yanbo Liang] trigger jenkins 3018a41 [Yanbo Liang] address comments 208e166 [Yanbo Liang] Naive Bayes API for spark.ml Pipelines --- .../spark/ml/classification/NaiveBayes.scala | 178 ++++++++++++++++++ .../mllib/classification/NaiveBayes.scala | 10 +- .../apache/spark/mllib/linalg/Matrices.scala | 6 +- .../classification/JavaNaiveBayesSuite.java | 98 ++++++++++ .../ml/classification/NaiveBayesSuite.scala | 116 ++++++++++++ 5 files changed, 400 insertions(+), 8 deletions(-) create mode 100644 mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala create mode 100644 mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java create mode 100644 mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala new file mode 100644 index 0000000000000..1f547e4a98af7 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala @@ -0,0 +1,178 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.classification + +import org.apache.spark.SparkException +import org.apache.spark.ml.{PredictorParams, PredictionModel, Predictor} +import org.apache.spark.ml.param.{ParamMap, ParamValidators, Param, DoubleParam} +import org.apache.spark.ml.util.Identifiable +import org.apache.spark.mllib.classification.{NaiveBayes => OldNaiveBayes} +import org.apache.spark.mllib.classification.{NaiveBayesModel => OldNaiveBayesModel} +import org.apache.spark.mllib.linalg._ +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.DataFrame + +/** + * Params for Naive Bayes Classifiers. + */ +private[ml] trait NaiveBayesParams extends PredictorParams { + + /** + * The smoothing parameter. + * (default = 1.0). + * @group param + */ + final val lambda: DoubleParam = new DoubleParam(this, "lambda", "The smoothing parameter.", + ParamValidators.gtEq(0)) + + /** @group getParam */ + final def getLambda: Double = $(lambda) + + /** + * The model type which is a string (case-sensitive). + * Supported options: "multinomial" and "bernoulli". + * (default = multinomial) + * @group param + */ + final val modelType: Param[String] = new Param[String](this, "modelType", "The model type " + + "which is a string (case-sensitive). Supported options: multinomial (default) and bernoulli.", + ParamValidators.inArray[String](OldNaiveBayes.supportedModelTypes.toArray)) + + /** @group getParam */ + final def getModelType: String = $(modelType) +} + +/** + * Naive Bayes Classifiers. + * It supports both Multinomial NB + * ([[http://nlp.stanford.edu/IR-book/html/htmledition/naive-bayes-text-classification-1.html]]) + * which can handle finitely supported discrete data. For example, by converting documents into + * TF-IDF vectors, it can be used for document classification. By making every vector a + * binary (0/1) data, it can also be used as Bernoulli NB + * ([[http://nlp.stanford.edu/IR-book/html/htmledition/the-bernoulli-model-1.html]]). + * The input feature values must be nonnegative. + */ +class NaiveBayes(override val uid: String) + extends Predictor[Vector, NaiveBayes, NaiveBayesModel] + with NaiveBayesParams { + + def this() = this(Identifiable.randomUID("nb")) + + /** + * Set the smoothing parameter. + * Default is 1.0. + * @group setParam + */ + def setLambda(value: Double): this.type = set(lambda, value) + setDefault(lambda -> 1.0) + + /** + * Set the model type using a string (case-sensitive). + * Supported options: "multinomial" and "bernoulli". + * Default is "multinomial" + */ + def setModelType(value: String): this.type = set(modelType, value) + setDefault(modelType -> OldNaiveBayes.Multinomial) + + override protected def train(dataset: DataFrame): NaiveBayesModel = { + val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset) + val oldModel = OldNaiveBayes.train(oldDataset, $(lambda), $(modelType)) + NaiveBayesModel.fromOld(oldModel, this) + } + + override def copy(extra: ParamMap): NaiveBayes = defaultCopy(extra) +} + +/** + * Model produced by [[NaiveBayes]] + */ +class NaiveBayesModel private[ml] ( + override val uid: String, + val pi: Vector, + val theta: Matrix) + extends PredictionModel[Vector, NaiveBayesModel] with NaiveBayesParams { + + import OldNaiveBayes.{Bernoulli, Multinomial} + + /** + * Bernoulli scoring requires log(condprob) if 1, log(1-condprob) if 0. + * This precomputes log(1.0 - exp(theta)) and its sum which are used for the linear algebra + * application of this condition (in predict function). + */ + private lazy val (thetaMinusNegTheta, negThetaSum) = $(modelType) match { + case Multinomial => (None, None) + case Bernoulli => + val negTheta = theta.map(value => math.log(1.0 - math.exp(value))) + val ones = new DenseVector(Array.fill(theta.numCols){1.0}) + val thetaMinusNegTheta = theta.map { value => + value - math.log(1.0 - math.exp(value)) + } + (Option(thetaMinusNegTheta), Option(negTheta.multiply(ones))) + case _ => + // This should never happen. + throw new UnknownError(s"Invalid modelType: ${$(modelType)}.") + } + + override protected def predict(features: Vector): Double = { + $(modelType) match { + case Multinomial => + val prob = theta.multiply(features) + BLAS.axpy(1.0, pi, prob) + prob.argmax + case Bernoulli => + features.foreachActive{ (index, value) => + if (value != 0.0 && value != 1.0) { + throw new SparkException( + s"Bernoulli naive Bayes requires 0 or 1 feature values but found $features") + } + } + val prob = thetaMinusNegTheta.get.multiply(features) + BLAS.axpy(1.0, pi, prob) + BLAS.axpy(1.0, negThetaSum.get, prob) + prob.argmax + case _ => + // This should never happen. + throw new UnknownError(s"Invalid modelType: ${$(modelType)}.") + } + } + + override def copy(extra: ParamMap): NaiveBayesModel = { + copyValues(new NaiveBayesModel(uid, pi, theta).setParent(this.parent), extra) + } + + override def toString: String = { + s"NaiveBayesModel with ${pi.size} classes" + } + +} + +private[ml] object NaiveBayesModel { + + /** Convert a model from the old API */ + def fromOld( + oldModel: OldNaiveBayesModel, + parent: NaiveBayes): NaiveBayesModel = { + val uid = if (parent != null) parent.uid else Identifiable.randomUID("nb") + val labels = Vectors.dense(oldModel.labels) + val pi = Vectors.dense(oldModel.pi) + val theta = new DenseMatrix(oldModel.labels.length, oldModel.theta(0).length, + oldModel.theta.flatten, true) + new NaiveBayesModel(uid, pi, theta) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala index 9e379d7d74b2f..8cf4e15efe7a7 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala @@ -40,7 +40,7 @@ import org.apache.spark.sql.{DataFrame, SQLContext} * where D is number of features * @param modelType The type of NB model to fit can be "multinomial" or "bernoulli" */ -class NaiveBayesModel private[mllib] ( +class NaiveBayesModel private[spark] ( val labels: Array[Double], val pi: Array[Double], val theta: Array[Array[Double]], @@ -382,7 +382,7 @@ class NaiveBayes private ( BLAS.axpy(1.0, c2._2, c1._2) (c1._1 + c2._1, c1._2) } - ).collect() + ).collect().sortBy(_._1) val numLabels = aggregated.length var numDocuments = 0L @@ -425,13 +425,13 @@ class NaiveBayes private ( object NaiveBayes { /** String name for multinomial model type. */ - private[classification] val Multinomial: String = "multinomial" + private[spark] val Multinomial: String = "multinomial" /** String name for Bernoulli model type. */ - private[classification] val Bernoulli: String = "bernoulli" + private[spark] val Bernoulli: String = "bernoulli" /* Set of modelTypes that NaiveBayes supports */ - private[classification] val supportedModelTypes = Set(Multinomial, Bernoulli) + private[spark] val supportedModelTypes = Set(Multinomial, Bernoulli) /** * Trains a Naive Bayes model given an RDD of `(label, features)` pairs. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala index 0df07663405a3..55da0e094d132 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala @@ -98,7 +98,7 @@ sealed trait Matrix extends Serializable { /** Map the values of this matrix using a function. Generates a new matrix. Performs the * function on only the backing array. For example, an operation such as addition or * subtraction will only be performed on the non-zero values in a `SparseMatrix`. */ - private[mllib] def map(f: Double => Double): Matrix + private[spark] def map(f: Double => Double): Matrix /** Update all the values of this matrix using the function f. Performed in-place on the * backing array. For example, an operation such as addition or subtraction will only be @@ -289,7 +289,7 @@ class DenseMatrix( override def copy: DenseMatrix = new DenseMatrix(numRows, numCols, values.clone()) - private[mllib] def map(f: Double => Double) = new DenseMatrix(numRows, numCols, values.map(f), + private[spark] def map(f: Double => Double) = new DenseMatrix(numRows, numCols, values.map(f), isTransposed) private[mllib] def update(f: Double => Double): DenseMatrix = { @@ -555,7 +555,7 @@ class SparseMatrix( new SparseMatrix(numRows, numCols, colPtrs, rowIndices, values.clone()) } - private[mllib] def map(f: Double => Double) = + private[spark] def map(f: Double => Double) = new SparseMatrix(numRows, numCols, colPtrs, rowIndices, values.map(f), isTransposed) private[mllib] def update(f: Double => Double): SparseMatrix = { diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java new file mode 100644 index 0000000000000..09a9fba0c19cf --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.classification; + +import java.io.Serializable; + +import com.google.common.collect.Lists; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.mllib.linalg.VectorUDT; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; + +public class JavaNaiveBayesSuite implements Serializable { + + private transient JavaSparkContext jsc; + private transient SQLContext jsql; + + @Before + public void setUp() { + jsc = new JavaSparkContext("local", "JavaLogisticRegressionSuite"); + jsql = new SQLContext(jsc); + } + + @After + public void tearDown() { + jsc.stop(); + jsc = null; + } + + public void validatePrediction(DataFrame predictionAndLabels) { + for (Row r : predictionAndLabels.collect()) { + double prediction = r.getAs(0); + double label = r.getAs(1); + assert(prediction == label); + } + } + + @Test + public void naiveBayesDefaultParams() { + NaiveBayes nb = new NaiveBayes(); + assert(nb.getLabelCol() == "label"); + assert(nb.getFeaturesCol() == "features"); + assert(nb.getPredictionCol() == "prediction"); + assert(nb.getLambda() == 1.0); + assert(nb.getModelType() == "multinomial"); + } + + @Test + public void testNaiveBayes() { + JavaRDD jrdd = jsc.parallelize(Lists.newArrayList( + RowFactory.create(0.0, Vectors.dense(1.0, 0.0, 0.0)), + RowFactory.create(0.0, Vectors.dense(2.0, 0.0, 0.0)), + RowFactory.create(1.0, Vectors.dense(0.0, 1.0, 0.0)), + RowFactory.create(1.0, Vectors.dense(0.0, 2.0, 0.0)), + RowFactory.create(2.0, Vectors.dense(0.0, 0.0, 1.0)), + RowFactory.create(2.0, Vectors.dense(0.0, 0.0, 2.0)) + )); + + StructType schema = new StructType(new StructField[]{ + new StructField("label", DataTypes.DoubleType, false, Metadata.empty()), + new StructField("features", new VectorUDT(), false, Metadata.empty()) + }); + + DataFrame dataset = jsql.createDataFrame(jrdd, schema); + NaiveBayes nb = new NaiveBayes().setLambda(0.5).setModelType("multinomial"); + NaiveBayesModel model = nb.fit(dataset); + + DataFrame predictionAndLabels = model.transform(dataset).select("prediction", "label"); + validatePrediction(predictionAndLabels); + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala new file mode 100644 index 0000000000000..76381a2741296 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala @@ -0,0 +1,116 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.classification + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.mllib.linalg._ +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.mllib.classification.NaiveBayesSuite._ +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.Row + +class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext { + + def validatePrediction(predictionAndLabels: DataFrame): Unit = { + val numOfErrorPredictions = predictionAndLabels.collect().count { + case Row(prediction: Double, label: Double) => + prediction != label + } + // At least 80% of the predictions should be on. + assert(numOfErrorPredictions < predictionAndLabels.count() / 5) + } + + def validateModelFit( + piData: Vector, + thetaData: Matrix, + model: NaiveBayesModel): Unit = { + assert(Vectors.dense(model.pi.toArray.map(math.exp)) ~== + Vectors.dense(piData.toArray.map(math.exp)) absTol 0.05, "pi mismatch") + assert(model.theta.map(math.exp) ~== thetaData.map(math.exp) absTol 0.05, "theta mismatch") + } + + test("params") { + ParamsSuite.checkParams(new NaiveBayes) + val model = new NaiveBayesModel("nb", pi = Vectors.dense(Array(0.2, 0.8)), + theta = new DenseMatrix(2, 3, Array(0.1, 0.2, 0.3, 0.4, 0.6, 0.4))) + ParamsSuite.checkParams(model) + } + + test("naive bayes: default params") { + val nb = new NaiveBayes + assert(nb.getLabelCol === "label") + assert(nb.getFeaturesCol === "features") + assert(nb.getPredictionCol === "prediction") + assert(nb.getLambda === 1.0) + assert(nb.getModelType === "multinomial") + } + + test("Naive Bayes Multinomial") { + val nPoints = 1000 + val piArray = Array(0.5, 0.1, 0.4).map(math.log) + val thetaArray = Array( + Array(0.70, 0.10, 0.10, 0.10), // label 0 + Array(0.10, 0.70, 0.10, 0.10), // label 1 + Array(0.10, 0.10, 0.70, 0.10) // label 2 + ).map(_.map(math.log)) + val pi = Vectors.dense(piArray) + val theta = new DenseMatrix(3, 4, thetaArray.flatten, true) + + val testDataset = sqlContext.createDataFrame(generateNaiveBayesInput( + piArray, thetaArray, nPoints, 42, "multinomial")) + val nb = new NaiveBayes().setLambda(1.0).setModelType("multinomial") + val model = nb.fit(testDataset) + + validateModelFit(pi, theta, model) + assert(model.hasParent) + + val validationDataset = sqlContext.createDataFrame(generateNaiveBayesInput( + piArray, thetaArray, nPoints, 17, "multinomial")) + val predictionAndLabels = model.transform(validationDataset).select("prediction", "label") + + validatePrediction(predictionAndLabels) + } + + test("Naive Bayes Bernoulli") { + val nPoints = 10000 + val piArray = Array(0.5, 0.3, 0.2).map(math.log) + val thetaArray = Array( + Array(0.50, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.40), // label 0 + Array(0.02, 0.70, 0.10, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02), // label 1 + Array(0.02, 0.02, 0.60, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.30) // label 2 + ).map(_.map(math.log)) + val pi = Vectors.dense(piArray) + val theta = new DenseMatrix(3, 12, thetaArray.flatten, true) + + val testDataset = sqlContext.createDataFrame(generateNaiveBayesInput( + piArray, thetaArray, nPoints, 45, "bernoulli")) + val nb = new NaiveBayes().setLambda(1.0).setModelType("bernoulli") + val model = nb.fit(testDataset) + + validateModelFit(pi, theta, model) + assert(model.hasParent) + + val validationDataset = sqlContext.createDataFrame(generateNaiveBayesInput( + piArray, thetaArray, nPoints, 20, "bernoulli")) + val predictionAndLabels = model.transform(validationDataset).select("prediction", "label") + + validatePrediction(predictionAndLabels) + } +} From 074085d6781a580017a45101b8b54ffd7bd31294 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 17 Jul 2015 13:57:31 -0700 Subject: [PATCH 17/58] [SPARK-9136] [SQL] fix several bugs in DateTimeUtils.stringToTimestamp a follow up of https://github.com/apache/spark/pull/7353 1. we should use `Calendar.HOUR_OF_DAY` instead of `Calendar.HOUR`(this is for AM, PM). 2. we should call `c.set(Calendar.MILLISECOND, 0)` after `Calendar.getInstance` I'm not sure why the tests didn't fail in jenkins, but I ran latest spark master branch locally and `DateTimeUtilsSuite` failed. Author: Wenchen Fan Closes #7473 from cloud-fan/datetime and squashes the following commits: 66cdaf2 [Wenchen Fan] fix several bugs in DateTimeUtils.stringToTimestamp --- .../spark/sql/catalyst/util/DateTimeUtils.scala | 5 +++-- .../sql/catalyst/util/DateTimeUtilsSuite.scala | 13 +++++++++++-- 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala index 53c32a0a9802b..f33e34b380bcf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala @@ -320,16 +320,17 @@ object DateTimeUtils { Calendar.getInstance( TimeZone.getTimeZone(f"GMT${timeZone.get.toChar}${segments(7)}%02d:${segments(8)}%02d")) } + c.set(Calendar.MILLISECOND, 0) if (justTime) { - c.set(Calendar.HOUR, segments(3)) + c.set(Calendar.HOUR_OF_DAY, segments(3)) c.set(Calendar.MINUTE, segments(4)) c.set(Calendar.SECOND, segments(5)) } else { c.set(segments(0), segments(1) - 1, segments(2), segments(3), segments(4), segments(5)) } - Some(c.getTimeInMillis / 1000 * 1000000 + segments(6)) + Some(c.getTimeInMillis * 1000 + segments(6)) } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala index c65fcbc4d1bc1..5c3a621c6d11f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala @@ -243,8 +243,17 @@ class DateTimeUtilsSuite extends SparkFunSuite { UTF8String.fromString("2015-03-18T12:03:17.12312+7:30")).get === c.getTimeInMillis * 1000 + 120) + c = Calendar.getInstance() + c.set(Calendar.HOUR_OF_DAY, 18) + c.set(Calendar.MINUTE, 12) + c.set(Calendar.SECOND, 15) + c.set(Calendar.MILLISECOND, 0) + assert(DateTimeUtils.stringToTimestamp( + UTF8String.fromString("18:12:15")).get === + c.getTimeInMillis * 1000) + c = Calendar.getInstance(TimeZone.getTimeZone("GMT+07:30")) - c.set(Calendar.HOUR, 18) + c.set(Calendar.HOUR_OF_DAY, 18) c.set(Calendar.MINUTE, 12) c.set(Calendar.SECOND, 15) c.set(Calendar.MILLISECOND, 123) @@ -253,7 +262,7 @@ class DateTimeUtilsSuite extends SparkFunSuite { c.getTimeInMillis * 1000 + 120) c = Calendar.getInstance(TimeZone.getTimeZone("GMT+07:30")) - c.set(Calendar.HOUR, 18) + c.set(Calendar.HOUR_OF_DAY, 18) c.set(Calendar.MINUTE, 12) c.set(Calendar.SECOND, 15) c.set(Calendar.MILLISECOND, 123) From ad0954f6de29761e0e7e543212c5bfe1fdcbed9f Mon Sep 17 00:00:00 2001 From: zsxwing Date: Fri, 17 Jul 2015 14:00:31 -0700 Subject: [PATCH 18/58] [SPARK-5681] [STREAMING] Move 'stopReceivers' to the event loop to resolve the race condition This is an alternative way to fix `SPARK-5681`. It minimizes the changes. Closes #4467 Author: zsxwing Author: Liang-Chi Hsieh Closes #6294 from zsxwing/pr4467 and squashes the following commits: 709ac1f [zsxwing] Fix the comment e103e8a [zsxwing] Move ReceiverTracker.stop into ReceiverTracker.stop f637142 [zsxwing] Address minor code style comments a178d37 [zsxwing] Move 'stopReceivers' to the event looop to resolve the race condition 51fb07e [zsxwing] Fix the code style 3cb19a3 [zsxwing] Merge branch 'master' into pr4467 b4c29e7 [zsxwing] Stop receiver only if we start it c41ee94 [zsxwing] Make stopReceivers private 7c73c1f [zsxwing] Use trackerStateLock to protect trackerState a8120c0 [zsxwing] Merge branch 'master' into pr4467 7b1d9af [zsxwing] "case Throwable" => "case NonFatal" 15ed4a1 [zsxwing] Register before starting the receiver fff63f9 [zsxwing] Use a lock to eliminate the race condition when stopping receivers and registering receivers happen at the same time. e0ef72a [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into tracker_status_timeout 19b76d9 [Liang-Chi Hsieh] Remove timeout. 34c18dc [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into tracker_status_timeout c419677 [Liang-Chi Hsieh] Fix style. 9e1a760 [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into tracker_status_timeout 355f9ce [Liang-Chi Hsieh] Separate register and start events for receivers. 3d568e8 [Liang-Chi Hsieh] Let receivers get registered first before going started. ae0d9fd [Liang-Chi Hsieh] Merge branch 'master' into tracker_status_timeout 77983f3 [Liang-Chi Hsieh] Add tracker status and stop to receive messages when stopping tracker. --- .../receiver/ReceiverSupervisor.scala | 42 ++++-- .../receiver/ReceiverSupervisorImpl.scala | 2 +- .../streaming/scheduler/ReceiverTracker.scala | 139 ++++++++++++------ .../spark/streaming/ReceiverSuite.scala | 2 + .../streaming/StreamingContextSuite.scala | 15 ++ 5 files changed, 138 insertions(+), 62 deletions(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala index eeb14ca3a49e9..6467029a277b2 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala @@ -22,6 +22,7 @@ import java.util.concurrent.CountDownLatch import scala.collection.mutable.ArrayBuffer import scala.concurrent._ +import scala.util.control.NonFatal import org.apache.spark.{Logging, SparkConf} import org.apache.spark.storage.StreamBlockId @@ -36,7 +37,7 @@ private[streaming] abstract class ReceiverSupervisor( conf: SparkConf ) extends Logging { - /** Enumeration to identify current state of the StreamingContext */ + /** Enumeration to identify current state of the Receiver */ object ReceiverState extends Enumeration { type CheckpointState = Value val Initialized, Started, Stopped = Value @@ -97,8 +98,8 @@ private[streaming] abstract class ReceiverSupervisor( /** Called when supervisor is stopped */ protected def onStop(message: String, error: Option[Throwable]) { } - /** Called when receiver is started */ - protected def onReceiverStart() { } + /** Called when receiver is started. Return true if the driver accepts us */ + protected def onReceiverStart(): Boolean /** Called when receiver is stopped */ protected def onReceiverStop(message: String, error: Option[Throwable]) { } @@ -121,13 +122,17 @@ private[streaming] abstract class ReceiverSupervisor( /** Start receiver */ def startReceiver(): Unit = synchronized { try { - logInfo("Starting receiver") - receiver.onStart() - logInfo("Called receiver onStart") - onReceiverStart() - receiverState = Started + if (onReceiverStart()) { + logInfo("Starting receiver") + receiverState = Started + receiver.onStart() + logInfo("Called receiver onStart") + } else { + // The driver refused us + stop("Registered unsuccessfully because Driver refused to start receiver " + streamId, None) + } } catch { - case t: Throwable => + case NonFatal(t) => stop("Error starting receiver " + streamId, Some(t)) } } @@ -136,12 +141,19 @@ private[streaming] abstract class ReceiverSupervisor( def stopReceiver(message: String, error: Option[Throwable]): Unit = synchronized { try { logInfo("Stopping receiver with message: " + message + ": " + error.getOrElse("")) - receiverState = Stopped - receiver.onStop() - logInfo("Called receiver onStop") - onReceiverStop(message, error) + receiverState match { + case Initialized => + logWarning("Skip stopping receiver because it has not yet stared") + case Started => + receiverState = Stopped + receiver.onStop() + logInfo("Called receiver onStop") + onReceiverStop(message, error) + case Stopped => + logWarning("Receiver has been stopped") + } } catch { - case t: Throwable => + case NonFatal(t) => logError("Error stopping receiver " + streamId + t.getStackTraceString) } } @@ -167,7 +179,7 @@ private[streaming] abstract class ReceiverSupervisor( }(futureExecutionContext) } - /** Check if receiver has been marked for stopping */ + /** Check if receiver has been marked for starting */ def isReceiverStarted(): Boolean = { logDebug("state = " + receiverState) receiverState == Started diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala index 6078cdf8f8790..f6ba66b3ae036 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala @@ -162,7 +162,7 @@ private[streaming] class ReceiverSupervisorImpl( env.rpcEnv.stop(endpoint) } - override protected def onReceiverStart() { + override protected def onReceiverStart(): Boolean = { val msg = RegisterReceiver( streamId, receiver.getClass.getSimpleName, Utils.localHostName(), endpoint) trackerEndpoint.askWithRetry[Boolean](msg) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala index 644e581cd8279..6910d81d9866e 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala @@ -20,7 +20,6 @@ package org.apache.spark.streaming.scheduler import scala.collection.mutable.{ArrayBuffer, HashMap, SynchronizedMap} import scala.language.existentials import scala.math.max -import org.apache.spark.rdd._ import org.apache.spark.streaming.util.WriteAheadLogUtils import org.apache.spark.{Logging, SparkEnv, SparkException} @@ -47,6 +46,8 @@ private[streaming] case class ReportError(streamId: Int, message: String, error: private[streaming] case class DeregisterReceiver(streamId: Int, msg: String, error: String) extends ReceiverTrackerMessage +private[streaming] case object StopAllReceivers extends ReceiverTrackerMessage + /** * This class manages the execution of the receivers of ReceiverInputDStreams. Instance of * this class must be created after all input streams have been added and StreamingContext.start() @@ -71,13 +72,23 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false ) private val listenerBus = ssc.scheduler.listenerBus + /** Enumeration to identify current state of the ReceiverTracker */ + object TrackerState extends Enumeration { + type TrackerState = Value + val Initialized, Started, Stopping, Stopped = Value + } + import TrackerState._ + + /** State of the tracker. Protected by "trackerStateLock" */ + @volatile private var trackerState = Initialized + // endpoint is created when generator starts. // This not being null means the tracker has been started and not stopped private var endpoint: RpcEndpointRef = null /** Start the endpoint and receiver execution thread. */ def start(): Unit = synchronized { - if (endpoint != null) { + if (isTrackerStarted) { throw new SparkException("ReceiverTracker already started") } @@ -86,20 +97,46 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false "ReceiverTracker", new ReceiverTrackerEndpoint(ssc.env.rpcEnv)) if (!skipReceiverLaunch) receiverExecutor.start() logInfo("ReceiverTracker started") + trackerState = Started } } /** Stop the receiver execution thread. */ def stop(graceful: Boolean): Unit = synchronized { - if (!receiverInputStreams.isEmpty && endpoint != null) { + if (isTrackerStarted) { // First, stop the receivers - if (!skipReceiverLaunch) receiverExecutor.stop(graceful) + trackerState = Stopping + if (!skipReceiverLaunch) { + // Send the stop signal to all the receivers + endpoint.askWithRetry[Boolean](StopAllReceivers) + + // Wait for the Spark job that runs the receivers to be over + // That is, for the receivers to quit gracefully. + receiverExecutor.awaitTermination(10000) + + if (graceful) { + val pollTime = 100 + logInfo("Waiting for receiver job to terminate gracefully") + while (receiverInfo.nonEmpty || receiverExecutor.running) { + Thread.sleep(pollTime) + } + logInfo("Waited for receiver job to terminate gracefully") + } + + // Check if all the receivers have been deregistered or not + if (receiverInfo.nonEmpty) { + logWarning("Not all of the receivers have deregistered, " + receiverInfo) + } else { + logInfo("All of the receivers have deregistered successfully") + } + } // Finally, stop the endpoint ssc.env.rpcEnv.stop(endpoint) endpoint = null receivedBlockTracker.stop() logInfo("ReceiverTracker stopped") + trackerState = Stopped } } @@ -145,14 +182,23 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false host: String, receiverEndpoint: RpcEndpointRef, senderAddress: RpcAddress - ) { + ): Boolean = { if (!receiverInputStreamIds.contains(streamId)) { throw new SparkException("Register received for unexpected id " + streamId) } - receiverInfo(streamId) = ReceiverInfo( - streamId, s"${typ}-${streamId}", receiverEndpoint, true, host) - listenerBus.post(StreamingListenerReceiverStarted(receiverInfo(streamId))) - logInfo("Registered receiver for stream " + streamId + " from " + senderAddress) + + if (isTrackerStopping || isTrackerStopped) { + false + } else { + // "stopReceivers" won't happen at the same time because both "registerReceiver" and are + // called in the event loop. So here we can assume "stopReceivers" has not yet been called. If + // "stopReceivers" is called later, it should be able to see this receiver. + receiverInfo(streamId) = ReceiverInfo( + streamId, s"${typ}-${streamId}", receiverEndpoint, true, host) + listenerBus.post(StreamingListenerReceiverStarted(receiverInfo(streamId))) + logInfo("Registered receiver for stream " + streamId + " from " + senderAddress) + true + } } /** Deregister a receiver */ @@ -220,20 +266,33 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { case RegisterReceiver(streamId, typ, host, receiverEndpoint) => - registerReceiver(streamId, typ, host, receiverEndpoint, context.sender.address) - context.reply(true) + val successful = + registerReceiver(streamId, typ, host, receiverEndpoint, context.sender.address) + context.reply(successful) case AddBlock(receivedBlockInfo) => context.reply(addBlock(receivedBlockInfo)) case DeregisterReceiver(streamId, message, error) => deregisterReceiver(streamId, message, error) context.reply(true) + case StopAllReceivers => + assert(isTrackerStopping || isTrackerStopped) + stopReceivers() + context.reply(true) + } + + /** Send stop signal to the receivers. */ + private def stopReceivers() { + // Signal the receivers to stop + receiverInfo.values.flatMap { info => Option(info.endpoint)} + .foreach { _.send(StopReceiver) } + logInfo("Sent stop signal to all " + receiverInfo.size + " receivers") } } /** This thread class runs all the receivers on the cluster. */ class ReceiverLauncher { @transient val env = ssc.env - @volatile @transient private var running = false + @volatile @transient var running = false @transient val thread = new Thread() { override def run() { try { @@ -249,31 +308,6 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false thread.start() } - def stop(graceful: Boolean) { - // Send the stop signal to all the receivers - stopReceivers() - - // Wait for the Spark job that runs the receivers to be over - // That is, for the receivers to quit gracefully. - thread.join(10000) - - if (graceful) { - val pollTime = 100 - logInfo("Waiting for receiver job to terminate gracefully") - while (receiverInfo.nonEmpty || running) { - Thread.sleep(pollTime) - } - logInfo("Waited for receiver job to terminate gracefully") - } - - // Check if all the receivers have been deregistered or not - if (receiverInfo.nonEmpty) { - logWarning("Not all of the receivers have deregistered, " + receiverInfo) - } else { - logInfo("All of the receivers have deregistered successfully") - } - } - /** * Get the list of executors excluding driver */ @@ -358,17 +392,30 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false // Distribute the receivers and start them logInfo("Starting " + receivers.length + " receivers") running = true - ssc.sparkContext.runJob(tempRDD, ssc.sparkContext.clean(startReceiver)) - running = false - logInfo("All of the receivers have been terminated") + try { + ssc.sparkContext.runJob(tempRDD, ssc.sparkContext.clean(startReceiver)) + logInfo("All of the receivers have been terminated") + } finally { + running = false + } } - /** Stops the receivers. */ - private def stopReceivers() { - // Signal the receivers to stop - receiverInfo.values.flatMap { info => Option(info.endpoint)} - .foreach { _.send(StopReceiver) } - logInfo("Sent stop signal to all " + receiverInfo.size + " receivers") + /** + * Wait until the Spark job that runs the receivers is terminated, or return when + * `milliseconds` elapses + */ + def awaitTermination(milliseconds: Long): Unit = { + thread.join(milliseconds) } } + + /** Check if tracker has been marked for starting */ + private def isTrackerStarted(): Boolean = trackerState == Started + + /** Check if tracker has been marked for stopping */ + private def isTrackerStopping(): Boolean = trackerState == Stopping + + /** Check if tracker has been marked for stopped */ + private def isTrackerStopped(): Boolean = trackerState == Stopped + } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala index 5d7127627eea5..13b4d17c86183 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala @@ -346,6 +346,8 @@ class ReceiverSuite extends TestSuiteBase with Timeouts with Serializable { def reportError(message: String, throwable: Throwable) { errors += throwable } + + override protected def onReceiverStart(): Boolean = true } /** diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala index f588cf5bc1e7c..4bba9691f8aa5 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala @@ -285,6 +285,21 @@ class StreamingContextSuite extends SparkFunSuite with BeforeAndAfter with Timeo } } + test("stop gracefully even if a receiver misses StopReceiver") { + // This is not a deterministic unit. But if this unit test is flaky, then there is definitely + // something wrong. See SPARK-5681 + val conf = new SparkConf().setMaster(master).setAppName(appName) + sc = new SparkContext(conf) + ssc = new StreamingContext(sc, Milliseconds(100)) + val input = ssc.receiverStream(new TestReceiver) + input.foreachRDD(_ => {}) + ssc.start() + // Call `ssc.stop` at once so that it's possible that the receiver will miss "StopReceiver" + failAfter(30000 millis) { + ssc.stop(stopSparkContext = true, stopGracefully = true) + } + } + test("stop slow receiver gracefully") { val conf = new SparkConf().setMaster(master).setAppName(appName) conf.set("spark.streaming.gracefulStopTimeout", "20000s") From 6da1069696186572c66cbd83947c1a1dbd2bc827 Mon Sep 17 00:00:00 2001 From: Feynman Liang Date: Fri, 17 Jul 2015 14:00:53 -0700 Subject: [PATCH 19/58] [SPARK-9090] [ML] Fix definition of residual in LinearRegressionSummary, EnsembleTestHelper, and SquaredError Make the definition of residuals in Spark consistent with literature. We have been using `prediction - label` for residuals, but literature usually defines `residual = label - prediction`. Author: Feynman Liang Closes #7435 from feynmanliang/SPARK-9090-Fix-LinearRegressionSummary-Residuals and squashes the following commits: f4b39d8 [Feynman Liang] Fix doc bc12a92 [Feynman Liang] Tweak EnsembleTestHelper and SquaredError residuals 63f0d60 [Feynman Liang] Fix definition of residual --- .../org/apache/spark/ml/regression/LinearRegression.scala | 4 ++-- .../scala/org/apache/spark/mllib/tree/loss/SquaredError.scala | 4 ++-- .../apache/spark/ml/regression/LinearRegressionSuite.scala | 4 ++-- .../org/apache/spark/mllib/tree/EnsembleTestHelper.scala | 2 +- 4 files changed, 7 insertions(+), 7 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index 8fc986056657d..89718e0f3e15a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -355,9 +355,9 @@ class LinearRegressionSummary private[regression] ( */ val r2: Double = metrics.r2 - /** Residuals (predicted value - label value) */ + /** Residuals (label - predicted value) */ @transient lazy val residuals: DataFrame = { - val t = udf { (pred: Double, label: Double) => pred - label} + val t = udf { (pred: Double, label: Double) => label - pred } predictions.select(t(col(predictionCol), col(labelCol)).as("residuals")) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala index a5582d3ef3324..011a5d57422f7 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala @@ -42,11 +42,11 @@ object SquaredError extends Loss { * @return Loss gradient */ override def gradient(prediction: Double, label: Double): Double = { - 2.0 * (prediction - label) + - 2.0 * (label - prediction) } override private[mllib] def computeError(prediction: Double, label: Double): Double = { - val err = prediction - label + val err = label - prediction err * err } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala index cf120cf2a4b47..374002c5b4fdd 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala @@ -302,7 +302,7 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { .map { case Row(features: DenseVector, label: Double) => val prediction = features(0) * model.weights(0) + features(1) * model.weights(1) + model.intercept - prediction - label + label - prediction } .zip(model.summary.residuals.map(_.getDouble(0))) .collect() @@ -314,7 +314,7 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { Use the following R code to generate model training results. predictions <- predict(fit, newx=features) - residuals <- predictions - label + residuals <- label - predictions > mean(residuals^2) # MSE [1] 0.009720325 > mean(abs(residuals)) # MAD diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/EnsembleTestHelper.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/EnsembleTestHelper.scala index 8972c229b7ecb..334bf3790fc7a 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/EnsembleTestHelper.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/EnsembleTestHelper.scala @@ -70,7 +70,7 @@ object EnsembleTestHelper { metricName: String = "mse") { val predictions = input.map(x => model.predict(x.features)) val errors = predictions.zip(input.map(_.label)).map { case (prediction, label) => - prediction - label + label - prediction } val metric = metricName match { case "mse" => From 830666f6fe1e77faa39eed2c1c3cd8e83bc93ef9 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Fri, 17 Jul 2015 14:08:06 -0700 Subject: [PATCH 20/58] [SPARK-8792] [ML] Add Python API for PCA transformer Add Python API for PCA transformer Author: Yanbo Liang Closes #7190 from yanboliang/spark-8792 and squashes the following commits: 8f4ac31 [Yanbo Liang] address comments 8a79cc0 [Yanbo Liang] Add Python API for PCA transformer --- python/pyspark/ml/feature.py | 64 +++++++++++++++++++++++++++++++++++- 1 file changed, 63 insertions(+), 1 deletion(-) diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 9bca7cc000aa5..86e654dd0779f 100644 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -24,7 +24,7 @@ __all__ = ['Binarizer', 'HashingTF', 'IDF', 'IDFModel', 'NGram', 'Normalizer', 'OneHotEncoder', 'PolynomialExpansion', 'RegexTokenizer', 'StandardScaler', 'StandardScalerModel', 'StringIndexer', 'StringIndexerModel', 'Tokenizer', 'VectorAssembler', 'VectorIndexer', - 'Word2Vec', 'Word2VecModel'] + 'Word2Vec', 'Word2VecModel', 'PCA', 'PCAModel'] @inherit_doc @@ -1048,6 +1048,68 @@ class Word2VecModel(JavaModel): """ +@inherit_doc +class PCA(JavaEstimator, HasInputCol, HasOutputCol): + """ + PCA trains a model to project vectors to a low-dimensional space using PCA. + + >>> from pyspark.mllib.linalg import Vectors + >>> data = [(Vectors.sparse(5, [(1, 1.0), (3, 7.0)]),), + ... (Vectors.dense([2.0, 0.0, 3.0, 4.0, 5.0]),), + ... (Vectors.dense([4.0, 0.0, 0.0, 6.0, 7.0]),)] + >>> df = sqlContext.createDataFrame(data,["features"]) + >>> pca = PCA(k=2, inputCol="features", outputCol="pca_features") + >>> model = pca.fit(df) + >>> model.transform(df).collect()[0].pca_features + DenseVector([1.648..., -4.013...]) + """ + + # a placeholder to make it appear in the generated doc + k = Param(Params._dummy(), "k", "the number of principal components") + + @keyword_only + def __init__(self, k=None, inputCol=None, outputCol=None): + """ + __init__(self, k=None, inputCol=None, outputCol=None) + """ + super(PCA, self).__init__() + self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.PCA", self.uid) + self.k = Param(self, "k", "the number of principal components") + kwargs = self.__init__._input_kwargs + self.setParams(**kwargs) + + @keyword_only + def setParams(self, k=None, inputCol=None, outputCol=None): + """ + setParams(self, k=None, inputCol=None, outputCol=None) + Set params for this PCA. + """ + kwargs = self.setParams._input_kwargs + return self._set(**kwargs) + + def setK(self, value): + """ + Sets the value of :py:attr:`k`. + """ + self._paramMap[self.k] = value + return self + + def getK(self): + """ + Gets the value of k or its default value. + """ + return self.getOrDefault(self.k) + + def _create_model(self, java_model): + return PCAModel(java_model) + + +class PCAModel(JavaModel): + """ + Model fitted by PCA. + """ + + if __name__ == "__main__": import doctest from pyspark.context import SparkContext From 8b8be1f5d698e796b96a92f1ed2c13162a90944e Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Fri, 17 Jul 2015 14:10:16 -0700 Subject: [PATCH 21/58] [SPARK-7127] [MLLIB] Adding broadcast of model before prediction for ensembles Broadcast of ensemble models in transformImpl before call to predict Author: Bryan Cutler Closes #6300 from BryanCutler/bcast-ensemble-models-7127 and squashes the following commits: 86e73de [Bryan Cutler] [SPARK-7127] Replaced deprecated callUDF with udf 40a139d [Bryan Cutler] Merge branch 'master' into bcast-ensemble-models-7127 9afad56 [Bryan Cutler] [SPARK-7127] Simplified calls by overriding transformImpl and using broadcasted model in callUDF to make prediction 1f34be4 [Bryan Cutler] [SPARK-7127] Removed accidental newline 171a6ce [Bryan Cutler] [SPARK-7127] Used modelAccessor parameter in predictImpl to access broadcasted model 6fd153c [Bryan Cutler] [SPARK-7127] Applied broadcasting to remaining ensemble models aaad77b [Bryan Cutler] [SPARK-7127] Removed abstract class for broadcasting model, instead passing a prediction function as param to transform 83904bb [Bryan Cutler] [SPARK-7127] Adding broadcast of model before prediction in RandomForestClassifier --- .../main/scala/org/apache/spark/ml/Predictor.scala | 12 ++++++++---- .../spark/ml/classification/GBTClassifier.scala | 11 ++++++++++- .../ml/classification/RandomForestClassifier.scala | 11 ++++++++++- .../apache/spark/ml/regression/GBTRegressor.scala | 11 ++++++++++- .../spark/ml/regression/RandomForestRegressor.scala | 11 ++++++++++- 5 files changed, 48 insertions(+), 8 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala index 333b42711ec52..19fe039b8fd03 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala @@ -169,10 +169,7 @@ abstract class PredictionModel[FeaturesType, M <: PredictionModel[FeaturesType, override def transform(dataset: DataFrame): DataFrame = { transformSchema(dataset.schema, logging = true) if ($(predictionCol).nonEmpty) { - val predictUDF = udf { (features: Any) => - predict(features.asInstanceOf[FeaturesType]) - } - dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol)))) + transformImpl(dataset) } else { this.logWarning(s"$uid: Predictor.transform() was called as NOOP" + " since no output columns were set.") @@ -180,6 +177,13 @@ abstract class PredictionModel[FeaturesType, M <: PredictionModel[FeaturesType, } } + protected def transformImpl(dataset: DataFrame): DataFrame = { + val predictUDF = udf { (features: Any) => + predict(features.asInstanceOf[FeaturesType]) + } + dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol)))) + } + /** * Predict label for the given features. * This internal method is used to implement [[transform()]] and output [[predictionCol]]. diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala index 554e3b8e052b2..eb0b1a0a405fc 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -34,6 +34,8 @@ import org.apache.spark.mllib.tree.loss.{LogLoss => OldLogLoss, Loss => OldLoss} import org.apache.spark.mllib.tree.model.{GradientBoostedTreesModel => OldGBTModel} import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types.DoubleType /** * :: Experimental :: @@ -177,8 +179,15 @@ final class GBTClassificationModel( override def treeWeights: Array[Double] = _treeWeights + override protected def transformImpl(dataset: DataFrame): DataFrame = { + val bcastModel = dataset.sqlContext.sparkContext.broadcast(this) + val predictUDF = udf { (features: Any) => + bcastModel.value.predict(features.asInstanceOf[Vector]) + } + dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol)))) + } + override protected def predict(features: Vector): Double = { - // TODO: Override transform() to broadcast model: SPARK-7127 // TODO: When we add a generic Boosting class, handle transform there? SPARK-7129 // Classifies by thresholding sum of weighted tree predictions val treePredictions = _trees.map(_.rootNode.predict(features)) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala index 490f04c7c7172..fc0693f67cc2e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala @@ -31,6 +31,8 @@ import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} import org.apache.spark.mllib.tree.model.{RandomForestModel => OldRandomForestModel} import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types.DoubleType /** * :: Experimental :: @@ -143,8 +145,15 @@ final class RandomForestClassificationModel private[ml] ( override def treeWeights: Array[Double] = _treeWeights + override protected def transformImpl(dataset: DataFrame): DataFrame = { + val bcastModel = dataset.sqlContext.sparkContext.broadcast(this) + val predictUDF = udf { (features: Any) => + bcastModel.value.predict(features.asInstanceOf[Vector]) + } + dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol)))) + } + override protected def predict(features: Vector): Double = { - // TODO: Override transform() to broadcast model. SPARK-7127 // TODO: When we add a generic Bagging class, handle transform there: SPARK-7128 // Classifies using majority votes. // Ignore the weights since all are 1.0 for now. diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala index 47c110d027d67..e38dc73ee0ba7 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala @@ -33,6 +33,8 @@ import org.apache.spark.mllib.tree.loss.{AbsoluteError => OldAbsoluteError, Loss import org.apache.spark.mllib.tree.model.{GradientBoostedTreesModel => OldGBTModel} import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types.DoubleType /** * :: Experimental :: @@ -167,8 +169,15 @@ final class GBTRegressionModel( override def treeWeights: Array[Double] = _treeWeights + override protected def transformImpl(dataset: DataFrame): DataFrame = { + val bcastModel = dataset.sqlContext.sparkContext.broadcast(this) + val predictUDF = udf { (features: Any) => + bcastModel.value.predict(features.asInstanceOf[Vector]) + } + dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol)))) + } + override protected def predict(features: Vector): Double = { - // TODO: Override transform() to broadcast model. SPARK-7127 // TODO: When we add a generic Boosting class, handle transform there? SPARK-7129 // Classifies by thresholding sum of weighted tree predictions val treePredictions = _trees.map(_.rootNode.predict(features)) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala index 5fd5c7c7bd3fc..506a878c2553b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala @@ -29,6 +29,8 @@ import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} import org.apache.spark.mllib.tree.model.{RandomForestModel => OldRandomForestModel} import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types.DoubleType /** * :: Experimental :: @@ -129,8 +131,15 @@ final class RandomForestRegressionModel private[ml] ( override def treeWeights: Array[Double] = _treeWeights + override protected def transformImpl(dataset: DataFrame): DataFrame = { + val bcastModel = dataset.sqlContext.sparkContext.broadcast(this) + val predictUDF = udf { (features: Any) => + bcastModel.value.predict(features.asInstanceOf[Vector]) + } + dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol)))) + } + override protected def predict(features: Vector): Double = { - // TODO: Override transform() to broadcast model. SPARK-7127 // TODO: When we add a generic Bagging class, handle transform there. SPARK-7128 // Predict average of tree predictions. // Ignore the weights since all are 1.0 for now. From 42d8a012f6652df1fa3f560f87c53731ea070640 Mon Sep 17 00:00:00 2001 From: Joshi Date: Fri, 17 Jul 2015 22:47:28 +0100 Subject: [PATCH 22/58] [SPARK-8593] [CORE] Sort app attempts by start time. This makes sure attempts are listed in the order they were executed, and that the app's state matches the state of the most current attempt. Author: Joshi Author: Rekha Joshi Closes #7253 from rekhajoshm/SPARK-8593 and squashes the following commits: 874dd80 [Joshi] History Server: updated order for multiple attempts(logcleaner) 716e0b1 [Joshi] History Server: updated order for multiple attempts(descending start time works everytime) 548c753 [Joshi] History Server: updated order for multiple attempts(descending start time works everytime) 83306a8 [Joshi] History Server: updated order for multiple attempts(descending start time) b0fc922 [Joshi] History Server: updated order for multiple attempts(updated comment) cc0fda7 [Joshi] History Server: updated order for multiple attempts(updated test) 304cb0b [Joshi] History Server: updated order for multiple attempts(reverted HistoryPage) 85024e8 [Joshi] History Server: updated order for multiple attempts a41ac4b [Joshi] History Server: updated order for multiple attempts ab65fa1 [Joshi] History Server: some attempt completed to work with showIncomplete 0be142d [Rekha Joshi] Merge pull request #3 from apache/master 106fd8e [Rekha Joshi] Merge pull request #2 from apache/master e3677c9 [Rekha Joshi] Merge pull request #1 from apache/master --- .../deploy/history/FsHistoryProvider.scala | 10 +++----- .../history/FsHistoryProviderSuite.scala | 24 +++++++++---------- 2 files changed, 14 insertions(+), 20 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index 2cc465e55fceb..e3060ac3fa1a9 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -407,8 +407,8 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) /** * Comparison function that defines the sort order for application attempts within the same - * application. Order is: running attempts before complete attempts, running attempts sorted - * by start time, completed attempts sorted by end time. + * application. Order is: attempts are sorted by descending start time. + * Most recent attempt state matches with current state of the app. * * Normally applications should have a single running attempt; but failure to call sc.stop() * may cause multiple running attempts to show up. @@ -418,11 +418,7 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) private def compareAttemptInfo( a1: FsApplicationAttemptInfo, a2: FsApplicationAttemptInfo): Boolean = { - if (a1.completed == a2.completed) { - if (a1.completed) a1.endTime >= a2.endTime else a1.startTime >= a2.startTime - } else { - !a1.completed - } + a1.startTime >= a2.startTime } /** diff --git a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala index 2a62450bcdbad..73cff89544dc3 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala @@ -243,13 +243,12 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc appListAfterRename.size should be (1) } - test("apps with multiple attempts") { + test("apps with multiple attempts with order") { val provider = new FsHistoryProvider(createTestConf()) - val attempt1 = newLogFile("app1", Some("attempt1"), inProgress = false) + val attempt1 = newLogFile("app1", Some("attempt1"), inProgress = true) writeFile(attempt1, true, None, - SparkListenerApplicationStart("app1", Some("app1"), 1L, "test", Some("attempt1")), - SparkListenerApplicationEnd(2L) + SparkListenerApplicationStart("app1", Some("app1"), 1L, "test", Some("attempt1")) ) updateAndCheck(provider) { list => @@ -259,7 +258,7 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc val attempt2 = newLogFile("app1", Some("attempt2"), inProgress = true) writeFile(attempt2, true, None, - SparkListenerApplicationStart("app1", Some("app1"), 3L, "test", Some("attempt2")) + SparkListenerApplicationStart("app1", Some("app1"), 2L, "test", Some("attempt2")) ) updateAndCheck(provider) { list => @@ -268,22 +267,21 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc list.head.attempts.head.attemptId should be (Some("attempt2")) } - val completedAttempt2 = newLogFile("app1", Some("attempt2"), inProgress = false) - attempt2.delete() - writeFile(attempt2, true, None, - SparkListenerApplicationStart("app1", Some("app1"), 3L, "test", Some("attempt2")), + val attempt3 = newLogFile("app1", Some("attempt3"), inProgress = false) + writeFile(attempt3, true, None, + SparkListenerApplicationStart("app1", Some("app1"), 3L, "test", Some("attempt3")), SparkListenerApplicationEnd(4L) ) updateAndCheck(provider) { list => list should not be (null) list.size should be (1) - list.head.attempts.size should be (2) - list.head.attempts.head.attemptId should be (Some("attempt2")) + list.head.attempts.size should be (3) + list.head.attempts.head.attemptId should be (Some("attempt3")) } val app2Attempt1 = newLogFile("app2", Some("attempt1"), inProgress = false) - writeFile(attempt2, true, None, + writeFile(attempt1, true, None, SparkListenerApplicationStart("app2", Some("app2"), 5L, "test", Some("attempt1")), SparkListenerApplicationEnd(6L) ) @@ -291,7 +289,7 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc updateAndCheck(provider) { list => list.size should be (2) list.head.attempts.size should be (1) - list.last.attempts.size should be (2) + list.last.attempts.size should be (3) list.head.attempts.head.attemptId should be (Some("attempt1")) list.foreach { case app => From b2aa490bb60176631c94ecadf87c14564960f12c Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Fri, 17 Jul 2015 15:02:13 -0700 Subject: [PATCH 23/58] [SPARK-9142] [SQL] Removing unnecessary self types in Catalyst. Just a small change to add Product type to the base expression/plan abstract classes, based on suggestions on #7434 and offline discussions. Author: Reynold Xin Closes #7479 from rxin/remove-self-types and squashes the following commits: e407ffd [Reynold Xin] [SPARK-9142][SQL] Removing unnecessary self types in Catalyst. --- .../apache/spark/sql/catalyst/analysis/unresolved.scala | 1 - .../spark/sql/catalyst/expressions/Expression.scala | 7 +------ .../spark/sql/catalyst/expressions/aggregates.scala | 3 --- .../spark/sql/catalyst/expressions/arithmetic.scala | 1 - .../spark/sql/catalyst/expressions/conditionals.scala | 1 - .../spark/sql/catalyst/expressions/generators.scala | 2 +- .../org/apache/spark/sql/catalyst/expressions/math.scala | 5 ++--- .../sql/catalyst/expressions/namedExpressions.scala | 4 ++-- .../spark/sql/catalyst/expressions/predicates.scala | 3 --- .../apache/spark/sql/catalyst/expressions/random.scala | 1 - .../sql/catalyst/expressions/windowExpressions.scala | 2 -- .../spark/sql/catalyst/plans/logical/LogicalPlan.scala | 9 +-------- .../sql/catalyst/plans/logical/basicOperators.scala | 2 +- .../spark/sql/catalyst/plans/logical/partitioning.scala | 2 -- .../scala/org/apache/spark/sql/execution/SparkPlan.scala | 9 +-------- .../scala/org/apache/spark/sql/execution/commands.scala | 2 -- .../org/apache/spark/sql/parquet/ParquetRelation.scala | 2 -- .../org/apache/spark/sql/hive/HiveMetastoreCatalog.scala | 2 -- 18 files changed, 9 insertions(+), 49 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index 7089f079b6dde..4a1a1ed61ebe7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -96,7 +96,6 @@ case class UnresolvedFunction(name: String, children: Seq[Expression]) extends E * "SELECT * FROM ...". A [[Star]] gets automatically expanded during analysis. */ abstract class Star extends LeafExpression with NamedExpression { - self: Product => override def name: String = throw new UnresolvedException(this, "name") override def exprId: ExprId = throw new UnresolvedException(this, "exprId") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index f396bd08a8238..c70b5af4aa448 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -43,8 +43,7 @@ import org.apache.spark.sql.types._ * * See [[Substring]] for an example. */ -abstract class Expression extends TreeNode[Expression] { - self: Product => +abstract class Expression extends TreeNode[Expression] with Product { /** * Returns true when an expression is a candidate for static evaluation before the query is @@ -187,7 +186,6 @@ abstract class Expression extends TreeNode[Expression] { * A leaf expression, i.e. one without any child expressions. */ abstract class LeafExpression extends Expression { - self: Product => def children: Seq[Expression] = Nil } @@ -198,7 +196,6 @@ abstract class LeafExpression extends Expression { * if the input is evaluated to null. */ abstract class UnaryExpression extends Expression { - self: Product => def child: Expression @@ -277,7 +274,6 @@ abstract class UnaryExpression extends Expression { * if any input is evaluated to null. */ abstract class BinaryExpression extends Expression { - self: Product => def left: Expression def right: Expression @@ -370,7 +366,6 @@ abstract class BinaryExpression extends Expression { * the analyzer will find the tightest common type and do the proper type casting. */ abstract class BinaryOperator extends BinaryExpression with ExpectsInputTypes { - self: Product => /** * Expected input type from both left/right child expressions, similar to the diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala index 71c943dc79e9e..af9a674ab4958 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala @@ -27,7 +27,6 @@ import org.apache.spark.sql.types._ import org.apache.spark.util.collection.OpenHashSet trait AggregateExpression extends Expression { - self: Product => /** * Aggregate expressions should not be foldable. @@ -65,7 +64,6 @@ case class SplitEvaluation( * These partial evaluations can then be combined to compute the actual answer. */ trait PartialAggregate extends AggregateExpression { - self: Product => /** * Returns a [[SplitEvaluation]] that computes this aggregation using partial aggregation. @@ -79,7 +77,6 @@ trait PartialAggregate extends AggregateExpression { */ abstract class AggregateFunction extends LeafExpression with AggregateExpression with Serializable { - self: Product => /** Base should return the generic aggregate expression that this function is computing */ val base: AggregateExpression diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 1616d1bc0aed5..c5960eb390ea4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -77,7 +77,6 @@ case class Abs(child: Expression) extends UnaryExpression with ExpectsInputTypes } abstract class BinaryArithmetic extends BinaryOperator { - self: Product => override def dataType: DataType = left.dataType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala index 9162b73fe56eb..15b33da884dcb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala @@ -77,7 +77,6 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi } trait CaseWhenLike extends Expression { - self: Product => // Note that `branches` are considered in consecutive pairs (cond, val), and the optional last // element is the value for the default catch-all case (if provided). diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala index 51dc77ee3fc5f..c58a6d36141c1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala @@ -40,7 +40,7 @@ import org.apache.spark.sql.types._ * requested. The attributes produced by this function will be automatically copied anytime rules * result in changes to the Generator or its children. */ -trait Generator extends Expression { self: Product => +trait Generator extends Expression { // TODO ideally we should return the type of ArrayType(StructType), // however, we don't keep the output field names in the Generator. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala index 7a543ff36afd1..b05a7b3ed0ea4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala @@ -34,7 +34,6 @@ import org.apache.spark.unsafe.types.UTF8String */ abstract class LeafMathExpression(c: Double, name: String) extends LeafExpression with Serializable { - self: Product => override def dataType: DataType = DoubleType override def foldable: Boolean = true @@ -58,7 +57,7 @@ abstract class LeafMathExpression(c: Double, name: String) * @param name The short name of the function */ abstract class UnaryMathExpression(f: Double => Double, name: String) - extends UnaryExpression with Serializable with ImplicitCastInputTypes { self: Product => + extends UnaryExpression with Serializable with ImplicitCastInputTypes { override def inputTypes: Seq[DataType] = Seq(DoubleType) override def dataType: DataType = DoubleType @@ -92,7 +91,7 @@ abstract class UnaryMathExpression(f: Double => Double, name: String) * @param name The short name of the function */ abstract class BinaryMathExpression(f: (Double, Double) => Double, name: String) - extends BinaryExpression with Serializable with ImplicitCastInputTypes { self: Product => + extends BinaryExpression with Serializable with ImplicitCastInputTypes { override def inputTypes: Seq[DataType] = Seq(DoubleType, DoubleType) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index 8bf7a7ce4e647..c083ac08ded21 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -40,7 +40,7 @@ case class ExprId(id: Long) /** * An [[Expression]] that is named. */ -trait NamedExpression extends Expression { self: Product => +trait NamedExpression extends Expression { /** We should never fold named expressions in order to not remove the alias. */ override def foldable: Boolean = false @@ -83,7 +83,7 @@ trait NamedExpression extends Expression { self: Product => } } -abstract class Attribute extends LeafExpression with NamedExpression { self: Product => +abstract class Attribute extends LeafExpression with NamedExpression { override def references: AttributeSet = AttributeSet(this) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index aa6c30e2f79f2..7a6fb2b3788ca 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -38,8 +38,6 @@ object InterpretedPredicate { * An [[Expression]] that returns a boolean value. */ trait Predicate extends Expression { - self: Product => - override def dataType: DataType = BooleanType } @@ -222,7 +220,6 @@ case class Or(left: Expression, right: Expression) extends BinaryOperator with P abstract class BinaryComparison extends BinaryOperator with Predicate { - self: Product => override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { if (ctx.isPrimitiveType(left.dataType)) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala index e10ba55396664..65093dc72264b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala @@ -33,7 +33,6 @@ import org.apache.spark.util.random.XORShiftRandom * Since this expression is stateful, it cannot be a case object. */ abstract class RDG(seed: Long) extends LeafExpression with Serializable { - self: Product => /** * Record ID within each partition. By being transient, the Random Number Generator is diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala index 344361685853f..c8aa571df64fc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala @@ -254,8 +254,6 @@ object SpecifiedWindowFrame { * to retrieve value corresponding with these n rows. */ trait WindowFunction extends Expression { - self: Product => - def init(): Unit def reset(): Unit diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index adac37231cc4a..dd6c5d43f5714 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -25,8 +25,7 @@ import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.trees.TreeNode -abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { - self: Product => +abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging with Product{ /** * Computes [[Statistics]] for this plan. The default implementation assumes the output @@ -277,8 +276,6 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { * A logical plan node with no children. */ abstract class LeafNode extends LogicalPlan { - self: Product => - override def children: Seq[LogicalPlan] = Nil } @@ -286,8 +283,6 @@ abstract class LeafNode extends LogicalPlan { * A logical plan node with single child. */ abstract class UnaryNode extends LogicalPlan { - self: Product => - def child: LogicalPlan override def children: Seq[LogicalPlan] = child :: Nil @@ -297,8 +292,6 @@ abstract class UnaryNode extends LogicalPlan { * A logical plan node with a left and right child. */ abstract class BinaryNode extends LogicalPlan { - self: Product => - def left: LogicalPlan def right: LogicalPlan diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index fae339808c233..fbe104db016d6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -298,7 +298,7 @@ case class Expand( } trait GroupingAnalytics extends UnaryNode { - self: Product => + def groupByExprs: Seq[Expression] def aggregations: Seq[NamedExpression] diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/partitioning.scala index 63df2c1ee72ff..1f76b03bcb0f6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/partitioning.scala @@ -24,8 +24,6 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, SortOrd * result have expectations about the distribution and ordering of partitioned input data. */ abstract class RedistributeData extends UnaryNode { - self: Product => - override def output: Seq[Attribute] = child.output } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index 632f633d82a2e..ba12056ee7a1b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -39,8 +39,7 @@ object SparkPlan { * :: DeveloperApi :: */ @DeveloperApi -abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializable { - self: Product => +abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Product with Serializable { /** * A handle to the SQL Context that was used to create this plan. Since many operators need @@ -239,14 +238,10 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ } private[sql] trait LeafNode extends SparkPlan { - self: Product => - override def children: Seq[SparkPlan] = Nil } private[sql] trait UnaryNode extends SparkPlan { - self: Product => - def child: SparkPlan override def children: Seq[SparkPlan] = child :: Nil @@ -255,8 +250,6 @@ private[sql] trait UnaryNode extends SparkPlan { } private[sql] trait BinaryNode extends SparkPlan { - self: Product => - def left: SparkPlan def right: SparkPlan diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala index 5e9951f248ff2..bace3f8a9c8d4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala @@ -35,8 +35,6 @@ import org.apache.spark.sql.{DataFrame, Row, SQLConf, SQLContext} * wrapped in `ExecutedCommand` during execution. */ private[sql] trait RunnableCommand extends LogicalPlan with logical.Command { - self: Product => - override def output: Seq[Attribute] = Seq.empty override def children: Seq[LogicalPlan] = Seq.empty def run(sqlContext: SQLContext): Seq[Row] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala index e0bea65a15f36..086559e9f7658 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala @@ -54,8 +54,6 @@ private[sql] case class ParquetRelation( partitioningAttributes: Seq[Attribute] = Nil) extends LeafNode with MultiInstanceRelation { - self: Product => - /** Schema derived from ParquetFile */ def parquetSchema: MessageType = ParquetTypesConverter diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 4b7a782c805a0..6589bc6ea2921 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -596,8 +596,6 @@ private[hive] case class MetastoreRelation (@transient sqlContext: SQLContext) extends LeafNode with MultiInstanceRelation { - self: Product => - override def equals(other: Any): Boolean = other match { case relation: MetastoreRelation => databaseName == relation.databaseName && From 15fc2ffe5530c43c64cfc37f2d1ce83f04ce3bd9 Mon Sep 17 00:00:00 2001 From: Yijie Shen Date: Fri, 17 Jul 2015 15:49:31 -0700 Subject: [PATCH 24/58] [SPARK-9080][SQL] add isNaN predicate expression JIRA: https://issues.apache.org/jira/browse/SPARK-9080 cc rxin Author: Yijie Shen Closes #7464 from yijieshen/isNaN and squashes the following commits: 11ae039 [Yijie Shen] add isNaN in functions 666718e [Yijie Shen] add isNaN predicate expression --- .../catalyst/analysis/FunctionRegistry.scala | 1 + .../sql/catalyst/expressions/predicates.scala | 50 +++++++++++++++++++ .../catalyst/expressions/PredicateSuite.scala | 12 ++++- .../scala/org/apache/spark/sql/Column.scala | 8 +++ .../org/apache/spark/sql/functions.scala | 10 +++- .../spark/sql/ColumnExpressionSuite.scala | 21 ++++++++ 6 files changed, 100 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index a45181712dbdf..7bb2579506a8a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -79,6 +79,7 @@ object FunctionRegistry { expression[Explode]("explode"), expression[Greatest]("greatest"), expression[If]("if"), + expression[IsNaN]("isnan"), expression[IsNull]("isnull"), expression[IsNotNull]("isnotnull"), expression[Least]("least"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 7a6fb2b3788ca..2751c8e75f357 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -120,6 +120,56 @@ case class InSet(child: Expression, hset: Set[Any]) } } +/** + * Evaluates to `true` if it's NaN or null + */ +case class IsNaN(child: Expression) extends UnaryExpression + with Predicate with ImplicitCastInputTypes { + + override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(DoubleType, FloatType)) + + override def nullable: Boolean = false + + override def eval(input: InternalRow): Any = { + val value = child.eval(input) + if (value == null) { + true + } else { + child.dataType match { + case DoubleType => value.asInstanceOf[Double].isNaN + case FloatType => value.asInstanceOf[Float].isNaN + } + } + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val eval = child.gen(ctx) + child.dataType match { + case FloatType => + s""" + ${eval.code} + boolean ${ev.isNull} = false; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + if (${eval.isNull}) { + ${ev.primitive} = true; + } else { + ${ev.primitive} = Float.isNaN(${eval.primitive}); + } + """ + case DoubleType => + s""" + ${eval.code} + boolean ${ev.isNull} = false; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + if (${eval.isNull}) { + ${ev.primitive} = true; + } else { + ${ev.primitive} = Double.isNaN(${eval.primitive}); + } + """ + } + } +} case class And(left: Expression, right: Expression) extends BinaryOperator with Predicate { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala index 188ecef9e7679..052abc51af5fd 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala @@ -21,7 +21,7 @@ import scala.collection.immutable.HashSet import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.types.{Decimal, IntegerType, BooleanType} +import org.apache.spark.sql.types.{Decimal, DoubleType, IntegerType, BooleanType} class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { @@ -116,6 +116,16 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { true) } + test("IsNaN") { + checkEvaluation(IsNaN(Literal(Double.NaN)), true) + checkEvaluation(IsNaN(Literal(Float.NaN)), true) + checkEvaluation(IsNaN(Literal(math.log(-3))), true) + checkEvaluation(IsNaN(Literal.create(null, DoubleType)), true) + checkEvaluation(IsNaN(Literal(Double.PositiveInfinity)), false) + checkEvaluation(IsNaN(Literal(Float.MaxValue)), false) + checkEvaluation(IsNaN(Literal(5.5f)), false) + } + test("INSET") { val hS = HashSet[Any]() + 1 + 2 val nS = HashSet[Any]() + 1 + 2 + null diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index 10250264625b2..221cd04c6d288 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -400,6 +400,14 @@ class Column(protected[sql] val expr: Expression) extends Logging { (this >= lowerBound) && (this <= upperBound) } + /** + * True if the current expression is NaN or null + * + * @group expr_ops + * @since 1.5.0 + */ + def isNaN: Column = IsNaN(expr) + /** * True if the current expression is null. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index fe511c296cfd2..b56fd9a71b321 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -620,7 +620,15 @@ object functions { def explode(e: Column): Column = Explode(e.expr) /** - * Converts a string exprsesion to lower case. + * Return true if the column is NaN or null + * + * @group normal_funcs + * @since 1.5.0 + */ + def isNaN(e: Column): Column = IsNaN(e.expr) + + /** + * Converts a string expression to lower case. * * @group normal_funcs * @since 1.3.0 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index 88bb743ab0bc9..8f15479308391 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -201,6 +201,27 @@ class ColumnExpressionSuite extends QueryTest { Row(false, true)) } + test("isNaN") { + val testData = ctx.createDataFrame(ctx.sparkContext.parallelize( + Row(Double.NaN, Float.NaN) :: + Row(math.log(-1), math.log(-3).toFloat) :: + Row(null, null) :: + Row(Double.MaxValue, Float.MinValue):: Nil), + StructType(Seq(StructField("a", DoubleType), StructField("b", FloatType)))) + + checkAnswer( + testData.select($"a".isNaN, $"b".isNaN), + Row(true, true) :: Row(true, true) :: Row(true, true) :: Row(false, false) :: Nil) + + checkAnswer( + testData.select(isNaN($"a"), isNaN($"b")), + Row(true, true) :: Row(true, true) :: Row(true, true) :: Row(false, false) :: Nil) + + checkAnswer( + ctx.sql("select isnan(15), isnan('invalid')"), + Row(false, true)) + } + test("===") { checkAnswer( testData2.filter($"a" === 1), From fd6b3101fbb0a8c3ebcf89ce9b4e8664406d9869 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 17 Jul 2015 16:03:33 -0700 Subject: [PATCH 25/58] [SPARK-9113] [SQL] enable analysis check code for self join The check was unreachable before, as `case operator: LogicalPlan` catches everything already. Author: Wenchen Fan Closes #7449 from cloud-fan/tmp and squashes the following commits: 2bb6637 [Wenchen Fan] add test 5493aea [Wenchen Fan] add the check back 27221a7 [Wenchen Fan] remove unnecessary analysis check code for self join --- .../sql/catalyst/analysis/Analyzer.scala | 2 +- .../sql/catalyst/analysis/CheckAnalysis.scala | 28 +++++++++---------- .../plans/logical/basicOperators.scala | 6 ++-- .../analysis/AnalysisErrorSuite.scala | 14 ++++++++-- 4 files changed, 29 insertions(+), 21 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index df8e7f2381fbd..e58f3f64947f3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -316,7 +316,7 @@ class Analyzer( ) // Special handling for cases when self-join introduce duplicate expression ids. - case j @ Join(left, right, _, _) if left.outputSet.intersect(right.outputSet).nonEmpty => + case j @ Join(left, right, _, _) if !j.selfJoinResolved => val conflictingAttributes = left.outputSet.intersect(right.outputSet) logDebug(s"Conflicting attributes ${conflictingAttributes.mkString(",")} in $j") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 476ac2b7cb474..c7f9713344c50 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -109,29 +109,27 @@ trait CheckAnalysis { s"resolved attribute(s) $missingAttributes missing from $input " + s"in operator ${operator.simpleString}") - case o if !o.resolved => - failAnalysis( - s"unresolved operator ${operator.simpleString}") - case p @ Project(exprs, _) if containsMultipleGenerators(exprs) => failAnalysis( s"""Only a single table generating function is allowed in a SELECT clause, found: | ${exprs.map(_.prettyString).mkString(",")}""".stripMargin) + // Special handling for cases when self-join introduce duplicate expression ids. + case j @ Join(left, right, _, _) if left.outputSet.intersect(right.outputSet).nonEmpty => + val conflictingAttributes = left.outputSet.intersect(right.outputSet) + failAnalysis( + s""" + |Failure when resolving conflicting references in Join: + |$plan + |Conflicting attributes: ${conflictingAttributes.mkString(",")} + |""".stripMargin) + + case o if !o.resolved => + failAnalysis( + s"unresolved operator ${operator.simpleString}") case _ => // Analysis successful! } - - // Special handling for cases when self-join introduce duplicate expression ids. - case j @ Join(left, right, _, _) if left.outputSet.intersect(right.outputSet).nonEmpty => - val conflictingAttributes = left.outputSet.intersect(right.outputSet) - failAnalysis( - s""" - |Failure when resolving conflicting references in Join: - |$plan - |Conflicting attributes: ${conflictingAttributes.mkString(",")} - |""".stripMargin) - } extendedCheckRules.foreach(_(plan)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index fbe104db016d6..17a91247327f7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -123,11 +123,11 @@ case class Join( } } - private def selfJoinResolved: Boolean = left.outputSet.intersect(right.outputSet).isEmpty + def selfJoinResolved: Boolean = left.outputSet.intersect(right.outputSet).isEmpty - // Joins are only resolved if they don't introduce ambiguious expression ids. + // Joins are only resolved if they don't introduce ambiguous expression ids. override lazy val resolved: Boolean = { - childrenResolved && !expressions.exists(!_.resolved) && selfJoinResolved + childrenResolved && expressions.forall(_.resolved) && selfJoinResolved } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index f0f17103991ef..2147d07e09bd3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -23,10 +23,11 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.types._ -import org.apache.spark.sql.catalyst.{InternalRow, SimpleCatalystConf} +import org.apache.spark.sql.catalyst.plans.Inner +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.types._ case class TestFunction( children: Seq[Expression], @@ -164,4 +165,13 @@ class AnalysisErrorSuite extends SparkFunSuite with BeforeAndAfter { assert(message.contains("resolved attribute(s) a#1 missing from a#2")) } + + test("error test for self-join") { + val join = Join(testRelation, testRelation, Inner, None) + val error = intercept[AnalysisException] { + SimpleAnalyzer.checkAnalysis(join) + } + error.message.contains("Failure when resolving conflicting references in Join") + error.message.contains("Conflicting attributes") + } } From bd903ee89f1d1bc4daf63f1f07958cb86d667e1e Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 17 Jul 2015 16:28:24 -0700 Subject: [PATCH 26/58] [SPARK-9117] [SQL] fix BooleanSimplification in case-insensitive Author: Wenchen Fan Closes #7452 from cloud-fan/boolean-simplify and squashes the following commits: 2a6e692 [Wenchen Fan] fix style d3cfd26 [Wenchen Fan] fix BooleanSimplification in case-insensitive --- .../sql/catalyst/optimizer/Optimizer.scala | 28 +++++----- .../BooleanSimplificationSuite.scala | 55 +++++++++---------- 2 files changed, 40 insertions(+), 43 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index d5beeec0ffac1..0f28a0d2c8fff 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -393,26 +393,26 @@ object BooleanSimplification extends Rule[LogicalPlan] with PredicateHelper { // (a || b) && (a || c) => a || (b && c) case _ => // 1. Split left and right to get the disjunctive predicates, - // i.e. lhsSet = (a, b), rhsSet = (a, c) + // i.e. lhs = (a, b), rhs = (a, c) // 2. Find the common predict between lhsSet and rhsSet, i.e. common = (a) // 3. Remove common predict from lhsSet and rhsSet, i.e. ldiff = (b), rdiff = (c) // 4. Apply the formula, get the optimized predicate: common || (ldiff && rdiff) - val lhsSet = splitDisjunctivePredicates(left).toSet - val rhsSet = splitDisjunctivePredicates(right).toSet - val common = lhsSet.intersect(rhsSet) + val lhs = splitDisjunctivePredicates(left) + val rhs = splitDisjunctivePredicates(right) + val common = lhs.filter(e => rhs.exists(e.semanticEquals(_))) if (common.isEmpty) { // No common factors, return the original predicate and } else { - val ldiff = lhsSet.diff(common) - val rdiff = rhsSet.diff(common) + val ldiff = lhs.filterNot(e => common.exists(e.semanticEquals(_))) + val rdiff = rhs.filterNot(e => common.exists(e.semanticEquals(_))) if (ldiff.isEmpty || rdiff.isEmpty) { // (a || b || c || ...) && (a || b) => (a || b) common.reduce(Or) } else { // (a || b || c || ...) && (a || b || d || ...) => // ((c || ...) && (d || ...)) || a || b - (common + And(ldiff.reduce(Or), rdiff.reduce(Or))).reduce(Or) + (common :+ And(ldiff.reduce(Or), rdiff.reduce(Or))).reduce(Or) } } } // end of And(left, right) @@ -431,26 +431,26 @@ object BooleanSimplification extends Rule[LogicalPlan] with PredicateHelper { // (a && b) || (a && c) => a && (b || c) case _ => // 1. Split left and right to get the conjunctive predicates, - // i.e. lhsSet = (a, b), rhsSet = (a, c) + // i.e. lhs = (a, b), rhs = (a, c) // 2. Find the common predict between lhsSet and rhsSet, i.e. common = (a) // 3. Remove common predict from lhsSet and rhsSet, i.e. ldiff = (b), rdiff = (c) // 4. Apply the formula, get the optimized predicate: common && (ldiff || rdiff) - val lhsSet = splitConjunctivePredicates(left).toSet - val rhsSet = splitConjunctivePredicates(right).toSet - val common = lhsSet.intersect(rhsSet) + val lhs = splitConjunctivePredicates(left) + val rhs = splitConjunctivePredicates(right) + val common = lhs.filter(e => rhs.exists(e.semanticEquals(_))) if (common.isEmpty) { // No common factors, return the original predicate or } else { - val ldiff = lhsSet.diff(common) - val rdiff = rhsSet.diff(common) + val ldiff = lhs.filterNot(e => common.exists(e.semanticEquals(_))) + val rdiff = rhs.filterNot(e => common.exists(e.semanticEquals(_))) if (ldiff.isEmpty || rdiff.isEmpty) { // (a && b) || (a && b && c && ...) => a && b common.reduce(And) } else { // (a && b && c && ...) || (a && b && d && ...) => // ((c && ...) || (d && ...)) && a && b - (common + Or(ldiff.reduce(And), rdiff.reduce(And))).reduce(And) + (common :+ Or(ldiff.reduce(And), rdiff.reduce(And))).reduce(And) } } } // end of Or(left, right) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala index 465a5e6914204..d4916ea8d273a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.optimizer -import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries +import org.apache.spark.sql.catalyst.analysis.{AnalysisSuite, EliminateSubQueries} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.PlanTest @@ -40,29 +40,11 @@ class BooleanSimplificationSuite extends PlanTest with PredicateHelper { val testRelation = LocalRelation('a.int, 'b.int, 'c.int, 'd.string) - // The `foldLeft` is required to handle cases like comparing `a && (b && c)` and `(a && b) && c` - def compareConditions(e1: Expression, e2: Expression): Boolean = (e1, e2) match { - case (lhs: And, rhs: And) => - val lhsSet = splitConjunctivePredicates(lhs).toSet - val rhsSet = splitConjunctivePredicates(rhs).toSet - lhsSet.foldLeft(rhsSet) { (set, e) => - set.find(compareConditions(_, e)).map(set - _).getOrElse(set) - }.isEmpty - - case (lhs: Or, rhs: Or) => - val lhsSet = splitDisjunctivePredicates(lhs).toSet - val rhsSet = splitDisjunctivePredicates(rhs).toSet - lhsSet.foldLeft(rhsSet) { (set, e) => - set.find(compareConditions(_, e)).map(set - _).getOrElse(set) - }.isEmpty - - case (l, r) => l == r - } - - def checkCondition(input: Expression, expected: Expression): Unit = { + private def checkCondition(input: Expression, expected: Expression): Unit = { val plan = testRelation.where(input).analyze - val actual = Optimize.execute(plan).expressions.head - compareConditions(actual, expected) + val actual = Optimize.execute(plan) + val correctAnswer = testRelation.where(expected).analyze + comparePlans(actual, correctAnswer) } test("a && a => a") { @@ -86,10 +68,8 @@ class BooleanSimplificationSuite extends PlanTest with PredicateHelper { ('a === 'b && 'c < 1 && 'a === 5) || ('a === 'b && 'b < 5 && 'a > 1) - val expected = - (((('b > 3) && ('c > 2)) || - (('c < 1) && ('a === 5))) || - (('b < 5) && ('a > 1))) && ('a === 'b) + val expected = 'a === 'b && ( + ('b > 3 && 'c > 2) || ('c < 1 && 'a === 5) || ('b < 5 && 'a > 1)) checkCondition(input, expected) } @@ -101,10 +81,27 @@ class BooleanSimplificationSuite extends PlanTest with PredicateHelper { checkCondition('a < 2 && ('a < 2 || 'a > 3 || 'b > 5) , 'a < 2) - checkCondition(('a < 2 || 'b > 3) && ('a < 2 || 'c > 5), ('b > 3 && 'c > 5) || 'a < 2) + checkCondition(('a < 2 || 'b > 3) && ('a < 2 || 'c > 5), 'a < 2 || ('b > 3 && 'c > 5)) checkCondition( ('a === 'b || 'b > 3) && ('a === 'b || 'a > 3) && ('a === 'b || 'a < 5), - ('b > 3 && 'a > 3 && 'a < 5) || 'a === 'b) + ('a === 'b || 'b > 3 && 'a > 3 && 'a < 5)) + } + + private def caseInsensitiveAnalyse(plan: LogicalPlan) = + AnalysisSuite.caseInsensitiveAnalyzer.execute(plan) + + test("(a && b) || (a && c) => a && (b || c) when case insensitive") { + val plan = caseInsensitiveAnalyse(testRelation.where(('a > 2 && 'b > 3) || ('A > 2 && 'b < 5))) + val actual = Optimize.execute(plan) + val expected = caseInsensitiveAnalyse(testRelation.where('a > 2 && ('b > 3 || 'b < 5))) + comparePlans(actual, expected) + } + + test("(a || b) && (a || c) => a || (b && c) when case insensitive") { + val plan = caseInsensitiveAnalyse(testRelation.where(('a > 2 || 'b > 3) && ('A > 2 || 'b < 5))) + val actual = Optimize.execute(plan) + val expected = caseInsensitiveAnalyse(testRelation.where('a > 2 || ('b > 3 && 'b < 5))) + comparePlans(actual, expected) } } From b13ef7723f254c10c685b93eb8dc08a52527ec73 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Fri, 17 Jul 2015 16:43:18 -0700 Subject: [PATCH 27/58] [SPARK-9030] [STREAMING] Add Kinesis.createStream unit tests that actual sends data Current Kinesis unit tests do not test createStream by sending data. This PR is to add such unit test. Note that this unit will not run by default. It will only run when the relevant environment variables are set. Author: Tathagata Das Closes #7413 from tdas/kinesis-tests and squashes the following commits: 0e16db5 [Tathagata Das] Added more comments regarding testOrIgnore 1ea5ce0 [Tathagata Das] Added more comments c7caef7 [Tathagata Das] Address comments a297b59 [Tathagata Das] Reverted unnecessary change in KafkaStreamSuite 90c9bde [Tathagata Das] Removed scalatest.FunSuite deb7f4f [Tathagata Das] Removed scalatest.FunSuite 18c2208 [Tathagata Das] Changed how SparkFunSuite is inherited dbb33a5 [Tathagata Das] Added license 88f6dab [Tathagata Das] Added scala docs c6be0d7 [Tathagata Das] minor changes 24a992b [Tathagata Das] Moved KinesisTestUtils to src instead of test for future python usage 465b55d [Tathagata Das] Made unit tests optional in a nice way 4d70703 [Tathagata Das] Added license 129d436 [Tathagata Das] Minor updates cc36510 [Tathagata Das] Added KinesisStreamSuite --- .../streaming/kinesis/KinesisTestUtils.scala | 197 ++++++++++++++++++ .../streaming/kinesis/KinesisFunSuite.scala | 37 ++++ .../kinesis/KinesisReceiverSuite.scala | 17 -- .../kinesis/KinesisStreamSuite.scala | 120 +++++++++++ 4 files changed, 354 insertions(+), 17 deletions(-) create mode 100644 extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala create mode 100644 extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisFunSuite.scala create mode 100644 extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala new file mode 100644 index 0000000000000..f6bf552e6bb8e --- /dev/null +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala @@ -0,0 +1,197 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.kinesis + +import java.nio.ByteBuffer +import java.util.concurrent.TimeUnit + +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer +import scala.util.{Failure, Random, Success, Try} + +import com.amazonaws.auth.{AWSCredentials, DefaultAWSCredentialsProviderChain} +import com.amazonaws.regions.RegionUtils +import com.amazonaws.services.dynamodbv2.AmazonDynamoDBClient +import com.amazonaws.services.dynamodbv2.document.DynamoDB +import com.amazonaws.services.kinesis.AmazonKinesisClient +import com.amazonaws.services.kinesis.model._ + +import org.apache.spark.Logging + +/** + * Shared utility methods for performing Kinesis tests that actually transfer data + */ +private class KinesisTestUtils( + val endpointUrl: String = "https://kinesis.us-west-2.amazonaws.com", + _regionName: String = "") extends Logging { + + val regionName = if (_regionName.length == 0) { + RegionUtils.getRegionByEndpoint(endpointUrl).getName() + } else { + RegionUtils.getRegion(_regionName).getName() + } + + val streamShardCount = 2 + + private val createStreamTimeoutSeconds = 300 + private val describeStreamPollTimeSeconds = 1 + + @volatile + private var streamCreated = false + private var _streamName: String = _ + + private lazy val kinesisClient = { + val client = new AmazonKinesisClient(KinesisTestUtils.getAWSCredentials()) + client.setEndpoint(endpointUrl) + client + } + + private lazy val dynamoDB = { + val dynamoDBClient = new AmazonDynamoDBClient(new DefaultAWSCredentialsProviderChain()) + dynamoDBClient.setRegion(RegionUtils.getRegion(regionName)) + new DynamoDB(dynamoDBClient) + } + + def streamName: String = { + require(streamCreated, "Stream not yet created, call createStream() to create one") + _streamName + } + + def createStream(): Unit = { + logInfo("Creating stream") + require(!streamCreated, "Stream already created") + _streamName = findNonExistentStreamName() + + // Create a stream. The number of shards determines the provisioned throughput. + val createStreamRequest = new CreateStreamRequest() + createStreamRequest.setStreamName(_streamName) + createStreamRequest.setShardCount(2) + kinesisClient.createStream(createStreamRequest) + + // The stream is now being created. Wait for it to become active. + waitForStreamToBeActive(_streamName) + streamCreated = true + logInfo("Created stream") + } + + /** + * Push data to Kinesis stream and return a map of + * shardId -> seq of (data, seq number) pushed to corresponding shard + */ + def pushData(testData: Seq[Int]): Map[String, Seq[(Int, String)]] = { + require(streamCreated, "Stream not yet created, call createStream() to create one") + val shardIdToSeqNumbers = new mutable.HashMap[String, ArrayBuffer[(Int, String)]]() + + testData.foreach { num => + val str = num.toString + val putRecordRequest = new PutRecordRequest().withStreamName(streamName) + .withData(ByteBuffer.wrap(str.getBytes())) + .withPartitionKey(str) + + val putRecordResult = kinesisClient.putRecord(putRecordRequest) + val shardId = putRecordResult.getShardId + val seqNumber = putRecordResult.getSequenceNumber() + val sentSeqNumbers = shardIdToSeqNumbers.getOrElseUpdate(shardId, + new ArrayBuffer[(Int, String)]()) + sentSeqNumbers += ((num, seqNumber)) + } + + logInfo(s"Pushed $testData:\n\t ${shardIdToSeqNumbers.mkString("\n\t")}") + shardIdToSeqNumbers.toMap + } + + def describeStream(streamNameToDescribe: String = streamName): Option[StreamDescription] = { + try { + val describeStreamRequest = new DescribeStreamRequest().withStreamName(streamNameToDescribe) + val desc = kinesisClient.describeStream(describeStreamRequest).getStreamDescription() + Some(desc) + } catch { + case rnfe: ResourceNotFoundException => + None + } + } + + def deleteStream(): Unit = { + try { + if (describeStream().nonEmpty) { + val deleteStreamRequest = new DeleteStreamRequest() + kinesisClient.deleteStream(streamName) + } + } catch { + case e: Exception => + logWarning(s"Could not delete stream $streamName") + } + } + + def deleteDynamoDBTable(tableName: String): Unit = { + try { + val table = dynamoDB.getTable(tableName) + table.delete() + table.waitForDelete() + } catch { + case e: Exception => + logWarning(s"Could not delete DynamoDB table $tableName") + } + } + + private def findNonExistentStreamName(): String = { + var testStreamName: String = null + do { + Thread.sleep(TimeUnit.SECONDS.toMillis(describeStreamPollTimeSeconds)) + testStreamName = s"KinesisTestUtils-${math.abs(Random.nextLong())}" + } while (describeStream(testStreamName).nonEmpty) + testStreamName + } + + private def waitForStreamToBeActive(streamNameToWaitFor: String): Unit = { + val startTime = System.currentTimeMillis() + val endTime = startTime + TimeUnit.SECONDS.toMillis(createStreamTimeoutSeconds) + while (System.currentTimeMillis() < endTime) { + Thread.sleep(TimeUnit.SECONDS.toMillis(describeStreamPollTimeSeconds)) + describeStream(streamNameToWaitFor).foreach { description => + val streamStatus = description.getStreamStatus() + logDebug(s"\t- current state: $streamStatus\n") + if ("ACTIVE".equals(streamStatus)) { + return + } + } + } + require(false, s"Stream $streamName never became active") + } +} + +private[kinesis] object KinesisTestUtils { + + val envVarName = "RUN_KINESIS_TESTS" + + val shouldRunTests = sys.env.get(envVarName) == Some("1") + + def isAWSCredentialsPresent: Boolean = { + Try { new DefaultAWSCredentialsProviderChain().getCredentials() }.isSuccess + } + + def getAWSCredentials(): AWSCredentials = { + assert(shouldRunTests, + "Kinesis test not enabled, should not attempt to get AWS credentials") + Try { new DefaultAWSCredentialsProviderChain().getCredentials() } match { + case Success(cred) => cred + case Failure(e) => + throw new Exception("Kinesis tests enabled, but could get not AWS credentials") + } + } +} diff --git a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisFunSuite.scala b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisFunSuite.scala new file mode 100644 index 0000000000000..6d011f295e7f7 --- /dev/null +++ b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisFunSuite.scala @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.kinesis + +import org.apache.spark.SparkFunSuite + +/** + * Helper class that runs Kinesis real data transfer tests or + * ignores them based on env variable is set or not. + */ +trait KinesisSuiteHelper { self: SparkFunSuite => + import KinesisTestUtils._ + + /** Run the test if environment variable is set or ignore the test */ + def testOrIgnore(testName: String)(testBody: => Unit) { + if (shouldRunTests) { + test(testName)(testBody) + } else { + ignore(s"$testName [enable by setting env var $envVarName=1]")(testBody) + } + } +} diff --git a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala index 2103dca6b766f..98f2c7c4f1bfb 100644 --- a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala +++ b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala @@ -73,23 +73,6 @@ class KinesisReceiverSuite extends TestSuiteBase with Matchers with BeforeAndAft checkpointStateMock, currentClockMock) } - test("KinesisUtils API") { - val ssc = new StreamingContext(master, framework, batchDuration) - // Tests the API, does not actually test data receiving - val kinesisStream1 = KinesisUtils.createStream(ssc, "mySparkStream", - "https://kinesis.us-west-2.amazonaws.com", Seconds(2), - InitialPositionInStream.LATEST, StorageLevel.MEMORY_AND_DISK_2) - val kinesisStream2 = KinesisUtils.createStream(ssc, "myAppNam", "mySparkStream", - "https://kinesis.us-west-2.amazonaws.com", "us-west-2", - InitialPositionInStream.LATEST, Seconds(2), StorageLevel.MEMORY_AND_DISK_2) - val kinesisStream3 = KinesisUtils.createStream(ssc, "myAppNam", "mySparkStream", - "https://kinesis.us-west-2.amazonaws.com", "us-west-2", - InitialPositionInStream.LATEST, Seconds(2), StorageLevel.MEMORY_AND_DISK_2, - "awsAccessKey", "awsSecretKey") - - ssc.stop() - } - test("check serializability of SerializableAWSCredentials") { Utils.deserialize[SerializableAWSCredentials]( Utils.serialize(new SerializableAWSCredentials("x", "y"))) diff --git a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala new file mode 100644 index 0000000000000..d3dd541fe4371 --- /dev/null +++ b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala @@ -0,0 +1,120 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.kinesis + +import scala.collection.mutable +import scala.concurrent.duration._ +import scala.language.postfixOps +import scala.util.Random + +import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream +import org.scalatest.concurrent.Eventually +import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll} + +import org.apache.spark.storage.StorageLevel +import org.apache.spark.streaming._ +import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} + +class KinesisStreamSuite extends SparkFunSuite with KinesisSuiteHelper + with Eventually with BeforeAndAfter with BeforeAndAfterAll { + + private val kinesisTestUtils = new KinesisTestUtils() + + // This is the name that KCL uses to save metadata to DynamoDB + private val kinesisAppName = s"KinesisStreamSuite-${math.abs(Random.nextLong())}" + + private var ssc: StreamingContext = _ + private var sc: SparkContext = _ + + override def beforeAll(): Unit = { + kinesisTestUtils.createStream() + val conf = new SparkConf() + .setMaster("local[4]") + .setAppName("KinesisStreamSuite") // Setting Spark app name to Kinesis app name + sc = new SparkContext(conf) + } + + override def afterAll(): Unit = { + sc.stop() + // Delete the Kinesis stream as well as the DynamoDB table generated by + // Kinesis Client Library when consuming the stream + kinesisTestUtils.deleteStream() + kinesisTestUtils.deleteDynamoDBTable(kinesisAppName) + } + + before { + // Delete the DynamoDB table generated by Kinesis Client Library when + // consuming from the stream, so that each unit test can start from + // scratch without prior history of data consumption + kinesisTestUtils.deleteDynamoDBTable(kinesisAppName) + } + + after { + if (ssc != null) { + ssc.stop(stopSparkContext = false) + ssc = null + } + } + + test("KinesisUtils API") { + ssc = new StreamingContext(sc, Seconds(1)) + // Tests the API, does not actually test data receiving + val kinesisStream1 = KinesisUtils.createStream(ssc, "mySparkStream", + "https://kinesis.us-west-2.amazonaws.com", Seconds(2), + InitialPositionInStream.LATEST, StorageLevel.MEMORY_AND_DISK_2) + val kinesisStream2 = KinesisUtils.createStream(ssc, "myAppNam", "mySparkStream", + "https://kinesis.us-west-2.amazonaws.com", "us-west-2", + InitialPositionInStream.LATEST, Seconds(2), StorageLevel.MEMORY_AND_DISK_2) + val kinesisStream3 = KinesisUtils.createStream(ssc, "myAppNam", "mySparkStream", + "https://kinesis.us-west-2.amazonaws.com", "us-west-2", + InitialPositionInStream.LATEST, Seconds(2), StorageLevel.MEMORY_AND_DISK_2, + "awsAccessKey", "awsSecretKey") + } + + + /** + * Test the stream by sending data to a Kinesis stream and receiving from it. + * This test is not run by default as it requires AWS credentials that the test + * environment may not have. Even if there is AWS credentials available, the user + * may not want to run these tests to avoid the Kinesis costs. To enable this test, + * you must have AWS credentials available through the default AWS provider chain, + * and you have to set the system environment variable RUN_KINESIS_TESTS=1 . + */ + testOrIgnore("basic operation") { + ssc = new StreamingContext(sc, Seconds(1)) + val aWSCredentials = KinesisTestUtils.getAWSCredentials() + val stream = KinesisUtils.createStream(ssc, kinesisAppName, kinesisTestUtils.streamName, + kinesisTestUtils.endpointUrl, kinesisTestUtils.regionName, InitialPositionInStream.LATEST, + Seconds(10), StorageLevel.MEMORY_ONLY, + aWSCredentials.getAWSAccessKeyId, aWSCredentials.getAWSSecretKey) + + val collected = new mutable.HashSet[Int] with mutable.SynchronizedSet[Int] + stream.map { bytes => new String(bytes).toInt }.foreachRDD { rdd => + collected ++= rdd.collect() + logInfo("Collected = " + rdd.collect().toSeq.mkString(", ")) + } + ssc.start() + + val testData = 1 to 10 + eventually(timeout(120 seconds), interval(10 second)) { + kinesisTestUtils.pushData(testData) + assert(collected === testData.toSet, "\nData received does not match data sent") + } + ssc.stop() + } +} From 1707238601690fd0e8e173e2c47f1b4286644a29 Mon Sep 17 00:00:00 2001 From: Daoyuan Wang Date: Fri, 17 Jul 2015 16:45:46 -0700 Subject: [PATCH 28/58] [SPARK-7026] [SQL] fix left semi join with equi key and non-equi condition When the `condition` extracted by `ExtractEquiJoinKeys` contain join Predicate for left semi join, we can not plan it as semiJoin. Such as SELECT * FROM testData2 x LEFT SEMI JOIN testData2 y ON x.b = y.b AND x.a >= y.a + 2 Condition `x.a >= y.a + 2` can not evaluate on table `x`, so it throw errors Author: Daoyuan Wang Closes #5643 from adrian-wang/spark7026 and squashes the following commits: cc09809 [Daoyuan Wang] refactor semijoin and add plan test 575a7c8 [Daoyuan Wang] fix notserializable 27841de [Daoyuan Wang] fix rebase 10bf124 [Daoyuan Wang] fix style 72baa02 [Daoyuan Wang] fix style 8e0afca [Daoyuan Wang] merge commits for rebase --- .../spark/sql/execution/SparkStrategies.scala | 10 +- .../joins/BroadcastLeftSemiJoinHash.scala | 42 ++++----- .../sql/execution/joins/HashOuterJoin.scala | 3 +- .../sql/execution/joins/HashSemiJoin.scala | 91 +++++++++++++++++++ .../execution/joins/LeftSemiJoinHash.scala | 35 ++----- .../org/apache/spark/sql/SQLQuerySuite.scala | 12 +++ .../sql/execution/joins/SemiJoinSuite.scala | 74 +++++++++++++++ 7 files changed, 208 insertions(+), 59 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 73b463471ec5a..240332a80af0f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -38,14 +38,12 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case ExtractEquiJoinKeys(LeftSemi, leftKeys, rightKeys, condition, left, right) if sqlContext.conf.autoBroadcastJoinThreshold > 0 && right.statistics.sizeInBytes <= sqlContext.conf.autoBroadcastJoinThreshold => - val semiJoin = joins.BroadcastLeftSemiJoinHash( - leftKeys, rightKeys, planLater(left), planLater(right)) - condition.map(Filter(_, semiJoin)).getOrElse(semiJoin) :: Nil + joins.BroadcastLeftSemiJoinHash( + leftKeys, rightKeys, planLater(left), planLater(right), condition) :: Nil // Find left semi joins where at least some predicates can be evaluated by matching join keys case ExtractEquiJoinKeys(LeftSemi, leftKeys, rightKeys, condition, left, right) => - val semiJoin = joins.LeftSemiJoinHash( - leftKeys, rightKeys, planLater(left), planLater(right)) - condition.map(Filter(_, semiJoin)).getOrElse(semiJoin) :: Nil + joins.LeftSemiJoinHash( + leftKeys, rightKeys, planLater(left), planLater(right), condition) :: Nil // no predicate can be evaluated by matching hash keys case logical.Join(left, right, LeftSemi, condition) => joins.LeftSemiJoinBNL(planLater(left), planLater(right), condition) :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala index f7b46d6888d7d..2750f58b005ac 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala @@ -33,37 +33,27 @@ case class BroadcastLeftSemiJoinHash( leftKeys: Seq[Expression], rightKeys: Seq[Expression], left: SparkPlan, - right: SparkPlan) extends BinaryNode with HashJoin { - - override val buildSide: BuildSide = BuildRight - - override def output: Seq[Attribute] = left.output + right: SparkPlan, + condition: Option[Expression]) extends BinaryNode with HashSemiJoin { protected override def doExecute(): RDD[InternalRow] = { - val buildIter = buildPlan.execute().map(_.copy()).collect().toIterator - val hashSet = new java.util.HashSet[InternalRow]() - var currentRow: InternalRow = null + val buildIter = right.execute().map(_.copy()).collect().toIterator - // Create a Hash set of buildKeys - while (buildIter.hasNext) { - currentRow = buildIter.next() - val rowKey = buildSideKeyGenerator(currentRow) - if (!rowKey.anyNull) { - val keyExists = hashSet.contains(rowKey) - if (!keyExists) { - // rowKey may be not serializable (from codegen) - hashSet.add(rowKey.copy()) - } - } - } + if (condition.isEmpty) { + // rowKey may be not serializable (from codegen) + val hashSet = buildKeyHashSet(buildIter, copy = true) + val broadcastedRelation = sparkContext.broadcast(hashSet) - val broadcastedRelation = sparkContext.broadcast(hashSet) + left.execute().mapPartitions { streamIter => + hashSemiJoin(streamIter, broadcastedRelation.value) + } + } else { + val hashRelation = HashedRelation(buildIter, rightKeyGenerator) + val broadcastedRelation = sparkContext.broadcast(hashRelation) - streamedPlan.execute().mapPartitions { streamIter => - val joinKeys = streamSideKeyGenerator() - streamIter.filter(current => { - !joinKeys(current).anyNull && broadcastedRelation.value.contains(joinKeys.currentValue) - }) + left.execute().mapPartitions { streamIter => + hashSemiJoin(streamIter, broadcastedRelation.value) + } } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala index 0522ee85eeb8a..74a7db7761758 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala @@ -65,8 +65,7 @@ override def outputPartitioning: Partitioning = joinType match { @transient private[this] lazy val leftNullRow = new GenericInternalRow(left.output.length) @transient private[this] lazy val rightNullRow = new GenericInternalRow(right.output.length) @transient private[this] lazy val boundCondition = - condition.map( - newPredicate(_, left.output ++ right.output)).getOrElse((row: InternalRow) => true) + newPredicate(condition.getOrElse(Literal(true)), left.output ++ right.output) // TODO we need to rewrite all of the iterators with our own implementation instead of the Scala // iterator for performance purpose. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala new file mode 100644 index 0000000000000..1b983bc3a90f9 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.joins + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.execution.SparkPlan + + +trait HashSemiJoin { + self: SparkPlan => + val leftKeys: Seq[Expression] + val rightKeys: Seq[Expression] + val left: SparkPlan + val right: SparkPlan + val condition: Option[Expression] + + override def output: Seq[Attribute] = left.output + + @transient protected lazy val rightKeyGenerator: Projection = + newProjection(rightKeys, right.output) + + @transient protected lazy val leftKeyGenerator: () => MutableProjection = + newMutableProjection(leftKeys, left.output) + + @transient private lazy val boundCondition = + newPredicate(condition.getOrElse(Literal(true)), left.output ++ right.output) + + protected def buildKeyHashSet( + buildIter: Iterator[InternalRow], + copy: Boolean): java.util.Set[InternalRow] = { + val hashSet = new java.util.HashSet[InternalRow]() + var currentRow: InternalRow = null + + // Create a Hash set of buildKeys + while (buildIter.hasNext) { + currentRow = buildIter.next() + val rowKey = rightKeyGenerator(currentRow) + if (!rowKey.anyNull) { + val keyExists = hashSet.contains(rowKey) + if (!keyExists) { + if (copy) { + hashSet.add(rowKey.copy()) + } else { + // rowKey may be not serializable (from codegen) + hashSet.add(rowKey) + } + } + } + } + hashSet + } + + protected def hashSemiJoin( + streamIter: Iterator[InternalRow], + hashedRelation: HashedRelation): Iterator[InternalRow] = { + val joinKeys = leftKeyGenerator() + val joinedRow = new JoinedRow + streamIter.filter(current => { + lazy val rowBuffer = hashedRelation.get(joinKeys.currentValue) + !joinKeys(current).anyNull && rowBuffer != null && rowBuffer.exists { + (build: InternalRow) => boundCondition(joinedRow(current, build)) + } + }) + } + + protected def hashSemiJoin( + streamIter: Iterator[InternalRow], + hashSet: java.util.Set[InternalRow]): Iterator[InternalRow] = { + val joinKeys = leftKeyGenerator() + val joinedRow = new JoinedRow + streamIter.filter(current => { + !joinKeys(current.copy()).anyNull && hashSet.contains(joinKeys.currentValue) + }) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala index 611ba928a16ec..9eaac817d9268 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.joins import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression} +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical.ClusteredDistribution import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} @@ -34,36 +34,21 @@ case class LeftSemiJoinHash( leftKeys: Seq[Expression], rightKeys: Seq[Expression], left: SparkPlan, - right: SparkPlan) extends BinaryNode with HashJoin { - - override val buildSide: BuildSide = BuildRight + right: SparkPlan, + condition: Option[Expression]) extends BinaryNode with HashSemiJoin { override def requiredChildDistribution: Seq[ClusteredDistribution] = ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil - override def output: Seq[Attribute] = left.output - protected override def doExecute(): RDD[InternalRow] = { - buildPlan.execute().zipPartitions(streamedPlan.execute()) { (buildIter, streamIter) => - val hashSet = new java.util.HashSet[InternalRow]() - var currentRow: InternalRow = null - - // Create a Hash set of buildKeys - while (buildIter.hasNext) { - currentRow = buildIter.next() - val rowKey = buildSideKeyGenerator(currentRow) - if (!rowKey.anyNull) { - val keyExists = hashSet.contains(rowKey) - if (!keyExists) { - hashSet.add(rowKey) - } - } + right.execute().zipPartitions(left.execute()) { (buildIter, streamIter) => + if (condition.isEmpty) { + val hashSet = buildKeyHashSet(buildIter, copy = false) + hashSemiJoin(streamIter, hashSet) + } else { + val hashRelation = HashedRelation(buildIter, rightKeyGenerator) + hashSemiJoin(streamIter, hashRelation) } - - val joinKeys = streamSideKeyGenerator() - streamIter.filter(current => { - !joinKeys(current).anyNull && hashSet.contains(joinKeys.currentValue) - }) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 5b8b70ed5ae11..61d5f2061ae18 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -395,6 +395,18 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { ) } + test("left semi greater than predicate and equal operator") { + checkAnswer( + sql("SELECT * FROM testData2 x LEFT SEMI JOIN testData2 y ON x.b = y.b and x.a >= y.a + 2"), + Seq(Row(3, 1), Row(3, 2)) + ) + + checkAnswer( + sql("SELECT * FROM testData2 x LEFT SEMI JOIN testData2 y ON x.b = y.a and x.a >= y.b + 1"), + Seq(Row(2, 1), Row(2, 2), Row(3, 1), Row(3, 2)) + ) + } + test("index into array of arrays") { checkAnswer( sql( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala new file mode 100644 index 0000000000000..927e85a7db3dc --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.joins + +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions.{LessThan, Expression} +import org.apache.spark.sql.execution.{SparkPlan, SparkPlanTest} + + +class SemiJoinSuite extends SparkPlanTest{ + val left = Seq( + (1, 2.0), + (1, 2.0), + (2, 1.0), + (2, 1.0), + (3, 3.0) + ).toDF("a", "b") + + val right = Seq( + (2, 3.0), + (2, 3.0), + (3, 2.0), + (4, 1.0) + ).toDF("c", "d") + + val leftKeys: List[Expression] = 'a :: Nil + val rightKeys: List[Expression] = 'c :: Nil + val condition = Some(LessThan('b, 'd)) + + test("left semi join hash") { + checkAnswer2(left, right, (left: SparkPlan, right: SparkPlan) => + LeftSemiJoinHash(leftKeys, rightKeys, left, right, condition), + Seq( + (2, 1.0), + (2, 1.0) + ).map(Row.fromTuple)) + } + + test("left semi join BNL") { + checkAnswer2(left, right, (left: SparkPlan, right: SparkPlan) => + LeftSemiJoinBNL(left, right, condition), + Seq( + (1, 2.0), + (1, 2.0), + (2, 1.0), + (2, 1.0) + ).map(Row.fromTuple)) + } + + test("broadcast left semi join hash") { + checkAnswer2(left, right, (left: SparkPlan, right: SparkPlan) => + BroadcastLeftSemiJoinHash(leftKeys, rightKeys, left, right, condition), + Seq( + (2, 1.0), + (2, 1.0) + ).map(Row.fromTuple)) + } +} From 529a2c2d92fef062e0078a8608fa3a8ae848c139 Mon Sep 17 00:00:00 2001 From: Yijie Shen Date: Fri, 17 Jul 2015 17:33:19 -0700 Subject: [PATCH 29/58] [SPARK-8280][SPARK-8281][SQL]Handle NaN, null and Infinity in math JIRA: https://issues.apache.org/jira/browse/SPARK-8280 https://issues.apache.org/jira/browse/SPARK-8281 Author: Yijie Shen Closes #7451 from yijieshen/nan_null2 and squashes the following commits: 47a529d [Yijie Shen] style fix 63dee44 [Yijie Shen] handle log expressions similar to Hive 188be51 [Yijie Shen] null to nan in Math Expression --- .../catalyst/analysis/FunctionRegistry.scala | 2 +- .../spark/sql/catalyst/expressions/math.scala | 97 ++++++++++------- .../expressions/MathFunctionsSuite.scala | 102 +++++++++++++++--- .../spark/sql/MathExpressionsSuite.scala | 7 +- .../execution/HiveCompatibilitySuite.scala | 12 ++- 5 files changed, 157 insertions(+), 63 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 7bb2579506a8a..ce552a1d65eda 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -112,9 +112,9 @@ object FunctionRegistry { expression[Log]("ln"), expression[Log10]("log10"), expression[Log1p]("log1p"), + expression[Log2]("log2"), expression[UnaryMinus]("negative"), expression[Pi]("pi"), - expression[Log2]("log2"), expression[Pow]("pow"), expression[Pow]("power"), expression[Pmod]("pmod"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala index b05a7b3ed0ea4..9101f11052218 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala @@ -65,22 +65,38 @@ abstract class UnaryMathExpression(f: Double => Double, name: String) override def toString: String = s"$name($child)" protected override def nullSafeEval(input: Any): Any = { - val result = f(input.asInstanceOf[Double]) - if (result.isNaN) null else result + f(input.asInstanceOf[Double]) } // name of function in java.lang.Math def funcName: String = name.toLowerCase override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - nullSafeCodeGen(ctx, ev, eval => { + defineCodeGen(ctx, ev, c => s"java.lang.Math.${funcName}($c)") + } +} + +abstract class UnaryLogExpression(f: Double => Double, name: String) + extends UnaryMathExpression(f, name) { self: Product => + + // values less than or equal to yAsymptote eval to null in Hive, instead of NaN or -Infinity + protected val yAsymptote: Double = 0.0 + + protected override def nullSafeEval(input: Any): Any = { + val d = input.asInstanceOf[Double] + if (d <= yAsymptote) null else f(d) + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + nullSafeCodeGen(ctx, ev, c => s""" - ${ev.primitive} = java.lang.Math.${funcName}($eval); - if (Double.valueOf(${ev.primitive}).isNaN()) { + if ($c <= $yAsymptote) { ${ev.isNull} = true; + } else { + ${ev.primitive} = java.lang.Math.${funcName}($c); } """ - }) + ) } } @@ -100,8 +116,7 @@ abstract class BinaryMathExpression(f: (Double, Double) => Double, name: String) override def dataType: DataType = DoubleType protected override def nullSafeEval(input1: Any, input2: Any): Any = { - val result = f(input1.asInstanceOf[Double], input2.asInstanceOf[Double]) - if (result.isNaN) null else result + f(input1.asInstanceOf[Double], input2.asInstanceOf[Double]) } override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { @@ -398,25 +413,28 @@ case class Factorial(child: Expression) extends UnaryExpression with ImplicitCas } } -case class Log(child: Expression) extends UnaryMathExpression(math.log, "LOG") +case class Log(child: Expression) extends UnaryLogExpression(math.log, "LOG") case class Log2(child: Expression) - extends UnaryMathExpression((x: Double) => math.log(x) / math.log(2), "LOG2") { + extends UnaryLogExpression((x: Double) => math.log(x) / math.log(2), "LOG2") { override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - nullSafeCodeGen(ctx, ev, eval => { + nullSafeCodeGen(ctx, ev, c => s""" - ${ev.primitive} = java.lang.Math.log($eval) / java.lang.Math.log(2); - if (Double.valueOf(${ev.primitive}).isNaN()) { + if ($c <= $yAsymptote) { ${ev.isNull} = true; + } else { + ${ev.primitive} = java.lang.Math.log($c) / java.lang.Math.log(2); } """ - }) + ) } } -case class Log10(child: Expression) extends UnaryMathExpression(math.log10, "LOG10") +case class Log10(child: Expression) extends UnaryLogExpression(math.log10, "LOG10") -case class Log1p(child: Expression) extends UnaryMathExpression(math.log1p, "LOG1P") +case class Log1p(child: Expression) extends UnaryLogExpression(math.log1p, "LOG1P") { + protected override val yAsymptote: Double = -1.0 +} case class Rint(child: Expression) extends UnaryMathExpression(math.rint, "ROUND") { override def funcName: String = "rint" @@ -577,27 +595,18 @@ case class Atan2(left: Expression, right: Expression) protected override def nullSafeEval(input1: Any, input2: Any): Any = { // With codegen, the values returned by -0.0 and 0.0 are different. Handled with +0.0 - val result = math.atan2(input1.asInstanceOf[Double] + 0.0, input2.asInstanceOf[Double] + 0.0) - if (result.isNaN) null else result + math.atan2(input1.asInstanceOf[Double] + 0.0, input2.asInstanceOf[Double] + 0.0) } override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.atan2($c1 + 0.0, $c2 + 0.0)") + s""" - if (Double.valueOf(${ev.primitive}).isNaN()) { - ${ev.isNull} = true; - } - """ + defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.atan2($c1 + 0.0, $c2 + 0.0)") } } case class Pow(left: Expression, right: Expression) extends BinaryMathExpression(math.pow, "POWER") { override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.pow($c1, $c2)") + s""" - if (Double.valueOf(${ev.primitive}).isNaN()) { - ${ev.isNull} = true; - } - """ + defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.pow($c1, $c2)") } } @@ -699,17 +708,33 @@ case class Logarithm(left: Expression, right: Expression) this(EulerNumber(), child) } + protected override def nullSafeEval(input1: Any, input2: Any): Any = { + val dLeft = input1.asInstanceOf[Double] + val dRight = input2.asInstanceOf[Double] + // Unlike Hive, we support Log base in (0.0, 1.0] + if (dLeft <= 0.0 || dRight <= 0.0) null else math.log(dRight) / math.log(dLeft) + } + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - val logCode = if (left.isInstanceOf[EulerNumber]) { - defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.log($c2)") + if (left.isInstanceOf[EulerNumber]) { + nullSafeCodeGen(ctx, ev, (c1, c2) => + s""" + if ($c2 <= 0.0) { + ${ev.isNull} = true; + } else { + ${ev.primitive} = java.lang.Math.log($c2); + } + """) } else { - defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.log($c2) / java.lang.Math.log($c1)") + nullSafeCodeGen(ctx, ev, (c1, c2) => + s""" + if ($c1 <= 0.0 || $c2 <= 0.0) { + ${ev.isNull} = true; + } else { + ${ev.primitive} = java.lang.Math.log($c2) / java.lang.Math.log($c1); + } + """) } - logCode + s""" - if (Double.isNaN(${ev.primitive})) { - ${ev.isNull} = true; - } - """ } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala index ca35c7ef8ae5d..df988f57fbfde 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala @@ -21,6 +21,10 @@ import com.google.common.math.LongMath import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateProjection, GenerateMutableProjection} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.optimizer.DefaultOptimizer +import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Project} import org.apache.spark.sql.types._ @@ -47,6 +51,7 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { * @param f The functions in scala.math or elsewhere used to generate expected results * @param domain The set of values to run the function with * @param expectNull Whether the given values should return null or not + * @param expectNaN Whether the given values should eval to NaN or not * @tparam T Generic type for primitives * @tparam U Generic type for the output of the given function `f` */ @@ -55,11 +60,16 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { f: T => U, domain: Iterable[T] = (-20 to 20).map(_ * 0.1), expectNull: Boolean = false, + expectNaN: Boolean = false, evalType: DataType = DoubleType): Unit = { if (expectNull) { domain.foreach { value => checkEvaluation(c(Literal(value)), null, EmptyRow) } + } else if (expectNaN) { + domain.foreach { value => + checkNaN(c(Literal(value)), EmptyRow) + } } else { domain.foreach { value => checkEvaluation(c(Literal(value)), f(value), EmptyRow) @@ -74,16 +84,22 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { * @param c The DataFrame function * @param f The functions in scala.math * @param domain The set of values to run the function with + * @param expectNull Whether the given values should return null or not + * @param expectNaN Whether the given values should eval to NaN or not */ private def testBinary( c: (Expression, Expression) => Expression, f: (Double, Double) => Double, domain: Iterable[(Double, Double)] = (-20 to 20).map(v => (v * 0.1, v * -0.1)), - expectNull: Boolean = false): Unit = { + expectNull: Boolean = false, expectNaN: Boolean = false): Unit = { if (expectNull) { domain.foreach { case (v1, v2) => checkEvaluation(c(Literal(v1), Literal(v2)), null, create_row(null)) } + } else if (expectNaN) { + domain.foreach { case (v1, v2) => + checkNaN(c(Literal(v1), Literal(v2)), EmptyRow) + } } else { domain.foreach { case (v1, v2) => checkEvaluation(c(Literal(v1), Literal(v2)), f(v1 + 0.0, v2 + 0.0), EmptyRow) @@ -112,6 +128,62 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { Conv(Literal("11abc"), Literal(10), Literal(16)), "B") } + private def checkNaN( + expression: Expression, inputRow: InternalRow = EmptyRow): Unit = { + checkNaNWithoutCodegen(expression, inputRow) + checkNaNWithGeneratedProjection(expression, inputRow) + checkNaNWithOptimization(expression, inputRow) + } + + private def checkNaNWithoutCodegen( + expression: Expression, + expected: Any, + inputRow: InternalRow = EmptyRow): Unit = { + val actual = try evaluate(expression, inputRow) catch { + case e: Exception => fail(s"Exception evaluating $expression", e) + } + if (!actual.asInstanceOf[Double].isNaN) { + val input = if (inputRow == EmptyRow) "" else s", input: $inputRow" + fail(s"Incorrect evaluation (codegen off): $expression, " + + s"actual: $actual, " + + s"expected: NaN") + } + } + + + private def checkNaNWithGeneratedProjection( + expression: Expression, + inputRow: InternalRow = EmptyRow): Unit = { + + val plan = try { + GenerateMutableProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil)() + } catch { + case e: Throwable => + val ctx = GenerateProjection.newCodeGenContext() + val evaluated = expression.gen(ctx) + fail( + s""" + |Code generation of $expression failed: + |${evaluated.code} + |$e + """.stripMargin) + } + + val actual = plan(inputRow).apply(0) + if (!actual.asInstanceOf[Double].isNaN) { + val input = if (inputRow == EmptyRow) "" else s", input: $inputRow" + fail(s"Incorrect Evaluation: $expression, actual: $actual, expected: NaN") + } + } + + private def checkNaNWithOptimization( + expression: Expression, + inputRow: InternalRow = EmptyRow): Unit = { + val plan = Project(Alias(expression, s"Optimized($expression)")() :: Nil, OneRowRelation) + val optimizedPlan = DefaultOptimizer.execute(plan) + checkNaNWithoutCodegen(optimizedPlan.expressions.head, inputRow) + } + test("e") { testLeaf(EulerNumber, math.E) } @@ -126,7 +198,7 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { test("asin") { testUnary(Asin, math.asin, (-10 to 10).map(_ * 0.1)) - testUnary(Asin, math.asin, (11 to 20).map(_ * 0.1), expectNull = true) + testUnary(Asin, math.asin, (11 to 20).map(_ * 0.1), expectNaN = true) } test("sinh") { @@ -139,7 +211,7 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { test("acos") { testUnary(Acos, math.acos, (-10 to 10).map(_ * 0.1)) - testUnary(Acos, math.acos, (11 to 20).map(_ * 0.1), expectNull = true) + testUnary(Acos, math.acos, (11 to 20).map(_ * 0.1), expectNaN = true) } test("cosh") { @@ -204,18 +276,18 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { } test("log") { - testUnary(Log, math.log, (0 to 20).map(_ * 0.1)) - testUnary(Log, math.log, (-5 to -1).map(_ * 0.1), expectNull = true) + testUnary(Log, math.log, (1 to 20).map(_ * 0.1)) + testUnary(Log, math.log, (-5 to 0).map(_ * 0.1), expectNull = true) } test("log10") { - testUnary(Log10, math.log10, (0 to 20).map(_ * 0.1)) - testUnary(Log10, math.log10, (-5 to -1).map(_ * 0.1), expectNull = true) + testUnary(Log10, math.log10, (1 to 20).map(_ * 0.1)) + testUnary(Log10, math.log10, (-5 to 0).map(_ * 0.1), expectNull = true) } test("log1p") { - testUnary(Log1p, math.log1p, (-1 to 20).map(_ * 0.1)) - testUnary(Log1p, math.log1p, (-10 to -2).map(_ * 1.0), expectNull = true) + testUnary(Log1p, math.log1p, (0 to 20).map(_ * 0.1)) + testUnary(Log1p, math.log1p, (-10 to -1).map(_ * 1.0), expectNull = true) } test("bin") { @@ -237,22 +309,22 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { test("log2") { def f: (Double) => Double = (x: Double) => math.log(x) / math.log(2) - testUnary(Log2, f, (0 to 20).map(_ * 0.1)) - testUnary(Log2, f, (-5 to -1).map(_ * 1.0), expectNull = true) + testUnary(Log2, f, (1 to 20).map(_ * 0.1)) + testUnary(Log2, f, (-5 to 0).map(_ * 1.0), expectNull = true) } test("sqrt") { testUnary(Sqrt, math.sqrt, (0 to 20).map(_ * 0.1)) - testUnary(Sqrt, math.sqrt, (-5 to -1).map(_ * 1.0), expectNull = true) + testUnary(Sqrt, math.sqrt, (-5 to -1).map(_ * 1.0), expectNaN = true) checkEvaluation(Sqrt(Literal.create(null, DoubleType)), null, create_row(null)) - checkEvaluation(Sqrt(Literal(-1.0)), null, EmptyRow) - checkEvaluation(Sqrt(Literal(-1.5)), null, EmptyRow) + checkNaN(Sqrt(Literal(-1.0)), EmptyRow) + checkNaN(Sqrt(Literal(-1.5)), EmptyRow) } test("pow") { testBinary(Pow, math.pow, (-5 to 5).map(v => (v * 1.0, v * 1.0))) - testBinary(Pow, math.pow, Seq((-1.0, 0.9), (-2.2, 1.7), (-2.2, -1.7)), expectNull = true) + testBinary(Pow, math.pow, Seq((-1.0, 0.9), (-2.2, 1.7), (-2.2, -1.7)), expectNaN = true) } test("shift left") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala index 8eb3fec756b4c..a51523f1a7a0f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala @@ -68,12 +68,7 @@ class MathExpressionsSuite extends QueryTest { if (f(-1) === math.log1p(-1)) { checkAnswer( nnDoubleData.select(c('b)), - (1 to 9).map(n => Row(f(n * -0.1))) :+ Row(Double.NegativeInfinity) - ) - } else { - checkAnswer( - nnDoubleData.select(c('b)), - (1 to 10).map(n => Row(null)) + (1 to 9).map(n => Row(f(n * -0.1))) :+ Row(null) ) } diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index 4ada64bc21966..6b8f2f6217a54 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -254,7 +254,10 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { // Spark SQL use Long for TimestampType, lose the precision under 1us "timestamp_1", "timestamp_2", - "timestamp_udf" + "timestamp_udf", + + // Unlike Hive, we do support log base in (0, 1.0], therefore disable this + "udf7" ) /** @@ -816,19 +819,18 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "udf2", "udf5", "udf6", - // "udf7", turn this on after we figure out null vs nan vs infinity "udf8", "udf9", "udf_10_trims", "udf_E", "udf_PI", "udf_abs", - // "udf_acos", turn this on after we figure out null vs nan vs infinity + "udf_acos", "udf_add", "udf_array", "udf_array_contains", "udf_ascii", - // "udf_asin", turn this on after we figure out null vs nan vs infinity + "udf_asin", "udf_atan", "udf_avg", "udf_bigint", @@ -915,7 +917,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "udf_regexp_replace", "udf_repeat", "udf_rlike", - // "udf_round", turn this on after we figure out null vs nan vs infinity + "udf_round", "udf_round_3", "udf_rpad", "udf_rtrim", From 34a889db857f8752a0a78dcedec75ac6cd6cd48d Mon Sep 17 00:00:00 2001 From: Yu ISHIKAWA Date: Fri, 17 Jul 2015 18:30:04 -0700 Subject: [PATCH 30/58] [SPARK-7879] [MLLIB] KMeans API for spark.ml Pipelines I Implemented the KMeans API for spark.ml Pipelines. But it doesn't include clustering abstractions for spark.ml (SPARK-7610). It would fit for another issues. And I'll try it later, since we are trying to add the hierarchical clustering algorithms in another issue. Thanks. [SPARK-7879] KMeans API for spark.ml Pipelines - ASF JIRA https://issues.apache.org/jira/browse/SPARK-7879 Author: Yu ISHIKAWA Closes #6756 from yu-iskw/SPARK-7879 and squashes the following commits: be752de [Yu ISHIKAWA] Add assertions a14939b [Yu ISHIKAWA] Fix the dashed line's length in pyspark.ml.rst 4c61693 [Yu ISHIKAWA] Remove the test about whether "features" and "prediction" columns exist or not in Python fb2417c [Yu ISHIKAWA] Use getInt, instead of get f397be4 [Yu ISHIKAWA] Switch the comparisons. ca78b7d [Yu ISHIKAWA] Add the Scala docs about the constraints of each parameter. effc650 [Yu ISHIKAWA] Using expertSetParam and expertGetParam c8dc6e6 [Yu ISHIKAWA] Remove an unnecessary test 19a9d63 [Yu ISHIKAWA] Include spark.ml.clustering to python tests 1abb19c [Yu ISHIKAWA] Add the statements about spark.ml.clustering into pyspark.ml.rst f8338bc [Yu ISHIKAWA] Add the placeholders in Python 4a03003 [Yu ISHIKAWA] Test for contains in Python 6566c8b [Yu ISHIKAWA] Use `get`, instead of `apply` 288e8d5 [Yu ISHIKAWA] Using `contains` to check the column names 5a7d574 [Yu ISHIKAWA] Renamce `validateInitializationMode` to `validateInitMode` and remove throwing exception 97cfae3 [Yu ISHIKAWA] Fix the type of return value of `KMeans.copy` e933723 [Yu ISHIKAWA] Remove the default value of seed from the Model class 978ee2c [Yu ISHIKAWA] Modify the docs of KMeans, according to mllib's KMeans 2ec80bc [Yu ISHIKAWA] Fit on 1 line e186be1 [Yu ISHIKAWA] Make a few variables, setters and getters be expert ones b2c205c [Yu ISHIKAWA] Rename the method `getInitializationSteps` to `getInitSteps` and `setInitializationSteps` to `setInitSteps` in Scala and Python f43f5b4 [Yu ISHIKAWA] Rename the method `getInitializationMode` to `getInitMode` and `setInitializationMode` to `setInitMode` in Scala and Python 3cb5ba4 [Yu ISHIKAWA] Modify the description about epsilon and the validation 4fa409b [Yu ISHIKAWA] Add a comment about the default value of epsilon 2f392e1 [Yu ISHIKAWA] Make some variables `final` and Use `IntParam` and `DoubleParam` 19326f8 [Yu ISHIKAWA] Use `udf`, instead of callUDF 4d2ad1e [Yu ISHIKAWA] Modify the indentations 0ae422f [Yu ISHIKAWA] Add a test for `setParams` 4ff7913 [Yu ISHIKAWA] Add "ml.clustering" to `javacOptions` in SparkBuild.scala 11ffdf1 [Yu ISHIKAWA] Use `===` and the variable 220a176 [Yu ISHIKAWA] Set a random seed in the unit testing 92c3efc [Yu ISHIKAWA] Make the points for a test be fewer c758692 [Yu ISHIKAWA] Modify the parameters of KMeans in Python 6aca147 [Yu ISHIKAWA] Add some unit testings to validate the setter methods 687cacc [Yu ISHIKAWA] Alias mllib.KMeans as MLlibKMeans in KMeansSuite.scala a4dfbef [Yu ISHIKAWA] Modify the last brace and indentations 5bedc51 [Yu ISHIKAWA] Remve an extra new line 444c289 [Yu ISHIKAWA] Add the validation for `runs` e41989c [Yu ISHIKAWA] Modify how to validate `initStep` 7ea133a [Yu ISHIKAWA] Change how to validate `initMode` 7991e15 [Yu ISHIKAWA] Add a validation for `k` c2df35d [Yu ISHIKAWA] Make `predict` private 93aa2ff [Yu ISHIKAWA] Use `withColumn` in `transform` d3a79f7 [Yu ISHIKAWA] Remove the inhefited docs e9532e1 [Yu ISHIKAWA] make `parentModel` of KMeansModel private 8559772 [Yu ISHIKAWA] Remove the `paramMap` parameter of KMeans 6684850 [Yu ISHIKAWA] Rename `initializationSteps` to `initSteps` 99b1b96 [Yu ISHIKAWA] Rename `initializationMode` to `initMode` 79ea82b [Yu ISHIKAWA] Modify the parameters of KMeans docs 6569bcd [Yu ISHIKAWA] Change how to set the default values with `setDefault` 20a795a [Yu ISHIKAWA] Change how to set the default values with `setDefault` 11c2a12 [Yu ISHIKAWA] Limit the imports badb481 [Yu ISHIKAWA] Alias spark.mllib.{KMeans, KMeansModel} f80319a [Yu ISHIKAWA] Rebase mater branch and add copy methods 85d92b1 [Yu ISHIKAWA] Add `KMeans.setPredictionCol` aa9469d [Yu ISHIKAWA] Fix a python test suite error caused by python 3.x c2d6bcb [Yu ISHIKAWA] ADD Java test suites of the KMeans API for spark.ml Pipeline 598ed2e [Yu ISHIKAWA] Implement the KMeans API for spark.ml Pipelines in Python 63ad785 [Yu ISHIKAWA] Implement the KMeans API for spark.ml Pipelines in Scala --- dev/sparktestsupport/modules.py | 1 + .../apache/spark/ml/clustering/KMeans.scala | 205 +++++++++++++++++ .../spark/mllib/clustering/KMeans.scala | 12 +- .../spark/ml/clustering/JavaKMeansSuite.java | 72 ++++++ .../spark/ml/clustering/KMeansSuite.scala | 114 ++++++++++ project/SparkBuild.scala | 4 +- python/docs/pyspark.ml.rst | 8 + python/pyspark/ml/clustering.py | 206 ++++++++++++++++++ 8 files changed, 617 insertions(+), 5 deletions(-) create mode 100644 mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala create mode 100644 mllib/src/test/java/org/apache/spark/ml/clustering/JavaKMeansSuite.java create mode 100644 mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala create mode 100644 python/pyspark/ml/clustering.py diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index 993583e2f4119..3073d489bad4a 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -338,6 +338,7 @@ def contains_file(self, filename): python_test_goals=[ "pyspark.ml.feature", "pyspark.ml.classification", + "pyspark.ml.clustering", "pyspark.ml.recommendation", "pyspark.ml.regression", "pyspark.ml.tuning", diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala new file mode 100644 index 0000000000000..dc192add6ca13 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -0,0 +1,205 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.clustering + +import org.apache.spark.annotation.Experimental +import org.apache.spark.ml.param.{Param, Params, IntParam, DoubleParam, ParamMap} +import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasMaxIter, HasPredictionCol, HasSeed} +import org.apache.spark.ml.util.{Identifiable, SchemaUtils} +import org.apache.spark.ml.{Estimator, Model} +import org.apache.spark.mllib.clustering.{KMeans => MLlibKMeans, KMeansModel => MLlibKMeansModel} +import org.apache.spark.mllib.linalg.{Vector, VectorUDT} +import org.apache.spark.sql.functions.{col, udf} +import org.apache.spark.sql.types.{IntegerType, StructType} +import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.util.Utils + + +/** + * Common params for KMeans and KMeansModel + */ +private[clustering] trait KMeansParams + extends Params with HasMaxIter with HasFeaturesCol with HasSeed with HasPredictionCol { + + /** + * Set the number of clusters to create (k). Must be > 1. Default: 2. + * @group param + */ + final val k = new IntParam(this, "k", "number of clusters to create", (x: Int) => x > 1) + + /** @group getParam */ + def getK: Int = $(k) + + /** + * Param the number of runs of the algorithm to execute in parallel. We initialize the algorithm + * this many times with random starting conditions (configured by the initialization mode), then + * return the best clustering found over any run. Must be >= 1. Default: 1. + * @group param + */ + final val runs = new IntParam(this, "runs", + "number of runs of the algorithm to execute in parallel", (value: Int) => value >= 1) + + /** @group getParam */ + def getRuns: Int = $(runs) + + /** + * Param the distance threshold within which we've consider centers to have converged. + * If all centers move less than this Euclidean distance, we stop iterating one run. + * Must be >= 0.0. Default: 1e-4 + * @group param + */ + final val epsilon = new DoubleParam(this, "epsilon", + "distance threshold within which we've consider centers to have converge", + (value: Double) => value >= 0.0) + + /** @group getParam */ + def getEpsilon: Double = $(epsilon) + + /** + * Param for the initialization algorithm. This can be either "random" to choose random points as + * initial cluster centers, or "k-means||" to use a parallel variant of k-means++ + * (Bahmani et al., Scalable K-Means++, VLDB 2012). Default: k-means||. + * @group expertParam + */ + final val initMode = new Param[String](this, "initMode", "initialization algorithm", + (value: String) => MLlibKMeans.validateInitMode(value)) + + /** @group expertGetParam */ + def getInitMode: String = $(initMode) + + /** + * Param for the number of steps for the k-means|| initialization mode. This is an advanced + * setting -- the default of 5 is almost always enough. Must be > 0. Default: 5. + * @group expertParam + */ + final val initSteps = new IntParam(this, "initSteps", "number of steps for k-means||", + (value: Int) => value > 0) + + /** @group expertGetParam */ + def getInitSteps: Int = $(initSteps) + + /** + * Validates and transforms the input schema. + * @param schema input schema + * @return output schema + */ + protected def validateAndTransformSchema(schema: StructType): StructType = { + SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT) + SchemaUtils.appendColumn(schema, $(predictionCol), IntegerType) + } +} + +/** + * :: Experimental :: + * Model fitted by KMeans. + * + * @param parentModel a model trained by spark.mllib.clustering.KMeans. + */ +@Experimental +class KMeansModel private[ml] ( + override val uid: String, + private val parentModel: MLlibKMeansModel) extends Model[KMeansModel] with KMeansParams { + + override def copy(extra: ParamMap): KMeansModel = { + val copied = new KMeansModel(uid, parentModel) + copyValues(copied, extra) + } + + override def transform(dataset: DataFrame): DataFrame = { + val predictUDF = udf((vector: Vector) => predict(vector)) + dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol)))) + } + + override def transformSchema(schema: StructType): StructType = { + validateAndTransformSchema(schema) + } + + private[clustering] def predict(features: Vector): Int = parentModel.predict(features) + + def clusterCenters: Array[Vector] = parentModel.clusterCenters +} + +/** + * :: Experimental :: + * K-means clustering with support for multiple parallel runs and a k-means++ like initialization + * mode (the k-means|| algorithm by Bahmani et al). When multiple concurrent runs are requested, + * they are executed together with joint passes over the data for efficiency. + */ +@Experimental +class KMeans(override val uid: String) extends Estimator[KMeansModel] with KMeansParams { + + setDefault( + k -> 2, + maxIter -> 20, + runs -> 1, + initMode -> MLlibKMeans.K_MEANS_PARALLEL, + initSteps -> 5, + epsilon -> 1e-4) + + override def copy(extra: ParamMap): KMeans = defaultCopy(extra) + + def this() = this(Identifiable.randomUID("kmeans")) + + /** @group setParam */ + def setFeaturesCol(value: String): this.type = set(featuresCol, value) + + /** @group setParam */ + def setPredictionCol(value: String): this.type = set(predictionCol, value) + + /** @group setParam */ + def setK(value: Int): this.type = set(k, value) + + /** @group expertSetParam */ + def setInitMode(value: String): this.type = set(initMode, value) + + /** @group expertSetParam */ + def setInitSteps(value: Int): this.type = set(initSteps, value) + + /** @group setParam */ + def setMaxIter(value: Int): this.type = set(maxIter, value) + + /** @group setParam */ + def setRuns(value: Int): this.type = set(runs, value) + + /** @group setParam */ + def setEpsilon(value: Double): this.type = set(epsilon, value) + + /** @group setParam */ + def setSeed(value: Long): this.type = set(seed, value) + + override def fit(dataset: DataFrame): KMeansModel = { + val rdd = dataset.select(col($(featuresCol))).map { case Row(point: Vector) => point } + + val algo = new MLlibKMeans() + .setK($(k)) + .setInitializationMode($(initMode)) + .setInitializationSteps($(initSteps)) + .setMaxIterations($(maxIter)) + .setSeed($(seed)) + .setEpsilon($(epsilon)) + .setRuns($(runs)) + val parentModel = algo.run(rdd) + val model = new KMeansModel(uid, parentModel) + copyValues(model) + } + + override def transformSchema(schema: StructType): StructType = { + validateAndTransformSchema(schema) + } +} + diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala index 68297130a7b03..0a65403f4ec95 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala @@ -85,9 +85,7 @@ class KMeans private ( * (Bahmani et al., Scalable K-Means++, VLDB 2012). Default: k-means||. */ def setInitializationMode(initializationMode: String): this.type = { - if (initializationMode != KMeans.RANDOM && initializationMode != KMeans.K_MEANS_PARALLEL) { - throw new IllegalArgumentException("Invalid initialization mode: " + initializationMode) - } + KMeans.validateInitMode(initializationMode) this.initializationMode = initializationMode this } @@ -550,6 +548,14 @@ object KMeans { v2: VectorWithNorm): Double = { MLUtils.fastSquaredDistance(v1.vector, v1.norm, v2.vector, v2.norm) } + + private[spark] def validateInitMode(initMode: String): Boolean = { + initMode match { + case KMeans.RANDOM => true + case KMeans.K_MEANS_PARALLEL => true + case _ => false + } + } } /** diff --git a/mllib/src/test/java/org/apache/spark/ml/clustering/JavaKMeansSuite.java b/mllib/src/test/java/org/apache/spark/ml/clustering/JavaKMeansSuite.java new file mode 100644 index 0000000000000..d09fa7fd5637c --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/ml/clustering/JavaKMeansSuite.java @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.clustering; + +import java.io.Serializable; +import java.util.Arrays; +import java.util.List; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.SQLContext; + +public class JavaKMeansSuite implements Serializable { + + private transient int k = 5; + private transient JavaSparkContext sc; + private transient DataFrame dataset; + private transient SQLContext sql; + + @Before + public void setUp() { + sc = new JavaSparkContext("local", "JavaKMeansSuite"); + sql = new SQLContext(sc); + + dataset = KMeansSuite.generateKMeansData(sql, 50, 3, k); + } + + @After + public void tearDown() { + sc.stop(); + sc = null; + } + + @Test + public void fitAndTransform() { + KMeans kmeans = new KMeans().setK(k).setSeed(1); + KMeansModel model = kmeans.fit(dataset); + + Vector[] centers = model.clusterCenters(); + assertEquals(k, centers.length); + + DataFrame transformed = model.transform(dataset); + List columns = Arrays.asList(transformed.columns()); + List expectedColumns = Arrays.asList("features", "prediction"); + for (String column: expectedColumns) { + assertTrue(columns.contains(column)); + } + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala new file mode 100644 index 0000000000000..1f15ac02f4008 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala @@ -0,0 +1,114 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.clustering + +import org.apache.spark.SparkFunSuite +import org.apache.spark.mllib.clustering.{KMeans => MLlibKMeans} +import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.{DataFrame, SQLContext} + +private[clustering] case class TestRow(features: Vector) + +object KMeansSuite { + def generateKMeansData(sql: SQLContext, rows: Int, dim: Int, k: Int): DataFrame = { + val sc = sql.sparkContext + val rdd = sc.parallelize(1 to rows).map(i => Vectors.dense(Array.fill(dim)((i % k).toDouble))) + .map(v => new TestRow(v)) + sql.createDataFrame(rdd) + } +} + +class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext { + + final val k = 5 + @transient var dataset: DataFrame = _ + + override def beforeAll(): Unit = { + super.beforeAll() + + dataset = KMeansSuite.generateKMeansData(sqlContext, 50, 3, k) + } + + test("default parameters") { + val kmeans = new KMeans() + + assert(kmeans.getK === 2) + assert(kmeans.getFeaturesCol === "features") + assert(kmeans.getPredictionCol === "prediction") + assert(kmeans.getMaxIter === 20) + assert(kmeans.getRuns === 1) + assert(kmeans.getInitMode === MLlibKMeans.K_MEANS_PARALLEL) + assert(kmeans.getInitSteps === 5) + assert(kmeans.getEpsilon === 1e-4) + } + + test("set parameters") { + val kmeans = new KMeans() + .setK(9) + .setFeaturesCol("test_feature") + .setPredictionCol("test_prediction") + .setMaxIter(33) + .setRuns(7) + .setInitMode(MLlibKMeans.RANDOM) + .setInitSteps(3) + .setSeed(123) + .setEpsilon(1e-3) + + assert(kmeans.getK === 9) + assert(kmeans.getFeaturesCol === "test_feature") + assert(kmeans.getPredictionCol === "test_prediction") + assert(kmeans.getMaxIter === 33) + assert(kmeans.getRuns === 7) + assert(kmeans.getInitMode === MLlibKMeans.RANDOM) + assert(kmeans.getInitSteps === 3) + assert(kmeans.getSeed === 123) + assert(kmeans.getEpsilon === 1e-3) + } + + test("parameters validation") { + intercept[IllegalArgumentException] { + new KMeans().setK(1) + } + intercept[IllegalArgumentException] { + new KMeans().setInitMode("no_such_a_mode") + } + intercept[IllegalArgumentException] { + new KMeans().setInitSteps(0) + } + intercept[IllegalArgumentException] { + new KMeans().setRuns(0) + } + } + + test("fit & transform") { + val predictionColName = "kmeans_prediction" + val kmeans = new KMeans().setK(k).setPredictionCol(predictionColName).setSeed(1) + val model = kmeans.fit(dataset) + assert(model.clusterCenters.length === k) + + val transformed = model.transform(dataset) + val expectedColumns = Array("features", predictionColName) + expectedColumns.foreach { column => + assert(transformed.columns.contains(column)) + } + val clusters = transformed.select(predictionColName).map(_.getInt(0)).distinct().collect().toSet + assert(clusters.size === k) + assert(clusters === Set(0, 1, 2, 3, 4)) + } +} diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 4291b0be2a616..12828547d7077 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -481,8 +481,8 @@ object Unidoc { "mllib.tree.impurity", "mllib.tree.model", "mllib.util", "mllib.evaluation", "mllib.feature", "mllib.random", "mllib.stat.correlation", "mllib.stat.test", "mllib.tree.impl", "mllib.tree.loss", - "ml", "ml.attribute", "ml.classification", "ml.evaluation", "ml.feature", "ml.param", - "ml.recommendation", "ml.regression", "ml.tuning" + "ml", "ml.attribute", "ml.classification", "ml.clustering", "ml.evaluation", "ml.feature", + "ml.param", "ml.recommendation", "ml.regression", "ml.tuning" ), "-group", "Spark SQL", packageList("sql.api.java", "sql.api.java.types", "sql.hive.api.java"), "-noqualifier", "java.lang" diff --git a/python/docs/pyspark.ml.rst b/python/docs/pyspark.ml.rst index 518b8e774dd5f..86d4186a2c798 100644 --- a/python/docs/pyspark.ml.rst +++ b/python/docs/pyspark.ml.rst @@ -33,6 +33,14 @@ pyspark.ml.classification module :undoc-members: :inherited-members: +pyspark.ml.clustering module +---------------------------- + +.. automodule:: pyspark.ml.clustering + :members: + :undoc-members: + :inherited-members: + pyspark.ml.recommendation module -------------------------------- diff --git a/python/pyspark/ml/clustering.py b/python/pyspark/ml/clustering.py new file mode 100644 index 0000000000000..b5e9b6549d9f1 --- /dev/null +++ b/python/pyspark/ml/clustering.py @@ -0,0 +1,206 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from pyspark.ml.util import keyword_only +from pyspark.ml.wrapper import JavaEstimator, JavaModel +from pyspark.ml.param.shared import * +from pyspark.mllib.common import inherit_doc +from pyspark.mllib.linalg import _convert_to_vector + +__all__ = ['KMeans', 'KMeansModel'] + + +class KMeansModel(JavaModel): + """ + Model fitted by KMeans. + """ + + def clusterCenters(self): + """Get the cluster centers, represented as a list of NumPy arrays.""" + return [c.toArray() for c in self._call_java("clusterCenters")] + + +@inherit_doc +class KMeans(JavaEstimator, HasFeaturesCol, HasMaxIter, HasSeed): + """ + K-means Clustering + + >>> from pyspark.mllib.linalg import Vectors + >>> data = [(Vectors.dense([0.0, 0.0]),), (Vectors.dense([1.0, 1.0]),), + ... (Vectors.dense([9.0, 8.0]),), (Vectors.dense([8.0, 9.0]),)] + >>> df = sqlContext.createDataFrame(data, ["features"]) + >>> kmeans = KMeans().setK(2).setSeed(1).setFeaturesCol("features") + >>> model = kmeans.fit(df) + >>> centers = model.clusterCenters() + >>> len(centers) + 2 + >>> transformed = model.transform(df).select("features", "prediction") + >>> rows = transformed.collect() + >>> rows[0].prediction == rows[1].prediction + True + >>> rows[2].prediction == rows[3].prediction + True + """ + + # a placeholder to make it appear in the generated doc + k = Param(Params._dummy(), "k", "number of clusters to create") + epsilon = Param(Params._dummy(), "epsilon", + "distance threshold within which " + + "we've consider centers to have converged") + runs = Param(Params._dummy(), "runs", "number of runs of the algorithm to execute in parallel") + initMode = Param(Params._dummy(), "initMode", + "the initialization algorithm. This can be either \"random\" to " + + "choose random points as initial cluster centers, or \"k-means||\" " + + "to use a parallel variant of k-means++") + initSteps = Param(Params._dummy(), "initSteps", "steps for k-means initialization mode") + + @keyword_only + def __init__(self, k=2, maxIter=20, runs=1, epsilon=1e-4, initMode="k-means||", initStep=5): + super(KMeans, self).__init__() + self._java_obj = self._new_java_obj("org.apache.spark.ml.clustering.KMeans", self.uid) + self.k = Param(self, "k", "number of clusters to create") + self.epsilon = Param(self, "epsilon", + "distance threshold within which " + + "we've consider centers to have converged") + self.runs = Param(self, "runs", "number of runs of the algorithm to execute in parallel") + self.seed = Param(self, "seed", "random seed") + self.initMode = Param(self, "initMode", + "the initialization algorithm. This can be either \"random\" to " + + "choose random points as initial cluster centers, or \"k-means||\" " + + "to use a parallel variant of k-means++") + self.initSteps = Param(self, "initSteps", "steps for k-means initialization mode") + self._setDefault(k=2, maxIter=20, runs=1, epsilon=1e-4, initMode="k-means||", initSteps=5) + kwargs = self.__init__._input_kwargs + self.setParams(**kwargs) + + def _create_model(self, java_model): + return KMeansModel(java_model) + + @keyword_only + def setParams(self, k=2, maxIter=20, runs=1, epsilon=1e-4, initMode="k-means||", initSteps=5): + """ + setParams(self, k=2, maxIter=20, runs=1, epsilon=1e-4, initMode="k-means||", initSteps=5): + + Sets params for KMeans. + """ + kwargs = self.setParams._input_kwargs + return self._set(**kwargs) + + def setK(self, value): + """ + Sets the value of :py:attr:`k`. + + >>> algo = KMeans().setK(10) + >>> algo.getK() + 10 + """ + self._paramMap[self.k] = value + return self + + def getK(self): + """ + Gets the value of `k` + """ + return self.getOrDefault(self.k) + + def setEpsilon(self, value): + """ + Sets the value of :py:attr:`epsilon`. + + >>> algo = KMeans().setEpsilon(1e-5) + >>> abs(algo.getEpsilon() - 1e-5) < 1e-5 + True + """ + self._paramMap[self.epsilon] = value + return self + + def getEpsilon(self): + """ + Gets the value of `epsilon` + """ + return self.getOrDefault(self.epsilon) + + def setRuns(self, value): + """ + Sets the value of :py:attr:`runs`. + + >>> algo = KMeans().setRuns(10) + >>> algo.getRuns() + 10 + """ + self._paramMap[self.runs] = value + return self + + def getRuns(self): + """ + Gets the value of `runs` + """ + return self.getOrDefault(self.runs) + + def setInitMode(self, value): + """ + Sets the value of :py:attr:`initMode`. + + >>> algo = KMeans() + >>> algo.getInitMode() + 'k-means||' + >>> algo = algo.setInitMode("random") + >>> algo.getInitMode() + 'random' + """ + self._paramMap[self.initMode] = value + return self + + def getInitMode(self): + """ + Gets the value of `initMode` + """ + return self.getOrDefault(self.initMode) + + def setInitSteps(self, value): + """ + Sets the value of :py:attr:`initSteps`. + + >>> algo = KMeans().setInitSteps(10) + >>> algo.getInitSteps() + 10 + """ + self._paramMap[self.initSteps] = value + return self + + def getInitSteps(self): + """ + Gets the value of `initSteps` + """ + return self.getOrDefault(self.initSteps) + + +if __name__ == "__main__": + import doctest + from pyspark.context import SparkContext + from pyspark.sql import SQLContext + globs = globals().copy() + # The small batch size here ensures that we see multiple batches, + # even in these small test examples: + sc = SparkContext("local[2]", "ml.clustering tests") + sqlContext = SQLContext(sc) + globs['sc'] = sc + globs['sqlContext'] = sqlContext + (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) + sc.stop() + if failure_count: + exit(-1) From 1017908205b7690dc0b0ed4753b36fab5641f7ac Mon Sep 17 00:00:00 2001 From: Rekha Joshi Date: Fri, 17 Jul 2015 20:02:05 -0700 Subject: [PATCH 31/58] [SPARK-9118] [ML] Implement IntArrayParam in mllib Implement IntArrayParam in mllib Author: Rekha Joshi Author: Joshi Closes #7481 from rekhajoshm/SPARK-9118 and squashes the following commits: d3b1766 [Joshi] Implement IntArrayParam 0be142d [Rekha Joshi] Merge pull request #3 from apache/master 106fd8e [Rekha Joshi] Merge pull request #2 from apache/master e3677c9 [Rekha Joshi] Merge pull request #1 from apache/master --- .../scala/org/apache/spark/ml/param/params.scala | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala index d034d7ec6b60e..824efa5ed4b28 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -295,6 +295,22 @@ class DoubleArrayParam(parent: Params, name: String, doc: String, isValid: Array w(value.asScala.map(_.asInstanceOf[Double]).toArray) } +/** + * :: DeveloperApi :: + * Specialized version of [[Param[Array[Int]]]] for Java. + */ +@DeveloperApi +class IntArrayParam(parent: Params, name: String, doc: String, isValid: Array[Int] => Boolean) + extends Param[Array[Int]](parent, name, doc, isValid) { + + def this(parent: Params, name: String, doc: String) = + this(parent, name, doc, ParamValidators.alwaysTrue) + + /** Creates a param pair with a [[java.util.List]] of values (for Java and Python). */ + def w(value: java.util.List[java.lang.Integer]): ParamPair[Array[Int]] = + w(value.asScala.map(_.asInstanceOf[Int]).toArray) +} + /** * :: Experimental :: * A param and its value. From b9ef7ac98c3dee3256c4a393e563b42b4612a4bf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Kozikowski?= Date: Sat, 18 Jul 2015 10:12:48 -0700 Subject: [PATCH 32/58] [MLLIB] [DOC] Seed fix in mllib naive bayes example MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Previous seed resulted in empty test data set. Author: Paweł Kozikowski Closes #7477 from mupakoz/patch-1 and squashes the following commits: f5d41ee [Paweł Kozikowski] Mllib Naive Bayes example data set enlarged --- data/mllib/sample_naive_bayes_data.txt | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/data/mllib/sample_naive_bayes_data.txt b/data/mllib/sample_naive_bayes_data.txt index 981da382d6ac8..bd22bea3a59d6 100644 --- a/data/mllib/sample_naive_bayes_data.txt +++ b/data/mllib/sample_naive_bayes_data.txt @@ -1,6 +1,12 @@ 0,1 0 0 0,2 0 0 +0,3 0 0 +0,4 0 0 1,0 1 0 1,0 2 0 +1,0 3 0 +1,0 4 0 2,0 0 1 2,0 0 2 +2,0 0 3 +2,0 0 4 \ No newline at end of file From fba3f5ba85673336c0556ef8731dcbcd175c7418 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Sat, 18 Jul 2015 11:06:46 -0700 Subject: [PATCH 33/58] [SPARK-9169][SQL] Improve unit test coverage for null expressions. Author: Reynold Xin Closes #7490 from rxin/unit-test-null-funcs and squashes the following commits: 7b276f0 [Reynold Xin] Move isNaN. 8307287 [Reynold Xin] [SPARK-9169][SQL] Improve unit test coverage for null expressions. --- .../catalyst/expressions/nullFunctions.scala | 81 +++++++++++++++++-- .../sql/catalyst/expressions/predicates.scala | 51 ------------ .../expressions/NullFunctionsSuite.scala | 78 +++++++++--------- .../catalyst/expressions/PredicateSuite.scala | 12 +-- 4 files changed, 119 insertions(+), 103 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala index 1522bcae08d17..98c67084642e3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala @@ -21,8 +21,19 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode} import org.apache.spark.sql.catalyst.util.TypeUtils -import org.apache.spark.sql.types.DataType +import org.apache.spark.sql.types._ + +/** + * An expression that is evaluated to the first non-null input. + * + * {{{ + * coalesce(1, 2) => 1 + * coalesce(null, 1, 2) => 1 + * coalesce(null, null, 2) => 2 + * coalesce(null, null, null) => null + * }}} + */ case class Coalesce(children: Seq[Expression]) extends Expression { /** Coalesce is nullable if all of its children are nullable, or if it has no children. */ @@ -70,6 +81,62 @@ case class Coalesce(children: Seq[Expression]) extends Expression { } } + +/** + * Evaluates to `true` if it's NaN or null + */ +case class IsNaN(child: Expression) extends UnaryExpression + with Predicate with ImplicitCastInputTypes { + + override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(DoubleType, FloatType)) + + override def nullable: Boolean = false + + override def eval(input: InternalRow): Any = { + val value = child.eval(input) + if (value == null) { + true + } else { + child.dataType match { + case DoubleType => value.asInstanceOf[Double].isNaN + case FloatType => value.asInstanceOf[Float].isNaN + } + } + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val eval = child.gen(ctx) + child.dataType match { + case FloatType => + s""" + ${eval.code} + boolean ${ev.isNull} = false; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + if (${eval.isNull}) { + ${ev.primitive} = true; + } else { + ${ev.primitive} = Float.isNaN(${eval.primitive}); + } + """ + case DoubleType => + s""" + ${eval.code} + boolean ${ev.isNull} = false; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + if (${eval.isNull}) { + ${ev.primitive} = true; + } else { + ${ev.primitive} = Double.isNaN(${eval.primitive}); + } + """ + } + } +} + + +/** + * An expression that is evaluated to true if the input is null. + */ case class IsNull(child: Expression) extends UnaryExpression with Predicate { override def nullable: Boolean = false @@ -83,13 +150,14 @@ case class IsNull(child: Expression) extends UnaryExpression with Predicate { ev.primitive = eval.isNull eval.code } - - override def toString: String = s"IS NULL $child" } + +/** + * An expression that is evaluated to true if the input is not null. + */ case class IsNotNull(child: Expression) extends UnaryExpression with Predicate { override def nullable: Boolean = false - override def toString: String = s"IS NOT NULL $child" override def eval(input: InternalRow): Any = { child.eval(input) != null @@ -103,12 +171,13 @@ case class IsNotNull(child: Expression) extends UnaryExpression with Predicate { } } + /** - * A predicate that is evaluated to be true if there are at least `n` non-null values. + * A predicate that is evaluated to be true if there are at least `n` non-null and non-NaN values. */ case class AtLeastNNonNulls(n: Int, children: Seq[Expression]) extends Predicate { override def nullable: Boolean = false - override def foldable: Boolean = false + override def foldable: Boolean = children.forall(_.foldable) override def toString: String = s"AtLeastNNulls(n, ${children.mkString(",")})" private[this] val childrenArray = children.toArray diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 2751c8e75f357..bddd2a9eccfc0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan @@ -120,56 +119,6 @@ case class InSet(child: Expression, hset: Set[Any]) } } -/** - * Evaluates to `true` if it's NaN or null - */ -case class IsNaN(child: Expression) extends UnaryExpression - with Predicate with ImplicitCastInputTypes { - - override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(DoubleType, FloatType)) - - override def nullable: Boolean = false - - override def eval(input: InternalRow): Any = { - val value = child.eval(input) - if (value == null) { - true - } else { - child.dataType match { - case DoubleType => value.asInstanceOf[Double].isNaN - case FloatType => value.asInstanceOf[Float].isNaN - } - } - } - - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - val eval = child.gen(ctx) - child.dataType match { - case FloatType => - s""" - ${eval.code} - boolean ${ev.isNull} = false; - ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; - if (${eval.isNull}) { - ${ev.primitive} = true; - } else { - ${ev.primitive} = Float.isNaN(${eval.primitive}); - } - """ - case DoubleType => - s""" - ${eval.code} - boolean ${ev.isNull} = false; - ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; - if (${eval.isNull}) { - ${ev.primitive} = true; - } else { - ${ev.primitive} = Double.isNaN(${eval.primitive}); - } - """ - } - } -} case class And(left: Expression, right: Expression) extends BinaryOperator with Predicate { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullFunctionsSuite.scala index ccdada8b56f83..765cc7a969b5d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullFunctionsSuite.scala @@ -18,48 +18,52 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.types.{BooleanType, StringType, ShortType} +import org.apache.spark.sql.types._ class NullFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { - test("null checking") { - val row = create_row("^Ba*n", null, true, null) - val c1 = 'a.string.at(0) - val c2 = 'a.string.at(1) - val c3 = 'a.boolean.at(2) - val c4 = 'a.boolean.at(3) - - checkEvaluation(c1.isNull, false, row) - checkEvaluation(c1.isNotNull, true, row) - - checkEvaluation(c2.isNull, true, row) - checkEvaluation(c2.isNotNull, false, row) - - checkEvaluation(Literal.create(1, ShortType).isNull, false) - checkEvaluation(Literal.create(1, ShortType).isNotNull, true) - - checkEvaluation(Literal.create(null, ShortType).isNull, true) - checkEvaluation(Literal.create(null, ShortType).isNotNull, false) + def testAllTypes(testFunc: (Any, DataType) => Unit): Unit = { + testFunc(false, BooleanType) + testFunc(1.toByte, ByteType) + testFunc(1.toShort, ShortType) + testFunc(1, IntegerType) + testFunc(1L, LongType) + testFunc(1.0F, FloatType) + testFunc(1.0, DoubleType) + testFunc(Decimal(1.5), DecimalType.Unlimited) + testFunc(new java.sql.Date(10), DateType) + testFunc(new java.sql.Timestamp(10), TimestampType) + testFunc("abcd", StringType) + } - checkEvaluation(Coalesce(c1 :: c2 :: Nil), "^Ba*n", row) - checkEvaluation(Coalesce(Literal.create(null, StringType) :: Nil), null, row) - checkEvaluation(Coalesce(Literal.create(null, StringType) :: c1 :: c2 :: Nil), "^Ba*n", row) + test("isnull and isnotnull") { + testAllTypes { (value: Any, tpe: DataType) => + checkEvaluation(IsNull(Literal.create(value, tpe)), false) + checkEvaluation(IsNotNull(Literal.create(value, tpe)), true) + checkEvaluation(IsNull(Literal.create(null, tpe)), true) + checkEvaluation(IsNotNull(Literal.create(null, tpe)), false) + } + } - checkEvaluation( - If(c3, Literal.create("a", StringType), Literal.create("b", StringType)), "a", row) - checkEvaluation(If(c3, c1, c2), "^Ba*n", row) - checkEvaluation(If(c4, c2, c1), "^Ba*n", row) - checkEvaluation(If(Literal.create(null, BooleanType), c2, c1), "^Ba*n", row) - checkEvaluation(If(Literal.create(true, BooleanType), c1, c2), "^Ba*n", row) - checkEvaluation(If(Literal.create(false, BooleanType), c2, c1), "^Ba*n", row) - checkEvaluation(If(Literal.create(false, BooleanType), - Literal.create("a", StringType), Literal.create("b", StringType)), "b", row) + test("IsNaN") { + checkEvaluation(IsNaN(Literal(Double.NaN)), true) + checkEvaluation(IsNaN(Literal(Float.NaN)), true) + checkEvaluation(IsNaN(Literal(math.log(-3))), true) + checkEvaluation(IsNaN(Literal.create(null, DoubleType)), true) + checkEvaluation(IsNaN(Literal(Double.PositiveInfinity)), false) + checkEvaluation(IsNaN(Literal(Float.MaxValue)), false) + checkEvaluation(IsNaN(Literal(5.5f)), false) + } - checkEvaluation(c1 in (c1, c2), true, row) - checkEvaluation( - Literal.create("^Ba*n", StringType) in (Literal.create("^Ba*n", StringType)), true, row) - checkEvaluation( - Literal.create("^Ba*n", StringType) in (Literal.create("^Ba*n", StringType), c2), true, row) + test("coalesce") { + testAllTypes { (value: Any, tpe: DataType) => + val lit = Literal.create(value, tpe) + val nullLit = Literal.create(null, tpe) + checkEvaluation(Coalesce(Seq(nullLit)), null) + checkEvaluation(Coalesce(Seq(lit)), value) + checkEvaluation(Coalesce(Seq(nullLit, lit)), value) + checkEvaluation(Coalesce(Seq(nullLit, lit, lit)), value) + checkEvaluation(Coalesce(Seq(nullLit, nullLit, lit)), value) + } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala index 052abc51af5fd..2173a0c25c645 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala @@ -114,16 +114,10 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation( And(In(Literal(1), Seq(Literal(1), Literal(2))), In(Literal(2), Seq(Literal(1), Literal(2)))), true) - } - test("IsNaN") { - checkEvaluation(IsNaN(Literal(Double.NaN)), true) - checkEvaluation(IsNaN(Literal(Float.NaN)), true) - checkEvaluation(IsNaN(Literal(math.log(-3))), true) - checkEvaluation(IsNaN(Literal.create(null, DoubleType)), true) - checkEvaluation(IsNaN(Literal(Double.PositiveInfinity)), false) - checkEvaluation(IsNaN(Literal(Float.MaxValue)), false) - checkEvaluation(IsNaN(Literal(5.5f)), false) + checkEvaluation(In(Literal("^Ba*n"), Seq(Literal("^Ba*n"))), true) + checkEvaluation(In(Literal("^Ba*n"), Seq(Literal("aa"), Literal("^Ba*n"))), true) + checkEvaluation(In(Literal("^Ba*n"), Seq(Literal("aa"), Literal("^n"))), false) } test("INSET") { From b8aec6cd236f09881cad0fff9a6f1a5692934e21 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sat, 18 Jul 2015 11:08:18 -0700 Subject: [PATCH 34/58] [SPARK-9143] [SQL] Add planner rule for automatically inserting Unsafe <-> Safe row format converters Now that we have two different internal row formats, UnsafeRow and the old Java-object-based row format, we end up having to perform conversions between these two formats. These conversions should not be performed by the operators themselves; instead, the planner should be responsible for inserting appropriate format conversions when they are needed. This patch makes the following changes: - Add two new physical operators for performing row format conversions, `ConvertToUnsafe` and `ConvertFromUnsafe`. - Add new methods to `SparkPlan` to allow operators to express whether they output UnsafeRows and whether they can handle safe or unsafe rows as inputs. - Implement an `EnsureRowFormats` rule to automatically insert converter operators where necessary. Author: Josh Rosen Closes #7482 from JoshRosen/unsafe-converter-planning and squashes the following commits: 7450fa5 [Josh Rosen] Resolve conflicts in favor of choosing UnsafeRow 5220cce [Josh Rosen] Add roundtrip converter test 2bb8da8 [Josh Rosen] Add Union unsafe support + tests to bump up test coverage 6f79449 [Josh Rosen] Add even more assertions to execute() 08ce199 [Josh Rosen] Rename ConvertFromUnsafe -> ConvertToSafe 0e2d548 [Josh Rosen] Add assertion if operators' input rows are in different formats cabb703 [Josh Rosen] Add tests for Filter 3b11ce3 [Josh Rosen] Add missing test file. ae2195a [Josh Rosen] Fixes 0fef0f8 [Josh Rosen] Rename file. d5f9005 [Josh Rosen] Finish writing EnsureRowFormats planner rule b5df19b [Josh Rosen] Merge remote-tracking branch 'origin/master' into unsafe-converter-planning 9ba3038 [Josh Rosen] WIP --- .../org/apache/spark/sql/SQLContext.scala | 9 +- .../spark/sql/execution/SparkPlan.scala | 24 ++++ .../spark/sql/execution/basicOperators.scala | 11 ++ .../sql/execution/rowFormatConverters.scala | 107 ++++++++++++++++++ .../execution/RowFormatConvertersSuite.scala | 91 +++++++++++++++ 5 files changed, 239 insertions(+), 3 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/rowFormatConverters.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 46bd60daa1f78..2dda3ad1211fa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -921,12 +921,15 @@ class SQLContext(@transient val sparkContext: SparkContext) protected[sql] lazy val emptyResult = sparkContext.parallelize(Seq.empty[InternalRow], 1) /** - * Prepares a planned SparkPlan for execution by inserting shuffle operations as needed. + * Prepares a planned SparkPlan for execution by inserting shuffle operations and internal + * row format conversions as needed. */ @transient protected[sql] val prepareForExecution = new RuleExecutor[SparkPlan] { - val batches = - Batch("Add exchange", Once, EnsureRequirements(self)) :: Nil + val batches = Seq( + Batch("Add exchange", Once, EnsureRequirements(self)), + Batch("Add row converters", Once, EnsureRowFormats) + ) } protected[sql] def openSession(): SQLSession = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index ba12056ee7a1b..f363e9947d5f6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -79,12 +79,36 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Product /** Specifies sort order for each partition requirements on the input data for this operator. */ def requiredChildOrdering: Seq[Seq[SortOrder]] = Seq.fill(children.size)(Nil) + /** Specifies whether this operator outputs UnsafeRows */ + def outputsUnsafeRows: Boolean = false + + /** Specifies whether this operator is capable of processing UnsafeRows */ + def canProcessUnsafeRows: Boolean = false + + /** + * Specifies whether this operator is capable of processing Java-object-based Rows (i.e. rows + * that are not UnsafeRows). + */ + def canProcessSafeRows: Boolean = true + /** * Returns the result of this query as an RDD[InternalRow] by delegating to doExecute * after adding query plan information to created RDDs for visualization. * Concrete implementations of SparkPlan should override doExecute instead. */ final def execute(): RDD[InternalRow] = { + if (children.nonEmpty) { + val hasUnsafeInputs = children.exists(_.outputsUnsafeRows) + val hasSafeInputs = children.exists(!_.outputsUnsafeRows) + assert(!(hasSafeInputs && hasUnsafeInputs), + "Child operators should output rows in the same format") + assert(canProcessSafeRows || canProcessUnsafeRows, + "Operator must be able to process at least one row format") + assert(!hasSafeInputs || canProcessSafeRows, + "Operator will receive safe rows as input but cannot process safe rows") + assert(!hasUnsafeInputs || canProcessUnsafeRows, + "Operator will receive unsafe rows as input but cannot process unsafe rows") + } RDDOperationScope.withScope(sparkContext, nodeName, false, true) { doExecute() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index 4c063c299ba53..82bef269b069f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -64,6 +64,12 @@ case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode { } override def outputOrdering: Seq[SortOrder] = child.outputOrdering + + override def outputsUnsafeRows: Boolean = child.outputsUnsafeRows + + override def canProcessUnsafeRows: Boolean = true + + override def canProcessSafeRows: Boolean = true } /** @@ -104,6 +110,9 @@ case class Sample( case class Union(children: Seq[SparkPlan]) extends SparkPlan { // TODO: attributes output by union should be distinct for nullability purposes override def output: Seq[Attribute] = children.head.output + override def outputsUnsafeRows: Boolean = children.forall(_.outputsUnsafeRows) + override def canProcessUnsafeRows: Boolean = true + override def canProcessSafeRows: Boolean = true protected override def doExecute(): RDD[InternalRow] = sparkContext.union(children.map(_.execute())) } @@ -306,6 +315,8 @@ case class UnsafeExternalSort( override def output: Seq[Attribute] = child.output override def outputOrdering: Seq[SortOrder] = sortOrder + + override def outputsUnsafeRows: Boolean = true } @DeveloperApi diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/rowFormatConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/rowFormatConverters.scala new file mode 100644 index 0000000000000..421d510e6782d --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/rowFormatConverters.scala @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.rules.Rule + +/** + * :: DeveloperApi :: + * Converts Java-object-based rows into [[UnsafeRow]]s. + */ +@DeveloperApi +case class ConvertToUnsafe(child: SparkPlan) extends UnaryNode { + override def output: Seq[Attribute] = child.output + override def outputsUnsafeRows: Boolean = true + override def canProcessUnsafeRows: Boolean = false + override def canProcessSafeRows: Boolean = true + override protected def doExecute(): RDD[InternalRow] = { + child.execute().mapPartitions { iter => + val convertToUnsafe = UnsafeProjection.create(child.schema) + iter.map(convertToUnsafe) + } + } +} + +/** + * :: DeveloperApi :: + * Converts [[UnsafeRow]]s back into Java-object-based rows. + */ +@DeveloperApi +case class ConvertToSafe(child: SparkPlan) extends UnaryNode { + override def output: Seq[Attribute] = child.output + override def outputsUnsafeRows: Boolean = false + override def canProcessUnsafeRows: Boolean = true + override def canProcessSafeRows: Boolean = false + override protected def doExecute(): RDD[InternalRow] = { + child.execute().mapPartitions { iter => + val convertToSafe = FromUnsafeProjection(child.output.map(_.dataType)) + iter.map(convertToSafe) + } + } +} + +private[sql] object EnsureRowFormats extends Rule[SparkPlan] { + + private def onlyHandlesSafeRows(operator: SparkPlan): Boolean = + operator.canProcessSafeRows && !operator.canProcessUnsafeRows + + private def onlyHandlesUnsafeRows(operator: SparkPlan): Boolean = + operator.canProcessUnsafeRows && !operator.canProcessSafeRows + + private def handlesBothSafeAndUnsafeRows(operator: SparkPlan): Boolean = + operator.canProcessSafeRows && operator.canProcessUnsafeRows + + override def apply(operator: SparkPlan): SparkPlan = operator.transformUp { + case operator: SparkPlan if onlyHandlesSafeRows(operator) => + if (operator.children.exists(_.outputsUnsafeRows)) { + operator.withNewChildren { + operator.children.map { + c => if (c.outputsUnsafeRows) ConvertToSafe(c) else c + } + } + } else { + operator + } + case operator: SparkPlan if onlyHandlesUnsafeRows(operator) => + if (operator.children.exists(!_.outputsUnsafeRows)) { + operator.withNewChildren { + operator.children.map { + c => if (!c.outputsUnsafeRows) ConvertToUnsafe(c) else c + } + } + } else { + operator + } + case operator: SparkPlan if handlesBothSafeAndUnsafeRows(operator) => + if (operator.children.map(_.outputsUnsafeRows).toSet.size != 1) { + // If this operator's children produce both unsafe and safe rows, then convert everything + // to unsafe rows + operator.withNewChildren { + operator.children.map { + c => if (!c.outputsUnsafeRows) ConvertToUnsafe(c) else c + } + } + } else { + operator + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala new file mode 100644 index 0000000000000..7b75f755918c1 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.CatalystTypeConverters +import org.apache.spark.sql.catalyst.expressions.IsNull +import org.apache.spark.sql.test.TestSQLContext + +class RowFormatConvertersSuite extends SparkPlanTest { + + private def getConverters(plan: SparkPlan): Seq[SparkPlan] = plan.collect { + case c: ConvertToUnsafe => c + case c: ConvertToSafe => c + } + + private val outputsSafe = ExternalSort(Nil, false, PhysicalRDD(Seq.empty, null)) + assert(!outputsSafe.outputsUnsafeRows) + private val outputsUnsafe = UnsafeExternalSort(Nil, false, PhysicalRDD(Seq.empty, null)) + assert(outputsUnsafe.outputsUnsafeRows) + + test("planner should insert unsafe->safe conversions when required") { + val plan = Limit(10, outputsUnsafe) + val preparedPlan = TestSQLContext.prepareForExecution.execute(plan) + assert(preparedPlan.children.head.isInstanceOf[ConvertToSafe]) + } + + test("filter can process unsafe rows") { + val plan = Filter(IsNull(null), outputsUnsafe) + val preparedPlan = TestSQLContext.prepareForExecution.execute(plan) + assert(getConverters(preparedPlan).isEmpty) + assert(preparedPlan.outputsUnsafeRows) + } + + test("filter can process safe rows") { + val plan = Filter(IsNull(null), outputsSafe) + val preparedPlan = TestSQLContext.prepareForExecution.execute(plan) + assert(getConverters(preparedPlan).isEmpty) + assert(!preparedPlan.outputsUnsafeRows) + } + + test("execute() fails an assertion if inputs rows are of different formats") { + val e = intercept[AssertionError] { + Union(Seq(outputsSafe, outputsUnsafe)).execute() + } + assert(e.getMessage.contains("format")) + } + + test("union requires all of its input rows' formats to agree") { + val plan = Union(Seq(outputsSafe, outputsUnsafe)) + assert(plan.canProcessSafeRows && plan.canProcessUnsafeRows) + val preparedPlan = TestSQLContext.prepareForExecution.execute(plan) + assert(preparedPlan.outputsUnsafeRows) + } + + test("union can process safe rows") { + val plan = Union(Seq(outputsSafe, outputsSafe)) + val preparedPlan = TestSQLContext.prepareForExecution.execute(plan) + assert(!preparedPlan.outputsUnsafeRows) + } + + test("union can process unsafe rows") { + val plan = Union(Seq(outputsUnsafe, outputsUnsafe)) + val preparedPlan = TestSQLContext.prepareForExecution.execute(plan) + assert(preparedPlan.outputsUnsafeRows) + } + + test("round trip with ConvertToUnsafe and ConvertToSafe") { + val input = Seq(("hello", 1), ("world", 2)) + checkAnswer( + TestSQLContext.createDataFrame(input), + plan => ConvertToSafe(ConvertToUnsafe(plan)), + input.map(Row.fromTuple) + ) + } +} From 1b4ff05538fbcfe10ca4fa97606bd6e39a8450cb Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Sat, 18 Jul 2015 11:13:49 -0700 Subject: [PATCH 35/58] [SPARK-9142][SQL] remove more self type in catalyst a follow up of https://github.com/apache/spark/pull/7479. The `TreeNode` is the root case of the requirement of `self: Product =>` stuff, so why not make `TreeNode` extend `Product`? Author: Wenchen Fan Closes #7495 from cloud-fan/self-type and squashes the following commits: 8676af7 [Wenchen Fan] remove more self type --- .../apache/spark/sql/catalyst/expressions/Expression.scala | 2 +- .../org/apache/spark/sql/catalyst/expressions/math.scala | 2 +- .../scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala | 2 +- .../apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala | 2 +- .../scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala | 4 ++-- .../main/scala/org/apache/spark/sql/execution/SparkPlan.scala | 2 +- 6 files changed, 7 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index c70b5af4aa448..0e128d8bdcd96 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -43,7 +43,7 @@ import org.apache.spark.sql.types._ * * See [[Substring]] for an example. */ -abstract class Expression extends TreeNode[Expression] with Product { +abstract class Expression extends TreeNode[Expression] { /** * Returns true when an expression is a candidate for static evaluation before the query is diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala index 9101f11052218..eb5c065a34123 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala @@ -77,7 +77,7 @@ abstract class UnaryMathExpression(f: Double => Double, name: String) } abstract class UnaryLogExpression(f: Double => Double, name: String) - extends UnaryMathExpression(f, name) { self: Product => + extends UnaryMathExpression(f, name) { // values less than or equal to yAsymptote eval to null in Hive, instead of NaN or -Infinity protected val yAsymptote: Double = 0.0 diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index b89e3382f06a9..d06a7a2add754 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.trees.TreeNode import org.apache.spark.sql.types.{ArrayType, DataType, StructField, StructType} abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanType] { - self: PlanType with Product => + self: PlanType => def output: Seq[Attribute] diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index dd6c5d43f5714..bedeaf06adf12 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.trees.TreeNode -abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging with Product{ +abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { /** * Computes [[Statistics]] for this plan. The default implementation assumes the output diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index 0f95ca688a7a8..122e9fc5ed77f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -54,8 +54,8 @@ object CurrentOrigin { } } -abstract class TreeNode[BaseType <: TreeNode[BaseType]] { - self: BaseType with Product => +abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { + self: BaseType => val origin: Origin = CurrentOrigin.get diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index f363e9947d5f6..b0d56b7bf0b86 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -39,7 +39,7 @@ object SparkPlan { * :: DeveloperApi :: */ @DeveloperApi -abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Product with Serializable { +abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializable { /** * A handle to the SQL Context that was used to create this plan. Since many operators need From 692378c01d949dfe2b2a884add153cd5f8054b5a Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Sat, 18 Jul 2015 11:25:16 -0700 Subject: [PATCH 36/58] [SPARK-9167][SQL] use UTC Calendar in `stringToDate` fix 2 bugs introduced in https://github.com/apache/spark/pull/7353 1. we should use UTC Calendar when cast string to date . Before #7353 , we use `DateTimeUtils.fromJavaDate(Date.valueOf(s.toString))` to cast string to date, and `fromJavaDate` will call `millisToDays` to avoid the time zone issue. Now we use `DateTimeUtils.stringToDate(s)`, we should create a Calendar with UTC in the begging. 2. we should not change the default time zone in test cases. The `threadLocalLocalTimeZone` and `threadLocalTimestampFormat` in `DateTimeUtils` will only be evaluated once for each thread, so we can't set the default time zone back anymore. Author: Wenchen Fan Closes #7488 from cloud-fan/datetime and squashes the following commits: 9cd6005 [Wenchen Fan] address comments 21ef293 [Wenchen Fan] fix 2 bugs in datetime --- .../sql/catalyst/util/DateTimeUtils.scala | 9 +++++---- .../sql/catalyst/expressions/CastSuite.scala | 3 --- .../catalyst/util/DateTimeUtilsSuite.scala | 19 ++++++++++--------- 3 files changed, 15 insertions(+), 16 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala index f33e34b380bcf..45e45aef1a349 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala @@ -65,8 +65,8 @@ object DateTimeUtils { def millisToDays(millisUtc: Long): Int = { // SPARK-6785: use Math.floor so negative number of days (dates before 1970) // will correctly work as input for function toJavaDate(Int) - val millisLocal = millisUtc.toDouble + threadLocalLocalTimeZone.get().getOffset(millisUtc) - Math.floor(millisLocal / MILLIS_PER_DAY).toInt + val millisLocal = millisUtc + threadLocalLocalTimeZone.get().getOffset(millisUtc) + Math.floor(millisLocal.toDouble / MILLIS_PER_DAY).toInt } // reverse of millisToDays @@ -375,8 +375,9 @@ object DateTimeUtils { segments(2) < 1 || segments(2) > 31) { return None } - val c = Calendar.getInstance() + val c = Calendar.getInstance(TimeZone.getTimeZone("GMT")) c.set(segments(0), segments(1) - 1, segments(2), 0, 0, 0) - Some((c.getTimeInMillis / 1000 / 3600 / 24).toInt) + c.set(Calendar.MILLISECOND, 0) + Some((c.getTimeInMillis / MILLIS_PER_DAY).toInt) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala index ef8bcd41f7280..ccf448eee0688 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala @@ -281,8 +281,6 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { val nts = sts + ".1" val ts = Timestamp.valueOf(nts) - val defaultTimeZone = TimeZone.getDefault - TimeZone.setDefault(TimeZone.getTimeZone("America/Los_Angeles")) var c = Calendar.getInstance() c.set(2015, 2, 8, 2, 30, 0) checkEvaluation(cast(cast(new Timestamp(c.getTimeInMillis), StringType), TimestampType), @@ -291,7 +289,6 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { c.set(2015, 10, 1, 2, 30, 0) checkEvaluation(cast(cast(new Timestamp(c.getTimeInMillis), StringType), TimestampType), c.getTimeInMillis * 1000) - TimeZone.setDefault(defaultTimeZone) checkEvaluation(cast("abdef", StringType), "abdef") checkEvaluation(cast("abdef", DecimalType.Unlimited), null) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala index 5c3a621c6d11f..04c5f09792ac3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala @@ -90,34 +90,35 @@ class DateTimeUtilsSuite extends SparkFunSuite { } test("string to date") { - val millisPerDay = 1000L * 3600L * 24L + import DateTimeUtils.millisToDays + var c = Calendar.getInstance() c.set(2015, 0, 28, 0, 0, 0) c.set(Calendar.MILLISECOND, 0) assert(DateTimeUtils.stringToDate(UTF8String.fromString("2015-01-28")).get === - c.getTimeInMillis / millisPerDay) + millisToDays(c.getTimeInMillis)) c.set(2015, 0, 1, 0, 0, 0) c.set(Calendar.MILLISECOND, 0) assert(DateTimeUtils.stringToDate(UTF8String.fromString("2015")).get === - c.getTimeInMillis / millisPerDay) + millisToDays(c.getTimeInMillis)) c = Calendar.getInstance() c.set(2015, 2, 1, 0, 0, 0) c.set(Calendar.MILLISECOND, 0) assert(DateTimeUtils.stringToDate(UTF8String.fromString("2015-03")).get === - c.getTimeInMillis / millisPerDay) + millisToDays(c.getTimeInMillis)) c = Calendar.getInstance() c.set(2015, 2, 18, 0, 0, 0) c.set(Calendar.MILLISECOND, 0) assert(DateTimeUtils.stringToDate(UTF8String.fromString("2015-03-18")).get === - c.getTimeInMillis / millisPerDay) + millisToDays(c.getTimeInMillis)) assert(DateTimeUtils.stringToDate(UTF8String.fromString("2015-03-18 ")).get === - c.getTimeInMillis / millisPerDay) + millisToDays(c.getTimeInMillis)) assert(DateTimeUtils.stringToDate(UTF8String.fromString("2015-03-18 123142")).get === - c.getTimeInMillis / millisPerDay) + millisToDays(c.getTimeInMillis)) assert(DateTimeUtils.stringToDate(UTF8String.fromString("2015-03-18T123123")).get === - c.getTimeInMillis / millisPerDay) + millisToDays(c.getTimeInMillis)) assert(DateTimeUtils.stringToDate(UTF8String.fromString("2015-03-18T")).get === - c.getTimeInMillis / millisPerDay) + millisToDays(c.getTimeInMillis)) assert(DateTimeUtils.stringToDate(UTF8String.fromString("2015-03-18X")).isEmpty) assert(DateTimeUtils.stringToDate(UTF8String.fromString("2015/03/18")).isEmpty) From 86c50bf72c41d95107a55c16a6853dcda7f3e143 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Sat, 18 Jul 2015 11:58:53 -0700 Subject: [PATCH 37/58] [SPARK-9171][SQL] add and improve tests for nondeterministic expressions Author: Wenchen Fan Closes #7496 from cloud-fan/tests and squashes the following commits: 0958f90 [Wenchen Fan] improve test for nondeterministic expressions --- .../scala/org/apache/spark/TaskContext.scala | 2 +- .../expressions/ExpressionEvalHelper.scala | 108 ++++++++++-------- .../expressions/MathFunctionsSuite.scala | 18 +-- .../catalyst/expressions/RandomSuite.scala | 6 +- .../spark/sql/ColumnExpressionSuite.scala | 9 +- .../expression/NondeterministicSuite.scala | 32 ++++++ 6 files changed, 103 insertions(+), 72 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/expression/NondeterministicSuite.scala diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala b/core/src/main/scala/org/apache/spark/TaskContext.scala index 345bb500a7dec..e93eb93124e51 100644 --- a/core/src/main/scala/org/apache/spark/TaskContext.scala +++ b/core/src/main/scala/org/apache/spark/TaskContext.scala @@ -38,7 +38,7 @@ object TaskContext { */ def getPartitionId(): Int = { val tc = taskContext.get() - if (tc == null) { + if (tc eq null) { 0 } else { tc.partitionId() diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index c43486b3ddcf5..7a96044d35a09 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -23,7 +23,7 @@ import org.scalatest.Matchers._ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.CatalystTypeConverters -import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateUnsafeProjection, GenerateProjection, GenerateMutableProjection} +import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.optimizer.DefaultOptimizer import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Project} @@ -38,7 +38,7 @@ trait ExpressionEvalHelper { } protected def checkEvaluation( - expression: Expression, expected: Any, inputRow: InternalRow = EmptyRow): Unit = { + expression: => Expression, expected: Any, inputRow: InternalRow = EmptyRow): Unit = { val catalystValue = CatalystTypeConverters.convertToCatalyst(expected) checkEvaluationWithoutCodegen(expression, catalystValue, inputRow) checkEvaluationWithGeneratedMutableProjection(expression, catalystValue, inputRow) @@ -51,12 +51,14 @@ trait ExpressionEvalHelper { /** * Check the equality between result of expression and expected value, it will handle - * Array[Byte]. + * Array[Byte] and Spread[Double]. */ protected def checkResult(result: Any, expected: Any): Boolean = { (result, expected) match { case (result: Array[Byte], expected: Array[Byte]) => java.util.Arrays.equals(result, expected) + case (result: Double, expected: Spread[Double]) => + expected.isWithin(result) case _ => result == expected } } @@ -65,10 +67,29 @@ trait ExpressionEvalHelper { expression.eval(inputRow) } + protected def generateProject( + generator: => Projection, + expression: Expression): Projection = { + try { + generator + } catch { + case e: Throwable => + val ctx = new CodeGenContext + val evaluated = expression.gen(ctx) + fail( + s""" + |Code generation of $expression failed: + |${evaluated.code} + |$e + """.stripMargin) + } + } + protected def checkEvaluationWithoutCodegen( expression: Expression, expected: Any, inputRow: InternalRow = EmptyRow): Unit = { + val actual = try evaluate(expression, inputRow) catch { case e: Exception => fail(s"Exception evaluating $expression", e) } @@ -85,21 +106,11 @@ trait ExpressionEvalHelper { expected: Any, inputRow: InternalRow = EmptyRow): Unit = { - val plan = try { - GenerateMutableProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil)() - } catch { - case e: Throwable => - val ctx = GenerateProjection.newCodeGenContext() - val evaluated = expression.gen(ctx) - fail( - s""" - |Code generation of $expression failed: - |${evaluated.code} - |$e - """.stripMargin) - } + val plan = generateProject( + GenerateMutableProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil)(), + expression) - val actual = plan(inputRow).apply(0) + val actual = plan(inputRow).get(0) if (!checkResult(actual, expected)) { val input = if (inputRow == EmptyRow) "" else s", input: $inputRow" fail(s"Incorrect Evaluation: $expression, actual: $actual, expected: $expected$input") @@ -110,24 +121,19 @@ trait ExpressionEvalHelper { expression: Expression, expected: Any, inputRow: InternalRow = EmptyRow): Unit = { - val ctx = GenerateProjection.newCodeGenContext() - lazy val evaluated = expression.gen(ctx) - val plan = try { - GenerateProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil) - } catch { - case e: Throwable => - fail( - s""" - |Code generation of $expression failed: - |${evaluated.code} - |$e - """.stripMargin) - } + val plan = generateProject( + GenerateProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil), + expression) val actual = plan(inputRow) val expectedRow = InternalRow(expected) + + // We reimplement hashCode in generated `SpecificRow`, make sure it's consistent with our + // interpreted version. if (actual.hashCode() != expectedRow.hashCode()) { + val ctx = new CodeGenContext + val evaluated = expression.gen(ctx) fail( s""" |Mismatched hashCodes for values: $actual, $expectedRow @@ -136,9 +142,10 @@ trait ExpressionEvalHelper { |Code: $evaluated """.stripMargin) } + if (actual != expectedRow) { val input = if (inputRow == EmptyRow) "" else s", input: $inputRow" - fail(s"Incorrect Evaluation: $expression, actual: $actual, expected: $expected$input") + fail(s"Incorrect Evaluation: $expression, actual: $actual, expected: $expectedRow$input") } if (actual.copy() != expectedRow) { fail(s"Copy of generated Row is wrong: actual: ${actual.copy()}, expected: $expectedRow") @@ -149,20 +156,10 @@ trait ExpressionEvalHelper { expression: Expression, expected: Any, inputRow: InternalRow = EmptyRow): Unit = { - val ctx = GenerateUnsafeProjection.newCodeGenContext() - lazy val evaluated = expression.gen(ctx) - val plan = try { - GenerateUnsafeProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil) - } catch { - case e: Throwable => - fail( - s""" - |Code generation of $expression failed: - |${evaluated.code} - |$e - """.stripMargin) - } + val plan = generateProject( + GenerateUnsafeProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil), + expression) val unsafeRow = plan(inputRow) // UnsafeRow cannot be compared with GenericInternalRow directly @@ -170,7 +167,7 @@ trait ExpressionEvalHelper { val expectedRow = InternalRow(expected) if (actual != expectedRow) { val input = if (inputRow == EmptyRow) "" else s", input: $inputRow" - fail(s"Incorrect Evaluation: $expression, actual: $actual, expected: $expected$input") + fail(s"Incorrect Evaluation: $expression, actual: $actual, expected: $expectedRow$input") } } @@ -184,12 +181,23 @@ trait ExpressionEvalHelper { } protected def checkDoubleEvaluation( - expression: Expression, + expression: => Expression, expected: Spread[Double], inputRow: InternalRow = EmptyRow): Unit = { - val actual = try evaluate(expression, inputRow) catch { - case e: Exception => fail(s"Exception evaluating $expression", e) - } - actual.asInstanceOf[Double] shouldBe expected + checkEvaluationWithoutCodegen(expression, expected) + checkEvaluationWithGeneratedMutableProjection(expression, expected) + checkEvaluationWithOptimization(expression, expected) + + var plan = generateProject( + GenerateProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil), + expression) + var actual = plan(inputRow).get(0) + assert(checkResult(actual, expected)) + + plan = generateProject( + GenerateUnsafeProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil), + expression) + actual = FromUnsafeProjection(expression.dataType :: Nil)(plan(inputRow)).get(0) + assert(checkResult(actual, expected)) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala index df988f57fbfde..04acd5b5ff4d1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala @@ -143,7 +143,6 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { case e: Exception => fail(s"Exception evaluating $expression", e) } if (!actual.asInstanceOf[Double].isNaN) { - val input = if (inputRow == EmptyRow) "" else s", input: $inputRow" fail(s"Incorrect evaluation (codegen off): $expression, " + s"actual: $actual, " + s"expected: NaN") @@ -155,23 +154,12 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { expression: Expression, inputRow: InternalRow = EmptyRow): Unit = { - val plan = try { - GenerateMutableProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil)() - } catch { - case e: Throwable => - val ctx = GenerateProjection.newCodeGenContext() - val evaluated = expression.gen(ctx) - fail( - s""" - |Code generation of $expression failed: - |${evaluated.code} - |$e - """.stripMargin) - } + val plan = generateProject( + GenerateMutableProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil)(), + expression) val actual = plan(inputRow).apply(0) if (!actual.asInstanceOf[Double].isNaN) { - val input = if (inputRow == EmptyRow) "" else s", input: $inputRow" fail(s"Incorrect Evaluation: $expression, actual: $actual, expected: NaN") } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RandomSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RandomSuite.scala index 9be2b23a53f27..698c81ba24482 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RandomSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RandomSuite.scala @@ -21,13 +21,13 @@ import org.scalatest.Matchers._ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.types.{DoubleType, IntegerType} +import org.apache.spark.sql.types.DoubleType class RandomSuite extends SparkFunSuite with ExpressionEvalHelper { test("random") { - val row = create_row(1.1, 2.0, 3.1, null) - checkDoubleEvaluation(Rand(30), (0.7363714192755834 +- 0.001), row) + checkDoubleEvaluation(Rand(30), 0.7363714192755834 +- 0.001) + checkDoubleEvaluation(Randn(30), 0.5181478766595276 +- 0.001) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index 8f15479308391..6bd5804196853 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -450,7 +450,7 @@ class ColumnExpressionSuite extends QueryTest { test("monotonicallyIncreasingId") { // Make sure we have 2 partitions, each with 2 records. - val df = ctx.sparkContext.parallelize(1 to 2, 2).mapPartitions { iter => + val df = ctx.sparkContext.parallelize(Seq[Int](), 2).mapPartitions { _ => Iterator(Tuple1(1), Tuple1(2)) }.toDF("a") checkAnswer( @@ -460,10 +460,13 @@ class ColumnExpressionSuite extends QueryTest { } test("sparkPartitionId") { - val df = ctx.sparkContext.parallelize(1 to 1, 1).map(i => (i, i)).toDF("a", "b") + // Make sure we have 2 partitions, each with 2 records. + val df = ctx.sparkContext.parallelize(Seq[Int](), 2).mapPartitions { _ => + Iterator(Tuple1(1), Tuple1(2)) + }.toDF("a") checkAnswer( df.select(sparkPartitionId()), - Row(0) + Row(0) :: Row(0) :: Row(1) :: Row(1) :: Nil ) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/expression/NondeterministicSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/expression/NondeterministicSuite.scala new file mode 100644 index 0000000000000..99e11fd64b2b9 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/expression/NondeterministicSuite.scala @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.expression + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.expressions. ExpressionEvalHelper +import org.apache.spark.sql.execution.expressions.{SparkPartitionID, MonotonicallyIncreasingID} + +class NondeterministicSuite extends SparkFunSuite with ExpressionEvalHelper { + test("MonotonicallyIncreasingID") { + checkEvaluation(MonotonicallyIncreasingID(), 0) + } + + test("SparkPartitionID") { + checkEvaluation(SparkPartitionID, 0) + } +} From 225de8da2b20ba03b358e222411610e8567aa88d Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sat, 18 Jul 2015 12:11:37 -0700 Subject: [PATCH 38/58] [SPARK-9151][SQL] Implement code generation for Abs JIRA: https://issues.apache.org/jira/browse/SPARK-9151 Add codegen support for `Abs`. Author: Liang-Chi Hsieh Closes #7498 from viirya/abs_codegen and squashes the following commits: 0c8410f [Liang-Chi Hsieh] Implement code generation for Abs. --- .../apache/spark/sql/catalyst/expressions/arithmetic.scala | 7 +++++++ .../main/scala/org/apache/spark/sql/types/Decimal.scala | 2 ++ 2 files changed, 9 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index c5960eb390ea4..e83650fc8cb0e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -73,6 +73,13 @@ case class Abs(child: Expression) extends UnaryExpression with ExpectsInputTypes private lazy val numeric = TypeUtils.getNumeric(dataType) + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = dataType match { + case dt: DecimalType => + defineCodeGen(ctx, ev, c => s"$c.abs()") + case dt: NumericType => + defineCodeGen(ctx, ev, c => s"(${ctx.javaType(dt)})(java.lang.Math.abs($c))") + } + protected override def nullSafeEval(input: Any): Any = numeric.abs(input) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala index a85af9e04aedb..bc689810bc292 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala @@ -278,6 +278,8 @@ final class Decimal extends Ordered[Decimal] with Serializable { Decimal(-longVal, precision, scale) } } + + def abs: Decimal = if (this.compare(Decimal(0)) < 0) this.unary_- else this } object Decimal { From cdc36eef4160dbae32e19a1eadbb4cf062f2fb2b Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Sat, 18 Jul 2015 12:25:04 -0700 Subject: [PATCH 39/58] Closes #6122 From 3d2134fc0d90379b89da08de7614aef1ac674b1b Mon Sep 17 00:00:00 2001 From: Yijie Shen Date: Sat, 18 Jul 2015 12:57:53 -0700 Subject: [PATCH 40/58] [SPARK-9055][SQL] WidenTypes should also support Intersect and Except JIRA: https://issues.apache.org/jira/browse/SPARK-9055 cc rxin Author: Yijie Shen Closes #7491 from yijieshen/widen and squashes the following commits: 079fa52 [Yijie Shen] widenType support for intersect and expect --- .../catalyst/analysis/HiveTypeCoercion.scala | 93 +++++++++++-------- .../plans/logical/basicOperators.scala | 8 ++ .../analysis/HiveTypeCoercionSuite.scala | 34 ++++++- 3 files changed, 94 insertions(+), 41 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 50db7d21f01ca..ff20835e82ba7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.analysis import javax.annotation.Nullable import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, Union} +import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.types._ @@ -168,52 +168,65 @@ object HiveTypeCoercion { * - LongType to DoubleType */ object WidenTypes extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transform { - // TODO: unions with fixed-precision decimals - case u @ Union(left, right) if u.childrenResolved && !u.resolved => - val castedInput = left.output.zip(right.output).map { - // When a string is found on one side, make the other side a string too. - case (lhs, rhs) if lhs.dataType == StringType && rhs.dataType != StringType => - (lhs, Alias(Cast(rhs, StringType), rhs.name)()) - case (lhs, rhs) if lhs.dataType != StringType && rhs.dataType == StringType => - (Alias(Cast(lhs, StringType), lhs.name)(), rhs) - case (lhs, rhs) if lhs.dataType != rhs.dataType => - logDebug(s"Resolving mismatched union input ${lhs.dataType}, ${rhs.dataType}") - findTightestCommonTypeOfTwo(lhs.dataType, rhs.dataType).map { widestType => - val newLeft = - if (lhs.dataType == widestType) lhs else Alias(Cast(lhs, widestType), lhs.name)() - val newRight = - if (rhs.dataType == widestType) rhs else Alias(Cast(rhs, widestType), rhs.name)() - - (newLeft, newRight) - }.getOrElse { - // If there is no applicable conversion, leave expression unchanged. - (lhs, rhs) - } + private[this] def widenOutputTypes(planName: String, left: LogicalPlan, right: LogicalPlan): + (LogicalPlan, LogicalPlan) = { + + // TODO: with fixed-precision decimals + val castedInput = left.output.zip(right.output).map { + // When a string is found on one side, make the other side a string too. + case (lhs, rhs) if lhs.dataType == StringType && rhs.dataType != StringType => + (lhs, Alias(Cast(rhs, StringType), rhs.name)()) + case (lhs, rhs) if lhs.dataType != StringType && rhs.dataType == StringType => + (Alias(Cast(lhs, StringType), lhs.name)(), rhs) + + case (lhs, rhs) if lhs.dataType != rhs.dataType => + logDebug(s"Resolving mismatched $planName input ${lhs.dataType}, ${rhs.dataType}") + findTightestCommonTypeOfTwo(lhs.dataType, rhs.dataType).map { widestType => + val newLeft = + if (lhs.dataType == widestType) lhs else Alias(Cast(lhs, widestType), lhs.name)() + val newRight = + if (rhs.dataType == widestType) rhs else Alias(Cast(rhs, widestType), rhs.name)() + + (newLeft, newRight) + }.getOrElse { + // If there is no applicable conversion, leave expression unchanged. + (lhs, rhs) + } - case other => other - } + case other => other + } - val (castedLeft, castedRight) = castedInput.unzip + val (castedLeft, castedRight) = castedInput.unzip - val newLeft = - if (castedLeft.map(_.dataType) != left.output.map(_.dataType)) { - logDebug(s"Widening numeric types in union $castedLeft ${left.output}") - Project(castedLeft, left) - } else { - left - } + val newLeft = + if (castedLeft.map(_.dataType) != left.output.map(_.dataType)) { + logDebug(s"Widening numeric types in $planName $castedLeft ${left.output}") + Project(castedLeft, left) + } else { + left + } - val newRight = - if (castedRight.map(_.dataType) != right.output.map(_.dataType)) { - logDebug(s"Widening numeric types in union $castedRight ${right.output}") - Project(castedRight, right) - } else { - right - } + val newRight = + if (castedRight.map(_.dataType) != right.output.map(_.dataType)) { + logDebug(s"Widening numeric types in $planName $castedRight ${right.output}") + Project(castedRight, right) + } else { + right + } + (newLeft, newRight) + } + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case u @ Union(left, right) if u.childrenResolved && !u.resolved => + val (newLeft, newRight) = widenOutputTypes(u.nodeName, left, right) Union(newLeft, newRight) + case e @ Except(left, right) if e.childrenResolved && !e.resolved => + val (newLeft, newRight) = widenOutputTypes(e.nodeName, left, right) + Except(newLeft, newRight) + case i @ Intersect(left, right) if i.childrenResolved && !i.resolved => + val (newLeft, newRight) = widenOutputTypes(i.nodeName, left, right) + Intersect(newLeft, newRight) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 17a91247327f7..986c315b3173a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -141,6 +141,10 @@ case class BroadcastHint(child: LogicalPlan) extends UnaryNode { case class Except(left: LogicalPlan, right: LogicalPlan) extends BinaryNode { override def output: Seq[Attribute] = left.output + + override lazy val resolved: Boolean = + childrenResolved && + left.output.zip(right.output).forall { case (l, r) => l.dataType == r.dataType } } case class InsertIntoTable( @@ -437,4 +441,8 @@ case object OneRowRelation extends LeafNode { case class Intersect(left: LogicalPlan, right: LogicalPlan) extends BinaryNode { override def output: Seq[Attribute] = left.output + + override lazy val resolved: Boolean = + childrenResolved && + left.output.zip(right.output).forall { case (l, r) => l.dataType == r.dataType } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala index d0fd033b981c8..c9b3c69c6de89 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, LocalRelation, Project} +import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.types._ @@ -305,6 +305,38 @@ class HiveTypeCoercionSuite extends PlanTest { ) } + test("WidenTypes for union except and intersect") { + def checkOutput(logical: LogicalPlan, expectTypes: Seq[DataType]): Unit = { + logical.output.zip(expectTypes).foreach { case (attr, dt) => + assert(attr.dataType === dt) + } + } + + val left = LocalRelation( + AttributeReference("i", IntegerType)(), + AttributeReference("u", DecimalType.Unlimited)(), + AttributeReference("b", ByteType)(), + AttributeReference("d", DoubleType)()) + val right = LocalRelation( + AttributeReference("s", StringType)(), + AttributeReference("d", DecimalType(2, 1))(), + AttributeReference("f", FloatType)(), + AttributeReference("l", LongType)()) + + val wt = HiveTypeCoercion.WidenTypes + val expectedTypes = Seq(StringType, DecimalType.Unlimited, FloatType, DoubleType) + + val r1 = wt(Union(left, right)).asInstanceOf[Union] + val r2 = wt(Except(left, right)).asInstanceOf[Except] + val r3 = wt(Intersect(left, right)).asInstanceOf[Intersect] + checkOutput(r1.left, expectedTypes) + checkOutput(r1.right, expectedTypes) + checkOutput(r2.left, expectedTypes) + checkOutput(r2.right, expectedTypes) + checkOutput(r3.left, expectedTypes) + checkOutput(r3.right, expectedTypes) + } + /** * There are rules that need to not fire before child expressions get resolved. * We use this test to make sure those rules do not fire early. From 6e1e2eba696e89ba57bf5450b9c72c4386e43dc8 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Sat, 18 Jul 2015 14:07:56 -0700 Subject: [PATCH 41/58] [SPARK-8240][SQL] string function: concat Author: Reynold Xin Closes #7486 from rxin/concat and squashes the following commits: 5217d6e [Reynold Xin] Removed Hive's concat test. f5cb7a3 [Reynold Xin] Concat is never nullable. ae4e61f [Reynold Xin] Removed extra import. fddcbbd [Reynold Xin] Fixed NPE. 22e831c [Reynold Xin] Added missing file. 57a2352 [Reynold Xin] [SPARK-8240][SQL] string function: concat --- .../catalyst/analysis/FunctionRegistry.scala | 1 + .../expressions/stringOperations.scala | 37 +++ ...ite.scala => StringExpressionsSuite.scala} | 24 +- .../org/apache/spark/sql/functions.scala | 22 ++ .../spark/sql/DataFrameFunctionsSuite.scala | 242 --------------- .../spark/sql/StringFunctionsSuite.scala | 284 ++++++++++++++++++ .../execution/HiveCompatibilitySuite.scala | 4 +- .../apache/spark/unsafe/types/UTF8String.java | 40 ++- .../spark/unsafe/types/UTF8StringSuite.java | 14 + 9 files changed, 421 insertions(+), 247 deletions(-) rename sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/{StringFunctionsSuite.scala => StringExpressionsSuite.scala} (96%) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index ce552a1d65eda..d1cda6bc27095 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -152,6 +152,7 @@ object FunctionRegistry { // string functions expression[Ascii]("ascii"), expression[Base64]("base64"), + expression[Concat]("concat"), expression[Encode]("encode"), expression[Decode]("decode"), expression[FormatNumber]("format_number"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index c64afe7b3f19a..b36354eff092a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -27,6 +27,43 @@ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String +//////////////////////////////////////////////////////////////////////////////////////////////////// +// This file defines expressions for string operations. +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +/** + * An expression that concatenates multiple input strings into a single string. + * Input expressions that are evaluated to nulls are skipped. + * + * For example, `concat("a", null, "b")` is evaluated to `"ab"`. + * + * Note that this is different from Hive since Hive outputs null if any input is null. + * We never output null. + */ +case class Concat(children: Seq[Expression]) extends Expression with ImplicitCastInputTypes { + + override def inputTypes: Seq[AbstractDataType] = Seq.fill(children.size)(StringType) + override def dataType: DataType = StringType + + override def nullable: Boolean = false + override def foldable: Boolean = children.forall(_.foldable) + + override def eval(input: InternalRow): Any = { + val inputs = children.map(_.eval(input).asInstanceOf[UTF8String]) + UTF8String.concat(inputs : _*) + } + + override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val evals = children.map(_.gen(ctx)) + val inputs = evals.map { eval => s"${eval.isNull} ? null : ${eval.primitive}" }.mkString(", ") + evals.map(_.code).mkString("\n") + s""" + boolean ${ev.isNull} = false; + UTF8String ${ev.primitive} = UTF8String.concat($inputs); + """ + } +} + trait StringRegexExpression extends ImplicitCastInputTypes { self: BinaryExpression => diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala similarity index 96% rename from sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala rename to sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala index 5d7763bedf6bd..0ed567a90dd1f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala @@ -22,7 +22,29 @@ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.types._ -class StringFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { +class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { + + test("concat") { + def testConcat(inputs: String*): Unit = { + val expected = inputs.filter(_ != null).mkString + checkEvaluation(Concat(inputs.map(Literal.create(_, StringType))), expected, EmptyRow) + } + + testConcat() + testConcat(null) + testConcat("") + testConcat("ab") + testConcat("a", "b") + testConcat("a", "b", "C") + testConcat("a", null, "C") + testConcat("a", null, null) + testConcat(null, null, null) + + // scalastyle:off + // non ascii characters are not allowed in the code, so we disable the scalastyle here. + testConcat("数据", null, "砖头") + // scalastyle:on + } test("StringComparison") { val row = create_row("abc", null) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index b56fd9a71b321..c180407389136 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1710,6 +1710,28 @@ object functions { // String functions ////////////////////////////////////////////////////////////////////////////////////////////// + /** + * Concatenates input strings together into a single string. + * + * @group string_funcs + * @since 1.5.0 + */ + @scala.annotation.varargs + def concat(exprs: Column*): Column = Concat(exprs.map(_.expr)) + + /** + * Concatenates input strings together into a single string. + * + * This is the variant of concat that takes in the column names. + * + * @group string_funcs + * @since 1.5.0 + */ + @scala.annotation.varargs + def concat(columnName: String, columnNames: String*): Column = { + concat((columnName +: columnNames).map(Column.apply): _*) + } + /** * Computes the length of a given string / binary value. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 6dccdd857b453..29f1197a8543c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -208,169 +208,6 @@ class DataFrameFunctionsSuite extends QueryTest { Row(2743272264L, 2180413220L)) } - test("Levenshtein distance") { - val df = Seq(("kitten", "sitting"), ("frog", "fog")).toDF("l", "r") - checkAnswer(df.select(levenshtein("l", "r")), Seq(Row(3), Row(1))) - checkAnswer(df.selectExpr("levenshtein(l, r)"), Seq(Row(3), Row(1))) - } - - test("string ascii function") { - val df = Seq(("abc", "")).toDF("a", "b") - checkAnswer( - df.select(ascii($"a"), ascii("b")), - Row(97, 0)) - - checkAnswer( - df.selectExpr("ascii(a)", "ascii(b)"), - Row(97, 0)) - } - - test("string base64/unbase64 function") { - val bytes = Array[Byte](1, 2, 3, 4) - val df = Seq((bytes, "AQIDBA==")).toDF("a", "b") - checkAnswer( - df.select(base64("a"), base64($"a"), unbase64("b"), unbase64($"b")), - Row("AQIDBA==", "AQIDBA==", bytes, bytes)) - - checkAnswer( - df.selectExpr("base64(a)", "unbase64(b)"), - Row("AQIDBA==", bytes)) - } - - test("string encode/decode function") { - val bytes = Array[Byte](-27, -92, -89, -27, -115, -125, -28, -72, -106, -25, -107, -116) - // scalastyle:off - // non ascii characters are not allowed in the code, so we disable the scalastyle here. - val df = Seq(("大千世界", "utf-8", bytes)).toDF("a", "b", "c") - checkAnswer( - df.select( - encode($"a", "utf-8"), - encode("a", "utf-8"), - decode($"c", "utf-8"), - decode("c", "utf-8")), - Row(bytes, bytes, "大千世界", "大千世界")) - - checkAnswer( - df.selectExpr("encode(a, 'utf-8')", "decode(c, 'utf-8')"), - Row(bytes, "大千世界")) - // scalastyle:on - } - - test("string trim functions") { - val df = Seq((" example ", "")).toDF("a", "b") - - checkAnswer( - df.select(ltrim($"a"), rtrim($"a"), trim($"a")), - Row("example ", " example", "example")) - - checkAnswer( - df.selectExpr("ltrim(a)", "rtrim(a)", "trim(a)"), - Row("example ", " example", "example")) - } - - test("string formatString function") { - val df = Seq(("aa%d%s", 123, "cc")).toDF("a", "b", "c") - - checkAnswer( - df.select(formatString($"a", $"b", $"c"), formatString("aa%d%s", "b", "c")), - Row("aa123cc", "aa123cc")) - - checkAnswer( - df.selectExpr("printf(a, b, c)"), - Row("aa123cc")) - } - - test("string instr function") { - val df = Seq(("aaads", "aa", "zz")).toDF("a", "b", "c") - - checkAnswer( - df.select(instr($"a", $"b"), instr("a", "b")), - Row(1, 1)) - - checkAnswer( - df.selectExpr("instr(a, b)"), - Row(1)) - } - - test("string locate function") { - val df = Seq(("aaads", "aa", "zz", 1)).toDF("a", "b", "c", "d") - - checkAnswer( - df.select( - locate($"b", $"a"), locate("b", "a"), locate($"b", $"a", 1), - locate("b", "a", 1), locate($"b", $"a", $"d"), locate("b", "a", "d")), - Row(1, 1, 2, 2, 2, 2)) - - checkAnswer( - df.selectExpr("locate(b, a)", "locate(b, a, d)"), - Row(1, 2)) - } - - test("string padding functions") { - val df = Seq(("hi", 5, "??")).toDF("a", "b", "c") - - checkAnswer( - df.select( - lpad($"a", $"b", $"c"), rpad("a", "b", "c"), - lpad($"a", 1, $"c"), rpad("a", 1, "c")), - Row("???hi", "hi???", "h", "h")) - - checkAnswer( - df.selectExpr("lpad(a, b, c)", "rpad(a, b, c)", "lpad(a, 1, c)", "rpad(a, 1, c)"), - Row("???hi", "hi???", "h", "h")) - } - - test("string repeat function") { - val df = Seq(("hi", 2)).toDF("a", "b") - - checkAnswer( - df.select( - repeat($"a", 2), repeat("a", 2), repeat($"a", $"b"), repeat("a", "b")), - Row("hihi", "hihi", "hihi", "hihi")) - - checkAnswer( - df.selectExpr("repeat(a, 2)", "repeat(a, b)"), - Row("hihi", "hihi")) - } - - test("string reverse function") { - val df = Seq(("hi", "hhhi")).toDF("a", "b") - - checkAnswer( - df.select(reverse($"a"), reverse("b")), - Row("ih", "ihhh")) - - checkAnswer( - df.selectExpr("reverse(b)"), - Row("ihhh")) - } - - test("string space function") { - val df = Seq((2, 3)).toDF("a", "b") - - checkAnswer( - df.select(space($"a"), space("b")), - Row(" ", " ")) - - checkAnswer( - df.selectExpr("space(b)"), - Row(" ")) - } - - test("string split function") { - val df = Seq(("aa2bb3cc", "[1-9]+")).toDF("a", "b") - - checkAnswer( - df.select( - split($"a", "[1-9]+"), - split("a", "[1-9]+")), - Row(Seq("aa", "bb", "cc"), Seq("aa", "bb", "cc"))) - - checkAnswer( - df.selectExpr("split(a, '[1-9]+')"), - Row(Seq("aa", "bb", "cc"))) - } - test("conditional function: least") { checkAnswer( testData2.select(least(lit(-1), lit(0), col("a"), col("b"))).limit(1), @@ -430,83 +267,4 @@ class DataFrameFunctionsSuite extends QueryTest { ) } - test("string / binary length function") { - val df = Seq(("123", Array[Byte](1, 2, 3, 4), 123)).toDF("a", "b", "c") - checkAnswer( - df.select(length($"a"), length("a"), length($"b"), length("b")), - Row(3, 3, 4, 4)) - - checkAnswer( - df.selectExpr("length(a)", "length(b)"), - Row(3, 4)) - - intercept[AnalysisException] { - checkAnswer( - df.selectExpr("length(c)"), // int type of the argument is unacceptable - Row("5.0000")) - } - } - - test("number format function") { - val tuple = - ("aa", 1.asInstanceOf[Byte], 2.asInstanceOf[Short], - 3.13223f, 4, 5L, 6.48173d, Decimal(7.128381)) - val df = - Seq(tuple) - .toDF( - "a", // string "aa" - "b", // byte 1 - "c", // short 2 - "d", // float 3.13223f - "e", // integer 4 - "f", // long 5L - "g", // double 6.48173d - "h") // decimal 7.128381 - - checkAnswer( - df.select( - format_number($"f", 4), - format_number("f", 4)), - Row("5.0000", "5.0000")) - - checkAnswer( - df.selectExpr("format_number(b, e)"), // convert the 1st argument to integer - Row("1.0000")) - - checkAnswer( - df.selectExpr("format_number(c, e)"), // convert the 1st argument to integer - Row("2.0000")) - - checkAnswer( - df.selectExpr("format_number(d, e)"), // convert the 1st argument to double - Row("3.1322")) - - checkAnswer( - df.selectExpr("format_number(e, e)"), // not convert anything - Row("4.0000")) - - checkAnswer( - df.selectExpr("format_number(f, e)"), // not convert anything - Row("5.0000")) - - checkAnswer( - df.selectExpr("format_number(g, e)"), // not convert anything - Row("6.4817")) - - checkAnswer( - df.selectExpr("format_number(h, e)"), // not convert anything - Row("7.1284")) - - intercept[AnalysisException] { - checkAnswer( - df.selectExpr("format_number(a, e)"), // string type of the 1st argument is unacceptable - Row("5.0000")) - } - - intercept[AnalysisException] { - checkAnswer( - df.selectExpr("format_number(e, g)"), // decimal type of the 2nd argument is unacceptable - Row("5.0000")) - } - } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala new file mode 100644 index 0000000000000..4eff33ed45042 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala @@ -0,0 +1,284 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types.Decimal + + +class StringFunctionsSuite extends QueryTest { + + private lazy val ctx = org.apache.spark.sql.test.TestSQLContext + import ctx.implicits._ + + test("string concat") { + val df = Seq[(String, String, String)](("a", "b", null)).toDF("a", "b", "c") + + checkAnswer( + df.select(concat($"a", $"b", $"c")), + Row("ab")) + + checkAnswer( + df.selectExpr("concat(a, b, c)"), + Row("ab")) + } + + + test("string Levenshtein distance") { + val df = Seq(("kitten", "sitting"), ("frog", "fog")).toDF("l", "r") + checkAnswer(df.select(levenshtein("l", "r")), Seq(Row(3), Row(1))) + checkAnswer(df.selectExpr("levenshtein(l, r)"), Seq(Row(3), Row(1))) + } + + test("string ascii function") { + val df = Seq(("abc", "")).toDF("a", "b") + checkAnswer( + df.select(ascii($"a"), ascii("b")), + Row(97, 0)) + + checkAnswer( + df.selectExpr("ascii(a)", "ascii(b)"), + Row(97, 0)) + } + + test("string base64/unbase64 function") { + val bytes = Array[Byte](1, 2, 3, 4) + val df = Seq((bytes, "AQIDBA==")).toDF("a", "b") + checkAnswer( + df.select(base64("a"), base64($"a"), unbase64("b"), unbase64($"b")), + Row("AQIDBA==", "AQIDBA==", bytes, bytes)) + + checkAnswer( + df.selectExpr("base64(a)", "unbase64(b)"), + Row("AQIDBA==", bytes)) + } + + test("string encode/decode function") { + val bytes = Array[Byte](-27, -92, -89, -27, -115, -125, -28, -72, -106, -25, -107, -116) + // scalastyle:off + // non ascii characters are not allowed in the code, so we disable the scalastyle here. + val df = Seq(("大千世界", "utf-8", bytes)).toDF("a", "b", "c") + checkAnswer( + df.select( + encode($"a", "utf-8"), + encode("a", "utf-8"), + decode($"c", "utf-8"), + decode("c", "utf-8")), + Row(bytes, bytes, "大千世界", "大千世界")) + + checkAnswer( + df.selectExpr("encode(a, 'utf-8')", "decode(c, 'utf-8')"), + Row(bytes, "大千世界")) + // scalastyle:on + } + + test("string trim functions") { + val df = Seq((" example ", "")).toDF("a", "b") + + checkAnswer( + df.select(ltrim($"a"), rtrim($"a"), trim($"a")), + Row("example ", " example", "example")) + + checkAnswer( + df.selectExpr("ltrim(a)", "rtrim(a)", "trim(a)"), + Row("example ", " example", "example")) + } + + test("string formatString function") { + val df = Seq(("aa%d%s", 123, "cc")).toDF("a", "b", "c") + + checkAnswer( + df.select(formatString($"a", $"b", $"c"), formatString("aa%d%s", "b", "c")), + Row("aa123cc", "aa123cc")) + + checkAnswer( + df.selectExpr("printf(a, b, c)"), + Row("aa123cc")) + } + + test("string instr function") { + val df = Seq(("aaads", "aa", "zz")).toDF("a", "b", "c") + + checkAnswer( + df.select(instr($"a", $"b"), instr("a", "b")), + Row(1, 1)) + + checkAnswer( + df.selectExpr("instr(a, b)"), + Row(1)) + } + + test("string locate function") { + val df = Seq(("aaads", "aa", "zz", 1)).toDF("a", "b", "c", "d") + + checkAnswer( + df.select( + locate($"b", $"a"), locate("b", "a"), locate($"b", $"a", 1), + locate("b", "a", 1), locate($"b", $"a", $"d"), locate("b", "a", "d")), + Row(1, 1, 2, 2, 2, 2)) + + checkAnswer( + df.selectExpr("locate(b, a)", "locate(b, a, d)"), + Row(1, 2)) + } + + test("string padding functions") { + val df = Seq(("hi", 5, "??")).toDF("a", "b", "c") + + checkAnswer( + df.select( + lpad($"a", $"b", $"c"), rpad("a", "b", "c"), + lpad($"a", 1, $"c"), rpad("a", 1, "c")), + Row("???hi", "hi???", "h", "h")) + + checkAnswer( + df.selectExpr("lpad(a, b, c)", "rpad(a, b, c)", "lpad(a, 1, c)", "rpad(a, 1, c)"), + Row("???hi", "hi???", "h", "h")) + } + + test("string repeat function") { + val df = Seq(("hi", 2)).toDF("a", "b") + + checkAnswer( + df.select( + repeat($"a", 2), repeat("a", 2), repeat($"a", $"b"), repeat("a", "b")), + Row("hihi", "hihi", "hihi", "hihi")) + + checkAnswer( + df.selectExpr("repeat(a, 2)", "repeat(a, b)"), + Row("hihi", "hihi")) + } + + test("string reverse function") { + val df = Seq(("hi", "hhhi")).toDF("a", "b") + + checkAnswer( + df.select(reverse($"a"), reverse("b")), + Row("ih", "ihhh")) + + checkAnswer( + df.selectExpr("reverse(b)"), + Row("ihhh")) + } + + test("string space function") { + val df = Seq((2, 3)).toDF("a", "b") + + checkAnswer( + df.select(space($"a"), space("b")), + Row(" ", " ")) + + checkAnswer( + df.selectExpr("space(b)"), + Row(" ")) + } + + test("string split function") { + val df = Seq(("aa2bb3cc", "[1-9]+")).toDF("a", "b") + + checkAnswer( + df.select( + split($"a", "[1-9]+"), + split("a", "[1-9]+")), + Row(Seq("aa", "bb", "cc"), Seq("aa", "bb", "cc"))) + + checkAnswer( + df.selectExpr("split(a, '[1-9]+')"), + Row(Seq("aa", "bb", "cc"))) + } + + test("string / binary length function") { + val df = Seq(("123", Array[Byte](1, 2, 3, 4), 123)).toDF("a", "b", "c") + checkAnswer( + df.select(length($"a"), length("a"), length($"b"), length("b")), + Row(3, 3, 4, 4)) + + checkAnswer( + df.selectExpr("length(a)", "length(b)"), + Row(3, 4)) + + intercept[AnalysisException] { + checkAnswer( + df.selectExpr("length(c)"), // int type of the argument is unacceptable + Row("5.0000")) + } + } + + test("number format function") { + val tuple = + ("aa", 1.asInstanceOf[Byte], 2.asInstanceOf[Short], + 3.13223f, 4, 5L, 6.48173d, Decimal(7.128381)) + val df = + Seq(tuple) + .toDF( + "a", // string "aa" + "b", // byte 1 + "c", // short 2 + "d", // float 3.13223f + "e", // integer 4 + "f", // long 5L + "g", // double 6.48173d + "h") // decimal 7.128381 + + checkAnswer( + df.select( + format_number($"f", 4), + format_number("f", 4)), + Row("5.0000", "5.0000")) + + checkAnswer( + df.selectExpr("format_number(b, e)"), // convert the 1st argument to integer + Row("1.0000")) + + checkAnswer( + df.selectExpr("format_number(c, e)"), // convert the 1st argument to integer + Row("2.0000")) + + checkAnswer( + df.selectExpr("format_number(d, e)"), // convert the 1st argument to double + Row("3.1322")) + + checkAnswer( + df.selectExpr("format_number(e, e)"), // not convert anything + Row("4.0000")) + + checkAnswer( + df.selectExpr("format_number(f, e)"), // not convert anything + Row("5.0000")) + + checkAnswer( + df.selectExpr("format_number(g, e)"), // not convert anything + Row("6.4817")) + + checkAnswer( + df.selectExpr("format_number(h, e)"), // not convert anything + Row("7.1284")) + + intercept[AnalysisException] { + checkAnswer( + df.selectExpr("format_number(a, e)"), // string type of the 1st argument is unacceptable + Row("5.0000")) + } + + intercept[AnalysisException] { + checkAnswer( + df.selectExpr("format_number(e, g)"), // decimal type of the 2nd argument is unacceptable + Row("5.0000")) + } + } +} diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index 6b8f2f6217a54..299cc599ff8f7 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -256,6 +256,9 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "timestamp_2", "timestamp_udf", + // Hive outputs NULL if any concat input has null. We never output null for concat. + "udf_concat", + // Unlike Hive, we do support log base in (0, 1.0], therefore disable this "udf7" ) @@ -846,7 +849,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "udf_case", "udf_ceil", "udf_ceiling", - "udf_concat", "udf_concat_insert1", "udf_concat_insert2", "udf_concat_ws", diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index e7f9fbb2bc682..9723b6e0834b2 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -21,6 +21,7 @@ import java.io.Serializable; import java.io.UnsupportedEncodingException; +import org.apache.spark.unsafe.PlatformDependent; import org.apache.spark.unsafe.array.ByteArrayMethods; import static org.apache.spark.unsafe.PlatformDependent.*; @@ -322,7 +323,7 @@ public int indexOf(UTF8String v, int start) { } i += numBytesForFirstByte(getByte(i)); c += 1; - } while(i < numBytes); + } while (i < numBytes); return -1; } @@ -395,6 +396,39 @@ public UTF8String lpad(int len, UTF8String pad) { } } + /** + * Concatenates input strings together into a single string. A null input is skipped. + * For example, concat("a", null, "c") would yield "ac". + */ + public static UTF8String concat(UTF8String... inputs) { + if (inputs == null) { + return fromBytes(new byte[0]); + } + + // Compute the total length of the result. + int totalLength = 0; + for (int i = 0; i < inputs.length; i++) { + if (inputs[i] != null) { + totalLength += inputs[i].numBytes; + } + } + + // Allocate a new byte array, and copy the inputs one by one into it. + final byte[] result = new byte[totalLength]; + int offset = 0; + for (int i = 0; i < inputs.length; i++) { + if (inputs[i] != null) { + int len = inputs[i].numBytes; + PlatformDependent.copyMemory( + inputs[i].base, inputs[i].offset, + result, PlatformDependent.BYTE_ARRAY_OFFSET + offset, + len); + offset += len; + } + } + return fromBytes(result); + } + @Override public String toString() { try { @@ -413,7 +447,7 @@ public UTF8String clone() { } @Override - public int compareTo(final UTF8String other) { + public int compareTo(@Nonnull final UTF8String other) { int len = Math.min(numBytes, other.numBytes); // TODO: compare 8 bytes as unsigned long for (int i = 0; i < len; i ++) { @@ -434,7 +468,7 @@ public int compare(final UTF8String other) { public boolean equals(final Object other) { if (other instanceof UTF8String) { UTF8String o = (UTF8String) other; - if (numBytes != o.numBytes){ + if (numBytes != o.numBytes) { return false; } return ByteArrayMethods.arrayEquals(base, offset, o.base, o.offset, numBytes); diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java index 694bdc29f39d1..0db7522b50c1a 100644 --- a/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java +++ b/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java @@ -86,6 +86,20 @@ public void upperAndLower() { testUpperandLower("大千世界 数据砖头", "大千世界 数据砖头"); } + @Test + public void concatTest() { + assertEquals(concat(), fromString("")); + assertEquals(concat(null), fromString("")); + assertEquals(concat(fromString("")), fromString("")); + assertEquals(concat(fromString("ab")), fromString("ab")); + assertEquals(concat(fromString("a"), fromString("b")), fromString("ab")); + assertEquals(concat(fromString("a"), fromString("b"), fromString("c")), fromString("abc")); + assertEquals(concat(fromString("a"), null, fromString("c")), fromString("ac")); + assertEquals(concat(fromString("a"), null, null), fromString("a")); + assertEquals(concat(null, null, null), fromString("")); + assertEquals(concat(fromString("数据"), fromString("砖头")), fromString("数据砖头")); + } + @Test public void contains() { assertTrue(fromString("").contains(fromString(""))); From e16a19a39ed3369dffd375d712066d12add71c9e Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Sat, 18 Jul 2015 15:29:38 -0700 Subject: [PATCH 42/58] [SPARK-9174][SQL] Add documentation for all public SQLConfs. Author: Reynold Xin Closes #7500 from rxin/sqlconf and squashes the following commits: a5726c8 [Reynold Xin] [SPARK-9174][SQL] Add documentation for all public SQLConfs. --- .../scala/org/apache/spark/sql/SQLConf.scala | 144 +++++++----------- 1 file changed, 53 insertions(+), 91 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index 6005d35f015a9..2c2f7c35dfdce 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -26,6 +26,11 @@ import org.apache.parquet.hadoop.ParquetOutputCommitter import org.apache.spark.sql.catalyst.CatalystConf +//////////////////////////////////////////////////////////////////////////////////////////////////// +// This file defines the configuration options for Spark SQL. +//////////////////////////////////////////////////////////////////////////////////////////////////// + + private[spark] object SQLConf { private val sqlConfEntries = java.util.Collections.synchronizedMap( @@ -184,17 +189,20 @@ private[spark] object SQLConf { val COMPRESS_CACHED = booleanConf("spark.sql.inMemoryColumnarStorage.compressed", defaultValue = Some(true), doc = "When set to true Spark SQL will automatically select a compression codec for each " + - "column based on statistics of the data.") + "column based on statistics of the data.", + isPublic = false) val COLUMN_BATCH_SIZE = intConf("spark.sql.inMemoryColumnarStorage.batchSize", defaultValue = Some(10000), doc = "Controls the size of batches for columnar caching. Larger batch sizes can improve " + - "memory utilization and compression, but risk OOMs when caching data.") + "memory utilization and compression, but risk OOMs when caching data.", + isPublic = false) val IN_MEMORY_PARTITION_PRUNING = booleanConf("spark.sql.inMemoryColumnarStorage.partitionPruning", defaultValue = Some(false), - doc = "") + doc = "When true, enable partition pruning for in-memory columnar tables.", + isPublic = false) val AUTO_BROADCASTJOIN_THRESHOLD = intConf("spark.sql.autoBroadcastJoinThreshold", defaultValue = Some(10 * 1024 * 1024), @@ -203,29 +211,35 @@ private[spark] object SQLConf { "Note that currently statistics are only supported for Hive Metastore tables where the " + "commandANALYZE TABLE <tableName> COMPUTE STATISTICS noscan has been run.") - val DEFAULT_SIZE_IN_BYTES = longConf("spark.sql.defaultSizeInBytes", isPublic = false) + val DEFAULT_SIZE_IN_BYTES = longConf( + "spark.sql.defaultSizeInBytes", + doc = "The default table size used in query planning. By default, it is set to a larger " + + "value than `spark.sql.autoBroadcastJoinThreshold` to be more conservative. That is to say " + + "by default the optimizer will not choose to broadcast a table unless it knows for sure its" + + "size is small enough.", + isPublic = false) val SHUFFLE_PARTITIONS = intConf("spark.sql.shuffle.partitions", defaultValue = Some(200), - doc = "Configures the number of partitions to use when shuffling data for joins or " + - "aggregations.") + doc = "The default number of partitions to use when shuffling data for joins or aggregations.") val CODEGEN_ENABLED = booleanConf("spark.sql.codegen", defaultValue = Some(true), doc = "When true, code will be dynamically generated at runtime for expression evaluation in" + - " a specific query. For some queries with complicated expression this option can lead to " + - "significant speed-ups. However, for simple queries this can actually slow down query " + - "execution.") + " a specific query.") val UNSAFE_ENABLED = booleanConf("spark.sql.unsafe.enabled", defaultValue = Some(false), - doc = "") + doc = "When true, use the new optimized Tungsten physical execution backend.") - val DIALECT = stringConf("spark.sql.dialect", defaultValue = Some("sql"), doc = "") + val DIALECT = stringConf( + "spark.sql.dialect", + defaultValue = Some("sql"), + doc = "The default SQL dialect to use.") val CASE_SENSITIVE = booleanConf("spark.sql.caseSensitive", defaultValue = Some(true), - doc = "") + doc = "Whether the query analyzer should be case sensitive or not.") val PARQUET_SCHEMA_MERGING_ENABLED = booleanConf("spark.sql.parquet.mergeSchema", defaultValue = Some(true), @@ -273,9 +287,8 @@ private[spark] object SQLConf { val PARQUET_FOLLOW_PARQUET_FORMAT_SPEC = booleanConf( key = "spark.sql.parquet.followParquetFormatSpec", defaultValue = Some(false), - doc = "Whether to stick to Parquet format specification when converting Parquet schema to " + - "Spark SQL schema and vice versa. Sticks to the specification if set to true; falls back " + - "to compatible mode if set to false.", + doc = "Whether to follow Parquet's format specification when converting Parquet schema to " + + "Spark SQL schema and vice versa.", isPublic = false) val PARQUET_OUTPUT_COMMITTER_CLASS = stringConf( @@ -290,7 +303,7 @@ private[spark] object SQLConf { val ORC_FILTER_PUSHDOWN_ENABLED = booleanConf("spark.sql.orc.filterPushdown", defaultValue = Some(false), - doc = "") + doc = "When true, enable filter pushdown for ORC files.") val HIVE_VERIFY_PARTITION_PATH = booleanConf("spark.sql.hive.verifyPartitionPath", defaultValue = Some(true), @@ -302,7 +315,7 @@ private[spark] object SQLConf { val BROADCAST_TIMEOUT = intConf("spark.sql.broadcastTimeout", defaultValue = Some(5 * 60), - doc = "") + doc = "Timeout in seconds for the broadcast wait time in broadcast joins.") // Options that control which operators can be chosen by the query planner. These should be // considered hints and may be ignored by future versions of Spark SQL. @@ -313,7 +326,7 @@ private[spark] object SQLConf { val SORTMERGE_JOIN = booleanConf("spark.sql.planner.sortMergeJoin", defaultValue = Some(false), - doc = "") + doc = "When true, use sort merge join (as opposed to hash join) by default for large joins.") // This is only used for the thriftserver val THRIFTSERVER_POOL = stringConf("spark.sql.thriftserver.scheduler.pool", @@ -321,16 +334,16 @@ private[spark] object SQLConf { val THRIFTSERVER_UI_STATEMENT_LIMIT = intConf("spark.sql.thriftserver.ui.retainedStatements", defaultValue = Some(200), - doc = "") + doc = "The number of SQL statements kept in the JDBC/ODBC web UI history.") val THRIFTSERVER_UI_SESSION_LIMIT = intConf("spark.sql.thriftserver.ui.retainedSessions", defaultValue = Some(200), - doc = "") + doc = "The number of SQL client sessions kept in the JDBC/ODBC web UI history.") // This is used to set the default data source val DEFAULT_DATA_SOURCE_NAME = stringConf("spark.sql.sources.default", defaultValue = Some("org.apache.spark.sql.parquet"), - doc = "") + doc = "The default data source to use in input/output.") // This is used to control the when we will split a schema's JSON string to multiple pieces // in order to fit the JSON string in metastore's table property (by default, the value has @@ -338,18 +351,20 @@ private[spark] object SQLConf { // to its length exceeds the threshold. val SCHEMA_STRING_LENGTH_THRESHOLD = intConf("spark.sql.sources.schemaStringLengthThreshold", defaultValue = Some(4000), - doc = "") + doc = "The maximum length allowed in a single cell when " + + "storing additional schema information in Hive's metastore.", + isPublic = false) // Whether to perform partition discovery when loading external data sources. Default to true. val PARTITION_DISCOVERY_ENABLED = booleanConf("spark.sql.sources.partitionDiscovery.enabled", defaultValue = Some(true), - doc = "") + doc = "When true, automtically discover data partitions.") // Whether to perform partition column type inference. Default to true. val PARTITION_COLUMN_TYPE_INFERENCE = booleanConf("spark.sql.sources.partitionColumnTypeInference.enabled", defaultValue = Some(true), - doc = "") + doc = "When true, automatically infer the data types for partitioned columns.") // The output committer class used by HadoopFsRelation. The specified class needs to be a // subclass of org.apache.hadoop.mapreduce.OutputCommitter. @@ -363,22 +378,28 @@ private[spark] object SQLConf { // Whether to perform eager analysis when constructing a dataframe. // Set to false when debugging requires the ability to look at invalid query plans. - val DATAFRAME_EAGER_ANALYSIS = booleanConf("spark.sql.eagerAnalysis", + val DATAFRAME_EAGER_ANALYSIS = booleanConf( + "spark.sql.eagerAnalysis", defaultValue = Some(true), - doc = "") + doc = "When true, eagerly applies query analysis on DataFrame operations.", + isPublic = false) // Whether to automatically resolve ambiguity in join conditions for self-joins. // See SPARK-6231. - val DATAFRAME_SELF_JOIN_AUTO_RESOLVE_AMBIGUITY = - booleanConf("spark.sql.selfJoinAutoResolveAmbiguity", defaultValue = Some(true), doc = "") + val DATAFRAME_SELF_JOIN_AUTO_RESOLVE_AMBIGUITY = booleanConf( + "spark.sql.selfJoinAutoResolveAmbiguity", + defaultValue = Some(true), + isPublic = false) // Whether to retain group by columns or not in GroupedData.agg. - val DATAFRAME_RETAIN_GROUP_COLUMNS = booleanConf("spark.sql.retainGroupColumns", + val DATAFRAME_RETAIN_GROUP_COLUMNS = booleanConf( + "spark.sql.retainGroupColumns", defaultValue = Some(true), - doc = "") + isPublic = false) - val USE_SQL_SERIALIZER2 = booleanConf("spark.sql.useSerializer2", - defaultValue = Some(true), doc = "") + val USE_SQL_SERIALIZER2 = booleanConf( + "spark.sql.useSerializer2", + defaultValue = Some(true), isPublic = false) val USE_JACKSON_STREAMING_API = booleanConf("spark.sql.json.useJacksonStreamingAPI", defaultValue = Some(true), doc = "") @@ -422,112 +443,53 @@ private[sql] class SQLConf extends Serializable with CatalystConf { */ private[spark] def dialect: String = getConf(DIALECT) - /** When true tables cached using the in-memory columnar caching will be compressed. */ private[spark] def useCompression: Boolean = getConf(COMPRESS_CACHED) - /** The compression codec for writing to a Parquetfile */ private[spark] def parquetCompressionCodec: String = getConf(PARQUET_COMPRESSION) private[spark] def parquetCacheMetadata: Boolean = getConf(PARQUET_CACHE_METADATA) - /** The number of rows that will be */ private[spark] def columnBatchSize: Int = getConf(COLUMN_BATCH_SIZE) - /** Number of partitions to use for shuffle operators. */ private[spark] def numShufflePartitions: Int = getConf(SHUFFLE_PARTITIONS) - /** When true predicates will be passed to the parquet record reader when possible. */ private[spark] def parquetFilterPushDown: Boolean = getConf(PARQUET_FILTER_PUSHDOWN_ENABLED) - /** When true uses Parquet implementation based on data source API */ private[spark] def parquetUseDataSourceApi: Boolean = getConf(PARQUET_USE_DATA_SOURCE_API) private[spark] def orcFilterPushDown: Boolean = getConf(ORC_FILTER_PUSHDOWN_ENABLED) - /** When true uses verifyPartitionPath to prune the path which is not exists. */ private[spark] def verifyPartitionPath: Boolean = getConf(HIVE_VERIFY_PARTITION_PATH) - /** When true the planner will use the external sort, which may spill to disk. */ private[spark] def externalSortEnabled: Boolean = getConf(EXTERNAL_SORT) - /** - * Sort merge join would sort the two side of join first, and then iterate both sides together - * only once to get all matches. Using sort merge join can save a lot of memory usage compared - * to HashJoin. - */ private[spark] def sortMergeJoinEnabled: Boolean = getConf(SORTMERGE_JOIN) - /** - * When set to true, Spark SQL will use the Janino at runtime to generate custom bytecode - * that evaluates expressions found in queries. In general this custom code runs much faster - * than interpreted evaluation, but there are some start-up costs (5-10ms) due to compilation. - */ private[spark] def codegenEnabled: Boolean = getConf(CODEGEN_ENABLED) - /** - * caseSensitive analysis true by default - */ def caseSensitiveAnalysis: Boolean = getConf(SQLConf.CASE_SENSITIVE) - /** - * When set to true, Spark SQL will use managed memory for certain operations. This option only - * takes effect if codegen is enabled. - * - * Defaults to false as this feature is currently experimental. - */ private[spark] def unsafeEnabled: Boolean = getConf(UNSAFE_ENABLED) private[spark] def useSqlSerializer2: Boolean = getConf(USE_SQL_SERIALIZER2) - /** - * Selects between the new (true) and old (false) JSON handlers, to be removed in Spark 1.5.0 - */ private[spark] def useJacksonStreamingAPI: Boolean = getConf(USE_JACKSON_STREAMING_API) - /** - * Upper bound on the sizes (in bytes) of the tables qualified for the auto conversion to - * a broadcast value during the physical executions of join operations. Setting this to -1 - * effectively disables auto conversion. - * - * Hive setting: hive.auto.convert.join.noconditionaltask.size, whose default value is 10000. - */ private[spark] def autoBroadcastJoinThreshold: Int = getConf(AUTO_BROADCASTJOIN_THRESHOLD) - /** - * The default size in bytes to assign to a logical operator's estimation statistics. By default, - * it is set to a larger value than `autoBroadcastJoinThreshold`, hence any logical operator - * without a properly implemented estimation of this statistic will not be incorrectly broadcasted - * in joins. - */ private[spark] def defaultSizeInBytes: Long = getConf(DEFAULT_SIZE_IN_BYTES, autoBroadcastJoinThreshold + 1L) - /** - * When set to true, we always treat byte arrays in Parquet files as strings. - */ private[spark] def isParquetBinaryAsString: Boolean = getConf(PARQUET_BINARY_AS_STRING) - /** - * When set to true, we always treat INT96Values in Parquet files as timestamp. - */ private[spark] def isParquetINT96AsTimestamp: Boolean = getConf(PARQUET_INT96_AS_TIMESTAMP) - /** - * When set to true, sticks to Parquet format spec when converting Parquet schema to Spark SQL - * schema and vice versa. Otherwise, falls back to compatible mode. - */ private[spark] def followParquetFormatSpec: Boolean = getConf(PARQUET_FOLLOW_PARQUET_FORMAT_SPEC) - /** - * When set to true, partition pruning for in-memory columnar tables is enabled. - */ private[spark] def inMemoryPartitionPruning: Boolean = getConf(IN_MEMORY_PARTITION_PRUNING) private[spark] def columnNameOfCorruptRecord: String = getConf(COLUMN_NAME_OF_CORRUPT_RECORD) - /** - * Timeout in seconds for the broadcast wait time in hash join - */ private[spark] def broadcastTimeout: Int = getConf(BROADCAST_TIMEOUT) private[spark] def defaultDataSourceName: String = getConf(DEFAULT_DATA_SOURCE_NAME) From 9914b1b2c5d5fe020f54d95f59f03023de2ea78a Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Sat, 18 Jul 2015 18:18:19 -0700 Subject: [PATCH 43/58] [SPARK-9150][SQL] Create CodegenFallback and Unevaluable trait It is very hard to track which expressions have code gen implemented or not. This patch removes the default fallback gencode implementation from Expression, and moves that into a new trait called CodegenFallback. Each concrete expression needs to either implement code generation, or mix in CodegenFallback. This makes it very easy to track which expressions have code generation implemented already. Additionally, this patch creates an Unevaluable trait that can be used to track expressions that don't support evaluation (e.g. Star). Author: Reynold Xin Closes #7487 from rxin/codegenfallback and squashes the following commits: 14ebf38 [Reynold Xin] Fixed Conv 6c1c882 [Reynold Xin] Fixed Alias. b42611b [Reynold Xin] [SPARK-9150][SQL] Create a trait to track code generation for expressions. cb5c066 [Reynold Xin] Removed extra import. 39cbe40 [Reynold Xin] [SPARK-8240][SQL] string function: concat --- .../sql/catalyst/analysis/unresolved.scala | 43 ++++--------- .../spark/sql/catalyst/expressions/Cast.scala | 7 +-- .../sql/catalyst/expressions/Expression.scala | 28 +++++---- .../sql/catalyst/expressions/ScalaUDF.scala | 4 +- .../sql/catalyst/expressions/SortOrder.scala | 10 +-- .../sql/catalyst/expressions/aggregates.scala | 10 +-- .../sql/catalyst/expressions/arithmetic.scala | 5 +- .../expressions/codegen/CodegenFallback.scala | 40 ++++++++++++ .../expressions/complexTypeCreator.scala | 11 ++-- .../expressions/datetimeFunctions.scala | 5 +- .../sql/catalyst/expressions/generators.scala | 8 +-- .../sql/catalyst/expressions/literals.scala | 7 ++- .../spark/sql/catalyst/expressions/math.scala | 61 ++++++++++--------- .../expressions/namedExpressions.scala | 15 ++--- .../sql/catalyst/expressions/predicates.scala | 7 ++- .../spark/sql/catalyst/expressions/sets.scala | 12 ++-- .../expressions/stringOperations.scala | 48 +++++++++------ .../expressions/windowExpressions.scala | 41 ++++--------- .../plans/physical/partitioning.scala | 16 +---- .../analysis/AnalysisErrorSuite.scala | 4 +- .../analysis/HiveTypeCoercionSuite.scala | 8 +-- .../sql/catalyst/trees/TreeNodeSuite.scala | 3 +- .../spark/sql/execution/pythonUDFs.scala | 6 +- .../org/apache/spark/sql/hive/hiveUDFs.scala | 25 ++++---- 24 files changed, 206 insertions(+), 218 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index 4a1a1ed61ebe7..0daee1990a6e0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -17,9 +17,8 @@ package org.apache.spark.sql.catalyst.analysis -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.{errors, trees} -import org.apache.spark.sql.catalyst.errors.TreeNodeException +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.catalyst.errors import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.LeafNode import org.apache.spark.sql.catalyst.trees.TreeNode @@ -50,7 +49,7 @@ case class UnresolvedRelation( /** * Holds the name of an attribute that has yet to be resolved. */ -case class UnresolvedAttribute(nameParts: Seq[String]) extends Attribute { +case class UnresolvedAttribute(nameParts: Seq[String]) extends Attribute with Unevaluable { def name: String = nameParts.map(n => if (n.contains(".")) s"`$n`" else n).mkString(".") @@ -66,10 +65,6 @@ case class UnresolvedAttribute(nameParts: Seq[String]) extends Attribute { override def withQualifiers(newQualifiers: Seq[String]): UnresolvedAttribute = this override def withName(newName: String): UnresolvedAttribute = UnresolvedAttribute.quoted(newName) - // Unresolved attributes are transient at compile time and don't get evaluated during execution. - override def eval(input: InternalRow = null): Any = - throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") - override def toString: String = s"'$name" } @@ -78,16 +73,14 @@ object UnresolvedAttribute { def quoted(name: String): UnresolvedAttribute = new UnresolvedAttribute(Seq(name)) } -case class UnresolvedFunction(name: String, children: Seq[Expression]) extends Expression { +case class UnresolvedFunction(name: String, children: Seq[Expression]) + extends Expression with Unevaluable { + override def dataType: DataType = throw new UnresolvedException(this, "dataType") override def foldable: Boolean = throw new UnresolvedException(this, "foldable") override def nullable: Boolean = throw new UnresolvedException(this, "nullable") override lazy val resolved = false - // Unresolved functions are transient at compile time and don't get evaluated during execution. - override def eval(input: InternalRow = null): Any = - throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") - override def toString: String = s"'$name(${children.mkString(",")})" } @@ -105,10 +98,6 @@ abstract class Star extends LeafExpression with NamedExpression { override def toAttribute: Attribute = throw new UnresolvedException(this, "toAttribute") override lazy val resolved = false - // Star gets expanded at runtime so we never evaluate a Star. - override def eval(input: InternalRow = null): Any = - throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") - def expand(input: Seq[Attribute], resolver: Resolver): Seq[NamedExpression] } @@ -120,7 +109,7 @@ abstract class Star extends LeafExpression with NamedExpression { * @param table an optional table that should be the target of the expansion. If omitted all * tables' columns are produced. */ -case class UnresolvedStar(table: Option[String]) extends Star { +case class UnresolvedStar(table: Option[String]) extends Star with Unevaluable { override def expand(input: Seq[Attribute], resolver: Resolver): Seq[NamedExpression] = { val expandedAttributes: Seq[Attribute] = table match { @@ -149,7 +138,7 @@ case class UnresolvedStar(table: Option[String]) extends Star { * @param names the names to be associated with each output of computing [[child]]. */ case class MultiAlias(child: Expression, names: Seq[String]) - extends UnaryExpression with NamedExpression { + extends UnaryExpression with NamedExpression with CodegenFallback { override def name: String = throw new UnresolvedException(this, "name") @@ -165,9 +154,6 @@ case class MultiAlias(child: Expression, names: Seq[String]) override lazy val resolved = false - override def eval(input: InternalRow = null): Any = - throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") - override def toString: String = s"$child AS $names" } @@ -178,7 +164,7 @@ case class MultiAlias(child: Expression, names: Seq[String]) * * @param expressions Expressions to expand. */ -case class ResolvedStar(expressions: Seq[NamedExpression]) extends Star { +case class ResolvedStar(expressions: Seq[NamedExpression]) extends Star with Unevaluable { override def expand(input: Seq[Attribute], resolver: Resolver): Seq[NamedExpression] = expressions override def toString: String = expressions.mkString("ResolvedStar(", ", ", ")") } @@ -192,23 +178,21 @@ case class ResolvedStar(expressions: Seq[NamedExpression]) extends Star { * can be key of Map, index of Array, field name of Struct. */ case class UnresolvedExtractValue(child: Expression, extraction: Expression) - extends UnaryExpression { + extends UnaryExpression with Unevaluable { override def dataType: DataType = throw new UnresolvedException(this, "dataType") override def foldable: Boolean = throw new UnresolvedException(this, "foldable") override def nullable: Boolean = throw new UnresolvedException(this, "nullable") override lazy val resolved = false - override def eval(input: InternalRow = null): Any = - throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") - override def toString: String = s"$child[$extraction]" } /** * Holds the expression that has yet to be aliased. */ -case class UnresolvedAlias(child: Expression) extends UnaryExpression with NamedExpression { +case class UnresolvedAlias(child: Expression) + extends UnaryExpression with NamedExpression with Unevaluable { override def toAttribute: Attribute = throw new UnresolvedException(this, "toAttribute") override def qualifiers: Seq[String] = throw new UnresolvedException(this, "qualifiers") @@ -218,7 +202,4 @@ case class UnresolvedAlias(child: Expression) extends UnaryExpression with Named override def name: String = throw new UnresolvedException(this, "name") override lazy val resolved = false - - override def eval(input: InternalRow = null): Any = - throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 692b9fddbb041..3346d3c9f9e61 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -18,12 +18,10 @@ package org.apache.spark.sql.catalyst.expressions import java.math.{BigDecimal => JavaBigDecimal} -import java.sql.{Date, Timestamp} -import org.apache.spark.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult -import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode} +import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.{Interval, UTF8String} @@ -106,7 +104,8 @@ object Cast { } /** Cast the child expression to the target data type. */ -case class Cast(child: Expression, dataType: DataType) extends UnaryExpression with Logging { +case class Cast(child: Expression, dataType: DataType) + extends UnaryExpression with CodegenFallback { override def checkInputDataTypes(): TypeCheckResult = { if (Cast.canCast(child.dataType, dataType)) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 0e128d8bdcd96..d0a1aa9a1e912 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -101,19 +101,7 @@ abstract class Expression extends TreeNode[Expression] { * @param ev an [[GeneratedExpressionCode]] with unique terms. * @return Java source code */ - protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - ctx.references += this - val objectTerm = ctx.freshName("obj") - s""" - /* expression: ${this} */ - Object $objectTerm = expressions[${ctx.references.size - 1}].eval(i); - boolean ${ev.isNull} = $objectTerm == null; - ${ctx.javaType(this.dataType)} ${ev.primitive} = ${ctx.defaultValue(this.dataType)}; - if (!${ev.isNull}) { - ${ev.primitive} = (${ctx.boxedType(this.dataType)}) $objectTerm; - } - """ - } + protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String /** * Returns `true` if this expression and all its children have been resolved to a specific schema @@ -182,6 +170,20 @@ abstract class Expression extends TreeNode[Expression] { } +/** + * An expression that cannot be evaluated. Some expressions don't live past analysis or optimization + * time (e.g. Star). This trait is used by those expressions. + */ +trait Unevaluable { self: Expression => + + override def eval(input: InternalRow = null): Any = + throw new UnsupportedOperationException(s"Cannot evaluate expression: $this") + + override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = + throw new UnsupportedOperationException(s"Cannot evaluate expression: $this") +} + + /** * A leaf expression, i.e. one without any child expressions. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala index 22687acd68a97..11c7950c0613b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.CatalystTypeConverters +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.types.DataType /** @@ -29,7 +30,8 @@ case class ScalaUDF( function: AnyRef, dataType: DataType, children: Seq[Expression], - inputTypes: Seq[DataType] = Nil) extends Expression with ImplicitCastInputTypes { + inputTypes: Seq[DataType] = Nil) + extends Expression with ImplicitCastInputTypes with CodegenFallback { override def nullable: Boolean = true diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala index b8f7068c9e5e5..3f436c0eb893c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala @@ -17,9 +17,6 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.errors.TreeNodeException -import org.apache.spark.sql.catalyst.trees import org.apache.spark.sql.types.DataType abstract sealed class SortDirection @@ -30,7 +27,8 @@ case object Descending extends SortDirection * An expression that can be used to sort a tuple. This class extends expression primarily so that * transformations over expression will descend into its child. */ -case class SortOrder(child: Expression, direction: SortDirection) extends UnaryExpression { +case class SortOrder(child: Expression, direction: SortDirection) + extends UnaryExpression with Unevaluable { /** Sort order is not foldable because we don't have an eval for it. */ override def foldable: Boolean = false @@ -38,9 +36,5 @@ case class SortOrder(child: Expression, direction: SortDirection) extends UnaryE override def dataType: DataType = child.dataType override def nullable: Boolean = child.nullable - // SortOrder itself is never evaluated. - override def eval(input: InternalRow = null): Any = - throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") - override def toString: String = s"$child ${if (direction == Ascending) "ASC" else "DESC"}" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala index af9a674ab4958..d705a1286065c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala @@ -26,7 +26,8 @@ import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ import org.apache.spark.util.collection.OpenHashSet -trait AggregateExpression extends Expression { + +trait AggregateExpression extends Expression with Unevaluable { /** * Aggregate expressions should not be foldable. @@ -38,13 +39,6 @@ trait AggregateExpression extends Expression { * of input rows/ */ def newInstance(): AggregateFunction - - /** - * [[AggregateExpression.eval]] should never be invoked because [[AggregateExpression]]'s are - * replaced with a physical aggregate operator at runtime. - */ - override def eval(input: InternalRow = null): Any = - throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index e83650fc8cb0e..05b5ad88fee8f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode} +import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.Interval @@ -65,7 +65,8 @@ case class UnaryPositive(child: Expression) extends UnaryExpression with Expects /** * A function that get the absolute value of the numeric value. */ -case class Abs(child: Expression) extends UnaryExpression with ExpectsInputTypes { +case class Abs(child: Expression) + extends UnaryExpression with ExpectsInputTypes with CodegenFallback { override def inputTypes: Seq[AbstractDataType] = Seq(NumericType) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala new file mode 100644 index 0000000000000..bf4f600cb26e5 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.codegen + +import org.apache.spark.sql.catalyst.expressions.Expression + +/** + * A trait that can be used to provide a fallback mode for expression code generation. + */ +trait CodegenFallback { self: Expression => + + protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + ctx.references += this + val objectTerm = ctx.freshName("obj") + s""" + /* expression: ${this} */ + Object $objectTerm = expressions[${ctx.references.size - 1}].eval(i); + boolean ${ev.isNull} = $objectTerm == null; + ${ctx.javaType(this.dataType)} ${ev.primitive} = ${ctx.defaultValue(this.dataType)}; + if (!${ev.isNull}) { + ${ev.primitive} = (${ctx.boxedType(this.dataType)}) $objectTerm; + } + """ + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index d1e4c458864f1..f9fd04c02aaef 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -19,13 +19,14 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ /** * Returns an Array containing the evaluation of all children expressions. */ -case class CreateArray(children: Seq[Expression]) extends Expression { +case class CreateArray(children: Seq[Expression]) extends Expression with CodegenFallback { override def foldable: Boolean = children.forall(_.foldable) @@ -51,7 +52,7 @@ case class CreateArray(children: Seq[Expression]) extends Expression { * Returns a Row containing the evaluation of all children expressions. * TODO: [[CreateStruct]] does not support codegen. */ -case class CreateStruct(children: Seq[Expression]) extends Expression { +case class CreateStruct(children: Seq[Expression]) extends Expression with CodegenFallback { override def foldable: Boolean = children.forall(_.foldable) @@ -83,7 +84,7 @@ case class CreateStruct(children: Seq[Expression]) extends Expression { * * @param children Seq(name1, val1, name2, val2, ...) */ -case class CreateNamedStruct(children: Seq[Expression]) extends Expression { +case class CreateNamedStruct(children: Seq[Expression]) extends Expression with CodegenFallback { private lazy val (nameExprs, valExprs) = children.grouped(2).map { case Seq(name, value) => (name, value) }.toList.unzip @@ -103,11 +104,11 @@ case class CreateNamedStruct(children: Seq[Expression]) extends Expression { override def checkInputDataTypes(): TypeCheckResult = { if (children.size % 2 != 0) { - TypeCheckResult.TypeCheckFailure("CreateNamedStruct expects an even number of arguments.") + TypeCheckResult.TypeCheckFailure(s"$prettyName expects an even number of arguments.") } else { val invalidNames = nameExprs.filterNot(e => e.foldable && e.dataType == StringType && !nullable) - if (invalidNames.size != 0) { + if (invalidNames.nonEmpty) { TypeCheckResult.TypeCheckFailure( s"Odd position only allow foldable and not-null StringType expressions, got :" + s" ${invalidNames.mkString(",")}") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeFunctions.scala index dd5ec330a771b..4bed140cffbfa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeFunctions.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ @@ -27,7 +28,7 @@ import org.apache.spark.sql.types._ * * There is no code generation since this expression should get constant folded by the optimizer. */ -case class CurrentDate() extends LeafExpression { +case class CurrentDate() extends LeafExpression with CodegenFallback { override def foldable: Boolean = true override def nullable: Boolean = false @@ -44,7 +45,7 @@ case class CurrentDate() extends LeafExpression { * * There is no code generation since this expression should get constant folded by the optimizer. */ -case class CurrentTimestamp() extends LeafExpression { +case class CurrentTimestamp() extends LeafExpression with CodegenFallback { override def foldable: Boolean = true override def nullable: Boolean = false diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala index c58a6d36141c1..2dbcf2830f876 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala @@ -20,9 +20,9 @@ package org.apache.spark.sql.catalyst.expressions import scala.collection.Map import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.analysis.TypeCheckResult -import org.apache.spark.sql.catalyst.{CatalystTypeConverters, trees} +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.types._ /** @@ -73,7 +73,7 @@ case class UserDefinedGenerator( elementTypes: Seq[(DataType, Boolean)], function: Row => TraversableOnce[InternalRow], children: Seq[Expression]) - extends Generator { + extends Generator with CodegenFallback { @transient private[this] var inputRow: InterpretedProjection = _ @transient private[this] var convertToScala: (InternalRow) => Row = _ @@ -100,7 +100,7 @@ case class UserDefinedGenerator( /** * Given an input array produces a sequence of rows for each value in the array. */ -case class Explode(child: Expression) extends UnaryExpression with Generator { +case class Explode(child: Expression) extends UnaryExpression with Generator with CodegenFallback { override def children: Seq[Expression] = child :: Nil diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index e1fdb29541fa8..f25ac32679587 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -21,7 +21,7 @@ import java.sql.{Date, Timestamp} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.CatalystTypeConverters -import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode} +import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types._ @@ -75,7 +75,8 @@ object IntegerLiteral { /** * In order to do type checking, use Literal.create() instead of constructor */ -case class Literal protected (value: Any, dataType: DataType) extends LeafExpression { +case class Literal protected (value: Any, dataType: DataType) + extends LeafExpression with CodegenFallback { override def foldable: Boolean = true override def nullable: Boolean = value == null @@ -142,7 +143,7 @@ case class Literal protected (value: Any, dataType: DataType) extends LeafExpres // TODO: Specialize case class MutableLiteral(var value: Any, dataType: DataType, nullable: Boolean = true) - extends LeafExpression { + extends LeafExpression with CodegenFallback { def update(expression: Expression, input: InternalRow): Unit = { value = expression.eval(input) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala index eb5c065a34123..7ce64d29ba59a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.catalyst.expressions import java.{lang => jl} -import java.util.Arrays import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckSuccess, TypeCheckFailure} @@ -29,11 +28,14 @@ import org.apache.spark.unsafe.types.UTF8String /** * A leaf expression specifically for math constants. Math constants expect no input. + * + * There is no code generation because they should get constant folded by the optimizer. + * * @param c The math constant. * @param name The short name of the function */ abstract class LeafMathExpression(c: Double, name: String) - extends LeafExpression with Serializable { + extends LeafExpression with CodegenFallback { override def dataType: DataType = DoubleType override def foldable: Boolean = true @@ -41,13 +43,6 @@ abstract class LeafMathExpression(c: Double, name: String) override def toString: String = s"$name()" override def eval(input: InternalRow): Any = c - - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - s""" - boolean ${ev.isNull} = false; - ${ctx.javaType(dataType)} ${ev.primitive} = java.lang.Math.$name; - """ - } } /** @@ -130,8 +125,16 @@ abstract class BinaryMathExpression(f: (Double, Double) => Double, name: String) //////////////////////////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////////////////////////// +/** + * Euler's number. Note that there is no code generation because this is only + * evaluated by the optimizer during constant folding. + */ case class EulerNumber() extends LeafMathExpression(math.E, "E") +/** + * Pi. Note that there is no code generation because this is only + * evaluated by the optimizer during constant folding. + */ case class Pi() extends LeafMathExpression(math.Pi, "PI") //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -161,7 +164,7 @@ case class Cosh(child: Expression) extends UnaryMathExpression(math.cosh, "COSH" * @param toBaseExpr to which base */ case class Conv(numExpr: Expression, fromBaseExpr: Expression, toBaseExpr: Expression) - extends Expression with ImplicitCastInputTypes{ + extends Expression with ImplicitCastInputTypes with CodegenFallback { override def foldable: Boolean = numExpr.foldable && fromBaseExpr.foldable && toBaseExpr.foldable @@ -171,6 +174,8 @@ case class Conv(numExpr: Expression, fromBaseExpr: Expression, toBaseExpr: Expre override def inputTypes: Seq[AbstractDataType] = Seq(StringType, IntegerType, IntegerType) + override def dataType: DataType = StringType + /** Returns the result of evaluating this expression on a given input Row */ override def eval(input: InternalRow): Any = { val num = numExpr.eval(input) @@ -179,17 +184,13 @@ case class Conv(numExpr: Expression, fromBaseExpr: Expression, toBaseExpr: Expre if (num == null || fromBase == null || toBase == null) { null } else { - conv(num.asInstanceOf[UTF8String].getBytes, - fromBase.asInstanceOf[Int], toBase.asInstanceOf[Int]) + conv( + num.asInstanceOf[UTF8String].getBytes, + fromBase.asInstanceOf[Int], + toBase.asInstanceOf[Int]) } } - /** - * Returns the [[DataType]] of the result of evaluating this expression. It is - * invalid to query the dataType of an unresolved expression (i.e., when `resolved` == false). - */ - override def dataType: DataType = StringType - private val value = new Array[Byte](64) /** @@ -208,7 +209,7 @@ case class Conv(numExpr: Expression, fromBaseExpr: Expression, toBaseExpr: Expre // Two's complement => x = uval - 2*MAX - 2 // => uval = x + 2*MAX + 2 // Now, use the fact: (a+b)/c = a/c + b/c + (a%c+b%c)/c - (x / m + 2 * (Long.MaxValue / m) + 2 / m + (x % m + 2 * (Long.MaxValue % m) + 2 % m) / m) + x / m + 2 * (Long.MaxValue / m) + 2 / m + (x % m + 2 * (Long.MaxValue % m) + 2 % m) / m } } @@ -220,7 +221,7 @@ case class Conv(numExpr: Expression, fromBaseExpr: Expression, toBaseExpr: Expre */ private def decode(v: Long, radix: Int): Unit = { var tmpV = v - Arrays.fill(value, 0.asInstanceOf[Byte]) + java.util.Arrays.fill(value, 0.asInstanceOf[Byte]) var i = value.length - 1 while (tmpV != 0) { val q = unsignedLongDiv(tmpV, radix) @@ -254,7 +255,7 @@ case class Conv(numExpr: Expression, fromBaseExpr: Expression, toBaseExpr: Expre v = v * radix + value(i) i += 1 } - return v + v } /** @@ -292,16 +293,16 @@ case class Conv(numExpr: Expression, fromBaseExpr: Expression, toBaseExpr: Expre * NB: This logic is borrowed from org.apache.hadoop.hive.ql.ud.UDFConv */ private def conv(n: Array[Byte] , fromBase: Int, toBase: Int ): UTF8String = { - if (n == null || fromBase == null || toBase == null || n.isEmpty) { - return null - } - if (fromBase < Character.MIN_RADIX || fromBase > Character.MAX_RADIX || Math.abs(toBase) < Character.MIN_RADIX || Math.abs(toBase) > Character.MAX_RADIX) { return null } + if (n.length == 0) { + return null + } + var (negative, first) = if (n(0) == '-') (true, 1) else (false, 0) // Copy the digits in the right side of the array @@ -340,7 +341,7 @@ case class Conv(numExpr: Expression, fromBaseExpr: Expression, toBaseExpr: Expre resultStartPos = firstNonZeroPos - 1 value(resultStartPos) = '-' } - UTF8String.fromBytes( Arrays.copyOfRange(value, resultStartPos, value.length)) + UTF8String.fromBytes(java.util.Arrays.copyOfRange(value, resultStartPos, value.length)) } } @@ -495,8 +496,8 @@ object Hex { * Otherwise if the number is a STRING, it converts each character into its hex representation * and returns the resulting STRING. Negative numbers would be treated as two's complement. */ -case class Hex(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { - // TODO: Create code-gen version. +case class Hex(child: Expression) + extends UnaryExpression with ImplicitCastInputTypes with CodegenFallback { override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(LongType, BinaryType, StringType)) @@ -539,8 +540,8 @@ case class Hex(child: Expression) extends UnaryExpression with ImplicitCastInput * Performs the inverse operation of HEX. * Resulting characters are returned as a byte array. */ -case class Unhex(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { - // TODO: Create code-gen version. +case class Unhex(child: Expression) + extends UnaryExpression with ImplicitCastInputTypes with CodegenFallback { override def inputTypes: Seq[AbstractDataType] = Seq(StringType) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index c083ac08ded21..6f173b52ad9b9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -19,9 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute -import org.apache.spark.sql.catalyst.errors.TreeNodeException -import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode} -import org.apache.spark.sql.catalyst.trees +import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.types._ object NamedExpression { @@ -122,7 +120,9 @@ case class Alias(child: Expression, name: String)( override def eval(input: InternalRow): Any = child.eval(input) + /** Just a simple passthrough for code generation. */ override def gen(ctx: CodeGenContext): GeneratedExpressionCode = child.gen(ctx) + override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = "" override def dataType: DataType = child.dataType override def nullable: Boolean = child.nullable @@ -177,7 +177,7 @@ case class AttributeReference( override val metadata: Metadata = Metadata.empty)( val exprId: ExprId = NamedExpression.newExprId, val qualifiers: Seq[String] = Nil) - extends Attribute { + extends Attribute with Unevaluable { /** * Returns true iff the expression id is the same for both attributes. @@ -236,10 +236,6 @@ case class AttributeReference( } } - // Unresolved attributes are transient at compile time and don't get evaluated during execution. - override def eval(input: InternalRow = null): Any = - throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") - override def toString: String = s"$name#${exprId.id}$typeSuffix" } @@ -247,7 +243,7 @@ case class AttributeReference( * A place holder used when printing expressions without debugging information such as the * expression id or the unresolved indicator. */ -case class PrettyAttribute(name: String) extends Attribute { +case class PrettyAttribute(name: String) extends Attribute with Unevaluable { override def toString: String = name @@ -259,7 +255,6 @@ case class PrettyAttribute(name: String) extends Attribute { override def withName(newName: String): Attribute = throw new UnsupportedOperationException override def qualifiers: Seq[String] = throw new UnsupportedOperationException override def exprId: ExprId = throw new UnsupportedOperationException - override def eval(input: InternalRow): Any = throw new UnsupportedOperationException override def nullable: Boolean = throw new UnsupportedOperationException override def dataType: DataType = NullType } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index bddd2a9eccfc0..40ec3df224ce1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -19,10 +19,11 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util.TypeUtils -import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext} +import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.types._ + object InterpretedPredicate { def create(expression: Expression, inputSchema: Seq[Attribute]): (InternalRow => Boolean) = create(BindReferences.bindReference(expression, inputSchema)) @@ -91,7 +92,7 @@ case class Not(child: Expression) /** * Evaluates to `true` if `list` contains `value`. */ -case class In(value: Expression, list: Seq[Expression]) extends Predicate { +case class In(value: Expression, list: Seq[Expression]) extends Predicate with CodegenFallback { override def children: Seq[Expression] = value +: list override def nullable: Boolean = true // TODO: Figure out correct nullability semantics of IN. @@ -109,7 +110,7 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate { * static. */ case class InSet(child: Expression, hset: Set[Any]) - extends UnaryExpression with Predicate { + extends UnaryExpression with Predicate with CodegenFallback { override def nullable: Boolean = true // TODO: Figure out correct nullability semantics of IN. override def toString: String = s"$child INSET ${hset.mkString("(", ",", ")")}" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala index 49b2026364cd6..5b0fe8dfe2fc8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode} +import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.types._ import org.apache.spark.util.collection.OpenHashSet @@ -52,7 +52,7 @@ private[sql] class OpenHashSetUDT( /** * Creates a new set of the specified type */ -case class NewSet(elementType: DataType) extends LeafExpression { +case class NewSet(elementType: DataType) extends LeafExpression with CodegenFallback { override def nullable: Boolean = false @@ -82,7 +82,8 @@ case class NewSet(elementType: DataType) extends LeafExpression { * Note: this expression is internal and created only by the GeneratedAggregate, * we don't need to do type check for it. */ -case class AddItemToSet(item: Expression, set: Expression) extends Expression { +case class AddItemToSet(item: Expression, set: Expression) + extends Expression with CodegenFallback { override def children: Seq[Expression] = item :: set :: Nil @@ -134,7 +135,8 @@ case class AddItemToSet(item: Expression, set: Expression) extends Expression { * Note: this expression is internal and created only by the GeneratedAggregate, * we don't need to do type check for it. */ -case class CombineSets(left: Expression, right: Expression) extends BinaryExpression { +case class CombineSets(left: Expression, right: Expression) + extends BinaryExpression with CodegenFallback { override def nullable: Boolean = left.nullable override def dataType: DataType = left.dataType @@ -181,7 +183,7 @@ case class CombineSets(left: Expression, right: Expression) extends BinaryExpres * Note: this expression is internal and created only by the GeneratedAggregate, * we don't need to do type check for it. */ -case class CountSet(child: Expression) extends UnaryExpression { +case class CountSet(child: Expression) extends UnaryExpression with CodegenFallback { override def dataType: DataType = LongType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index b36354eff092a..560b1bc2d889f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -103,7 +103,7 @@ trait StringRegexExpression extends ImplicitCastInputTypes { * Simple RegEx pattern matching function */ case class Like(left: Expression, right: Expression) - extends BinaryExpression with StringRegexExpression { + extends BinaryExpression with StringRegexExpression with CodegenFallback { // replace the _ with .{1} exactly match 1 time of any character // replace the % with .*, match 0 or more times with any character @@ -133,14 +133,16 @@ case class Like(left: Expression, right: Expression) override def toString: String = s"$left LIKE $right" } + case class RLike(left: Expression, right: Expression) - extends BinaryExpression with StringRegexExpression { + extends BinaryExpression with StringRegexExpression with CodegenFallback { override def escape(v: String): String = v override def matches(regex: Pattern, str: String): Boolean = regex.matcher(str).find(0) override def toString: String = s"$left RLIKE $right" } + trait String2StringExpression extends ImplicitCastInputTypes { self: UnaryExpression => @@ -156,7 +158,8 @@ trait String2StringExpression extends ImplicitCastInputTypes { /** * A function that converts the characters of a string to uppercase. */ -case class Upper(child: Expression) extends UnaryExpression with String2StringExpression { +case class Upper(child: Expression) + extends UnaryExpression with String2StringExpression { override def convert(v: UTF8String): UTF8String = v.toUpperCase @@ -301,7 +304,7 @@ case class StringInstr(str: Expression, substr: Expression) * in given string after position pos. */ case class StringLocate(substr: Expression, str: Expression, start: Expression) - extends Expression with ImplicitCastInputTypes { + extends Expression with ImplicitCastInputTypes with CodegenFallback { def this(substr: Expression, str: Expression) = { this(substr, str, Literal(0)) @@ -342,7 +345,7 @@ case class StringLocate(substr: Expression, str: Expression, start: Expression) * Returns str, left-padded with pad to a length of len. */ case class StringLPad(str: Expression, len: Expression, pad: Expression) - extends Expression with ImplicitCastInputTypes { + extends Expression with ImplicitCastInputTypes with CodegenFallback { override def children: Seq[Expression] = str :: len :: pad :: Nil override def foldable: Boolean = children.forall(_.foldable) @@ -380,7 +383,7 @@ case class StringLPad(str: Expression, len: Expression, pad: Expression) * Returns str, right-padded with pad to a length of len. */ case class StringRPad(str: Expression, len: Expression, pad: Expression) - extends Expression with ImplicitCastInputTypes { + extends Expression with ImplicitCastInputTypes with CodegenFallback { override def children: Seq[Expression] = str :: len :: pad :: Nil override def foldable: Boolean = children.forall(_.foldable) @@ -417,9 +420,9 @@ case class StringRPad(str: Expression, len: Expression, pad: Expression) /** * Returns the input formatted according do printf-style format strings */ -case class StringFormat(children: Expression*) extends Expression { +case class StringFormat(children: Expression*) extends Expression with CodegenFallback { - require(children.length >=1, "printf() should take at least 1 argument") + require(children.nonEmpty, "printf() should take at least 1 argument") override def foldable: Boolean = children.forall(_.foldable) override def nullable: Boolean = children(0).nullable @@ -436,7 +439,7 @@ case class StringFormat(children: Expression*) extends Expression { val formatter = new java.util.Formatter(sb, Locale.US) val arglist = args.map(_.eval(input).asInstanceOf[AnyRef]) - formatter.format(pattern.asInstanceOf[UTF8String].toString(), arglist: _*) + formatter.format(pattern.asInstanceOf[UTF8String].toString, arglist: _*) UTF8String.fromString(sb.toString) } @@ -483,7 +486,8 @@ case class StringReverse(child: Expression) extends UnaryExpression with String2 /** * Returns a n spaces string. */ -case class StringSpace(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { +case class StringSpace(child: Expression) + extends UnaryExpression with ImplicitCastInputTypes with CodegenFallback { override def dataType: DataType = StringType override def inputTypes: Seq[DataType] = Seq(IntegerType) @@ -503,7 +507,7 @@ case class StringSpace(child: Expression) extends UnaryExpression with ImplicitC * Splits str around pat (pattern is a regular expression). */ case class StringSplit(str: Expression, pattern: Expression) - extends BinaryExpression with ImplicitCastInputTypes { + extends BinaryExpression with ImplicitCastInputTypes with CodegenFallback { override def left: Expression = str override def right: Expression = pattern @@ -524,7 +528,7 @@ case class StringSplit(str: Expression, pattern: Expression) * Defined for String and Binary types. */ case class Substring(str: Expression, pos: Expression, len: Expression) - extends Expression with ImplicitCastInputTypes { + extends Expression with ImplicitCastInputTypes with CodegenFallback { def this(str: Expression, pos: Expression) = { this(str, pos, Literal(Integer.MAX_VALUE)) @@ -606,8 +610,6 @@ case class Length(child: Expression) extends UnaryExpression with ExpectsInputTy case BinaryType => defineCodeGen(ctx, ev, c => s"($c).length") } } - - override def prettyName: String = "length" } /** @@ -632,7 +634,9 @@ case class Levenshtein(left: Expression, right: Expression) extends BinaryExpres /** * Returns the numeric value of the first character of str. */ -case class Ascii(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { +case class Ascii(child: Expression) + extends UnaryExpression with ImplicitCastInputTypes with CodegenFallback { + override def dataType: DataType = IntegerType override def inputTypes: Seq[DataType] = Seq(StringType) @@ -649,7 +653,9 @@ case class Ascii(child: Expression) extends UnaryExpression with ImplicitCastInp /** * Converts the argument from binary to a base 64 string. */ -case class Base64(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { +case class Base64(child: Expression) + extends UnaryExpression with ImplicitCastInputTypes with CodegenFallback { + override def dataType: DataType = StringType override def inputTypes: Seq[DataType] = Seq(BinaryType) @@ -663,7 +669,9 @@ case class Base64(child: Expression) extends UnaryExpression with ImplicitCastIn /** * Converts the argument from a base 64 string to BINARY. */ -case class UnBase64(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { +case class UnBase64(child: Expression) + extends UnaryExpression with ImplicitCastInputTypes with CodegenFallback { + override def dataType: DataType = BinaryType override def inputTypes: Seq[DataType] = Seq(StringType) @@ -677,7 +685,7 @@ case class UnBase64(child: Expression) extends UnaryExpression with ImplicitCast * If either argument is null, the result will also be null. */ case class Decode(bin: Expression, charset: Expression) - extends BinaryExpression with ImplicitCastInputTypes { + extends BinaryExpression with ImplicitCastInputTypes with CodegenFallback { override def left: Expression = bin override def right: Expression = charset @@ -696,7 +704,7 @@ case class Decode(bin: Expression, charset: Expression) * If either argument is null, the result will also be null. */ case class Encode(value: Expression, charset: Expression) - extends BinaryExpression with ImplicitCastInputTypes { + extends BinaryExpression with ImplicitCastInputTypes with CodegenFallback { override def left: Expression = value override def right: Expression = charset @@ -715,7 +723,7 @@ case class Encode(value: Expression, charset: Expression) * fractional part. */ case class FormatNumber(x: Expression, d: Expression) - extends BinaryExpression with ExpectsInputTypes { + extends BinaryExpression with ExpectsInputTypes with CodegenFallback { override def left: Expression = x override def right: Expression = d diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala index c8aa571df64fc..50bbfd644d302 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.UnresolvedException -import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.types.{DataType, NumericType} /** @@ -37,7 +36,7 @@ sealed trait WindowSpec case class WindowSpecDefinition( partitionSpec: Seq[Expression], orderSpec: Seq[SortOrder], - frameSpecification: WindowFrame) extends Expression with WindowSpec { + frameSpecification: WindowFrame) extends Expression with WindowSpec with Unevaluable { def validate: Option[String] = frameSpecification match { case UnspecifiedFrame => @@ -75,7 +74,6 @@ case class WindowSpecDefinition( override def toString: String = simpleString - override def eval(input: InternalRow): Any = throw new UnsupportedOperationException override def nullable: Boolean = true override def foldable: Boolean = false override def dataType: DataType = throw new UnsupportedOperationException @@ -274,60 +272,43 @@ trait WindowFunction extends Expression { case class UnresolvedWindowFunction( name: String, children: Seq[Expression]) - extends Expression with WindowFunction { + extends Expression with WindowFunction with Unevaluable { override def dataType: DataType = throw new UnresolvedException(this, "dataType") override def foldable: Boolean = throw new UnresolvedException(this, "foldable") override def nullable: Boolean = throw new UnresolvedException(this, "nullable") override lazy val resolved = false - override def init(): Unit = - throw new UnresolvedException(this, "init") - override def reset(): Unit = - throw new UnresolvedException(this, "reset") + override def init(): Unit = throw new UnresolvedException(this, "init") + override def reset(): Unit = throw new UnresolvedException(this, "reset") override def prepareInputParameters(input: InternalRow): AnyRef = throw new UnresolvedException(this, "prepareInputParameters") - override def update(input: AnyRef): Unit = - throw new UnresolvedException(this, "update") + override def update(input: AnyRef): Unit = throw new UnresolvedException(this, "update") override def batchUpdate(inputs: Array[AnyRef]): Unit = throw new UnresolvedException(this, "batchUpdate") - override def evaluate(): Unit = - throw new UnresolvedException(this, "evaluate") - override def get(index: Int): Any = - throw new UnresolvedException(this, "get") - // Unresolved functions are transient at compile time and don't get evaluated during execution. - override def eval(input: InternalRow = null): Any = - throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") + override def evaluate(): Unit = throw new UnresolvedException(this, "evaluate") + override def get(index: Int): Any = throw new UnresolvedException(this, "get") override def toString: String = s"'$name(${children.mkString(",")})" - override def newInstance(): WindowFunction = - throw new UnresolvedException(this, "newInstance") + override def newInstance(): WindowFunction = throw new UnresolvedException(this, "newInstance") } case class UnresolvedWindowExpression( child: UnresolvedWindowFunction, - windowSpec: WindowSpecReference) extends UnaryExpression { + windowSpec: WindowSpecReference) extends UnaryExpression with Unevaluable { override def dataType: DataType = throw new UnresolvedException(this, "dataType") override def foldable: Boolean = throw new UnresolvedException(this, "foldable") override def nullable: Boolean = throw new UnresolvedException(this, "nullable") override lazy val resolved = false - - // Unresolved functions are transient at compile time and don't get evaluated during execution. - override def eval(input: InternalRow = null): Any = - throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") } case class WindowExpression( windowFunction: WindowFunction, - windowSpec: WindowSpecDefinition) extends Expression { - - override def children: Seq[Expression] = - windowFunction :: windowSpec :: Nil + windowSpec: WindowSpecDefinition) extends Expression with Unevaluable { - override def eval(input: InternalRow): Any = - throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") + override def children: Seq[Expression] = windowFunction :: windowSpec :: Nil override def dataType: DataType = windowFunction.dataType override def foldable: Boolean = windowFunction.foldable diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index 42dead7c28425..2dcfa19fec383 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -17,9 +17,7 @@ package org.apache.spark.sql.catalyst.plans.physical -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.errors.TreeNodeException -import org.apache.spark.sql.catalyst.expressions.{Expression, SortOrder} +import org.apache.spark.sql.catalyst.expressions.{Unevaluable, Expression, SortOrder} import org.apache.spark.sql.types.{DataType, IntegerType} /** @@ -146,8 +144,7 @@ case object BroadcastPartitioning extends Partitioning { * in the same partition. */ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int) - extends Expression - with Partitioning { + extends Expression with Partitioning with Unevaluable { override def children: Seq[Expression] = expressions override def nullable: Boolean = false @@ -169,9 +166,6 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int) } override def keyExpressions: Seq[Expression] = expressions - - override def eval(input: InternalRow = null): Any = - throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") } /** @@ -187,8 +181,7 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int) * into its child. */ case class RangePartitioning(ordering: Seq[SortOrder], numPartitions: Int) - extends Expression - with Partitioning { + extends Expression with Partitioning with Unevaluable { override def children: Seq[SortOrder] = ordering override def nullable: Boolean = false @@ -213,7 +206,4 @@ case class RangePartitioning(ordering: Seq[SortOrder], numPartitions: Int) } override def keyExpressions: Seq[Expression] = ordering.map(_.child) - - override def eval(input: InternalRow): Any = - throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index 2147d07e09bd3..dca8c881f21ab 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -31,9 +31,9 @@ import org.apache.spark.sql.types._ case class TestFunction( children: Seq[Expression], - inputTypes: Seq[AbstractDataType]) extends Expression with ImplicitCastInputTypes { + inputTypes: Seq[AbstractDataType]) + extends Expression with ImplicitCastInputTypes with Unevaluable { override def nullable: Boolean = true - override def eval(input: InternalRow): Any = throw new UnsupportedOperationException override def dataType: DataType = StringType } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala index c9b3c69c6de89..f9442bccc4a7a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala @@ -363,26 +363,26 @@ class HiveTypeCoercionSuite extends PlanTest { object HiveTypeCoercionSuite { case class AnyTypeUnaryExpression(child: Expression) - extends UnaryExpression with ExpectsInputTypes { + extends UnaryExpression with ExpectsInputTypes with Unevaluable { override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType) override def dataType: DataType = NullType } case class NumericTypeUnaryExpression(child: Expression) - extends UnaryExpression with ExpectsInputTypes { + extends UnaryExpression with ExpectsInputTypes with Unevaluable { override def inputTypes: Seq[AbstractDataType] = Seq(NumericType) override def dataType: DataType = NullType } case class AnyTypeBinaryOperator(left: Expression, right: Expression) - extends BinaryOperator { + extends BinaryOperator with Unevaluable { override def dataType: DataType = NullType override def inputType: AbstractDataType = AnyDataType override def symbol: String = "anytype" } case class NumericTypeBinaryOperator(left: Expression, right: Expression) - extends BinaryOperator { + extends BinaryOperator with Unevaluable { override def dataType: DataType = NullType override def inputType: AbstractDataType = NumericType override def symbol: String = "numerictype" diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala index 1bd7d4e5cdf0f..8fff39906b342 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala @@ -22,9 +22,10 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.types.{IntegerType, StringType, NullType} -case class Dummy(optKey: Option[Expression]) extends Expression { +case class Dummy(optKey: Option[Expression]) extends Expression with CodegenFallback { override def children: Seq[Expression] = optKey.toSeq override def nullable: Boolean = true override def dataType: NullType = NullType diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala index 6d6e67dace177..e6e27a87c7151 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala @@ -51,15 +51,11 @@ private[spark] case class PythonUDF( broadcastVars: JList[Broadcast[PythonBroadcast]], accumulator: Accumulator[JList[Array[Byte]]], dataType: DataType, - children: Seq[Expression]) extends Expression with SparkLogging { + children: Seq[Expression]) extends Expression with Unevaluable with SparkLogging { override def toString: String = s"PythonUDF#$name(${children.mkString(",")})" override def nullable: Boolean = true - - override def eval(input: InternalRow): Any = { - throw new UnsupportedOperationException("PythonUDFs can not be directly evaluated.") - } } /** diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala index 0bc8adb16afc0..4d23c7035c03d 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala @@ -36,8 +36,8 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder -import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.hive.HiveShim._ @@ -81,7 +81,7 @@ private[hive] class HiveFunctionRegistry(underlying: analysis.FunctionRegistry) } private[hive] case class HiveSimpleUDF(funcWrapper: HiveFunctionWrapper, children: Seq[Expression]) - extends Expression with HiveInspectors with Logging { + extends Expression with HiveInspectors with CodegenFallback with Logging { type UDFType = UDF @@ -146,7 +146,7 @@ private[hive] class DeferredObjectAdapter(oi: ObjectInspector) } private[hive] case class HiveGenericUDF(funcWrapper: HiveFunctionWrapper, children: Seq[Expression]) - extends Expression with HiveInspectors with Logging { + extends Expression with HiveInspectors with CodegenFallback with Logging { type UDFType = GenericUDF override def deterministic: Boolean = isUDFDeterministic @@ -166,8 +166,8 @@ private[hive] case class HiveGenericUDF(funcWrapper: HiveFunctionWrapper, childr @transient protected lazy val isUDFDeterministic = { - val udfType = function.getClass().getAnnotation(classOf[HiveUDFType]) - (udfType != null && udfType.deterministic()) + val udfType = function.getClass.getAnnotation(classOf[HiveUDFType]) + udfType != null && udfType.deterministic() } override def foldable: Boolean = @@ -301,7 +301,7 @@ private[hive] case class HiveWindowFunction( pivotResult: Boolean, isUDAFBridgeRequired: Boolean, children: Seq[Expression]) extends WindowFunction - with HiveInspectors { + with HiveInspectors with Unevaluable { // Hive window functions are based on GenericUDAFResolver2. type UDFType = GenericUDAFResolver2 @@ -330,7 +330,7 @@ private[hive] case class HiveWindowFunction( evaluator.init(GenericUDAFEvaluator.Mode.COMPLETE, inputInspectors) } - def dataType: DataType = + override def dataType: DataType = if (!pivotResult) { inspectorToDataType(returnInspector) } else { @@ -344,10 +344,7 @@ private[hive] case class HiveWindowFunction( } } - def nullable: Boolean = true - - override def eval(input: InternalRow): Any = - throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") + override def nullable: Boolean = true @transient lazy val inputProjection = new InterpretedProjection(children) @@ -406,7 +403,7 @@ private[hive] case class HiveWindowFunction( s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})" } - override def newInstance: WindowFunction = + override def newInstance(): WindowFunction = new HiveWindowFunction(funcWrapper, pivotResult, isUDAFBridgeRequired, children) } @@ -476,7 +473,7 @@ private[hive] case class HiveUDAF( /** * Converts a Hive Generic User Defined Table Generating Function (UDTF) to a - * [[catalyst.expressions.Generator Generator]]. Note that the semantics of Generators do not allow + * [[Generator]]. Note that the semantics of Generators do not allow * Generators to maintain state in between input rows. Thus UDTFs that rely on partitioning * dependent operations like calls to `close()` before producing output will not operate the same as * in Hive. However, in practice this should not affect compatibility for most sane UDTFs @@ -488,7 +485,7 @@ private[hive] case class HiveUDAF( private[hive] case class HiveGenericUDTF( funcWrapper: HiveFunctionWrapper, children: Seq[Expression]) - extends Generator with HiveInspectors { + extends Generator with HiveInspectors with CodegenFallback { @transient protected lazy val function: GenericUDTF = { From 45d798c323ffe32bc2eba4dbd271c4572f5a30cf Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Sat, 18 Jul 2015 20:27:55 -0700 Subject: [PATCH 44/58] [SPARK-8278] Remove non-streaming JSON reader. Author: Reynold Xin Closes #7501 from rxin/jsonrdd and squashes the following commits: 767ec55 [Reynold Xin] More Mima 51f456e [Reynold Xin] Mima exclude. 789cb80 [Reynold Xin] Fixed compilation error. b4cf50d [Reynold Xin] [SPARK-8278] Remove non-streaming JSON reader. --- project/MimaExcludes.scala | 3 + .../apache/spark/sql/DataFrameReader.scala | 15 +- .../scala/org/apache/spark/sql/SQLConf.scala | 5 - .../apache/spark/sql/json/JSONRelation.scala | 48 +- .../org/apache/spark/sql/json/JsonRDD.scala | 449 ------------------ .../org/apache/spark/sql/json/JsonSuite.scala | 27 +- 6 files changed, 29 insertions(+), 518 deletions(-) delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 4e4e810ec36e3..36417f5df9f2d 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -64,6 +64,9 @@ object MimaExcludes { excludePackage("org.apache.spark.sql.execution"), // Parquet support is considered private. excludePackage("org.apache.spark.sql.parquet"), + // The old JSON RDD is removed in favor of streaming Jackson + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.json.JsonRDD$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.json.JsonRDD"), // local function inside a method ProblemFilters.exclude[MissingMethodProblem]( "org.apache.spark.sql.SQLContext.org$apache$spark$sql$SQLContext$$needsConversion$1") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 9ad6e21da7bf7..9b23df4843c06 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -27,7 +27,7 @@ import org.apache.spark.api.java.JavaRDD import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.rdd.RDD import org.apache.spark.sql.jdbc.{JDBCPartition, JDBCPartitioningInfo, JDBCRelation} -import org.apache.spark.sql.json.{JsonRDD, JSONRelation} +import org.apache.spark.sql.json.JSONRelation import org.apache.spark.sql.parquet.ParquetRelation2 import org.apache.spark.sql.sources.{LogicalRelation, ResolvedDataSource} import org.apache.spark.sql.types.StructType @@ -236,17 +236,8 @@ class DataFrameReader private[sql](sqlContext: SQLContext) { */ def json(jsonRDD: RDD[String]): DataFrame = { val samplingRatio = extraOptions.getOrElse("samplingRatio", "1.0").toDouble - if (sqlContext.conf.useJacksonStreamingAPI) { - sqlContext.baseRelationToDataFrame( - new JSONRelation(() => jsonRDD, None, samplingRatio, userSpecifiedSchema)(sqlContext)) - } else { - val columnNameOfCorruptJsonRecord = sqlContext.conf.columnNameOfCorruptRecord - val appliedSchema = userSpecifiedSchema.getOrElse( - JsonRDD.nullTypeToStringType( - JsonRDD.inferSchema(jsonRDD, 1.0, columnNameOfCorruptJsonRecord))) - val rowRDD = JsonRDD.jsonStringToRow(jsonRDD, appliedSchema, columnNameOfCorruptJsonRecord) - sqlContext.internalCreateDataFrame(rowRDD, appliedSchema) - } + sqlContext.baseRelationToDataFrame( + new JSONRelation(() => jsonRDD, None, samplingRatio, userSpecifiedSchema)(sqlContext)) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index 2c2f7c35dfdce..84d3271ceb738 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -401,9 +401,6 @@ private[spark] object SQLConf { "spark.sql.useSerializer2", defaultValue = Some(true), isPublic = false) - val USE_JACKSON_STREAMING_API = booleanConf("spark.sql.json.useJacksonStreamingAPI", - defaultValue = Some(true), doc = "") - object Deprecated { val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks" } @@ -473,8 +470,6 @@ private[sql] class SQLConf extends Serializable with CatalystConf { private[spark] def useSqlSerializer2: Boolean = getConf(USE_SQL_SERIALIZER2) - private[spark] def useJacksonStreamingAPI: Boolean = getConf(USE_JACKSON_STREAMING_API) - private[spark] def autoBroadcastJoinThreshold: Int = getConf(AUTO_BROADCASTJOIN_THRESHOLD) private[spark] def defaultSizeInBytes: Long = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala index 2361d3bf52d2b..25802d054ac00 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala @@ -157,51 +157,27 @@ private[sql] class JSONRelation( } } - private val useJacksonStreamingAPI: Boolean = sqlContext.conf.useJacksonStreamingAPI - override val needConversion: Boolean = false override lazy val schema = userSpecifiedSchema.getOrElse { - if (useJacksonStreamingAPI) { - InferSchema( - baseRDD(), - samplingRatio, - sqlContext.conf.columnNameOfCorruptRecord) - } else { - JsonRDD.nullTypeToStringType( - JsonRDD.inferSchema( - baseRDD(), - samplingRatio, - sqlContext.conf.columnNameOfCorruptRecord)) - } + InferSchema( + baseRDD(), + samplingRatio, + sqlContext.conf.columnNameOfCorruptRecord) } override def buildScan(): RDD[Row] = { - if (useJacksonStreamingAPI) { - JacksonParser( - baseRDD(), - schema, - sqlContext.conf.columnNameOfCorruptRecord).map(_.asInstanceOf[Row]) - } else { - JsonRDD.jsonStringToRow( - baseRDD(), - schema, - sqlContext.conf.columnNameOfCorruptRecord).map(_.asInstanceOf[Row]) - } + JacksonParser( + baseRDD(), + schema, + sqlContext.conf.columnNameOfCorruptRecord).map(_.asInstanceOf[Row]) } override def buildScan(requiredColumns: Seq[Attribute], filters: Seq[Expression]): RDD[Row] = { - if (useJacksonStreamingAPI) { - JacksonParser( - baseRDD(), - StructType.fromAttributes(requiredColumns), - sqlContext.conf.columnNameOfCorruptRecord).map(_.asInstanceOf[Row]) - } else { - JsonRDD.jsonStringToRow( - baseRDD(), - StructType.fromAttributes(requiredColumns), - sqlContext.conf.columnNameOfCorruptRecord).map(_.asInstanceOf[Row]) - } + JacksonParser( + baseRDD(), + StructType.fromAttributes(requiredColumns), + sqlContext.conf.columnNameOfCorruptRecord).map(_.asInstanceOf[Row]) } override def insert(data: DataFrame, overwrite: Boolean): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala deleted file mode 100644 index b392a51bf7dce..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala +++ /dev/null @@ -1,449 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.json - -import scala.collection.Map -import scala.collection.convert.Wrappers.{JListWrapper, JMapWrapper} - -import com.fasterxml.jackson.core.JsonProcessingException -import com.fasterxml.jackson.databind.ObjectMapper - -import org.apache.spark.Logging -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.ScalaReflection -import org.apache.spark.sql.catalyst.analysis.HiveTypeCoercion -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.UTF8String - - -private[sql] object JsonRDD extends Logging { - - private[sql] def jsonStringToRow( - json: RDD[String], - schema: StructType, - columnNameOfCorruptRecords: String): RDD[InternalRow] = { - parseJson(json, columnNameOfCorruptRecords).map(parsed => asRow(parsed, schema)) - } - - private[sql] def inferSchema( - json: RDD[String], - samplingRatio: Double = 1.0, - columnNameOfCorruptRecords: String): StructType = { - require(samplingRatio > 0, s"samplingRatio ($samplingRatio) should be greater than 0") - val schemaData = if (samplingRatio > 0.99) json else json.sample(false, samplingRatio, 1) - val allKeys = - if (schemaData.isEmpty()) { - Set.empty[(String, DataType)] - } else { - parseJson(schemaData, columnNameOfCorruptRecords).map(allKeysWithValueTypes).reduce(_ ++ _) - } - createSchema(allKeys) - } - - private def createSchema(allKeys: Set[(String, DataType)]): StructType = { - // Resolve type conflicts - val resolved = allKeys.groupBy { - case (key, dataType) => key - }.map { - // Now, keys and types are organized in the format of - // key -> Set(type1, type2, ...). - case (key, typeSet) => { - val fieldName = key.substring(1, key.length - 1).split("`.`").toSeq - val dataType = typeSet.map { - case (_, dataType) => dataType - }.reduce((type1: DataType, type2: DataType) => compatibleType(type1, type2)) - - (fieldName, dataType) - } - } - - def makeStruct(values: Seq[Seq[String]], prefix: Seq[String]): StructType = { - val (topLevel, structLike) = values.partition(_.size == 1) - - val topLevelFields = topLevel.filter { - name => resolved.get(prefix ++ name).get match { - case ArrayType(elementType, _) => { - def hasInnerStruct(t: DataType): Boolean = t match { - case s: StructType => true - case ArrayType(t1, _) => hasInnerStruct(t1) - case o => false - } - - // Check if this array has inner struct. - !hasInnerStruct(elementType) - } - case struct: StructType => false - case _ => true - } - }.map { - a => StructField(a.head, resolved.get(prefix ++ a).get, nullable = true) - } - val topLevelFieldNameSet = topLevelFields.map(_.name) - - val structFields: Seq[StructField] = structLike.groupBy(_(0)).filter { - case (name, _) => !topLevelFieldNameSet.contains(name) - }.map { - case (name, fields) => { - val nestedFields = fields.map(_.tail) - val structType = makeStruct(nestedFields, prefix :+ name) - val dataType = resolved.get(prefix :+ name).get - dataType match { - case array: ArrayType => - // The pattern of this array is ArrayType(...(ArrayType(StructType))). - // Since the inner struct of array is a placeholder (StructType(Nil)), - // we need to replace this placeholder with the actual StructType (structType). - def getActualArrayType( - innerStruct: StructType, - currentArray: ArrayType): ArrayType = currentArray match { - case ArrayType(s: StructType, containsNull) => - ArrayType(innerStruct, containsNull) - case ArrayType(a: ArrayType, containsNull) => - ArrayType(getActualArrayType(innerStruct, a), containsNull) - } - Some(StructField(name, getActualArrayType(structType, array), nullable = true)) - case struct: StructType => Some(StructField(name, structType, nullable = true)) - // dataType is StringType means that we have resolved type conflicts involving - // primitive types and complex types. So, the type of name has been relaxed to - // StringType. Also, this field should have already been put in topLevelFields. - case StringType => None - } - } - }.flatMap(field => field).toSeq - - StructType((topLevelFields ++ structFields).sortBy(_.name)) - } - - makeStruct(resolved.keySet.toSeq, Nil) - } - - private[sql] def nullTypeToStringType(struct: StructType): StructType = { - val fields = struct.fields.map { - case StructField(fieldName, dataType, nullable, _) => { - val newType = dataType match { - case NullType => StringType - case ArrayType(NullType, containsNull) => ArrayType(StringType, containsNull) - case ArrayType(struct: StructType, containsNull) => - ArrayType(nullTypeToStringType(struct), containsNull) - case struct: StructType => nullTypeToStringType(struct) - case other: DataType => other - } - StructField(fieldName, newType, nullable) - } - } - - StructType(fields) - } - - /** - * Returns the most general data type for two given data types. - */ - private[json] def compatibleType(t1: DataType, t2: DataType): DataType = { - HiveTypeCoercion.findTightestCommonTypeOfTwo(t1, t2) match { - case Some(commonType) => commonType - case None => - // t1 or t2 is a StructType, ArrayType, or an unexpected type. - (t1, t2) match { - case (other: DataType, NullType) => other - case (NullType, other: DataType) => other - case (StructType(fields1), StructType(fields2)) => { - val newFields = (fields1 ++ fields2).groupBy(field => field.name).map { - case (name, fieldTypes) => { - val dataType = fieldTypes.map(field => field.dataType).reduce( - (type1: DataType, type2: DataType) => compatibleType(type1, type2)) - StructField(name, dataType, true) - } - } - StructType(newFields.toSeq.sortBy(_.name)) - } - case (ArrayType(elementType1, containsNull1), ArrayType(elementType2, containsNull2)) => - ArrayType(compatibleType(elementType1, elementType2), containsNull1 || containsNull2) - // TODO: We should use JsonObjectStringType to mark that values of field will be - // strings and every string is a Json object. - case (_, _) => StringType - } - } - } - - private def typeOfPrimitiveValue: PartialFunction[Any, DataType] = { - // For Integer values, use LongType by default. - val useLongType: PartialFunction[Any, DataType] = { - case value: IntegerType.InternalType => LongType - } - - useLongType orElse ScalaReflection.typeOfObject orElse { - // Since we do not have a data type backed by BigInteger, - // when we see a Java BigInteger, we use DecimalType. - case value: java.math.BigInteger => DecimalType.Unlimited - // DecimalType's JVMType is scala BigDecimal. - case value: java.math.BigDecimal => DecimalType.Unlimited - // Unexpected data type. - case _ => StringType - } - } - - /** - * Returns the element type of an JSON array. We go through all elements of this array - * to detect any possible type conflict. We use [[compatibleType]] to resolve - * type conflicts. - */ - private def typeOfArray(l: Seq[Any]): ArrayType = { - val elements = l.flatMap(v => Option(v)) - if (elements.isEmpty) { - // If this JSON array is empty, we use NullType as a placeholder. - // If this array is not empty in other JSON objects, we can resolve - // the type after we have passed through all JSON objects. - ArrayType(NullType, containsNull = true) - } else { - val elementType = elements.map { - e => e match { - case map: Map[_, _] => StructType(Nil) - // We have an array of arrays. If those element arrays do not have the same - // element types, we will return ArrayType[StringType]. - case seq: Seq[_] => typeOfArray(seq) - case value => typeOfPrimitiveValue(value) - } - }.reduce((type1: DataType, type2: DataType) => compatibleType(type1, type2)) - - ArrayType(elementType, containsNull = true) - } - } - - /** - * Figures out all key names and data types of values from a parsed JSON object - * (in the format of Map[Stirng, Any]). When the value of a key is an JSON object, we - * only use a placeholder (StructType(Nil)) to mark that it should be a struct - * instead of getting all fields of this struct because a field does not appear - * in this JSON object can appear in other JSON objects. - */ - private def allKeysWithValueTypes(m: Map[String, Any]): Set[(String, DataType)] = { - val keyValuePairs = m.map { - // Quote the key with backticks to handle cases which have dots - // in the field name. - case (key, value) => (s"`$key`", value) - }.toSet - keyValuePairs.flatMap { - case (key: String, struct: Map[_, _]) => { - // The value associated with the key is an JSON object. - allKeysWithValueTypes(struct.asInstanceOf[Map[String, Any]]).map { - case (k, dataType) => (s"$key.$k", dataType) - } ++ Set((key, StructType(Nil))) - } - case (key: String, array: Seq[_]) => { - // The value associated with the key is an array. - // Handle inner structs of an array. - def buildKeyPathForInnerStructs(v: Any, t: DataType): Seq[(String, DataType)] = t match { - case ArrayType(e: StructType, _) => { - // The elements of this arrays are structs. - v.asInstanceOf[Seq[Map[String, Any]]].flatMap(Option(_)).flatMap { - element => allKeysWithValueTypes(element) - }.map { - case (k, t) => (s"$key.$k", t) - } - } - case ArrayType(t1, _) => - v.asInstanceOf[Seq[Any]].flatMap(Option(_)).flatMap { - element => buildKeyPathForInnerStructs(element, t1) - } - case other => Nil - } - val elementType = typeOfArray(array) - buildKeyPathForInnerStructs(array, elementType) :+ (key, elementType) - } - // we couldn't tell what the type is if the value is null or empty string - case (key: String, value) if value == "" || value == null => (key, NullType) :: Nil - case (key: String, value) => (key, typeOfPrimitiveValue(value)) :: Nil - } - } - - /** - * Converts a Java Map/List to a Scala Map/Seq. - * We do not use Jackson's scala module at here because - * DefaultScalaModule in jackson-module-scala will make - * the parsing very slow. - */ - private def scalafy(obj: Any): Any = obj match { - case map: java.util.Map[_, _] => - // .map(identity) is used as a workaround of non-serializable Map - // generated by .mapValues. - // This issue is documented at https://issues.scala-lang.org/browse/SI-7005 - JMapWrapper(map).mapValues(scalafy).map(identity) - case list: java.util.List[_] => - JListWrapper(list).map(scalafy) - case atom => atom - } - - private def parseJson( - json: RDD[String], - columnNameOfCorruptRecords: String): RDD[Map[String, Any]] = { - // According to [Jackson-72: https://jira.codehaus.org/browse/JACKSON-72], - // ObjectMapper will not return BigDecimal when - // "DeserializationFeature.USE_BIG_DECIMAL_FOR_FLOATS" is disabled - // (see NumberDeserializer.deserialize for the logic). - // But, we do not want to enable this feature because it will use BigDecimal - // for every float number, which will be slow. - // So, right now, we will have Infinity for those BigDecimal number. - // TODO: Support BigDecimal. - json.mapPartitions(iter => { - // When there is a key appearing multiple times (a duplicate key), - // the ObjectMapper will take the last value associated with this duplicate key. - // For example: for {"key": 1, "key":2}, we will get "key"->2. - val mapper = new ObjectMapper() - iter.flatMap { record => - try { - val parsed = mapper.readValue(record, classOf[Object]) match { - case map: java.util.Map[_, _] => scalafy(map).asInstanceOf[Map[String, Any]] :: Nil - case list: java.util.List[_] => scalafy(list).asInstanceOf[Seq[Map[String, Any]]] - case _ => - sys.error( - s"Failed to parse record $record. Please make sure that each line of the file " + - "(or each string in the RDD) is a valid JSON object or an array of JSON objects.") - } - - parsed - } catch { - case e: JsonProcessingException => - Map(columnNameOfCorruptRecords -> UTF8String.fromString(record)) :: Nil - } - } - }) - } - - private def toLong(value: Any): Long = { - value match { - case value: java.lang.Integer => value.asInstanceOf[Int].toLong - case value: java.lang.Long => value.asInstanceOf[Long] - } - } - - private def toDouble(value: Any): Double = { - value match { - case value: java.lang.Integer => value.asInstanceOf[Int].toDouble - case value: java.lang.Long => value.asInstanceOf[Long].toDouble - case value: java.lang.Double => value.asInstanceOf[Double] - } - } - - private def toDecimal(value: Any): Decimal = { - value match { - case value: java.lang.Integer => Decimal(value) - case value: java.lang.Long => Decimal(value) - case value: java.math.BigInteger => Decimal(new java.math.BigDecimal(value)) - case value: java.lang.Double => Decimal(value) - case value: java.math.BigDecimal => Decimal(value) - } - } - - private def toJsonArrayString(seq: Seq[Any]): String = { - val builder = new StringBuilder - builder.append("[") - var count = 0 - seq.foreach { - element => - if (count > 0) builder.append(",") - count += 1 - builder.append(toString(element)) - } - builder.append("]") - - builder.toString() - } - - private def toJsonObjectString(map: Map[String, Any]): String = { - val builder = new StringBuilder - builder.append("{") - var count = 0 - map.foreach { - case (key, value) => - if (count > 0) builder.append(",") - count += 1 - val stringValue = if (value.isInstanceOf[String]) s"""\"$value\"""" else toString(value) - builder.append(s"""\"${key}\":${stringValue}""") - } - builder.append("}") - - builder.toString() - } - - private def toString(value: Any): String = { - value match { - case value: Map[_, _] => toJsonObjectString(value.asInstanceOf[Map[String, Any]]) - case value: Seq[_] => toJsonArrayString(value) - case value => Option(value).map(_.toString).orNull - } - } - - private def toDate(value: Any): Int = { - value match { - // only support string as date - case value: java.lang.String => - DateTimeUtils.millisToDays(DateTimeUtils.stringToTime(value).getTime) - case value: java.sql.Date => DateTimeUtils.fromJavaDate(value) - } - } - - private def toTimestamp(value: Any): Long = { - value match { - case value: java.lang.Integer => value.asInstanceOf[Int].toLong * 1000L - case value: java.lang.Long => value * 1000L - case value: java.lang.String => DateTimeUtils.stringToTime(value).getTime * 1000L - } - } - - private[json] def enforceCorrectType(value: Any, desiredType: DataType): Any = { - if (value == null) { - null - } else { - desiredType match { - case StringType => UTF8String.fromString(toString(value)) - case _ if value == null || value == "" => null // guard the non string type - case IntegerType => value.asInstanceOf[IntegerType.InternalType] - case LongType => toLong(value) - case DoubleType => toDouble(value) - case DecimalType() => toDecimal(value) - case BooleanType => value.asInstanceOf[BooleanType.InternalType] - case NullType => null - case ArrayType(elementType, _) => - value.asInstanceOf[Seq[Any]].map(enforceCorrectType(_, elementType)) - case MapType(StringType, valueType, _) => - val map = value.asInstanceOf[Map[String, Any]] - map.map { - case (k, v) => - (UTF8String.fromString(k), enforceCorrectType(v, valueType)) - }.map(identity) - case struct: StructType => asRow(value.asInstanceOf[Map[String, Any]], struct) - case DateType => toDate(value) - case TimestampType => toTimestamp(value) - } - } - } - - private def asRow(json: Map[String, Any], schema: StructType): InternalRow = { - // TODO: Reuse the row instead of creating a new one for every record. - val row = new GenericMutableRow(schema.fields.length) - schema.fields.zipWithIndex.foreach { - case (StructField(name, dataType, _, _), i) => - row.update(i, json.get(name).flatMap(v => Option(v)).map( - enforceCorrectType(_, dataType)).orNull) - } - - row - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala index 8204a584179bb..3475f9dd6787e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala @@ -1079,28 +1079,23 @@ class JsonSuite extends QueryTest with TestJsonData { } test("SPARK-7565 MapType in JsonRDD") { - val useStreaming = ctx.conf.useJacksonStreamingAPI val oldColumnNameOfCorruptRecord = ctx.conf.columnNameOfCorruptRecord ctx.setConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD, "_unparsed") val schemaWithSimpleMap = StructType( StructField("map", MapType(StringType, IntegerType, true), false) :: Nil) - try{ - for (useStreaming <- List(true, false)) { - ctx.setConf(SQLConf.USE_JACKSON_STREAMING_API, useStreaming) - val temp = Utils.createTempDir().getPath - - val df = ctx.read.schema(schemaWithSimpleMap).json(mapType1) - df.write.mode("overwrite").parquet(temp) - // order of MapType is not defined - assert(ctx.read.parquet(temp).count() == 5) - - val df2 = ctx.read.json(corruptRecords) - df2.write.mode("overwrite").parquet(temp) - checkAnswer(ctx.read.parquet(temp), df2.collect()) - } + try { + val temp = Utils.createTempDir().getPath + + val df = ctx.read.schema(schemaWithSimpleMap).json(mapType1) + df.write.mode("overwrite").parquet(temp) + // order of MapType is not defined + assert(ctx.read.parquet(temp).count() == 5) + + val df2 = ctx.read.json(corruptRecords) + df2.write.mode("overwrite").parquet(temp) + checkAnswer(ctx.read.parquet(temp), df2.collect()) } finally { - ctx.setConf(SQLConf.USE_JACKSON_STREAMING_API, useStreaming) ctx.setConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD, oldColumnNameOfCorruptRecord) } } From 6cb6096c016178b9ce5c97592abe529ddb18cef2 Mon Sep 17 00:00:00 2001 From: Forest Fang Date: Sat, 18 Jul 2015 21:05:44 -0700 Subject: [PATCH 45/58] [SPARK-8443][SQL] Split GenerateMutableProjection Codegen due to JVM Code Size Limits By grouping projection calls into multiple apply function, we are able to push the number of projections codegen can handle from ~1k to ~60k. I have set the unit test to test against 5k as 60k took 15s for the unit test to complete. Author: Forest Fang Closes #7076 from saurfang/codegen_size_limit and squashes the following commits: b7a7635 [Forest Fang] [SPARK-8443][SQL] Execute and verify split projections in test adef95a [Forest Fang] [SPARK-8443][SQL] Use safer factor and rewrite splitting code 1b5aa7e [Forest Fang] [SPARK-8443][SQL] inline execution if one block only 9405680 [Forest Fang] [SPARK-8443][SQL] split projection code by size limit --- .../codegen/GenerateMutableProjection.scala | 39 ++++++++++++++++++- .../expressions/CodeGenerationSuite.scala | 14 ++++++- 2 files changed, 50 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index 71e47d4f9b620..b82bd6814b487 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.catalyst.expressions.codegen import org.apache.spark.sql.catalyst.expressions._ +import scala.collection.mutable.ArrayBuffer + // MutableProjection is not accessible in Java abstract class BaseMutableProjection extends MutableProjection @@ -45,10 +47,41 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu else ${ctx.setColumn("mutableRow", e.dataType, i, evaluationCode.primitive)}; """ - }.mkString("\n") + } + // collect projections into blocks as function has 64kb codesize limit in JVM + val projectionBlocks = new ArrayBuffer[String]() + val blockBuilder = new StringBuilder() + for (projection <- projectionCode) { + if (blockBuilder.length > 16 * 1000) { + projectionBlocks.append(blockBuilder.toString()) + blockBuilder.clear() + } + blockBuilder.append(projection) + } + projectionBlocks.append(blockBuilder.toString()) + + val (projectionFuns, projectionCalls) = { + // inline execution if codesize limit was not broken + if (projectionBlocks.length == 1) { + ("", projectionBlocks.head) + } else { + ( + projectionBlocks.zipWithIndex.map { case (body, i) => + s""" + |private void apply$i(InternalRow i) { + | $body + |} + """.stripMargin + }.mkString, + projectionBlocks.indices.map(i => s"apply$i(i);").mkString("\n") + ) + } + } + val mutableStates = ctx.mutableStates.map { case (javaType, variableName, initialValue) => s"private $javaType $variableName = $initialValue;" }.mkString("\n ") + val code = s""" public Object generate($exprType[] expr) { return new SpecificProjection(expr); @@ -75,9 +108,11 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu return (InternalRow) mutableRow; } + $projectionFuns + public Object apply(Object _i) { InternalRow i = (InternalRow) _i; - $projectionCode + $projectionCalls return mutableRow; } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala index 481b335d15dfd..e05218a23aa73 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen._ /** * Additional tests for code generation. */ -class CodeGenerationSuite extends SparkFunSuite { +class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { test("multithreaded eval") { import scala.concurrent._ @@ -42,4 +42,16 @@ class CodeGenerationSuite extends SparkFunSuite { futures.foreach(Await.result(_, 10.seconds)) } + + test("SPARK-8443: split wide projections into blocks due to JVM code size limit") { + val length = 5000 + val expressions = List.fill(length)(EqualTo(Literal(1), Literal(1))) + val plan = GenerateMutableProjection.generate(expressions)() + val actual = plan(new GenericMutableRow(length)).toSeq + val expected = Seq.fill(length)(true) + + if (!checkResult(actual, expected)) { + fail(s"Incorrect Evaluation: expressions: $expressions, actual: $actual, expected: $expected") + } + } } From 83b682beec884da76708769414108f4316e620f2 Mon Sep 17 00:00:00 2001 From: Tarek Auel Date: Sat, 18 Jul 2015 22:48:05 -0700 Subject: [PATCH 46/58] [SPARK-8199][SPARK-8184][SPARK-8183][SPARK-8182][SPARK-8181][SPARK-8180][SPARK-8179][SPARK-8177][SPARK-8178][SPARK-9115][SQL] date functions Jira: https://issues.apache.org/jira/browse/SPARK-8199 https://issues.apache.org/jira/browse/SPARK-8184 https://issues.apache.org/jira/browse/SPARK-8183 https://issues.apache.org/jira/browse/SPARK-8182 https://issues.apache.org/jira/browse/SPARK-8181 https://issues.apache.org/jira/browse/SPARK-8180 https://issues.apache.org/jira/browse/SPARK-8179 https://issues.apache.org/jira/browse/SPARK-8177 https://issues.apache.org/jira/browse/SPARK-8179 https://issues.apache.org/jira/browse/SPARK-9115 Regarding `day`and `dayofmonth` are both necessary? ~~I am going to add `Quarter` to this PR as well.~~ Done. ~~As soon as the Scala coding is reviewed and discussed, I'll add the python api.~~ Done Author: Tarek Auel Author: Tarek Auel Closes #6981 from tarekauel/SPARK-8199 and squashes the following commits: f7b4c8c [Tarek Auel] [SPARK-8199] fixed bug in tests bb567b6 [Tarek Auel] [SPARK-8199] fixed test 3e095ba [Tarek Auel] [SPARK-8199] style and timezone fix 256c357 [Tarek Auel] [SPARK-8199] code cleanup 5983dcc [Tarek Auel] [SPARK-8199] whitespace fix 6e0c78f [Tarek Auel] [SPARK-8199] removed setTimeZone in tests, according to cloud-fans comment in #7488 4afc09c [Tarek Auel] [SPARK-8199] concise leap year handling ea6c110 [Tarek Auel] [SPARK-8199] fix after merging master 70238e0 [Tarek Auel] Merge branch 'master' into SPARK-8199 3c6ae2e [Tarek Auel] [SPARK-8199] removed binary search fb98ba0 [Tarek Auel] [SPARK-8199] python docstring fix cdfae27 [Tarek Auel] [SPARK-8199] cleanup & python docstring fix 746b80a [Tarek Auel] [SPARK-8199] build fix 0ad6db8 [Tarek Auel] [SPARK-8199] minor fix 523542d [Tarek Auel] [SPARK-8199] address comments 2259299 [Tarek Auel] [SPARK-8199] day_of_month alias d01b977 [Tarek Auel] [SPARK-8199] python underscore 56c4a92 [Tarek Auel] [SPARK-8199] update python docu e223bc0 [Tarek Auel] [SPARK-8199] refactoring d6aa14e [Tarek Auel] [SPARK-8199] fixed Hive compatibility b382267 [Tarek Auel] [SPARK-8199] fixed bug in day calculation; removed set TimeZone in HiveCompatibilitySuite for test purposes; removed Hive tests for second and minute, because we can cast '2015-03-18' to a timestamp and extract a minute/second from it 1b2e540 [Tarek Auel] [SPARK-8119] style fix 0852655 [Tarek Auel] [SPARK-8119] changed from ExpectsInputTypes to implicit casts ec87c69 [Tarek Auel] [SPARK-8119] bug fixing and refactoring 1358cdc [Tarek Auel] Merge remote-tracking branch 'origin/master' into SPARK-8199 740af0e [Tarek Auel] implement date function using a calculation based on days 4fb66da [Tarek Auel] WIP: date functions on calculation only 1a436c9 [Tarek Auel] wip f775f39 [Tarek Auel] fixed return type ad17e96 [Tarek Auel] improved implementation c42b444 [Tarek Auel] Removed merge conflict file ccb723c [Tarek Auel] [SPARK-8199] style and fixed merge issues 10e4ad1 [Tarek Auel] Merge branch 'master' into date-functions-fast 7d9f0eb [Tarek Auel] [SPARK-8199] git renaming issue f3e7a9f [Tarek Auel] [SPARK-8199] revert change in DataFrameFunctionsSuite 6f5d95c [Tarek Auel] [SPARK-8199] fixed year interval d9f8ac3 [Tarek Auel] [SPARK-8199] implement fast track 7bc9d93 [Tarek Auel] Merge branch 'master' into SPARK-8199 5a105d9 [Tarek Auel] [SPARK-8199] rebase after #6985 got merged eb6760d [Tarek Auel] Merge branch 'master' into SPARK-8199 f120415 [Tarek Auel] improved runtime a8edebd [Tarek Auel] use Calendar instead of SimpleDateFormat 5fe74e1 [Tarek Auel] fixed python style 3bfac90 [Tarek Auel] fixed style 356df78 [Tarek Auel] rely on cast mechanism of Spark. Simplified implementation 02efc5d [Tarek Auel] removed doubled code a5ea120 [Tarek Auel] added python api; changed test to be more meaningful b680db6 [Tarek Auel] added codegeneration to all functions c739788 [Tarek Auel] added support for quarter SPARK-8178 849fb41 [Tarek Auel] fixed stupid test 638596f [Tarek Auel] improved codegen 4d8049b [Tarek Auel] fixed tests and added type check 5ebb235 [Tarek Auel] resolved naming conflict d0e2f99 [Tarek Auel] date functions --- python/pyspark/sql/functions.py | 150 +++++++++++ .../catalyst/analysis/FunctionRegistry.scala | 14 +- .../expressions/datetimeFunctions.scala | 206 +++++++++++++++ .../sql/catalyst/util/DateTimeUtils.scala | 195 +++++++++++++- .../expressions/DateFunctionsSuite.scala | 249 ++++++++++++++++++ .../catalyst/util/DateTimeUtilsSuite.scala | 91 +++++-- .../org/apache/spark/sql/functions.scala | 176 +++++++++++++ .../spark/sql/DateExpressionsSuite.scala | 170 ++++++++++++ .../execution/HiveCompatibilitySuite.scala | 9 +- 9 files changed, 1234 insertions(+), 26 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateFunctionsSuite.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/DateExpressionsSuite.scala diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index e0816b3e654bc..0aca3788922aa 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -652,6 +652,156 @@ def ntile(n): return Column(sc._jvm.functions.ntile(int(n))) +@ignore_unicode_prefix +@since(1.5) +def date_format(dateCol, format): + """ + Converts a date/timestamp/string to a value of string in the format specified by the date + format given by the second argument. + + A pattern could be for instance `dd.MM.yyyy` and could return a string like '18.03.1993'. All + pattern letters of the Java class `java.text.SimpleDateFormat` can be used. + + NOTE: Use when ever possible specialized functions like `year`. These benefit from a + specialized implementation. + + >>> df = sqlContext.createDataFrame([('2015-04-08',)], ['a']) + >>> df.select(date_format('a', 'MM/dd/yyy').alias('date')).collect() + [Row(date=u'04/08/2015')] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.date_format(dateCol, format)) + + +@since(1.5) +def year(col): + """ + Extract the year of a given date as integer. + + >>> df = sqlContext.createDataFrame([('2015-04-08',)], ['a']) + >>> df.select(year('a').alias('year')).collect() + [Row(year=2015)] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.year(col)) + + +@since(1.5) +def quarter(col): + """ + Extract the quarter of a given date as integer. + + >>> df = sqlContext.createDataFrame([('2015-04-08',)], ['a']) + >>> df.select(quarter('a').alias('quarter')).collect() + [Row(quarter=2)] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.quarter(col)) + + +@since(1.5) +def month(col): + """ + Extract the month of a given date as integer. + + >>> df = sqlContext.createDataFrame([('2015-04-08',)], ['a']) + >>> df.select(month('a').alias('month')).collect() + [Row(month=4)] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.month(col)) + + +@since(1.5) +def day(col): + """ + Extract the day of the month of a given date as integer. + + >>> sqlContext.createDataFrame([('2015-04-08',)], ['a']).select(day('a').alias('day')).collect() + [Row(day=8)] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.day(col)) + + +@since(1.5) +def day_of_month(col): + """ + Extract the day of the month of a given date as integer. + + >>> df = sqlContext.createDataFrame([('2015-04-08',)], ['a']) + >>> df.select(day_of_month('a').alias('day')).collect() + [Row(day=8)] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.day_of_month(col)) + + +@since(1.5) +def day_in_year(col): + """ + Extract the day of the year of a given date as integer. + + >>> df = sqlContext.createDataFrame([('2015-04-08',)], ['a']) + >>> df.select(day_in_year('a').alias('day')).collect() + [Row(day=98)] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.day_in_year(col)) + + +@since(1.5) +def hour(col): + """ + Extract the hours of a given date as integer. + + >>> df = sqlContext.createDataFrame([('2015-04-08 13:08:15',)], ['a']) + >>> df.select(hour('a').alias('hour')).collect() + [Row(hour=13)] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.hour(col)) + + +@since(1.5) +def minute(col): + """ + Extract the minutes of a given date as integer. + + >>> df = sqlContext.createDataFrame([('2015-04-08 13:08:15',)], ['a']) + >>> df.select(minute('a').alias('minute')).collect() + [Row(minute=8)] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.minute(col)) + + +@since(1.5) +def second(col): + """ + Extract the seconds of a given date as integer. + + >>> df = sqlContext.createDataFrame([('2015-04-08 13:08:15',)], ['a']) + >>> df.select(second('a').alias('second')).collect() + [Row(second=15)] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.second(col)) + + +@since(1.5) +def week_of_year(col): + """ + Extract the week number of a given date as integer. + + >>> df = sqlContext.createDataFrame([('2015-04-08',)], ['a']) + >>> df.select(week_of_year('a').alias('week')).collect() + [Row(week=15)] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.week_of_year(col)) + + class UserDefinedFunction(object): """ User defined function in Python diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index d1cda6bc27095..159f7eca7acfe 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -181,7 +181,19 @@ object FunctionRegistry { // datetime functions expression[CurrentDate]("current_date"), - expression[CurrentTimestamp]("current_timestamp") + expression[CurrentTimestamp]("current_timestamp"), + expression[DateFormatClass]("date_format"), + expression[Day]("day"), + expression[DayInYear]("day_in_year"), + expression[Day]("day_of_month"), + expression[Hour]("hour"), + expression[Month]("month"), + expression[Minute]("minute"), + expression[Quarter]("quarter"), + expression[Second]("second"), + expression[WeekOfYear]("week_of_year"), + expression[Year]("year") + ) val builtin: FunctionRegistry = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeFunctions.scala index 4bed140cffbfa..f9cbbb8c6bee0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeFunctions.scala @@ -17,10 +17,16 @@ package org.apache.spark.sql.catalyst.expressions +import java.sql.Date +import java.text.SimpleDateFormat +import java.util.{Calendar, TimeZone} + +import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String /** * Returns the current date at the start of query evaluation. @@ -55,3 +61,203 @@ case class CurrentTimestamp() extends LeafExpression with CodegenFallback { System.currentTimeMillis() * 1000L } } + +case class Hour(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { + + override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType) + + override def dataType: DataType = IntegerType + + override protected def nullSafeEval(timestamp: Any): Any = { + DateTimeUtils.getHours(timestamp.asInstanceOf[Long]) + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") + defineCodeGen(ctx, ev, (c) => + s"""$dtu.getHours($c)""" + ) + } +} + +case class Minute(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { + + override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType) + + override def dataType: DataType = IntegerType + + override protected def nullSafeEval(timestamp: Any): Any = { + DateTimeUtils.getMinutes(timestamp.asInstanceOf[Long]) + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") + defineCodeGen(ctx, ev, (c) => + s"""$dtu.getMinutes($c)""" + ) + } +} + +case class Second(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { + + override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType) + + override def dataType: DataType = IntegerType + + override protected def nullSafeEval(timestamp: Any): Any = { + DateTimeUtils.getSeconds(timestamp.asInstanceOf[Long]) + } + + override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") + defineCodeGen(ctx, ev, (c) => + s"""$dtu.getSeconds($c)""" + ) + } +} + +case class DayInYear(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { + + override def inputTypes: Seq[AbstractDataType] = Seq(DateType) + + override def dataType: DataType = IntegerType + + override def prettyName: String = "day_in_year" + + override protected def nullSafeEval(date: Any): Any = { + DateTimeUtils.getDayInYear(date.asInstanceOf[Int]) + } + + override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") + defineCodeGen(ctx, ev, (c) => + s"""$dtu.getDayInYear($c)""" + ) + } +} + + +case class Year(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { + + override def inputTypes: Seq[AbstractDataType] = Seq(DateType) + + override def dataType: DataType = IntegerType + + override protected def nullSafeEval(date: Any): Any = { + DateTimeUtils.getYear(date.asInstanceOf[Int]) + } + + override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") + defineCodeGen(ctx, ev, (c) => + s"""$dtu.getYear($c)""" + ) + } +} + +case class Quarter(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { + + override def inputTypes: Seq[AbstractDataType] = Seq(DateType) + + override def dataType: DataType = IntegerType + + override protected def nullSafeEval(date: Any): Any = { + DateTimeUtils.getQuarter(date.asInstanceOf[Int]) + } + + override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") + defineCodeGen(ctx, ev, (c) => + s"""$dtu.getQuarter($c)""" + ) + } +} + +case class Month(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { + + override def inputTypes: Seq[AbstractDataType] = Seq(DateType) + + override def dataType: DataType = IntegerType + + override protected def nullSafeEval(date: Any): Any = { + DateTimeUtils.getMonth(date.asInstanceOf[Int]) + } + + override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") + defineCodeGen(ctx, ev, (c) => + s"""$dtu.getMonth($c)""" + ) + } +} + +case class Day(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { + + override def inputTypes: Seq[AbstractDataType] = Seq(DateType) + + override def dataType: DataType = IntegerType + + override protected def nullSafeEval(date: Any): Any = { + DateTimeUtils.getDayOfMonth(date.asInstanceOf[Int]) + } + + override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") + defineCodeGen(ctx, ev, (c) => + s"""$dtu.getDayOfMonth($c)""" + ) + } +} + +case class WeekOfYear(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { + + override def inputTypes: Seq[AbstractDataType] = Seq(DateType) + + override def dataType: DataType = IntegerType + + override def prettyName: String = "week_of_year" + + override protected def nullSafeEval(date: Any): Any = { + val c = Calendar.getInstance(TimeZone.getTimeZone("UTC")) + c.setFirstDayOfWeek(Calendar.MONDAY) + c.setMinimalDaysInFirstWeek(4) + c.setTimeInMillis(date.asInstanceOf[Int] * 1000L * 3600L * 24L) + c.get(Calendar.WEEK_OF_YEAR) + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = + nullSafeCodeGen(ctx, ev, (time) => { + val cal = classOf[Calendar].getName + val c = ctx.freshName("cal") + s""" + $cal $c = $cal.getInstance(java.util.TimeZone.getTimeZone("UTC")); + $c.setFirstDayOfWeek($cal.MONDAY); + $c.setMinimalDaysInFirstWeek(4); + $c.setTimeInMillis($time * 1000L * 3600L * 24L); + ${ev.primitive} = $c.get($cal.WEEK_OF_YEAR); + """ + }) +} + +case class DateFormatClass(left: Expression, right: Expression) extends BinaryExpression + with ImplicitCastInputTypes { + + override def dataType: DataType = StringType + + override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType, StringType) + + override def prettyName: String = "date_format" + + override protected def nullSafeEval(timestamp: Any, format: Any): Any = { + val sdf = new SimpleDateFormat(format.toString) + UTF8String.fromString(sdf.format(new Date(timestamp.asInstanceOf[Long] / 1000))) + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val sdf = classOf[SimpleDateFormat].getName + defineCodeGen(ctx, ev, (timestamp, format) => { + s"""UTF8String.fromString((new $sdf($format.toString())) + .format(new java.sql.Date($timestamp / 1000)))""" + }) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala index 45e45aef1a349..a0da73a995a82 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.util import java.sql.{Date, Timestamp} import java.text.{DateFormat, SimpleDateFormat} -import java.util.{Calendar, TimeZone} +import java.util.{TimeZone, Calendar} import org.apache.spark.unsafe.types.UTF8String @@ -39,6 +39,15 @@ object DateTimeUtils { final val MICROS_PER_SECOND = 1000L * 1000L final val NANOS_PER_SECOND = MICROS_PER_SECOND * 1000L + // number of days in 400 years + final val daysIn400Years: Int = 146097 + // number of days between 1.1.1970 and 1.1.2001 + final val to2001 = -11323 + + // this is year -17999, calculation: 50 * daysIn400Year + final val toYearZero = to2001 + 7304850 + + @transient lazy val defaultTimeZone = TimeZone.getDefault // Java TimeZone has no mention of thread safety. Use thread local instance to be safe. private val threadLocalLocalTimeZone = new ThreadLocal[TimeZone] { @@ -380,4 +389,188 @@ object DateTimeUtils { c.set(Calendar.MILLISECOND, 0) Some((c.getTimeInMillis / MILLIS_PER_DAY).toInt) } + + /** + * Returns the hour value of a given timestamp value. The timestamp is expressed in microseconds. + */ + def getHours(timestamp: Long): Int = { + val localTs = (timestamp / 1000) + defaultTimeZone.getOffset(timestamp / 1000) + ((localTs / 1000 / 3600) % 24).toInt + } + + /** + * Returns the minute value of a given timestamp value. The timestamp is expressed in + * microseconds. + */ + def getMinutes(timestamp: Long): Int = { + val localTs = (timestamp / 1000) + defaultTimeZone.getOffset(timestamp / 1000) + ((localTs / 1000 / 60) % 60).toInt + } + + /** + * Returns the second value of a given timestamp value. The timestamp is expressed in + * microseconds. + */ + def getSeconds(timestamp: Long): Int = { + ((timestamp / 1000 / 1000) % 60).toInt + } + + private[this] def isLeapYear(year: Int): Boolean = { + (year % 4) == 0 && ((year % 100) != 0 || (year % 400) == 0) + } + + /** + * Return the number of days since the start of 400 year period. + * The second year of a 400 year period (year 1) starts on day 365. + */ + private[this] def yearBoundary(year: Int): Int = { + year * 365 + ((year / 4 ) - (year / 100) + (year / 400)) + } + + /** + * Calculates the number of years for the given number of days. This depends + * on a 400 year period. + * @param days days since the beginning of the 400 year period + * @return (number of year, days in year) + */ + private[this] def numYears(days: Int): (Int, Int) = { + val year = days / 365 + val boundary = yearBoundary(year) + if (days > boundary) (year, days - boundary) else (year - 1, days - yearBoundary(year - 1)) + } + + /** + * Calculates the year and and the number of the day in the year for the given + * number of days. The given days is the number of days since 1.1.1970. + * + * The calculation uses the fact that the period 1.1.2001 until 31.12.2400 is + * equals to the period 1.1.1601 until 31.12.2000. + */ + private[this] def getYearAndDayInYear(daysSince1970: Int): (Int, Int) = { + // add the difference (in days) between 1.1.1970 and the artificial year 0 (-17999) + val daysNormalized = daysSince1970 + toYearZero + val numOfQuarterCenturies = daysNormalized / daysIn400Years + val daysInThis400 = daysNormalized % daysIn400Years + 1 + val (years, dayInYear) = numYears(daysInThis400) + val year: Int = (2001 - 20000) + 400 * numOfQuarterCenturies + years + (year, dayInYear) + } + + /** + * Returns the 'day in year' value for the given date. The date is expressed in days + * since 1.1.1970. + */ + def getDayInYear(date: Int): Int = { + getYearAndDayInYear(date)._2 + } + + /** + * Returns the year value for the given date. The date is expressed in days + * since 1.1.1970. + */ + def getYear(date: Int): Int = { + getYearAndDayInYear(date)._1 + } + + /** + * Returns the quarter for the given date. The date is expressed in days + * since 1.1.1970. + */ + def getQuarter(date: Int): Int = { + var (year, dayInYear) = getYearAndDayInYear(date) + if (isLeapYear(year)) { + dayInYear = dayInYear - 1 + } + if (dayInYear <= 90) { + 1 + } else if (dayInYear <= 181) { + 2 + } else if (dayInYear <= 273) { + 3 + } else { + 4 + } + } + + /** + * Returns the month value for the given date. The date is expressed in days + * since 1.1.1970. January is month 1. + */ + def getMonth(date: Int): Int = { + var (year, dayInYear) = getYearAndDayInYear(date) + if (isLeapYear(year)) { + if (dayInYear == 60) { + return 2 + } else if (dayInYear > 60) { + dayInYear = dayInYear - 1 + } + } + + if (dayInYear <= 31) { + 1 + } else if (dayInYear <= 59) { + 2 + } else if (dayInYear <= 90) { + 3 + } else if (dayInYear <= 120) { + 4 + } else if (dayInYear <= 151) { + 5 + } else if (dayInYear <= 181) { + 6 + } else if (dayInYear <= 212) { + 7 + } else if (dayInYear <= 243) { + 8 + } else if (dayInYear <= 273) { + 9 + } else if (dayInYear <= 304) { + 10 + } else if (dayInYear <= 334) { + 11 + } else { + 12 + } + } + + /** + * Returns the 'day of month' value for the given date. The date is expressed in days + * since 1.1.1970. + */ + def getDayOfMonth(date: Int): Int = { + var (year, dayInYear) = getYearAndDayInYear(date) + if (isLeapYear(year)) { + if (dayInYear == 60) { + return 29 + } else if (dayInYear > 60) { + dayInYear = dayInYear - 1 + } + } + + if (dayInYear <= 31) { + dayInYear + } else if (dayInYear <= 59) { + dayInYear - 31 + } else if (dayInYear <= 90) { + dayInYear - 59 + } else if (dayInYear <= 120) { + dayInYear - 90 + } else if (dayInYear <= 151) { + dayInYear - 120 + } else if (dayInYear <= 181) { + dayInYear - 151 + } else if (dayInYear <= 212) { + dayInYear - 181 + } else if (dayInYear <= 243) { + dayInYear - 212 + } else if (dayInYear <= 273) { + dayInYear - 243 + } else if (dayInYear <= 304) { + dayInYear - 273 + } else if (dayInYear <= 334) { + dayInYear - 304 + } else { + dayInYear - 334 + } + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateFunctionsSuite.scala new file mode 100644 index 0000000000000..49d0b0aceac0d --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateFunctionsSuite.scala @@ -0,0 +1,249 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import java.sql.{Timestamp, Date} +import java.text.SimpleDateFormat +import java.util.{TimeZone, Calendar} + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.types.{StringType, TimestampType, DateType} + +class DateFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { + + val sdf = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss") + val sdfDate = new SimpleDateFormat("yyyy-MM-dd") + val d = new Date(sdf.parse("2015-04-08 13:10:15").getTime) + val ts = new Timestamp(sdf.parse("2013-11-08 13:10:15").getTime) + + test("Day in Year") { + val sdfDay = new SimpleDateFormat("D") + (2002 to 2004).foreach { y => + (0 to 11).foreach { m => + (0 to 5).foreach { i => + val c = Calendar.getInstance() + c.set(y, m, 28, 0, 0, 0) + c.add(Calendar.DATE, i) + checkEvaluation(DayInYear(Cast(Literal(new Date(c.getTimeInMillis)), DateType)), + sdfDay.format(c.getTime).toInt) + } + } + } + + (1998 to 2002).foreach { y => + (0 to 11).foreach { m => + (0 to 5).foreach { i => + val c = Calendar.getInstance() + c.set(y, m, 28, 0, 0, 0) + c.add(Calendar.DATE, 1) + checkEvaluation(DayInYear(Cast(Literal(new Date(c.getTimeInMillis)), DateType)), + sdfDay.format(c.getTime).toInt) + } + } + } + + (1969 to 1970).foreach { y => + (0 to 11).foreach { m => + (0 to 5).foreach { i => + val c = Calendar.getInstance() + c.set(y, m, 28, 0, 0, 0) + c.add(Calendar.DATE, 1) + checkEvaluation(DayInYear(Cast(Literal(new Date(c.getTimeInMillis)), DateType)), + sdfDay.format(c.getTime).toInt) + } + } + } + + (2402 to 2404).foreach { y => + (0 to 11).foreach { m => + (0 to 5).foreach { i => + val c = Calendar.getInstance() + c.set(y, m, 28, 0, 0, 0) + c.add(Calendar.DATE, 1) + checkEvaluation(DayInYear(Cast(Literal(new Date(c.getTimeInMillis)), DateType)), + sdfDay.format(c.getTime).toInt) + } + } + } + + (2398 to 2402).foreach { y => + (0 to 11).foreach { m => + (0 to 5).foreach { i => + val c = Calendar.getInstance() + c.set(y, m, 28, 0, 0, 0) + c.add(Calendar.DATE, 1) + checkEvaluation(DayInYear(Cast(Literal(new Date(c.getTimeInMillis)), DateType)), + sdfDay.format(c.getTime).toInt) + } + } + } + } + + test("Year") { + checkEvaluation(Year(Literal.create(null, DateType)), null) + checkEvaluation(Year(Cast(Literal(d), DateType)), 2015) + checkEvaluation(Year(Cast(Literal(sdfDate.format(d)), DateType)), 2015) + checkEvaluation(Year(Cast(Literal(ts), DateType)), 2013) + + val c = Calendar.getInstance() + (2000 to 2010).foreach { y => + (0 to 11 by 11).foreach { m => + c.set(y, m, 28) + (0 to 5 * 24).foreach { i => + c.add(Calendar.HOUR_OF_DAY, 1) + checkEvaluation(Year(Cast(Literal(new Date(c.getTimeInMillis)), DateType)), + c.get(Calendar.YEAR)) + } + } + } + } + + test("Quarter") { + checkEvaluation(Quarter(Literal.create(null, DateType)), null) + checkEvaluation(Quarter(Cast(Literal(d), DateType)), 2) + checkEvaluation(Quarter(Cast(Literal(sdfDate.format(d)), DateType)), 2) + checkEvaluation(Quarter(Cast(Literal(ts), DateType)), 4) + + val c = Calendar.getInstance() + (2003 to 2004).foreach { y => + (0 to 11 by 3).foreach { m => + c.set(y, m, 28, 0, 0, 0) + (0 to 5 * 24).foreach { i => + c.add(Calendar.HOUR_OF_DAY, 1) + checkEvaluation(Quarter(Cast(Literal(new Date(c.getTimeInMillis)), DateType)), + c.get(Calendar.MONTH) / 3 + 1) + } + } + } + } + + test("Month") { + checkEvaluation(Month(Literal.create(null, DateType)), null) + checkEvaluation(Month(Cast(Literal(d), DateType)), 4) + checkEvaluation(Month(Cast(Literal(sdfDate.format(d)), DateType)), 4) + checkEvaluation(Month(Cast(Literal(ts), DateType)), 11) + + (2003 to 2004).foreach { y => + (0 to 11).foreach { m => + (0 to 5 * 24).foreach { i => + val c = Calendar.getInstance() + c.set(y, m, 28, 0, 0, 0) + c.add(Calendar.HOUR_OF_DAY, i) + checkEvaluation(Month(Cast(Literal(new Date(c.getTimeInMillis)), DateType)), + c.get(Calendar.MONTH) + 1) + } + } + } + + (1999 to 2000).foreach { y => + (0 to 11).foreach { m => + (0 to 5 * 24).foreach { i => + val c = Calendar.getInstance() + c.set(y, m, 28, 0, 0, 0) + c.add(Calendar.HOUR_OF_DAY, i) + checkEvaluation(Month(Cast(Literal(new Date(c.getTimeInMillis)), DateType)), + c.get(Calendar.MONTH) + 1) + } + } + } + } + + test("Day") { + checkEvaluation(Day(Cast(Literal("2000-02-29"), DateType)), 29) + checkEvaluation(Day(Literal.create(null, DateType)), null) + checkEvaluation(Day(Cast(Literal(d), DateType)), 8) + checkEvaluation(Day(Cast(Literal(sdfDate.format(d)), DateType)), 8) + checkEvaluation(Day(Cast(Literal(ts), DateType)), 8) + + (1999 to 2000).foreach { y => + val c = Calendar.getInstance() + c.set(y, 0, 1, 0, 0, 0) + (0 to 365).foreach { d => + c.add(Calendar.DATE, 1) + checkEvaluation(Day(Cast(Literal(new Date(c.getTimeInMillis)), DateType)), + c.get(Calendar.DAY_OF_MONTH)) + } + } + } + + test("Seconds") { + checkEvaluation(Second(Literal.create(null, DateType)), null) + checkEvaluation(Second(Cast(Literal(d), TimestampType)), 0) + checkEvaluation(Second(Cast(Literal(sdf.format(d)), TimestampType)), 15) + checkEvaluation(Second(Literal(ts)), 15) + + val c = Calendar.getInstance() + (0 to 60 by 5).foreach { s => + c.set(2015, 18, 3, 3, 5, s) + checkEvaluation(Second(Cast(Literal(new Timestamp(c.getTimeInMillis)), TimestampType)), + c.get(Calendar.SECOND)) + } + } + + test("WeekOfYear") { + checkEvaluation(WeekOfYear(Literal.create(null, DateType)), null) + checkEvaluation(WeekOfYear(Cast(Literal(d), DateType)), 15) + checkEvaluation(WeekOfYear(Cast(Literal(sdfDate.format(d)), DateType)), 15) + checkEvaluation(WeekOfYear(Cast(Literal(ts), DateType)), 45) + checkEvaluation(WeekOfYear(Cast(Literal("2011-05-06"), DateType)), 18) + } + + test("DateFormat") { + checkEvaluation(DateFormatClass(Literal.create(null, TimestampType), Literal("y")), null) + checkEvaluation(DateFormatClass(Cast(Literal(d), TimestampType), + Literal.create(null, StringType)), null) + checkEvaluation(DateFormatClass(Cast(Literal(d), TimestampType), + Literal("y")), "2015") + checkEvaluation(DateFormatClass(Literal(ts), Literal("y")), "2013") + } + + test("Hour") { + checkEvaluation(Hour(Literal.create(null, DateType)), null) + checkEvaluation(Hour(Cast(Literal(d), TimestampType)), 0) + checkEvaluation(Hour(Cast(Literal(sdf.format(d)), TimestampType)), 13) + checkEvaluation(Hour(Literal(ts)), 13) + + val c = Calendar.getInstance() + (0 to 24).foreach { h => + (0 to 60 by 15).foreach { m => + (0 to 60 by 15).foreach { s => + c.set(2015, 18, 3, h, m, s) + checkEvaluation(Hour(Cast(Literal(new Timestamp(c.getTimeInMillis)), TimestampType)), + c.get(Calendar.HOUR_OF_DAY)) + } + } + } + } + + test("Minute") { + checkEvaluation(Minute(Literal.create(null, DateType)), null) + checkEvaluation(Minute(Cast(Literal(d), TimestampType)), 0) + checkEvaluation(Minute(Cast(Literal(sdf.format(d)), TimestampType)), 10) + checkEvaluation(Minute(Literal(ts)), 10) + + val c = Calendar.getInstance() + (0 to 60 by 5).foreach { m => + (0 to 60 by 15).foreach { s => + c.set(2015, 18, 3, 3, m, s) + checkEvaluation(Minute(Cast(Literal(new Timestamp(c.getTimeInMillis)), TimestampType)), + c.get(Calendar.MINUTE)) + } + } + } + +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala index 04c5f09792ac3..fab9eb9cd4c9f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala @@ -26,6 +26,11 @@ import org.apache.spark.unsafe.types.UTF8String class DateTimeUtilsSuite extends SparkFunSuite { + private[this] def getInUTCDays(timestamp: Long): Int = { + val tz = TimeZone.getDefault + ((timestamp + tz.getOffset(timestamp)) / DateTimeUtils.MILLIS_PER_DAY).toInt + } + test("timestamp and us") { val now = new Timestamp(System.currentTimeMillis()) now.setNanos(1000) @@ -277,28 +282,6 @@ class DateTimeUtilsSuite extends SparkFunSuite { assert(DateTimeUtils.stringToTimestamp( UTF8String.fromString("2011-05-06 07:08:09.1000")).get === c.getTimeInMillis * 1000) - val defaultTimeZone = TimeZone.getDefault - TimeZone.setDefault(TimeZone.getTimeZone("America/Los_Angeles")) - - c = Calendar.getInstance() - c.set(2015, 2, 8, 2, 0, 0) - c.set(Calendar.MILLISECOND, 0) - assert(DateTimeUtils.stringToTimestamp( - UTF8String.fromString("2015-3-8 2:0:0")).get === c.getTimeInMillis * 1000) - c.add(Calendar.MINUTE, 30) - assert(DateTimeUtils.stringToTimestamp( - UTF8String.fromString("2015-3-8 3:30:0")).get === c.getTimeInMillis * 1000) - assert(DateTimeUtils.stringToTimestamp( - UTF8String.fromString("2015-3-8 2:30:0")).get === c.getTimeInMillis * 1000) - - c = Calendar.getInstance() - c.set(2015, 10, 1, 1, 59, 0) - c.set(Calendar.MILLISECOND, 0) - c.add(Calendar.MINUTE, 31) - assert(DateTimeUtils.stringToTimestamp( - UTF8String.fromString("2015-11-1 2:30:0")).get === c.getTimeInMillis * 1000) - TimeZone.setDefault(defaultTimeZone) - assert(DateTimeUtils.stringToTimestamp(UTF8String.fromString("238")).isEmpty) assert(DateTimeUtils.stringToTimestamp(UTF8String.fromString("2015-03-18 123142")).isEmpty) assert(DateTimeUtils.stringToTimestamp(UTF8String.fromString("2015-03-18T123123")).isEmpty) @@ -314,4 +297,68 @@ class DateTimeUtilsSuite extends SparkFunSuite { assert(DateTimeUtils.stringToTimestamp( UTF8String.fromString("2015-03-18T12:03.17-1:0:0")).isEmpty) } + + test("hours") { + val c = Calendar.getInstance() + c.set(2015, 2, 18, 13, 2, 11) + assert(DateTimeUtils.getHours(c.getTimeInMillis * 1000) === 13) + c.set(2015, 12, 8, 2, 7, 9) + assert(DateTimeUtils.getHours(c.getTimeInMillis * 1000) === 2) + } + + test("minutes") { + val c = Calendar.getInstance() + c.set(2015, 2, 18, 13, 2, 11) + assert(DateTimeUtils.getMinutes(c.getTimeInMillis * 1000) === 2) + c.set(2015, 2, 8, 2, 7, 9) + assert(DateTimeUtils.getMinutes(c.getTimeInMillis * 1000) === 7) + } + + test("seconds") { + val c = Calendar.getInstance() + c.set(2015, 2, 18, 13, 2, 11) + assert(DateTimeUtils.getSeconds(c.getTimeInMillis * 1000) === 11) + c.set(2015, 2, 8, 2, 7, 9) + assert(DateTimeUtils.getSeconds(c.getTimeInMillis * 1000) === 9) + } + + test("get day in year") { + val c = Calendar.getInstance() + c.set(2015, 2, 18, 0, 0, 0) + assert(DateTimeUtils.getDayInYear(getInUTCDays(c.getTimeInMillis)) === 77) + c.set(2012, 2, 18, 0, 0, 0) + assert(DateTimeUtils.getDayInYear(getInUTCDays(c.getTimeInMillis)) === 78) + } + + test("get year") { + val c = Calendar.getInstance() + c.set(2015, 2, 18, 0, 0, 0) + assert(DateTimeUtils.getYear(getInUTCDays(c.getTimeInMillis)) === 2015) + c.set(2012, 2, 18, 0, 0, 0) + assert(DateTimeUtils.getYear(getInUTCDays(c.getTimeInMillis)) === 2012) + } + + test("get quarter") { + val c = Calendar.getInstance() + c.set(2015, 2, 18, 0, 0, 0) + assert(DateTimeUtils.getQuarter(getInUTCDays(c.getTimeInMillis)) === 1) + c.set(2012, 11, 18, 0, 0, 0) + assert(DateTimeUtils.getQuarter(getInUTCDays(c.getTimeInMillis)) === 4) + } + + test("get month") { + val c = Calendar.getInstance() + c.set(2015, 2, 18, 0, 0, 0) + assert(DateTimeUtils.getMonth(getInUTCDays(c.getTimeInMillis)) === 3) + c.set(2012, 11, 18, 0, 0, 0) + assert(DateTimeUtils.getMonth(getInUTCDays(c.getTimeInMillis)) === 12) + } + + test("get day of month") { + val c = Calendar.getInstance() + c.set(2015, 2, 18, 0, 0, 0) + assert(DateTimeUtils.getDayOfMonth(getInUTCDays(c.getTimeInMillis)) === 18) + c.set(2012, 11, 24, 0, 0, 0) + assert(DateTimeUtils.getDayOfMonth(getInUTCDays(c.getTimeInMillis)) === 24) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index c180407389136..cadb25d597d19 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1748,6 +1748,182 @@ object functions { */ def length(columnName: String): Column = length(Column(columnName)) + ////////////////////////////////////////////////////////////////////////////////////////////// + // DateTime functions + ////////////////////////////////////////////////////////////////////////////////////////////// + + /** + * Converts a date/timestamp/string to a value of string in the format specified by the date + * format given by the second argument. + * + * A pattern could be for instance `dd.MM.yyyy` and could return a string like '18.03.1993'. All + * pattern letters of [[java.text.SimpleDateFormat]] can be used. + * + * NOTE: Use when ever possible specialized functions like [[year]]. These benefit from a + * specialized implementation. + * + * @group datetime_funcs + * @since 1.5.0 + */ + def date_format(dateExpr: Column, format: String): Column = + DateFormatClass(dateExpr.expr, Literal(format)) + + /** + * Converts a date/timestamp/string to a value of string in the format specified by the date + * format given by the second argument. + * + * A pattern could be for instance `dd.MM.yyyy` and could return a string like '18.03.1993'. All + * pattern letters of [[java.text.SimpleDateFormat]] can be used. + * + * NOTE: Use when ever possible specialized functions like [[year]]. These benefit from a + * specialized implementation. + * + * @group datetime_funcs + * @since 1.5.0 + */ + def date_format(dateColumnName: String, format: String): Column = + date_format(Column(dateColumnName), format) + + /** + * Extracts the year as an integer from a given date/timestamp/string. + * @group datetime_funcs + * @since 1.5.0 + */ + def year(e: Column): Column = Year(e.expr) + + /** + * Extracts the year as an integer from a given date/timestamp/string. + * @group datetime_funcs + * @since 1.5.0 + */ + def year(columnName: String): Column = year(Column(columnName)) + + /** + * Extracts the quarter as an integer from a given date/timestamp/string. + * @group datetime_funcs + * @since 1.5.0 + */ + def quarter(e: Column): Column = Quarter(e.expr) + + /** + * Extracts the quarter as an integer from a given date/timestamp/string. + * @group datetime_funcs + * @since 1.5.0 + */ + def quarter(columnName: String): Column = quarter(Column(columnName)) + + /** + * Extracts the month as an integer from a given date/timestamp/string. + * @group datetime_funcs + * @since 1.5.0 + */ + def month(e: Column): Column = Month(e.expr) + + /** + * Extracts the month as an integer from a given date/timestamp/string. + * @group datetime_funcs + * @since 1.5.0 + */ + def month(columnName: String): Column = month(Column(columnName)) + + /** + * Extracts the day of the month as an integer from a given date/timestamp/string. + * @group datetime_funcs + * @since 1.5.0 + */ + def day(e: Column): Column = Day(e.expr) + + /** + * Extracts the day of the month as an integer from a given date/timestamp/string. + * @group datetime_funcs + * @since 1.5.0 + */ + def day(columnName: String): Column = day(Column(columnName)) + + /** + * Extracts the day of the month as an integer from a given date/timestamp/string. + * @group datetime_funcs + * @since 1.5.0 + */ + def day_of_month(e: Column): Column = Day(e.expr) + + /** + * Extracts the day of the month as an integer from a given date/timestamp/string. + * @group datetime_funcs + * @since 1.5.0 + */ + def day_of_month(columnName: String): Column = day_of_month(Column(columnName)) + + /** + * Extracts the day of the year as an integer from a given date/timestamp/string. + * @group datetime_funcs + * @since 1.5.0 + */ + def day_in_year(e: Column): Column = DayInYear(e.expr) + + /** + * Extracts the day of the year as an integer from a given date/timestamp/string. + * @group datetime_funcs + * @since 1.5.0 + */ + def day_in_year(columnName: String): Column = day_in_year(Column(columnName)) + + /** + * Extracts the hours as an integer from a given date/timestamp/string. + * @group datetime_funcs + * @since 1.5.0 + */ + def hour(e: Column): Column = Hour(e.expr) + + /** + * Extracts the hours as an integer from a given date/timestamp/string. + * @group datetime_funcs + * @since 1.5.0 + */ + def hour(columnName: String): Column = hour(Column(columnName)) + + /** + * Extracts the minutes as an integer from a given date/timestamp/string. + * @group datetime_funcs + * @since 1.5.0 + */ + def minute(e: Column): Column = Minute(e.expr) + + /** + * Extracts the minutes as an integer from a given date/timestamp/string. + * @group datetime_funcs + * @since 1.5.0 + */ + def minute(columnName: String): Column = minute(Column(columnName)) + + /** + * Extracts the seconds as an integer from a given date/timestamp/string. + * @group datetime_funcs + * @since 1.5.0 + */ + def second(e: Column): Column = Second(e.expr) + + /** + * Extracts the seconds as an integer from a given date/timestamp/string. + * @group datetime_funcs + * @since 1.5.0 + */ + def second(columnName: String): Column = second(Column(columnName)) + + /** + * Extracts the week number as an integer from a given date/timestamp/string. + * @group datetime_funcs + * @since 1.5.0 + */ + def week_of_year(e: Column): Column = WeekOfYear(e.expr) + + /** + * Extracts the week number as an integer from a given date/timestamp/string. + * @group datetime_funcs + * @since 1.5.0 + */ + def week_of_year(columnName: String): Column = week_of_year(Column(columnName)) + /** * Formats the number X to a format like '#,###,###.##', rounded to d decimal places, * and returns the result as a string. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DateExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DateExpressionsSuite.scala new file mode 100644 index 0000000000000..d24e3ee1dd8f5 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/DateExpressionsSuite.scala @@ -0,0 +1,170 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import java.sql.{Timestamp, Date} +import java.text.SimpleDateFormat + +import org.apache.spark.sql.functions._ + +class DateExpressionsSuite extends QueryTest { + private lazy val ctx = org.apache.spark.sql.test.TestSQLContext + + import ctx.implicits._ + + val sdf = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss") + val sdfDate = new SimpleDateFormat("yyyy-MM-dd") + val d = new Date(sdf.parse("2015-04-08 13:10:15").getTime) + val ts = new Timestamp(sdf.parse("2013-04-08 13:10:15").getTime) + + + test("date format") { + val df = Seq((d, sdf.format(d), ts)).toDF("a", "b", "c") + + checkAnswer( + df.select(date_format("a", "y"), date_format("b", "y"), date_format("c", "y")), + Row("2015", "2015", "2013")) + + checkAnswer( + df.selectExpr("date_format(a, 'y')", "date_format(b, 'y')", "date_format(c, 'y')"), + Row("2015", "2015", "2013")) + } + + test("year") { + val df = Seq((d, sdfDate.format(d), ts)).toDF("a", "b", "c") + + checkAnswer( + df.select(year("a"), year("b"), year("c")), + Row(2015, 2015, 2013)) + + checkAnswer( + df.selectExpr("year(a)", "year(b)", "year(c)"), + Row(2015, 2015, 2013)) + } + + test("quarter") { + val ts = new Timestamp(sdf.parse("2013-11-08 13:10:15").getTime) + + val df = Seq((d, sdfDate.format(d), ts)).toDF("a", "b", "c") + + checkAnswer( + df.select(quarter("a"), quarter("b"), quarter("c")), + Row(2, 2, 4)) + + checkAnswer( + df.selectExpr("quarter(a)", "quarter(b)", "quarter(c)"), + Row(2, 2, 4)) + } + + test("month") { + val df = Seq((d, sdfDate.format(d), ts)).toDF("a", "b", "c") + + checkAnswer( + df.select(month("a"), month("b"), month("c")), + Row(4, 4, 4)) + + checkAnswer( + df.selectExpr("month(a)", "month(b)", "month(c)"), + Row(4, 4, 4)) + } + + test("day") { + val df = Seq((d, sdfDate.format(d), ts)).toDF("a", "b", "c") + + checkAnswer( + df.select(day("a"), day("b"), day("c")), + Row(8, 8, 8)) + + checkAnswer( + df.selectExpr("day(a)", "day(b)", "day(c)"), + Row(8, 8, 8)) + } + + test("day of month") { + val df = Seq((d, sdfDate.format(d), ts)).toDF("a", "b", "c") + + checkAnswer( + df.select(day_of_month("a"), day_of_month("b"), day_of_month("c")), + Row(8, 8, 8)) + + checkAnswer( + df.selectExpr("day_of_month(a)", "day_of_month(b)", "day_of_month(c)"), + Row(8, 8, 8)) + } + + test("day in year") { + val df = Seq((d, sdfDate.format(d), ts)).toDF("a", "b", "c") + + checkAnswer( + df.select(day_in_year("a"), day_in_year("b"), day_in_year("c")), + Row(98, 98, 98)) + + checkAnswer( + df.selectExpr("day_in_year(a)", "day_in_year(b)", "day_in_year(c)"), + Row(98, 98, 98)) + } + + test("hour") { + val df = Seq((d, sdf.format(d), ts)).toDF("a", "b", "c") + + checkAnswer( + df.select(hour("a"), hour("b"), hour("c")), + Row(0, 13, 13)) + + checkAnswer( + df.selectExpr("hour(a)", "hour(b)", "hour(c)"), + Row(0, 13, 13)) + } + + test("minute") { + val df = Seq((d, sdf.format(d), ts)).toDF("a", "b", "c") + + checkAnswer( + df.select(minute("a"), minute("b"), minute("c")), + Row(0, 10, 10)) + + checkAnswer( + df.selectExpr("minute(a)", "minute(b)", "minute(c)"), + Row(0, 10, 10)) + } + + test("second") { + val df = Seq((d, sdf.format(d), ts)).toDF("a", "b", "c") + + checkAnswer( + df.select(second("a"), second("b"), second("c")), + Row(0, 15, 15)) + + checkAnswer( + df.selectExpr("second(a)", "second(b)", "second(c)"), + Row(0, 15, 15)) + } + + test("week of year") { + val df = Seq((d, sdfDate.format(d), ts)).toDF("a", "b", "c") + + checkAnswer( + df.select(week_of_year("a"), week_of_year("b"), week_of_year("c")), + Row(15, 15, 15)) + + checkAnswer( + df.selectExpr("week_of_year(a)", "week_of_year(b)", "week_of_year(c)"), + Row(15, 15, 15)) + } + +} diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index 299cc599ff8f7..2689d904d6541 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -115,6 +115,13 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { // This test is totally fine except that it includes wrong queries and expects errors, but error // message format in Hive and Spark SQL differ. Should workaround this later. "udf_to_unix_timestamp", + // we can cast dates likes '2015-03-18' to a timestamp and extract the seconds. + // Hive returns null for second('2015-03-18') + "udf_second", + // we can cast dates likes '2015-03-18' to a timestamp and extract the minutes. + // Hive returns null for minute('2015-03-18') + "udf_minute", + // Cant run without local map/reduce. "index_auto_update", @@ -896,7 +903,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "udf_lpad", "udf_ltrim", "udf_map", - "udf_minute", "udf_modulo", "udf_month", "udf_named_struct", @@ -923,7 +929,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "udf_round_3", "udf_rpad", "udf_rtrim", - "udf_second", "udf_sign", "udf_sin", "udf_smallint", From 04c1b49f5eee915ad1159a32bf12836a3b9f2620 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Sat, 18 Jul 2015 22:50:34 -0700 Subject: [PATCH 47/58] Fixed test cases. --- .../spark/sql/catalyst/expressions/DateFunctionsSuite.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateFunctionsSuite.scala index 49d0b0aceac0d..f469f42116d21 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateFunctionsSuite.scala @@ -50,7 +50,7 @@ class DateFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { (0 to 5).foreach { i => val c = Calendar.getInstance() c.set(y, m, 28, 0, 0, 0) - c.add(Calendar.DATE, 1) + c.add(Calendar.DATE, i) checkEvaluation(DayInYear(Cast(Literal(new Date(c.getTimeInMillis)), DateType)), sdfDay.format(c.getTime).toInt) } @@ -62,7 +62,7 @@ class DateFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { (0 to 5).foreach { i => val c = Calendar.getInstance() c.set(y, m, 28, 0, 0, 0) - c.add(Calendar.DATE, 1) + c.add(Calendar.DATE, i) checkEvaluation(DayInYear(Cast(Literal(new Date(c.getTimeInMillis)), DateType)), sdfDay.format(c.getTime).toInt) } From a9a0d0cebf8ab3c539723488e5945794ebfd6104 Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Sat, 18 Jul 2015 23:44:38 -0700 Subject: [PATCH 48/58] [SPARK-8638] [SQL] Window Function Performance Improvements ## Description Performance improvements for Spark Window functions. This PR will also serve as the basis for moving away from Hive UDAFs to Spark UDAFs. See JIRA tickets SPARK-8638 and SPARK-7712 for more information. ## Improvements * Much better performance (10x) in running cases (e.g. BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) and UNBOUDED FOLLOWING cases. The current implementation in spark uses a sliding window approach in these cases. This means that an aggregate is maintained for every row, so space usage is N (N being the number of rows). This also means that all these aggregates all need to be updated separately, this takes N*(N-1)/2 updates. The running case differs from the Sliding case because we are only adding data to an aggregate function (no reset is required), we only need to maintain one aggregate (like in the UNBOUNDED PRECEDING AND UNBOUNDED case), update the aggregate for each row, and get the aggregate value after each update. This is what the new implementation does. This approach only uses 1 buffer, and only requires N updates; I am currently working on data with window sizes of 500-1000 doing running sums and this saves a lot of time. The CURRENT ROW AND UNBOUNDED FOLLOWING case also uses this approach and the fact that aggregate operations are communitative, there is one twist though it will process the input buffer in reverse. * Fewer comparisons in the sliding case. The current implementation determines frame boundaries for every input row. The new implementation makes more use of the fact that the window is sorted, maintains the boundaries, and only moves them when the current row order changes. This is a minor improvement. * A single Window node is able to process all types of Frames for the same Partitioning/Ordering. This saves a little time/memory spent buffering and managing partitions. This will be enabled in a follow-up PR. * A lot of the staging code is moved from the execution phase to the initialization phase. Minor performance improvement, and improves readability of the execution code. ## Benchmarking I have done a small benchmark using [on time performance](http://www.transtats.bts.gov) data of the month april. I have used the origin as a partioning key, as a result there is quite some variation in window sizes. The code for the benchmark can be found in the JIRA ticket. These are the results per Frame type: Frame | Master | SPARK-8638 ----- | ------ | ---------- Entire Frame | 2 s | 1 s Sliding | 18 s | 1 s Growing | 14 s | 0.9 s Shrinking | 13 s | 1 s Author: Herman van Hovell Closes #7057 from hvanhovell/SPARK-8638 and squashes the following commits: 3bfdc49 [Herman van Hovell] Fixed Perfomance Regression for Shrinking Window Frames (+Rebase) 2eb3b33 [Herman van Hovell] Corrected reverse range frame processing. 2cd2d5b [Herman van Hovell] Corrected reverse range frame processing. b0654d7 [Herman van Hovell] Tests for exotic frame specifications. e75b76e [Herman van Hovell] More docs, added support for reverse sliding range frames, and some reorganization of code. 1fdb558 [Herman van Hovell] Changed Data In HiveDataFrameWindowSuite. ac2f682 [Herman van Hovell] Added a few more comments. 1938312 [Herman van Hovell] Added Documentation to the createBoundOrdering methods. bb020e6 [Herman van Hovell] Major overhaul of Window operator. --- .../expressions/windowExpressions.scala | 12 + .../apache/spark/sql/execution/Window.scala | 1072 +++++++++++------ .../sql/hive/HiveDataFrameWindowSuite.scala | 6 +- .../sql/hive/execution/WindowSuite.scala | 79 ++ 4 files changed, 765 insertions(+), 404 deletions(-) create mode 100644 sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/WindowSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala index 50bbfd644d302..09ec0e333aa44 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala @@ -316,3 +316,15 @@ case class WindowExpression( override def toString: String = s"$windowFunction $windowSpec" } + +/** + * Extractor for making working with frame boundaries easier. + */ +object FrameBoundaryExtractor { + def unapply(boundary: FrameBoundary): Option[Int] = boundary match { + case CurrentRow => Some(0) + case ValuePreceding(offset) => Some(-offset) + case ValueFollowing(offset) => Some(offset) + case _ => None + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala index 6e127e548a120..a054f52b8b489 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala @@ -19,18 +19,64 @@ package org.apache.spark.sql.execution import java.util -import org.apache.spark.rdd.RDD +import org.apache.spark.annotation.DeveloperApi import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, Partitioning} +import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.types.IntegerType +import org.apache.spark.rdd.RDD import org.apache.spark.util.collection.CompactBuffer +import scala.collection.mutable /** * :: DeveloperApi :: - * For every row, evaluates `windowExpression` containing Window Functions and attaches - * the results with other regular expressions (presented by `projectList`). - * Evert operator handles a single Window Specification, `windowSpec`. + * This class calculates and outputs (windowed) aggregates over the rows in a single (sorted) + * partition. The aggregates are calculated for each row in the group. Special processing + * instructions, frames, are used to calculate these aggregates. Frames are processed in the order + * specified in the window specification (the ORDER BY ... clause). There are four different frame + * types: + * - Entire partition: The frame is the entire partition, i.e. + * UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING. For this case, window function will take all + * rows as inputs and be evaluated once. + * - Growing frame: We only add new rows into the frame, i.e. UNBOUNDED PRECEDING AND .... + * Every time we move to a new row to process, we add some rows to the frame. We do not remove + * rows from this frame. + * - Shrinking frame: We only remove rows from the frame, i.e. ... AND UNBOUNDED FOLLOWING. + * Every time we move to a new row to process, we remove some rows from the frame. We do not add + * rows to this frame. + * - Moving frame: Every time we move to a new row to process, we remove some rows from the frame + * and we add some rows to the frame. Examples are: + * 1 PRECEDING AND CURRENT ROW and 1 FOLLOWING AND 2 FOLLOWING. + * + * Different frame boundaries can be used in Growing, Shrinking and Moving frames. A frame + * boundary can be either Row or Range based: + * - Row Based: A row based boundary is based on the position of the row within the partition. + * An offset indicates the number of rows above or below the current row, the frame for the + * current row starts or ends. For instance, given a row based sliding frame with a lower bound + * offset of -1 and a upper bound offset of +2. The frame for row with index 5 would range from + * index 4 to index 6. + * - Range based: A range based boundary is based on the actual value of the ORDER BY + * expression(s). An offset is used to alter the value of the ORDER BY expression, for + * instance if the current order by expression has a value of 10 and the lower bound offset + * is -3, the resulting lower bound for the current row will be 10 - 3 = 7. This however puts a + * number of constraints on the ORDER BY expressions: there can be only one expression and this + * expression must have a numerical data type. An exception can be made when the offset is 0, + * because no value modification is needed, in this case multiple and non-numeric ORDER BY + * expression are allowed. + * + * This is quite an expensive operator because every row for a single group must be in the same + * partition and partitions must be sorted according to the grouping and sort order. The operator + * requires the planner to take care of the partitioning and sorting. + * + * The operator is semi-blocking. The window functions and aggregates are calculated one group at + * a time, the result will only be made available after the processing for the entire group has + * finished. The operator is able to process different frame configurations at the same time. This + * is done by delegating the actual frame processing (i.e. calculation of the window functions) to + * specialized classes, see [[WindowFunctionFrame]], which take care of their own frame type: + * Entire Partition, Sliding, Growing & Shrinking. Boundary evaluation is also delegated to a pair + * of specialized classes: [[RowBoundOrdering]] & [[RangeBoundOrdering]]. */ +@DeveloperApi case class Window( projectList: Seq[Attribute], windowExpression: Seq[NamedExpression], @@ -38,443 +84,667 @@ case class Window( child: SparkPlan) extends UnaryNode { - override def output: Seq[Attribute] = - (projectList ++ windowExpression).map(_.toAttribute) + override def output: Seq[Attribute] = projectList ++ windowExpression.map(_.toAttribute) - override def requiredChildDistribution: Seq[Distribution] = + override def requiredChildDistribution: Seq[Distribution] = { if (windowSpec.partitionSpec.isEmpty) { - // This operator will be very expensive. + // Only show warning when the number of bytes is larger than 100 MB? + logWarning("No Partition Defined for Window operation! Moving all data to a single " + + "partition, this can cause serious performance degradation.") AllTuples :: Nil - } else { - ClusteredDistribution(windowSpec.partitionSpec) :: Nil - } - - // Since window functions are adding columns to the input rows, the child's outputPartitioning - // is preserved. - override def outputPartitioning: Partitioning = child.outputPartitioning - - override def requiredChildOrdering: Seq[Seq[SortOrder]] = { - // The required child ordering has two parts. - // The first part is the expressions in the partition specification. - // We add these expressions to the required ordering to make sure input rows are grouped - // based on the partition specification. So, we only need to process a single partition - // at a time. - // The second part is the expressions specified in the ORDER BY cluase. - // Basically, we first use sort to group rows based on partition specifications and then sort - // Rows in a group based on the order specification. - (windowSpec.partitionSpec.map(SortOrder(_, Ascending)) ++ windowSpec.orderSpec) :: Nil + } else ClusteredDistribution(windowSpec.partitionSpec) :: Nil } - // Since window functions basically add columns to input rows, this operator - // will not change the ordering of input rows. + override def requiredChildOrdering: Seq[Seq[SortOrder]] = + Seq(windowSpec.partitionSpec.map(SortOrder(_, Ascending)) ++ windowSpec.orderSpec) + override def outputOrdering: Seq[SortOrder] = child.outputOrdering - case class ComputedWindow( - unbound: WindowExpression, - windowFunction: WindowFunction, - resultAttribute: AttributeReference) - - // A list of window functions that need to be computed for each group. - private[this] val computedWindowExpressions = windowExpression.flatMap { window => - window.collect { - case w: WindowExpression => - ComputedWindow( - w, - BindReferences.bindReference(w.windowFunction, child.output), - AttributeReference(s"windowResult:$w", w.dataType, w.nullable)()) + /** + * Create a bound ordering object for a given frame type and offset. A bound ordering object is + * used to determine which input row lies within the frame boundaries of an output row. + * + * This method uses Code Generation. It can only be used on the executor side. + * + * @param frameType to evaluate. This can either be Row or Range based. + * @param offset with respect to the row. + * @return a bound ordering object. + */ + private[this] def createBoundOrdering(frameType: FrameType, offset: Int): BoundOrdering = { + frameType match { + case RangeFrame => + val (exprs, current, bound) = if (offset == 0) { + // Use the entire order expression when the offset is 0. + val exprs = windowSpec.orderSpec.map(_.child) + val projection = newMutableProjection(exprs, child.output) + (windowSpec.orderSpec, projection(), projection()) + } + else if (windowSpec.orderSpec.size == 1) { + // Use only the first order expression when the offset is non-null. + val sortExpr = windowSpec.orderSpec.head + val expr = sortExpr.child + // Create the projection which returns the current 'value'. + val current = newMutableProjection(expr :: Nil, child.output)() + // Flip the sign of the offset when processing the order is descending + val boundOffset = if (sortExpr.direction == Descending) -offset + else offset + // Create the projection which returns the current 'value' modified by adding the offset. + val boundExpr = Add(expr, Cast(Literal.create(boundOffset, IntegerType), expr.dataType)) + val bound = newMutableProjection(boundExpr :: Nil, child.output)() + (sortExpr :: Nil, current, bound) + } + else { + sys.error("Non-Zero range offsets are not supported for windows " + + "with multiple order expressions.") + } + // Construct the ordering. This is used to compare the result of current value projection + // to the result of bound value projection. This is done manually because we want to use + // Code Generation (if it is enabled). + val (sortExprs, schema) = exprs.map { case e => + val ref = AttributeReference("ordExpr", e.dataType, e.nullable)() + (SortOrder(ref, e.direction), ref) + }.unzip + val ordering = newOrdering(sortExprs, schema) + RangeBoundOrdering(ordering, current, bound) + case RowFrame => RowBoundOrdering(offset) } - }.toArray + } - private[this] val windowFrame = - windowSpec.frameSpecification.asInstanceOf[SpecifiedWindowFrame] + /** + * Create a frame processor. + * + * This method uses Code Generation. It can only be used on the executor side. + * + * @param frame boundaries. + * @param functions to process in the frame. + * @param ordinal at which the processor starts writing to the output. + * @return a frame processor. + */ + private[this] def createFrameProcessor( + frame: WindowFrame, + functions: Array[WindowFunction], + ordinal: Int): WindowFunctionFrame = frame match { + // Growing Frame. + case SpecifiedWindowFrame(frameType, UnboundedPreceding, FrameBoundaryExtractor(high)) => + val uBoundOrdering = createBoundOrdering(frameType, high) + new UnboundedPrecedingWindowFunctionFrame(ordinal, functions, uBoundOrdering) + + // Shrinking Frame. + case SpecifiedWindowFrame(frameType, FrameBoundaryExtractor(low), UnboundedFollowing) => + val lBoundOrdering = createBoundOrdering(frameType, low) + new UnboundedFollowingWindowFunctionFrame(ordinal, functions, lBoundOrdering) + + // Moving Frame. + case SpecifiedWindowFrame(frameType, + FrameBoundaryExtractor(low), FrameBoundaryExtractor(high)) => + val lBoundOrdering = createBoundOrdering(frameType, low) + val uBoundOrdering = createBoundOrdering(frameType, high) + new SlidingWindowFunctionFrame(ordinal, functions, lBoundOrdering, uBoundOrdering) + + // Entire Partition Frame. + case SpecifiedWindowFrame(_, UnboundedPreceding, UnboundedFollowing) => + new UnboundedWindowFunctionFrame(ordinal, functions) + + // Error + case fr => + sys.error(s"Unsupported Frame $fr for functions: $functions") + } - // Create window functions. - private[this] def windowFunctions(): Array[WindowFunction] = { - val functions = new Array[WindowFunction](computedWindowExpressions.length) - var i = 0 - while (i < computedWindowExpressions.length) { - functions(i) = computedWindowExpressions(i).windowFunction.newInstance() - functions(i).init() - i += 1 + /** + * Create the resulting projection. + * + * This method uses Code Generation. It can only be used on the executor side. + * + * @param expressions unbound ordered function expressions. + * @return the final resulting projection. + */ + private[this] def createResultProjection( + expressions: Seq[Expression]): MutableProjection = { + val unboundToAttr = expressions.map { + e => (e, AttributeReference("windowResult", e.dataType, e.nullable)()) } - functions + val unboundToAttrMap = unboundToAttr.toMap + val patchedWindowExpression = windowExpression.map(_.transform(unboundToAttrMap)) + newMutableProjection( + projectList ++ patchedWindowExpression, + child.output ++ unboundToAttr.map(_._2))() } - // The schema of the result of all window function evaluations - private[this] val computedSchema = computedWindowExpressions.map(_.resultAttribute) - - private[this] val computedResultMap = - computedWindowExpressions.map { w => w.unbound -> w.resultAttribute }.toMap + protected override def doExecute(): RDD[InternalRow] = { + // Prepare processing. + // Group the window expression by their processing frame. + val windowExprs = windowExpression.flatMap { + _.collect { + case e: WindowExpression => e + } + } - private[this] val windowExpressionResult = windowExpression.map { window => - window.transform { - case w: WindowExpression if computedResultMap.contains(w) => computedResultMap(w) + // Create Frame processor factories and order the unbound window expressions by the frame they + // are processed in; this is the order in which their results will be written to window + // function result buffer. + val framedWindowExprs = windowExprs.groupBy(_.windowSpec.frameSpecification) + val factories = Array.ofDim[() => WindowFunctionFrame](framedWindowExprs.size) + val unboundExpressions = mutable.Buffer.empty[Expression] + framedWindowExprs.zipWithIndex.foreach { + case ((frame, unboundFrameExpressions), index) => + // Track the ordinal. + val ordinal = unboundExpressions.size + + // Track the unbound expressions + unboundExpressions ++= unboundFrameExpressions + + // Bind the expressions. + val functions = unboundFrameExpressions.map { e => + BindReferences.bindReference(e.windowFunction, child.output) + }.toArray + + // Create the frame processor factory. + factories(index) = () => createFrameProcessor(frame, functions, ordinal) } - } - protected override def doExecute(): RDD[InternalRow] = { - child.execute().mapPartitions { iter => + // Start processing. + child.execute().mapPartitions { stream => new Iterator[InternalRow] { - // Although input rows are grouped based on windowSpec.partitionSpec, we need to - // know when we have a new partition. - // This is to manually construct an ordering that can be used to compare rows. - // TODO: We may want to have a newOrdering that takes BoundReferences. - // So, we can take advantave of code gen. - private val partitionOrdering: Ordering[InternalRow] = - RowOrdering.forSchema(windowSpec.partitionSpec.map(_.dataType)) - - // This is used to project expressions for the partition specification. - protected val partitionGenerator = - newMutableProjection(windowSpec.partitionSpec, child.output)() - - // This is ued to project expressions for the order specification. - protected val rowOrderGenerator = - newMutableProjection(windowSpec.orderSpec.map(_.child), child.output)() - - // The position of next output row in the inputRowBuffer. - var rowPosition: Int = 0 - // The number of buffered rows in the inputRowBuffer (the size of the current partition). - var partitionSize: Int = 0 - // The buffer used to buffer rows in a partition. - var inputRowBuffer: CompactBuffer[InternalRow] = _ - // The partition key of the current partition. - var currentPartitionKey: InternalRow = _ - // The partition key of next partition. - var nextPartitionKey: InternalRow = _ - // The first row of next partition. - var firstRowInNextPartition: InternalRow = _ - // Indicates if this partition is the last one in the iter. - var lastPartition: Boolean = false - - def createBoundaryEvaluator(): () => Unit = { - def findPhysicalBoundary( - boundary: FrameBoundary): () => Int = boundary match { - case UnboundedPreceding => () => 0 - case UnboundedFollowing => () => partitionSize - 1 - case CurrentRow => () => rowPosition - case ValuePreceding(value) => - () => - val newPosition = rowPosition - value - if (newPosition > 0) newPosition else 0 - case ValueFollowing(value) => - () => - val newPosition = rowPosition + value - if (newPosition < partitionSize) newPosition else partitionSize - 1 + // Get all relevant projections. + val result = createResultProjection(unboundExpressions) + val grouping = newProjection(windowSpec.partitionSpec, child.output) + + // Manage the stream and the grouping. + var nextRow: InternalRow = EmptyRow + var nextGroup: InternalRow = EmptyRow + var nextRowAvailable: Boolean = false + private[this] def fetchNextRow() { + nextRowAvailable = stream.hasNext + if (nextRowAvailable) { + nextRow = stream.next() + nextGroup = grouping(nextRow) + } else { + nextRow = EmptyRow + nextGroup = EmptyRow } - - def findLogicalBoundary( - boundary: FrameBoundary, - searchDirection: Int, - evaluator: Expression, - joinedRow: JoinedRow): () => Int = boundary match { - case UnboundedPreceding => () => 0 - case UnboundedFollowing => () => partitionSize - 1 - case other => - () => { - // CurrentRow, ValuePreceding, or ValueFollowing. - var newPosition = rowPosition + searchDirection - var stopSearch = false - // rowOrderGenerator is a mutable projection. - // We need to make a copy of the returned by rowOrderGenerator since we will - // compare searched row with this currentOrderByValue. - val currentOrderByValue = rowOrderGenerator(inputRowBuffer(rowPosition)).copy() - while (newPosition >= 0 && newPosition < partitionSize && !stopSearch) { - val r = rowOrderGenerator(inputRowBuffer(newPosition)) - stopSearch = - !(evaluator.eval(joinedRow(currentOrderByValue, r)).asInstanceOf[Boolean]) - if (!stopSearch) { - newPosition += searchDirection - } - } - newPosition -= searchDirection - - if (newPosition < 0) { - 0 - } else if (newPosition >= partitionSize) { - partitionSize - 1 - } else { - newPosition - } - } + } + fetchNextRow() + + // Manage the current partition. + var rows: CompactBuffer[InternalRow] = _ + val frames: Array[WindowFunctionFrame] = factories.map(_()) + val numFrames = frames.length + private[this] def fetchNextPartition() { + // Collect all the rows in the current partition. + val currentGroup = nextGroup + rows = new CompactBuffer + while (nextRowAvailable && nextGroup == currentGroup) { + rows += nextRow.copy() + fetchNextRow() } - windowFrame.frameType match { - case RowFrame => - val findStart = findPhysicalBoundary(windowFrame.frameStart) - val findEnd = findPhysicalBoundary(windowFrame.frameEnd) - () => { - frameStart = findStart() - frameEnd = findEnd() - } - case RangeFrame => - val joinedRowForBoundaryEvaluation: JoinedRow = new JoinedRow() - val orderByExpr = windowSpec.orderSpec.head - val currentRowExpr = - BoundReference(0, orderByExpr.dataType, orderByExpr.nullable) - val examedRowExpr = - BoundReference(1, orderByExpr.dataType, orderByExpr.nullable) - val differenceExpr = Abs(Subtract(currentRowExpr, examedRowExpr)) - - val frameStartEvaluator = windowFrame.frameStart match { - case CurrentRow => EqualTo(currentRowExpr, examedRowExpr) - case ValuePreceding(value) => - LessThanOrEqual(differenceExpr, Cast(Literal(value), orderByExpr.dataType)) - case ValueFollowing(value) => - GreaterThanOrEqual(differenceExpr, Cast(Literal(value), orderByExpr.dataType)) - case o => Literal(true) // This is just a dummy expression, we will not use it. - } - - val frameEndEvaluator = windowFrame.frameEnd match { - case CurrentRow => EqualTo(currentRowExpr, examedRowExpr) - case ValuePreceding(value) => - GreaterThanOrEqual(differenceExpr, Cast(Literal(value), orderByExpr.dataType)) - case ValueFollowing(value) => - LessThanOrEqual(differenceExpr, Cast(Literal(value), orderByExpr.dataType)) - case o => Literal(true) // This is just a dummy expression, we will not use it. - } - - val findStart = - findLogicalBoundary( - boundary = windowFrame.frameStart, - searchDirection = -1, - evaluator = frameStartEvaluator, - joinedRow = joinedRowForBoundaryEvaluation) - val findEnd = - findLogicalBoundary( - boundary = windowFrame.frameEnd, - searchDirection = 1, - evaluator = frameEndEvaluator, - joinedRow = joinedRowForBoundaryEvaluation) - () => { - frameStart = findStart() - frameEnd = findEnd() - } + // Setup the frames. + var i = 0 + while (i < numFrames) { + frames(i).prepare(rows) + i += 1 } + + // Setup iteration + rowIndex = 0 + rowsSize = rows.size } - val boundaryEvaluator = createBoundaryEvaluator() - // Indicates if we the specified window frame requires us to maintain a sliding frame - // (e.g. RANGES BETWEEN 1 PRECEDING AND CURRENT ROW) or the window frame - // is the entire partition (e.g. ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING). - val requireUpdateFrame: Boolean = { - def requireUpdateBoundary(boundary: FrameBoundary): Boolean = boundary match { - case UnboundedPreceding => false - case UnboundedFollowing => false - case _ => true - } + // Iteration + var rowIndex = 0 + var rowsSize = 0 + override final def hasNext: Boolean = rowIndex < rowsSize || nextRowAvailable - requireUpdateBoundary(windowFrame.frameStart) || - requireUpdateBoundary(windowFrame.frameEnd) - } - // The start position of the current frame in the partition. - var frameStart: Int = 0 - // The end position of the current frame in the partition. - var frameEnd: Int = -1 - // Window functions. - val functions: Array[WindowFunction] = windowFunctions() - // Buffers used to store input parameters for window functions. Because we may need to - // maintain a sliding frame, we use this buffer to avoid evaluate the parameters from - // the same row multiple times. - val windowFunctionParameterBuffers: Array[util.LinkedList[AnyRef]] = - functions.map(_ => new util.LinkedList[AnyRef]()) - - // The projection used to generate the final result rows of this operator. - private[this] val resultProjection = - newMutableProjection( - projectList ++ windowExpressionResult, - projectList ++ computedSchema)() - - // The row used to hold results of window functions. - private[this] val windowExpressionResultRow = - new GenericMutableRow(computedSchema.length) - - private[this] val joinedRow = new JoinedRow6 - - // Initialize this iterator. - initialize() - - private def initialize(): Unit = { - if (iter.hasNext) { - val currentRow = iter.next().copy() - // partitionGenerator is a mutable projection. Since we need to track nextPartitionKey, - // we are making a copy of the returned partitionKey at here. - nextPartitionKey = partitionGenerator(currentRow).copy() - firstRowInNextPartition = currentRow + val join = new JoinedRow6 + val windowFunctionResult = new GenericMutableRow(unboundExpressions.size) + override final def next(): InternalRow = { + // Load the next partition if we need to. + if (rowIndex >= rowsSize && nextRowAvailable) { fetchNextPartition() - } else { - // The iter is an empty one. So, we set all of the following variables - // to make sure hasNext will return false. - lastPartition = true - rowPosition = 0 - partitionSize = 0 } - } - - // Indicates if we will have new output row. - override final def hasNext: Boolean = { - !lastPartition || (rowPosition < partitionSize) - } - override final def next(): InternalRow = { - if (hasNext) { - if (rowPosition == partitionSize) { - // All rows of this buffer have been consumed. - // We will move to next partition. - fetchNextPartition() - } - // Get the input row for the current output row. - val inputRow = inputRowBuffer(rowPosition) - // Get all results of the window functions for this output row. + if (rowIndex < rowsSize) { + // Get the results for the window frames. var i = 0 - while (i < functions.length) { - windowExpressionResultRow.update(i, functions(i).get(rowPosition)) + while (i < numFrames) { + frames(i).write(windowFunctionResult) i += 1 } - // Construct the output row. - val outputRow = resultProjection(joinedRow(inputRow, windowExpressionResultRow)) - // We will move to the next one. - rowPosition += 1 - if (requireUpdateFrame && rowPosition < partitionSize) { - // If we need to maintain a sliding frame and - // we will still work on this partition when next is called next time, do the update. - updateFrame() - } + // 'Merge' the input row with the window function result + join(rows(rowIndex), windowFunctionResult) + rowIndex += 1 - // Return the output row. - outputRow - } else { - // no more result - throw new NoSuchElementException - } + // Return the projection. + result(join) + } else throw new NoSuchElementException } + } + } + } +} - // Fetch the next partition. - private def fetchNextPartition(): Unit = { - // Create a new buffer for input rows. - inputRowBuffer = new CompactBuffer[InternalRow]() - // We already have the first row for this partition - // (recorded in firstRowInNextPartition). Add it back. - inputRowBuffer += firstRowInNextPartition - // Set the current partition key. - currentPartitionKey = nextPartitionKey - // Now, we will start to find all rows belonging to this partition. - // Create a variable to track if we see the next partition. - var findNextPartition = false - // The search will stop when we see the next partition or there is no - // input row left in the iter. - while (iter.hasNext && !findNextPartition) { - // Make a copy of the input row since we will put it in the buffer. - val currentRow = iter.next().copy() - // Get the partition key based on the partition specification. - // For the below compare method, we do not need to make a copy of partitionKey. - val partitionKey = partitionGenerator(currentRow) - // Check if the current row belongs the current input row. - val comparing = partitionOrdering.compare(currentPartitionKey, partitionKey) - if (comparing == 0) { - // This row is still in the current partition. - inputRowBuffer += currentRow - } else { - // The current input row is in a different partition. - findNextPartition = true - // partitionGenerator is a mutable projection. - // Since we need to track nextPartitionKey and we determine that it should be set - // as partitionKey, we are making a copy of the partitionKey at here. - nextPartitionKey = partitionKey.copy() - firstRowInNextPartition = currentRow - } - } +/** + * Function for comparing boundary values. + */ +private[execution] abstract class BoundOrdering { + def compare(input: Seq[InternalRow], inputIndex: Int, outputIndex: Int): Int +} - // We have not seen a new partition. It means that there is no new row in the - // iter. The current partition is the last partition of the iter. - if (!findNextPartition) { - lastPartition = true - } +/** + * Compare the input index to the bound of the output index. + */ +private[execution] final case class RowBoundOrdering(offset: Int) extends BoundOrdering { + override def compare(input: Seq[InternalRow], inputIndex: Int, outputIndex: Int): Int = + inputIndex - (outputIndex + offset) +} - // We have got all rows for the current partition. - // Set rowPosition to 0 (the next output row will be based on the first - // input row of this partition). - rowPosition = 0 - // The size of this partition. - partitionSize = inputRowBuffer.size - // Reset all parameter buffers of window functions. - var i = 0 - while (i < windowFunctionParameterBuffers.length) { - windowFunctionParameterBuffers(i).clear() - i += 1 - } - frameStart = 0 - frameEnd = -1 - // Create the first window frame for this partition. - // If we do not need to maintain a sliding frame, this frame will - // have the entire partition. - updateFrame() - } +/** + * Compare the value of the input index to the value bound of the output index. + */ +private[execution] final case class RangeBoundOrdering( + ordering: Ordering[InternalRow], + current: Projection, + bound: Projection) extends BoundOrdering { + override def compare(input: Seq[InternalRow], inputIndex: Int, outputIndex: Int): Int = + ordering.compare(current(input(inputIndex)), bound(input(outputIndex))) +} - /** The function used to maintain the sliding frame. */ - private def updateFrame(): Unit = { - // Based on the difference between the new frame and old frame, - // updates the buffers holding input parameters of window functions. - // We will start to prepare input parameters starting from the row - // indicated by offset in the input row buffer. - def updateWindowFunctionParameterBuffers( - numToRemove: Int, - numToAdd: Int, - offset: Int): Unit = { - // First, remove unneeded entries from the head of every buffer. - var i = 0 - while (i < numToRemove) { - var j = 0 - while (j < windowFunctionParameterBuffers.length) { - windowFunctionParameterBuffers(j).remove() - j += 1 - } - i += 1 - } - // Then, add needed entries to the tail of every buffer. - i = 0 - while (i < numToAdd) { - var j = 0 - while (j < windowFunctionParameterBuffers.length) { - // Ask the function to prepare the input parameters. - val parameters = functions(j).prepareInputParameters(inputRowBuffer(i + offset)) - windowFunctionParameterBuffers(j).add(parameters) - j += 1 - } - i += 1 - } - } +/** + * A window function calculates the results of a number of window functions for a window frame. + * Before use a frame must be prepared by passing it all the rows in the current partition. After + * preparation the update method can be called to fill the output rows. + * + * TODO How to improve performance? A few thoughts: + * - Window functions are expensive due to its distribution and ordering requirements. + * Unfortunately it is up to the Spark engine to solve this. Improvements in the form of project + * Tungsten are on the way. + * - The window frame processing bit can be improved though. But before we start doing that we + * need to see how much of the time and resources are spent on partitioning and ordering, and + * how much time and resources are spent processing the partitions. There are a couple ways to + * improve on the current situation: + * - Reduce memory footprint by performing streaming calculations. This can only be done when + * there are no Unbound/Unbounded Following calculations present. + * - Use Tungsten style memory usage. + * - Use code generation in general, and use the approach to aggregation taken in the + * GeneratedAggregate class in specific. + * + * @param ordinal of the first column written by this frame. + * @param functions to calculate the row values with. + */ +private[execution] abstract class WindowFunctionFrame( + ordinal: Int, + functions: Array[WindowFunction]) { + + // Make sure functions are initialized. + functions.foreach(_.init()) + + /** Number of columns the window function frame is managing */ + val numColumns = functions.length + + /** + * Create a fresh thread safe copy of the frame. + * + * @return the copied frame. + */ + def copy: WindowFunctionFrame + + /** + * Create new instances of the functions. + * + * @return an array containing copies of the current window functions. + */ + protected final def copyFunctions: Array[WindowFunction] = functions.map(_.newInstance()) + + /** + * Prepare the frame for calculating the results for a partition. + * + * @param rows to calculate the frame results for. + */ + def prepare(rows: CompactBuffer[InternalRow]): Unit + + /** + * Write the result for the current row to the given target row. + * + * @param target row to write the result for the current row to. + */ + def write(target: GenericMutableRow): Unit + + /** Reset the current window functions. */ + protected final def reset(): Unit = { + var i = 0 + while (i < numColumns) { + functions(i).reset() + i += 1 + } + } - // Record the current frame start point and end point before - // we update them. - val previousFrameStart = frameStart - val previousFrameEnd = frameEnd - boundaryEvaluator() - updateWindowFunctionParameterBuffers( - frameStart - previousFrameStart, - frameEnd - previousFrameEnd, - previousFrameEnd + 1) - // Evaluate the current frame. - evaluateCurrentFrame() - } + /** Prepare an input row for processing. */ + protected final def prepare(input: InternalRow): Array[AnyRef] = { + val prepared = new Array[AnyRef](numColumns) + var i = 0 + while (i < numColumns) { + prepared(i) = functions(i).prepareInputParameters(input) + i += 1 + } + prepared + } - /** Evaluate the current window frame. */ - private def evaluateCurrentFrame(): Unit = { - var i = 0 - while (i < functions.length) { - // Reset the state of the window function. - functions(i).reset() - // Get all buffered input parameters based on rows of this window frame. - val inputParameters = windowFunctionParameterBuffers(i).toArray() - // Send these input parameters to the window function. - functions(i).batchUpdate(inputParameters) - // Ask the function to evaluate based on this window frame. - functions(i).evaluate() - i += 1 - } - } + /** Evaluate a prepared buffer (iterator). */ + protected final def evaluatePrepared(iterator: java.util.Iterator[Array[AnyRef]]): Unit = { + reset() + while (iterator.hasNext) { + val prepared = iterator.next() + var i = 0 + while (i < numColumns) { + functions(i).update(prepared(i)) + i += 1 } } + evaluate() } + + /** Evaluate a prepared buffer (array). */ + protected final def evaluatePrepared(prepared: Array[Array[AnyRef]], + fromIndex: Int, toIndex: Int): Unit = { + var i = 0 + while (i < numColumns) { + val function = functions(i) + function.reset() + var j = fromIndex + while (j < toIndex) { + function.update(prepared(j)(i)) + j += 1 + } + function.evaluate() + i += 1 + } + } + + /** Update an array of window functions. */ + protected final def update(input: InternalRow): Unit = { + var i = 0 + while (i < numColumns) { + val aggregate = functions(i) + val preparedInput = aggregate.prepareInputParameters(input) + aggregate.update(preparedInput) + i += 1 + } + } + + /** Evaluate the window functions. */ + protected final def evaluate(): Unit = { + var i = 0 + while (i < numColumns) { + functions(i).evaluate() + i += 1 + } + } + + /** Fill a target row with the current window function results. */ + protected final def fill(target: GenericMutableRow, rowIndex: Int): Unit = { + var i = 0 + while (i < numColumns) { + target.update(ordinal + i, functions(i).get(rowIndex)) + i += 1 + } + } +} + +/** + * The sliding window frame calculates frames with the following SQL form: + * ... BETWEEN 1 PRECEDING AND 1 FOLLOWING + * + * @param ordinal of the first column written by this frame. + * @param functions to calculate the row values with. + * @param lbound comparator used to identify the lower bound of an output row. + * @param ubound comparator used to identify the upper bound of an output row. + */ +private[execution] final class SlidingWindowFunctionFrame( + ordinal: Int, + functions: Array[WindowFunction], + lbound: BoundOrdering, + ubound: BoundOrdering) extends WindowFunctionFrame(ordinal, functions) { + + /** Rows of the partition currently being processed. */ + private[this] var input: CompactBuffer[InternalRow] = null + + /** Index of the first input row with a value greater than the upper bound of the current + * output row. */ + private[this] var inputHighIndex = 0 + + /** Index of the first input row with a value equal to or greater than the lower bound of the + * current output row. */ + private[this] var inputLowIndex = 0 + + /** Buffer used for storing prepared input for the window functions. */ + private[this] val buffer = new util.ArrayDeque[Array[AnyRef]] + + /** Index of the row we are currently writing. */ + private[this] var outputIndex = 0 + + /** Prepare the frame for calculating a new partition. Reset all variables. */ + override def prepare(rows: CompactBuffer[InternalRow]): Unit = { + input = rows + inputHighIndex = 0 + inputLowIndex = 0 + outputIndex = 0 + buffer.clear() + } + + /** Write the frame columns for the current row to the given target row. */ + override def write(target: GenericMutableRow): Unit = { + var bufferUpdated = outputIndex == 0 + + // Add all rows to the buffer for which the input row value is equal to or less than + // the output row upper bound. + while (inputHighIndex < input.size && + ubound.compare(input, inputHighIndex, outputIndex) <= 0) { + buffer.offer(prepare(input(inputHighIndex))) + inputHighIndex += 1 + bufferUpdated = true + } + + // Drop all rows from the buffer for which the input row value is smaller than + // the output row lower bound. + while (inputLowIndex < inputHighIndex && + lbound.compare(input, inputLowIndex, outputIndex) < 0) { + buffer.pop() + inputLowIndex += 1 + bufferUpdated = true + } + + // Only recalculate and update when the buffer changes. + if (bufferUpdated) { + evaluatePrepared(buffer.iterator()) + fill(target, outputIndex) + } + + // Move to the next row. + outputIndex += 1 + } + + /** Copy the frame. */ + override def copy: SlidingWindowFunctionFrame = + new SlidingWindowFunctionFrame(ordinal, copyFunctions, lbound, ubound) +} + +/** + * The unbounded window frame calculates frames with the following SQL forms: + * ... (No Frame Definition) + * ... BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING + * + * Its results are the same for each and every row in the partition. This class can be seen as a + * special case of a sliding window, but is optimized for the unbound case. + * + * @param ordinal of the first column written by this frame. + * @param functions to calculate the row values with. + */ +private[execution] final class UnboundedWindowFunctionFrame( + ordinal: Int, + functions: Array[WindowFunction]) extends WindowFunctionFrame(ordinal, functions) { + + /** Index of the row we are currently writing. */ + private[this] var outputIndex = 0 + + /** Prepare the frame for calculating a new partition. Process all rows eagerly. */ + override def prepare(rows: CompactBuffer[InternalRow]): Unit = { + reset() + outputIndex = 0 + val iterator = rows.iterator + while (iterator.hasNext) { + update(iterator.next()) + } + evaluate() + } + + /** Write the frame columns for the current row to the given target row. */ + override def write(target: GenericMutableRow): Unit = { + fill(target, outputIndex) + outputIndex += 1 + } + + /** Copy the frame. */ + override def copy: UnboundedWindowFunctionFrame = + new UnboundedWindowFunctionFrame(ordinal, copyFunctions) +} + +/** + * The UnboundPreceding window frame calculates frames with the following SQL form: + * ... BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW + * + * There is only an upper bound. Very common use cases are for instance running sums or counts + * (row_number). Technically this is a special case of a sliding window. However a sliding window + * has to maintain a buffer, and it must do a full evaluation everytime the buffer changes. This + * is not the case when there is no lower bound, given the additive nature of most aggregates + * streaming updates and partial evaluation suffice and no buffering is needed. + * + * @param ordinal of the first column written by this frame. + * @param functions to calculate the row values with. + * @param ubound comparator used to identify the upper bound of an output row. + */ +private[execution] final class UnboundedPrecedingWindowFunctionFrame( + ordinal: Int, + functions: Array[WindowFunction], + ubound: BoundOrdering) extends WindowFunctionFrame(ordinal, functions) { + + /** Rows of the partition currently being processed. */ + private[this] var input: CompactBuffer[InternalRow] = null + + /** Index of the first input row with a value greater than the upper bound of the current + * output row. */ + private[this] var inputIndex = 0 + + /** Index of the row we are currently writing. */ + private[this] var outputIndex = 0 + + /** Prepare the frame for calculating a new partition. */ + override def prepare(rows: CompactBuffer[InternalRow]): Unit = { + reset() + input = rows + inputIndex = 0 + outputIndex = 0 + } + + /** Write the frame columns for the current row to the given target row. */ + override def write(target: GenericMutableRow): Unit = { + var bufferUpdated = outputIndex == 0 + + // Add all rows to the aggregates for which the input row value is equal to or less than + // the output row upper bound. + while (inputIndex < input.size && ubound.compare(input, inputIndex, outputIndex) <= 0) { + update(input(inputIndex)) + inputIndex += 1 + bufferUpdated = true + } + + // Only recalculate and update when the buffer changes. + if (bufferUpdated) { + evaluate() + fill(target, outputIndex) + } + + // Move to the next row. + outputIndex += 1 + } + + /** Copy the frame. */ + override def copy: UnboundedPrecedingWindowFunctionFrame = + new UnboundedPrecedingWindowFunctionFrame(ordinal, copyFunctions, ubound) +} + +/** + * The UnboundFollowing window frame calculates frames with the following SQL form: + * ... BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING + * + * There is only an upper bound. This is a slightly modified version of the sliding window. The + * sliding window operator has to check if both upper and the lower bound change when a new row + * gets processed, where as the unbounded following only has to check the lower bound. + * + * This is a very expensive operator to use, O(n * (n - 1) /2), because we need to maintain a + * buffer and must do full recalculation after each row. Reverse iteration would be possible, if + * the communitativity of the used window functions can be guaranteed. + * + * @param ordinal of the first column written by this frame. + * @param functions to calculate the row values with. + * @param lbound comparator used to identify the lower bound of an output row. + */ +private[execution] final class UnboundedFollowingWindowFunctionFrame( + ordinal: Int, + functions: Array[WindowFunction], + lbound: BoundOrdering) extends WindowFunctionFrame(ordinal, functions) { + + /** Buffer used for storing prepared input for the window functions. */ + private[this] var buffer: Array[Array[AnyRef]] = _ + + /** Rows of the partition currently being processed. */ + private[this] var input: CompactBuffer[InternalRow] = null + + /** Index of the first input row with a value equal to or greater than the lower bound of the + * current output row. */ + private[this] var inputIndex = 0 + + /** Index of the row we are currently writing. */ + private[this] var outputIndex = 0 + + /** Prepare the frame for calculating a new partition. */ + override def prepare(rows: CompactBuffer[InternalRow]): Unit = { + input = rows + inputIndex = 0 + outputIndex = 0 + val size = input.size + buffer = Array.ofDim(size) + var i = 0 + while (i < size) { + buffer(i) = prepare(input(i)) + i += 1 + } + evaluatePrepared(buffer, 0, buffer.length) + } + + /** Write the frame columns for the current row to the given target row. */ + override def write(target: GenericMutableRow): Unit = { + var bufferUpdated = outputIndex == 0 + + // Drop all rows from the buffer for which the input row value is smaller than + // the output row lower bound. + while (inputIndex < input.size && lbound.compare(input, inputIndex, outputIndex) < 0) { + inputIndex += 1 + bufferUpdated = true + } + + // Only recalculate and update when the buffer changes. + if (bufferUpdated) { + evaluatePrepared(buffer, inputIndex, buffer.length) + fill(target, outputIndex) + } + + // Move to the next row. + outputIndex += 1 + } + + /** Copy the frame. */ + override def copy: UnboundedFollowingWindowFunctionFrame = + new UnboundedFollowingWindowFunctionFrame(ordinal, copyFunctions, lbound) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameWindowSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameWindowSuite.scala index efb3f2545db84..15b5f418f0a8c 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameWindowSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameWindowSuite.scala @@ -183,13 +183,13 @@ class HiveDataFrameWindowSuite extends QueryTest { } test("aggregation and range betweens with unbounded") { - val df = Seq((1, "1"), (2, "2"), (2, "2"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") + val df = Seq((5, "1"), (5, "2"), (4, "2"), (6, "2"), (3, "1"), (2, "2")).toDF("key", "value") df.registerTempTable("window_table") checkAnswer( df.select( $"key", last("value").over( - Window.partitionBy($"value").orderBy($"key").rangeBetween(1, Long.MaxValue)) + Window.partitionBy($"value").orderBy($"key").rangeBetween(-2, -1)) .equalTo("2") .as("last_v"), avg("key").over(Window.partitionBy("value").orderBy("key").rangeBetween(Long.MinValue, 1)) @@ -203,7 +203,7 @@ class HiveDataFrameWindowSuite extends QueryTest { """SELECT | key, | last_value(value) OVER - | (PARTITION BY value ORDER BY key RANGE 1 preceding) == "2", + | (PARTITION BY value ORDER BY key RANGE BETWEEN 2 preceding and 1 preceding) == "2", | avg(key) OVER | (PARTITION BY value ORDER BY key RANGE BETWEEN unbounded preceding and 1 following), | avg(key) OVER diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/WindowSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/WindowSuite.scala new file mode 100644 index 0000000000000..a089d0d165195 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/WindowSuite.scala @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.execution + +import org.apache.spark.sql.{Row, QueryTest} +import org.apache.spark.sql.expressions.Window +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.hive.test.TestHive.implicits._ + +/** + * Window expressions are tested extensively by the following test suites: + * [[org.apache.spark.sql.hive.HiveDataFrameWindowSuite]] + * [[org.apache.spark.sql.hive.execution.HiveWindowFunctionQueryWithoutCodeGenSuite]] + * [[org.apache.spark.sql.hive.execution.HiveWindowFunctionQueryFileWithoutCodeGenSuite]] + * However these suites do not cover all possible (i.e. more exotic) settings. This suite fill + * this gap. + * + * TODO Move this class to the sql/core project when we move to Native Spark UDAFs. + */ +class WindowSuite extends QueryTest { + + test("reverse sliding range frame") { + val df = Seq( + (1, "Thin", "Cell Phone", 6000), + (2, "Normal", "Tablet", 1500), + (3, "Mini", "Tablet", 5500), + (4, "Ultra thin", "Cell Phone", 5500), + (5, "Very thin", "Cell Phone", 6000), + (6, "Big", "Tablet", 2500), + (7, "Bendable", "Cell Phone", 3000), + (8, "Foldable", "Cell Phone", 3000), + (9, "Pro", "Tablet", 4500), + (10, "Pro2", "Tablet", 6500)). + toDF("id", "product", "category", "revenue") + val window = Window. + partitionBy($"category"). + orderBy($"revenue".desc). + rangeBetween(-2000L, 1000L) + checkAnswer( + df.select( + $"id", + avg($"revenue").over(window).cast("int")), + Row(1, 5833) :: Row(2, 2000) :: Row(3, 5500) :: + Row(4, 5833) :: Row(5, 5833) :: Row(6, 2833) :: + Row(7, 3000) :: Row(8, 3000) :: Row(9, 5500) :: + Row(10, 6000) :: Nil) + } + + // This is here to illustrate the fact that reverse order also reverses offsets. + test("reverse unbounded range frame") { + val df = Seq(1, 2, 4, 3, 2, 1). + map(Tuple1.apply). + toDF("value") + val window = Window.orderBy($"value".desc) + checkAnswer( + df.select( + $"value", + sum($"value").over(window.rangeBetween(Long.MinValue, 1)), + sum($"value").over(window.rangeBetween(1, Long.MaxValue))), + Row(1, 13, null) :: Row(2, 13, 2) :: Row(4, 7, 9) :: + Row(3, 11, 6) :: Row(2, 13, 2) :: Row(1, 13, null) :: Nil) + + } +} From 89d135851d928f9d7dcebe785c1b3b6a4d8dfc87 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Sat, 18 Jul 2015 23:47:40 -0700 Subject: [PATCH 49/58] Closes #6775 since it is subsumbed by other patches. From 9b644c41306cac53185ce0d2de4cb72127ada932 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sun, 19 Jul 2015 00:32:56 -0700 Subject: [PATCH 50/58] [SPARK-9166][SQL][PYSPARK] Capture and hide IllegalArgumentException in Python API JIRA: https://issues.apache.org/jira/browse/SPARK-9166 Simply capture and hide `IllegalArgumentException` in Python API. Author: Liang-Chi Hsieh Closes #7497 from viirya/hide_illegalargument and squashes the following commits: 8324dce [Liang-Chi Hsieh] Fix python style. 9ace67d [Liang-Chi Hsieh] Also check exception message. 8b2ce5c [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into hide_illegalargument 7be016a [Liang-Chi Hsieh] Capture and hide IllegalArgumentException in Python. --- python/pyspark/sql/tests.py | 11 +++++++++-- python/pyspark/sql/utils.py | 8 ++++++++ 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 241eac45cfe36..86706e2dc41a3 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -45,9 +45,9 @@ from pyspark.sql.types import * from pyspark.sql.types import UserDefinedType, _infer_type from pyspark.tests import ReusedPySparkTestCase -from pyspark.sql.functions import UserDefinedFunction +from pyspark.sql.functions import UserDefinedFunction, sha2 from pyspark.sql.window import Window -from pyspark.sql.utils import AnalysisException +from pyspark.sql.utils import AnalysisException, IllegalArgumentException class UTC(datetime.tzinfo): @@ -894,6 +894,13 @@ def test_capture_analysis_exception(self): # RuntimeException should not be captured self.assertRaises(py4j.protocol.Py4JJavaError, lambda: self.sqlCtx.sql("abc")) + def test_capture_illegalargument_exception(self): + self.assertRaisesRegexp(IllegalArgumentException, "Setting negative mapred.reduce.tasks", + lambda: self.sqlCtx.sql("SET mapred.reduce.tasks=-1")) + df = self.sqlCtx.createDataFrame([(1, 2)], ["a", "b"]) + self.assertRaisesRegexp(IllegalArgumentException, "1024 is not in the permitted values", + lambda: df.select(sha2(df.a, 1024)).collect()) + class HiveContextSQLTests(ReusedPySparkTestCase): diff --git a/python/pyspark/sql/utils.py b/python/pyspark/sql/utils.py index cc5b2c088b7cc..0f795ca35b38a 100644 --- a/python/pyspark/sql/utils.py +++ b/python/pyspark/sql/utils.py @@ -24,6 +24,12 @@ class AnalysisException(Exception): """ +class IllegalArgumentException(Exception): + """ + Passed an illegal or inappropriate argument. + """ + + def capture_sql_exception(f): def deco(*a, **kw): try: @@ -32,6 +38,8 @@ def deco(*a, **kw): s = e.java_exception.toString() if s.startswith('org.apache.spark.sql.AnalysisException: '): raise AnalysisException(s.split(': ', 1)[1]) + if s.startswith('java.lang.IllegalArgumentException: '): + raise IllegalArgumentException(s.split(': ', 1)[1]) raise return deco From 344d1567e5ac28b3ab8f83f18d2fa9d98acef152 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carl=20Anders=20D=C3=BCvel?= Date: Sun, 19 Jul 2015 09:14:55 +0100 Subject: [PATCH 51/58] [SPARK-9094] [PARENT] Increased io.dropwizard.metrics from 3.1.0 to 3.1.2 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit We are running Spark 1.4.0 in production and ran into problems because after a network hiccup (which happens often in our current environment) no more metrics were reported to graphite leaving us blindfolded about the current state of our spark applications. [This problem](https://github.com/dropwizard/metrics/commit/70559816f1fc3a0a0122b5263d5478ff07396991) was fixed in the current version of the metrics library. We run spark with this change in production now and have seen no problems. We also had a look at the commit history since 3.1.0 and did not detect any potentially incompatible changes but many fixes which could potentially help other users as well. Author: Carl Anders Düvel Closes #7493 from hackbert/bump-metrics-lib-version and squashes the following commits: 6677565 [Carl Anders Düvel] [SPARK-9094] [PARENT] Increased io.dropwizard.metrics from 3.1.0 to 3.1.2 in order to get this fix https://github.com/dropwizard/metrics/commit/70559816f1fc3a0a0122b5263d5478ff07396991 --- pom.xml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pom.xml b/pom.xml index c5c655834bdeb..2de0c35fbd51a 100644 --- a/pom.xml +++ b/pom.xml @@ -144,7 +144,7 @@ 0.5.0 2.4.0 2.0.8 - 3.1.0 + 3.1.2 1.7.7 hadoop2 0.7.1 From a53d13f7aa5d44c706e5510f57399a32c7558b80 Mon Sep 17 00:00:00 2001 From: Tarek Auel Date: Sun, 19 Jul 2015 01:16:01 -0700 Subject: [PATCH 52/58] [SPARK-8199][SQL] follow up; revert change in test rxin / davies Sorry for that unnecessary change. And thanks again for all your support! Author: Tarek Auel Closes #7505 from tarekauel/SPARK-8199-FollowUp and squashes the following commits: d09321c [Tarek Auel] [SPARK-8199] follow up; revert change in test c17397f [Tarek Auel] [SPARK-8199] follow up; revert change in test 67acfe6 [Tarek Auel] [SPARK-8199] follow up; revert change in test --- .../spark/sql/catalyst/expressions/DateFunctionsSuite.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateFunctionsSuite.scala index f469f42116d21..a0991ec998311 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateFunctionsSuite.scala @@ -74,7 +74,7 @@ class DateFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { (0 to 5).foreach { i => val c = Calendar.getInstance() c.set(y, m, 28, 0, 0, 0) - c.add(Calendar.DATE, 1) + c.add(Calendar.DATE, i) checkEvaluation(DayInYear(Cast(Literal(new Date(c.getTimeInMillis)), DateType)), sdfDay.format(c.getTime).toInt) } @@ -86,7 +86,7 @@ class DateFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { (0 to 5).foreach { i => val c = Calendar.getInstance() c.set(y, m, 28, 0, 0, 0) - c.add(Calendar.DATE, 1) + c.add(Calendar.DATE, i) checkEvaluation(DayInYear(Cast(Literal(new Date(c.getTimeInMillis)), DateType)), sdfDay.format(c.getTime).toInt) } From 3427937ea2a4ed19142bd3d66707864879417d61 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Sun, 19 Jul 2015 01:17:22 -0700 Subject: [PATCH 53/58] [SQL] Make date/time functions more consistent with other database systems. This pull request fixes some of the problems in #6981. - Added date functions to `__all__` so they get exposed - Rename day_of_month -> dayofmonth - Rename day_in_year -> dayofyear - Rename week_of_year -> weekofyear - Removed "day" from Scala/Python API since it is ambiguous. Only leaving the alias in SQL. Author: Reynold Xin This patch had conflicts when merged, resolved by Committer: Reynold Xin Closes #7506 from rxin/datetime and squashes the following commits: 0cb24d9 [Reynold Xin] Export all functions in Python. e44a4a0 [Reynold Xin] Removed day function from Scala and Python. 9c08fdc [Reynold Xin] [SQL] Make date/time functions more consistent with other database systems. --- python/pyspark/sql/functions.py | 35 +- .../catalyst/analysis/FunctionRegistry.scala | 8 +- .../expressions/datetimeFunctions.scala | 13 +- .../sql/catalyst/util/DateTimeUtils.scala | 4 +- ...Suite.scala => DateExpressionsSuite.scala} | 26 +- .../org/apache/spark/sql/functions.scala | 338 +++++++++--------- .../apache/spark/sql/DataFrameDateSuite.scala | 56 --- ...nsSuite.scala => DateFunctionsSuite.scala} | 61 ++-- 8 files changed, 239 insertions(+), 302 deletions(-) rename sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/{DateFunctionsSuite.scala => DateExpressionsSuite.scala} (91%) delete mode 100644 sql/core/src/test/scala/org/apache/spark/sql/DataFrameDateSuite.scala rename sql/core/src/test/scala/org/apache/spark/sql/{DateExpressionsSuite.scala => DateFunctionsSuite.scala} (74%) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 0aca3788922aa..fd5a3ba8adab3 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -55,6 +55,11 @@ __all__ += ['lag', 'lead', 'ntile'] +__all__ += [ + 'date_format', + 'year', 'quarter', 'month', 'hour', 'minute', 'second', + 'dayofmonth', 'dayofyear', 'weekofyear'] + def _create_function(name, doc=""): """ Create a function for aggregator by name""" @@ -713,41 +718,29 @@ def month(col): @since(1.5) -def day(col): - """ - Extract the day of the month of a given date as integer. - - >>> sqlContext.createDataFrame([('2015-04-08',)], ['a']).select(day('a').alias('day')).collect() - [Row(day=8)] - """ - sc = SparkContext._active_spark_context - return Column(sc._jvm.functions.day(col)) - - -@since(1.5) -def day_of_month(col): +def dayofmonth(col): """ Extract the day of the month of a given date as integer. >>> df = sqlContext.createDataFrame([('2015-04-08',)], ['a']) - >>> df.select(day_of_month('a').alias('day')).collect() + >>> df.select(dayofmonth('a').alias('day')).collect() [Row(day=8)] """ sc = SparkContext._active_spark_context - return Column(sc._jvm.functions.day_of_month(col)) + return Column(sc._jvm.functions.dayofmonth(col)) @since(1.5) -def day_in_year(col): +def dayofyear(col): """ Extract the day of the year of a given date as integer. >>> df = sqlContext.createDataFrame([('2015-04-08',)], ['a']) - >>> df.select(day_in_year('a').alias('day')).collect() + >>> df.select(dayofyear('a').alias('day')).collect() [Row(day=98)] """ sc = SparkContext._active_spark_context - return Column(sc._jvm.functions.day_in_year(col)) + return Column(sc._jvm.functions.dayofyear(col)) @since(1.5) @@ -790,16 +783,16 @@ def second(col): @since(1.5) -def week_of_year(col): +def weekofyear(col): """ Extract the week number of a given date as integer. >>> df = sqlContext.createDataFrame([('2015-04-08',)], ['a']) - >>> df.select(week_of_year('a').alias('week')).collect() + >>> df.select(weekofyear('a').alias('week')).collect() [Row(week=15)] """ sc = SparkContext._active_spark_context - return Column(sc._jvm.functions.week_of_year(col)) + return Column(sc._jvm.functions.weekofyear(col)) class UserDefinedFunction(object): diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 159f7eca7acfe..4b256adcc60c6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -183,15 +183,15 @@ object FunctionRegistry { expression[CurrentDate]("current_date"), expression[CurrentTimestamp]("current_timestamp"), expression[DateFormatClass]("date_format"), - expression[Day]("day"), - expression[DayInYear]("day_in_year"), - expression[Day]("day_of_month"), + expression[DayOfMonth]("day"), + expression[DayOfYear]("dayofyear"), + expression[DayOfMonth]("dayofmonth"), expression[Hour]("hour"), expression[Month]("month"), expression[Minute]("minute"), expression[Quarter]("quarter"), expression[Second]("second"), - expression[WeekOfYear]("week_of_year"), + expression[WeekOfYear]("weekofyear"), expression[Year]("year") ) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeFunctions.scala index f9cbbb8c6bee0..802445509285d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeFunctions.scala @@ -116,14 +116,12 @@ case class Second(child: Expression) extends UnaryExpression with ImplicitCastIn } } -case class DayInYear(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { +case class DayOfYear(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { override def inputTypes: Seq[AbstractDataType] = Seq(DateType) override def dataType: DataType = IntegerType - override def prettyName: String = "day_in_year" - override protected def nullSafeEval(date: Any): Any = { DateTimeUtils.getDayInYear(date.asInstanceOf[Int]) } @@ -149,7 +147,7 @@ case class Year(child: Expression) extends UnaryExpression with ImplicitCastInpu override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") - defineCodeGen(ctx, ev, (c) => + defineCodeGen(ctx, ev, c => s"""$dtu.getYear($c)""" ) } @@ -191,7 +189,7 @@ case class Month(child: Expression) extends UnaryExpression with ImplicitCastInp } } -case class Day(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { +case class DayOfMonth(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { override def inputTypes: Seq[AbstractDataType] = Seq(DateType) @@ -215,8 +213,6 @@ case class WeekOfYear(child: Expression) extends UnaryExpression with ImplicitCa override def dataType: DataType = IntegerType - override def prettyName: String = "week_of_year" - override protected def nullSafeEval(date: Any): Any = { val c = Calendar.getInstance(TimeZone.getTimeZone("UTC")) c.setFirstDayOfWeek(Calendar.MONDAY) @@ -225,7 +221,7 @@ case class WeekOfYear(child: Expression) extends UnaryExpression with ImplicitCa c.get(Calendar.WEEK_OF_YEAR) } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { nullSafeCodeGen(ctx, ev, (time) => { val cal = classOf[Calendar].getName val c = ctx.freshName("cal") @@ -237,6 +233,7 @@ case class WeekOfYear(child: Expression) extends UnaryExpression with ImplicitCa ${ev.primitive} = $c.get($cal.WEEK_OF_YEAR); """ }) + } } case class DateFormatClass(left: Expression, right: Expression) extends BinaryExpression diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala index a0da73a995a82..07412e73b6a5b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala @@ -31,14 +31,14 @@ import org.apache.spark.unsafe.types.UTF8String * precision. */ object DateTimeUtils { - final val MILLIS_PER_DAY = SECONDS_PER_DAY * 1000L - // see http://stackoverflow.com/questions/466321/convert-unix-timestamp-to-julian final val JULIAN_DAY_OF_EPOCH = 2440587 // and .5 final val SECONDS_PER_DAY = 60 * 60 * 24L final val MICROS_PER_SECOND = 1000L * 1000L final val NANOS_PER_SECOND = MICROS_PER_SECOND * 1000L + final val MILLIS_PER_DAY = SECONDS_PER_DAY * 1000L + // number of days in 400 years final val daysIn400Years: Int = 146097 // number of days between 1.1.1970 and 1.1.2001 diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala similarity index 91% rename from sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateFunctionsSuite.scala rename to sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala index a0991ec998311..f01589c58ea86 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala @@ -19,19 +19,19 @@ package org.apache.spark.sql.catalyst.expressions import java.sql.{Timestamp, Date} import java.text.SimpleDateFormat -import java.util.{TimeZone, Calendar} +import java.util.Calendar import org.apache.spark.SparkFunSuite import org.apache.spark.sql.types.{StringType, TimestampType, DateType} -class DateFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { +class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { val sdf = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss") val sdfDate = new SimpleDateFormat("yyyy-MM-dd") val d = new Date(sdf.parse("2015-04-08 13:10:15").getTime) val ts = new Timestamp(sdf.parse("2013-11-08 13:10:15").getTime) - test("Day in Year") { + test("DayOfYear") { val sdfDay = new SimpleDateFormat("D") (2002 to 2004).foreach { y => (0 to 11).foreach { m => @@ -39,7 +39,7 @@ class DateFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { val c = Calendar.getInstance() c.set(y, m, 28, 0, 0, 0) c.add(Calendar.DATE, i) - checkEvaluation(DayInYear(Cast(Literal(new Date(c.getTimeInMillis)), DateType)), + checkEvaluation(DayOfYear(Cast(Literal(new Date(c.getTimeInMillis)), DateType)), sdfDay.format(c.getTime).toInt) } } @@ -51,7 +51,7 @@ class DateFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { val c = Calendar.getInstance() c.set(y, m, 28, 0, 0, 0) c.add(Calendar.DATE, i) - checkEvaluation(DayInYear(Cast(Literal(new Date(c.getTimeInMillis)), DateType)), + checkEvaluation(DayOfYear(Cast(Literal(new Date(c.getTimeInMillis)), DateType)), sdfDay.format(c.getTime).toInt) } } @@ -63,7 +63,7 @@ class DateFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { val c = Calendar.getInstance() c.set(y, m, 28, 0, 0, 0) c.add(Calendar.DATE, i) - checkEvaluation(DayInYear(Cast(Literal(new Date(c.getTimeInMillis)), DateType)), + checkEvaluation(DayOfYear(Cast(Literal(new Date(c.getTimeInMillis)), DateType)), sdfDay.format(c.getTime).toInt) } } @@ -163,19 +163,19 @@ class DateFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { } } - test("Day") { - checkEvaluation(Day(Cast(Literal("2000-02-29"), DateType)), 29) - checkEvaluation(Day(Literal.create(null, DateType)), null) - checkEvaluation(Day(Cast(Literal(d), DateType)), 8) - checkEvaluation(Day(Cast(Literal(sdfDate.format(d)), DateType)), 8) - checkEvaluation(Day(Cast(Literal(ts), DateType)), 8) + test("Day / DayOfMonth") { + checkEvaluation(DayOfMonth(Cast(Literal("2000-02-29"), DateType)), 29) + checkEvaluation(DayOfMonth(Literal.create(null, DateType)), null) + checkEvaluation(DayOfMonth(Cast(Literal(d), DateType)), 8) + checkEvaluation(DayOfMonth(Cast(Literal(sdfDate.format(d)), DateType)), 8) + checkEvaluation(DayOfMonth(Cast(Literal(ts), DateType)), 8) (1999 to 2000).foreach { y => val c = Calendar.getInstance() c.set(y, 0, 1, 0, 0, 0) (0 to 365).foreach { d => c.add(Calendar.DATE, 1) - checkEvaluation(Day(Cast(Literal(new Date(c.getTimeInMillis)), DateType)), + checkEvaluation(DayOfMonth(Cast(Literal(new Date(c.getTimeInMillis)), DateType)), c.get(Calendar.DAY_OF_MONTH)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index cadb25d597d19..f67c89437bb4a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1748,182 +1748,6 @@ object functions { */ def length(columnName: String): Column = length(Column(columnName)) - ////////////////////////////////////////////////////////////////////////////////////////////// - // DateTime functions - ////////////////////////////////////////////////////////////////////////////////////////////// - - /** - * Converts a date/timestamp/string to a value of string in the format specified by the date - * format given by the second argument. - * - * A pattern could be for instance `dd.MM.yyyy` and could return a string like '18.03.1993'. All - * pattern letters of [[java.text.SimpleDateFormat]] can be used. - * - * NOTE: Use when ever possible specialized functions like [[year]]. These benefit from a - * specialized implementation. - * - * @group datetime_funcs - * @since 1.5.0 - */ - def date_format(dateExpr: Column, format: String): Column = - DateFormatClass(dateExpr.expr, Literal(format)) - - /** - * Converts a date/timestamp/string to a value of string in the format specified by the date - * format given by the second argument. - * - * A pattern could be for instance `dd.MM.yyyy` and could return a string like '18.03.1993'. All - * pattern letters of [[java.text.SimpleDateFormat]] can be used. - * - * NOTE: Use when ever possible specialized functions like [[year]]. These benefit from a - * specialized implementation. - * - * @group datetime_funcs - * @since 1.5.0 - */ - def date_format(dateColumnName: String, format: String): Column = - date_format(Column(dateColumnName), format) - - /** - * Extracts the year as an integer from a given date/timestamp/string. - * @group datetime_funcs - * @since 1.5.0 - */ - def year(e: Column): Column = Year(e.expr) - - /** - * Extracts the year as an integer from a given date/timestamp/string. - * @group datetime_funcs - * @since 1.5.0 - */ - def year(columnName: String): Column = year(Column(columnName)) - - /** - * Extracts the quarter as an integer from a given date/timestamp/string. - * @group datetime_funcs - * @since 1.5.0 - */ - def quarter(e: Column): Column = Quarter(e.expr) - - /** - * Extracts the quarter as an integer from a given date/timestamp/string. - * @group datetime_funcs - * @since 1.5.0 - */ - def quarter(columnName: String): Column = quarter(Column(columnName)) - - /** - * Extracts the month as an integer from a given date/timestamp/string. - * @group datetime_funcs - * @since 1.5.0 - */ - def month(e: Column): Column = Month(e.expr) - - /** - * Extracts the month as an integer from a given date/timestamp/string. - * @group datetime_funcs - * @since 1.5.0 - */ - def month(columnName: String): Column = month(Column(columnName)) - - /** - * Extracts the day of the month as an integer from a given date/timestamp/string. - * @group datetime_funcs - * @since 1.5.0 - */ - def day(e: Column): Column = Day(e.expr) - - /** - * Extracts the day of the month as an integer from a given date/timestamp/string. - * @group datetime_funcs - * @since 1.5.0 - */ - def day(columnName: String): Column = day(Column(columnName)) - - /** - * Extracts the day of the month as an integer from a given date/timestamp/string. - * @group datetime_funcs - * @since 1.5.0 - */ - def day_of_month(e: Column): Column = Day(e.expr) - - /** - * Extracts the day of the month as an integer from a given date/timestamp/string. - * @group datetime_funcs - * @since 1.5.0 - */ - def day_of_month(columnName: String): Column = day_of_month(Column(columnName)) - - /** - * Extracts the day of the year as an integer from a given date/timestamp/string. - * @group datetime_funcs - * @since 1.5.0 - */ - def day_in_year(e: Column): Column = DayInYear(e.expr) - - /** - * Extracts the day of the year as an integer from a given date/timestamp/string. - * @group datetime_funcs - * @since 1.5.0 - */ - def day_in_year(columnName: String): Column = day_in_year(Column(columnName)) - - /** - * Extracts the hours as an integer from a given date/timestamp/string. - * @group datetime_funcs - * @since 1.5.0 - */ - def hour(e: Column): Column = Hour(e.expr) - - /** - * Extracts the hours as an integer from a given date/timestamp/string. - * @group datetime_funcs - * @since 1.5.0 - */ - def hour(columnName: String): Column = hour(Column(columnName)) - - /** - * Extracts the minutes as an integer from a given date/timestamp/string. - * @group datetime_funcs - * @since 1.5.0 - */ - def minute(e: Column): Column = Minute(e.expr) - - /** - * Extracts the minutes as an integer from a given date/timestamp/string. - * @group datetime_funcs - * @since 1.5.0 - */ - def minute(columnName: String): Column = minute(Column(columnName)) - - /** - * Extracts the seconds as an integer from a given date/timestamp/string. - * @group datetime_funcs - * @since 1.5.0 - */ - def second(e: Column): Column = Second(e.expr) - - /** - * Extracts the seconds as an integer from a given date/timestamp/string. - * @group datetime_funcs - * @since 1.5.0 - */ - def second(columnName: String): Column = second(Column(columnName)) - - /** - * Extracts the week number as an integer from a given date/timestamp/string. - * @group datetime_funcs - * @since 1.5.0 - */ - def week_of_year(e: Column): Column = WeekOfYear(e.expr) - - /** - * Extracts the week number as an integer from a given date/timestamp/string. - * @group datetime_funcs - * @since 1.5.0 - */ - def week_of_year(columnName: String): Column = week_of_year(Column(columnName)) - /** * Formats the number X to a format like '#,###,###.##', rounded to d decimal places, * and returns the result as a string. @@ -2409,6 +2233,168 @@ object functions { StringSpace(n.expr) } + ////////////////////////////////////////////////////////////////////////////////////////////// + // DateTime functions + ////////////////////////////////////////////////////////////////////////////////////////////// + + /** + * Converts a date/timestamp/string to a value of string in the format specified by the date + * format given by the second argument. + * + * A pattern could be for instance `dd.MM.yyyy` and could return a string like '18.03.1993'. All + * pattern letters of [[java.text.SimpleDateFormat]] can be used. + * + * NOTE: Use when ever possible specialized functions like [[year]]. These benefit from a + * specialized implementation. + * + * @group datetime_funcs + * @since 1.5.0 + */ + def date_format(dateExpr: Column, format: String): Column = + DateFormatClass(dateExpr.expr, Literal(format)) + + /** + * Converts a date/timestamp/string to a value of string in the format specified by the date + * format given by the second argument. + * + * A pattern could be for instance `dd.MM.yyyy` and could return a string like '18.03.1993'. All + * pattern letters of [[java.text.SimpleDateFormat]] can be used. + * + * NOTE: Use when ever possible specialized functions like [[year]]. These benefit from a + * specialized implementation. + * + * @group datetime_funcs + * @since 1.5.0 + */ + def date_format(dateColumnName: String, format: String): Column = + date_format(Column(dateColumnName), format) + + /** + * Extracts the year as an integer from a given date/timestamp/string. + * @group datetime_funcs + * @since 1.5.0 + */ + def year(e: Column): Column = Year(e.expr) + + /** + * Extracts the year as an integer from a given date/timestamp/string. + * @group datetime_funcs + * @since 1.5.0 + */ + def year(columnName: String): Column = year(Column(columnName)) + + /** + * Extracts the quarter as an integer from a given date/timestamp/string. + * @group datetime_funcs + * @since 1.5.0 + */ + def quarter(e: Column): Column = Quarter(e.expr) + + /** + * Extracts the quarter as an integer from a given date/timestamp/string. + * @group datetime_funcs + * @since 1.5.0 + */ + def quarter(columnName: String): Column = quarter(Column(columnName)) + + /** + * Extracts the month as an integer from a given date/timestamp/string. + * @group datetime_funcs + * @since 1.5.0 + */ + def month(e: Column): Column = Month(e.expr) + + /** + * Extracts the month as an integer from a given date/timestamp/string. + * @group datetime_funcs + * @since 1.5.0 + */ + def month(columnName: String): Column = month(Column(columnName)) + + /** + * Extracts the day of the month as an integer from a given date/timestamp/string. + * @group datetime_funcs + * @since 1.5.0 + */ + def dayofmonth(e: Column): Column = DayOfMonth(e.expr) + + /** + * Extracts the day of the month as an integer from a given date/timestamp/string. + * @group datetime_funcs + * @since 1.5.0 + */ + def dayofmonth(columnName: String): Column = dayofmonth(Column(columnName)) + + /** + * Extracts the day of the year as an integer from a given date/timestamp/string. + * @group datetime_funcs + * @since 1.5.0 + */ + def dayofyear(e: Column): Column = DayOfYear(e.expr) + + /** + * Extracts the day of the year as an integer from a given date/timestamp/string. + * @group datetime_funcs + * @since 1.5.0 + */ + def dayofyear(columnName: String): Column = dayofyear(Column(columnName)) + + /** + * Extracts the hours as an integer from a given date/timestamp/string. + * @group datetime_funcs + * @since 1.5.0 + */ + def hour(e: Column): Column = Hour(e.expr) + + /** + * Extracts the hours as an integer from a given date/timestamp/string. + * @group datetime_funcs + * @since 1.5.0 + */ + def hour(columnName: String): Column = hour(Column(columnName)) + + /** + * Extracts the minutes as an integer from a given date/timestamp/string. + * @group datetime_funcs + * @since 1.5.0 + */ + def minute(e: Column): Column = Minute(e.expr) + + /** + * Extracts the minutes as an integer from a given date/timestamp/string. + * @group datetime_funcs + * @since 1.5.0 + */ + def minute(columnName: String): Column = minute(Column(columnName)) + + /** + * Extracts the seconds as an integer from a given date/timestamp/string. + * @group datetime_funcs + * @since 1.5.0 + */ + def second(e: Column): Column = Second(e.expr) + + /** + * Extracts the seconds as an integer from a given date/timestamp/string. + * @group datetime_funcs + * @since 1.5.0 + */ + def second(columnName: String): Column = second(Column(columnName)) + + /** + * Extracts the week number as an integer from a given date/timestamp/string. + * @group datetime_funcs + * @since 1.5.0 + */ + def weekofyear(e: Column): Column = WeekOfYear(e.expr) + + /** + * Extracts the week number as an integer from a given date/timestamp/string. + * @group datetime_funcs + * @since 1.5.0 + */ + def weekofyear(columnName: String): Column = weekofyear(Column(columnName)) + ////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameDateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameDateSuite.scala deleted file mode 100644 index a4719a38de1d4..0000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameDateSuite.scala +++ /dev/null @@ -1,56 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql - -import java.sql.{Date, Timestamp} - -class DataFrameDateTimeSuite extends QueryTest { - - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext - import ctx.implicits._ - - test("timestamp comparison with date strings") { - val df = Seq( - (1, Timestamp.valueOf("2015-01-01 00:00:00")), - (2, Timestamp.valueOf("2014-01-01 00:00:00"))).toDF("i", "t") - - checkAnswer( - df.select("t").filter($"t" <= "2014-06-01"), - Row(Timestamp.valueOf("2014-01-01 00:00:00")) :: Nil) - - - checkAnswer( - df.select("t").filter($"t" >= "2014-06-01"), - Row(Timestamp.valueOf("2015-01-01 00:00:00")) :: Nil) - } - - test("date comparison with date strings") { - val df = Seq( - (1, Date.valueOf("2015-01-01")), - (2, Date.valueOf("2014-01-01"))).toDF("i", "t") - - checkAnswer( - df.select("t").filter($"t" <= "2014-06-01"), - Row(Date.valueOf("2014-01-01")) :: Nil) - - - checkAnswer( - df.select("t").filter($"t" >= "2015"), - Row(Date.valueOf("2015-01-01")) :: Nil) - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DateExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala similarity index 74% rename from sql/core/src/test/scala/org/apache/spark/sql/DateExpressionsSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala index d24e3ee1dd8f5..9e80ae86920d9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DateExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala @@ -22,7 +22,7 @@ import java.text.SimpleDateFormat import org.apache.spark.sql.functions._ -class DateExpressionsSuite extends QueryTest { +class DateFunctionsSuite extends QueryTest { private lazy val ctx = org.apache.spark.sql.test.TestSQLContext import ctx.implicits._ @@ -32,6 +32,35 @@ class DateExpressionsSuite extends QueryTest { val d = new Date(sdf.parse("2015-04-08 13:10:15").getTime) val ts = new Timestamp(sdf.parse("2013-04-08 13:10:15").getTime) + test("timestamp comparison with date strings") { + val df = Seq( + (1, Timestamp.valueOf("2015-01-01 00:00:00")), + (2, Timestamp.valueOf("2014-01-01 00:00:00"))).toDF("i", "t") + + checkAnswer( + df.select("t").filter($"t" <= "2014-06-01"), + Row(Timestamp.valueOf("2014-01-01 00:00:00")) :: Nil) + + + checkAnswer( + df.select("t").filter($"t" >= "2014-06-01"), + Row(Timestamp.valueOf("2015-01-01 00:00:00")) :: Nil) + } + + test("date comparison with date strings") { + val df = Seq( + (1, Date.valueOf("2015-01-01")), + (2, Date.valueOf("2014-01-01"))).toDF("i", "t") + + checkAnswer( + df.select("t").filter($"t" <= "2014-06-01"), + Row(Date.valueOf("2014-01-01")) :: Nil) + + + checkAnswer( + df.select("t").filter($"t" >= "2015"), + Row(Date.valueOf("2015-01-01")) :: Nil) + } test("date format") { val df = Seq((d, sdf.format(d), ts)).toDF("a", "b", "c") @@ -83,39 +112,27 @@ class DateExpressionsSuite extends QueryTest { Row(4, 4, 4)) } - test("day") { - val df = Seq((d, sdfDate.format(d), ts)).toDF("a", "b", "c") - - checkAnswer( - df.select(day("a"), day("b"), day("c")), - Row(8, 8, 8)) - - checkAnswer( - df.selectExpr("day(a)", "day(b)", "day(c)"), - Row(8, 8, 8)) - } - - test("day of month") { + test("dayofmonth") { val df = Seq((d, sdfDate.format(d), ts)).toDF("a", "b", "c") checkAnswer( - df.select(day_of_month("a"), day_of_month("b"), day_of_month("c")), + df.select(dayofmonth("a"), dayofmonth("b"), dayofmonth("c")), Row(8, 8, 8)) checkAnswer( - df.selectExpr("day_of_month(a)", "day_of_month(b)", "day_of_month(c)"), + df.selectExpr("day(a)", "day(b)", "dayofmonth(c)"), Row(8, 8, 8)) } - test("day in year") { + test("dayofyear") { val df = Seq((d, sdfDate.format(d), ts)).toDF("a", "b", "c") checkAnswer( - df.select(day_in_year("a"), day_in_year("b"), day_in_year("c")), + df.select(dayofyear("a"), dayofyear("b"), dayofyear("c")), Row(98, 98, 98)) checkAnswer( - df.selectExpr("day_in_year(a)", "day_in_year(b)", "day_in_year(c)"), + df.selectExpr("dayofyear(a)", "dayofyear(b)", "dayofyear(c)"), Row(98, 98, 98)) } @@ -155,15 +172,15 @@ class DateExpressionsSuite extends QueryTest { Row(0, 15, 15)) } - test("week of year") { + test("weekofyear") { val df = Seq((d, sdfDate.format(d), ts)).toDF("a", "b", "c") checkAnswer( - df.select(week_of_year("a"), week_of_year("b"), week_of_year("c")), + df.select(weekofyear("a"), weekofyear("b"), weekofyear("c")), Row(15, 15, 15)) checkAnswer( - df.selectExpr("week_of_year(a)", "week_of_year(b)", "week_of_year(c)"), + df.selectExpr("weekofyear(a)", "weekofyear(b)", "weekofyear(c)"), Row(15, 15, 15)) } From bc24289f5d54e4ff61cd75a5941338c9d946ff73 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Sun, 19 Jul 2015 17:37:25 +0800 Subject: [PATCH 54/58] [SPARK-9179] [BUILD] Allows committers to specify primary author of the PR to be merged It's a common case that some contributor contributes an initial version of a feature/bugfix, and later on some other people (mostly committers) fork and add more improvements. When merging these PRs, we probably want to specify the original author as the primary author. Currently we can only do this by running ``` $ git commit --amend --author="name " ``` manually right before the merge script pushes to Apache Git repo. It would be nice if the script accepts user specified primary author information. Author: Cheng Lian Closes #7508 from liancheng/spark-9179 and squashes the following commits: 218d88e [Cheng Lian] Allows committers to specify primary author of the PR to be merged --- dev/merge_spark_pr.py | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/dev/merge_spark_pr.py b/dev/merge_spark_pr.py index 4a17d48d8171d..d586a57481aa1 100755 --- a/dev/merge_spark_pr.py +++ b/dev/merge_spark_pr.py @@ -130,7 +130,10 @@ def merge_pr(pr_num, target_ref, title, body, pr_repo_desc): '--pretty=format:%an <%ae>']).split("\n") distinct_authors = sorted(set(commit_authors), key=lambda x: commit_authors.count(x), reverse=True) - primary_author = distinct_authors[0] + primary_author = raw_input( + "Enter primary author in the format of \"name \" [%s]: " % + distinct_authors[0]) + commits = run_cmd(['git', 'log', 'HEAD..%s' % pr_branch_name, '--pretty=format:%h [%an] %s']).split("\n\n") @@ -281,7 +284,7 @@ def get_version_json(version_str): resolve = filter(lambda a: a['name'] == "Resolve Issue", asf_jira.transitions(jira_id))[0] resolution = filter(lambda r: r.raw['name'] == "Fixed", asf_jira.resolutions())[0] asf_jira.transition_issue( - jira_id, resolve["id"], fixVersions = jira_fix_versions, + jira_id, resolve["id"], fixVersions = jira_fix_versions, comment = comment, resolution = {'id': resolution.raw['id']}) print "Successfully resolved %s with fixVersions=%s!" % (jira_id, fix_versions) @@ -300,7 +303,7 @@ def standardize_jira_ref(text): """ Standardize the [SPARK-XXXXX] [MODULE] prefix Converts "[SPARK-XXX][mllib] Issue", "[MLLib] SPARK-XXX. Issue" or "SPARK XXX [MLLIB]: Issue" to "[SPARK-XXX] [MLLIB] Issue" - + >>> standardize_jira_ref("[SPARK-5821] [SQL] ParquetRelation2 CTAS should check if delete is successful") '[SPARK-5821] [SQL] ParquetRelation2 CTAS should check if delete is successful' >>> standardize_jira_ref("[SPARK-4123][Project Infra][WIP]: Show new dependencies added in pull requests") @@ -322,11 +325,11 @@ def standardize_jira_ref(text): """ jira_refs = [] components = [] - + # If the string is compliant, no need to process any further if (re.search(r'^\[SPARK-[0-9]{3,6}\] (\[[A-Z0-9_\s,]+\] )+\S+', text)): return text - + # Extract JIRA ref(s): pattern = re.compile(r'(SPARK[-\s]*[0-9]{3,6})+', re.IGNORECASE) for ref in pattern.findall(text): @@ -348,18 +351,18 @@ def standardize_jira_ref(text): # Assemble full text (JIRA ref(s), module(s), remaining text) clean_text = ' '.join(jira_refs).strip() + " " + ' '.join(components).strip() + " " + text.strip() - + # Replace multiple spaces with a single space, e.g. if no jira refs and/or components were included clean_text = re.sub(r'\s+', ' ', clean_text.strip()) - + return clean_text def main(): global original_head - + os.chdir(SPARK_HOME) original_head = run_cmd("git rev-parse HEAD")[:8] - + branches = get_json("%s/branches" % GITHUB_API_BASE) branch_names = filter(lambda x: x.startswith("branch-"), [x['name'] for x in branches]) # Assumes branch names can be sorted lexicographically @@ -448,5 +451,5 @@ def main(): (failure_count, test_count) = doctest.testmod() if failure_count: exit(-1) - + main() From 34ed82bb44c4519819695ddc760e6c9a98bc2e40 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Sun, 19 Jul 2015 18:58:19 +0800 Subject: [PATCH 55/58] [HOTFIX] [SQL] Fixes compilation error introduced by PR #7506 PR #7506 breaks master build because of compilation error. Note that #7506 itself looks good, but it seems that `git merge` did something stupid. Author: Cheng Lian Closes #7510 from liancheng/hotfix-for-pr-7506 and squashes the following commits: 7ea7e89 [Cheng Lian] Fixes compilation error --- .../spark/sql/catalyst/expressions/DateExpressionsSuite.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala index f01589c58ea86..f724bab4d8839 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala @@ -75,7 +75,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { val c = Calendar.getInstance() c.set(y, m, 28, 0, 0, 0) c.add(Calendar.DATE, i) - checkEvaluation(DayInYear(Cast(Literal(new Date(c.getTimeInMillis)), DateType)), + checkEvaluation(DayOfYear(Cast(Literal(new Date(c.getTimeInMillis)), DateType)), sdfDay.format(c.getTime).toInt) } } @@ -87,7 +87,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { val c = Calendar.getInstance() c.set(y, m, 28, 0, 0, 0) c.add(Calendar.DATE, i) - checkEvaluation(DayInYear(Cast(Literal(new Date(c.getTimeInMillis)), DateType)), + checkEvaluation(DayOfYear(Cast(Literal(new Date(c.getTimeInMillis)), DateType)), sdfDay.format(c.getTime).toInt) } } From a803ac3e060d181c7b34d9501c9350e5f215ba85 Mon Sep 17 00:00:00 2001 From: Nicholas Hwang Date: Sun, 19 Jul 2015 10:30:28 -0700 Subject: [PATCH 56/58] [SPARK-9021] [PYSPARK] Change RDD.aggregate() to do reduce(mapPartitions()) instead of mapPartitions.fold() I'm relatively new to Spark and functional programming, so forgive me if this pull request is just a result of my misunderstanding of how Spark should be used. Currently, if one happens to use a mutable object as `zeroValue` for `RDD.aggregate()`, possibly unexpected behavior can occur. This is because pyspark's current implementation of `RDD.aggregate()` does not serialize or make a copy of `zeroValue` before handing it off to `RDD.mapPartitions(...).fold(...)`. This results in a single reference to `zeroValue` being used for both `RDD.mapPartitions()` and `RDD.fold()` on each partition. This can result in strange accumulator values being fed into each partition's call to `RDD.fold()`, as the `zeroValue` may have been changed in-place during the `RDD.mapPartitions()` call. As an illustrative example, submit the following to `spark-submit`: ``` from pyspark import SparkConf, SparkContext import collections def updateCounter(acc, val): print 'update acc:', acc print 'update val:', val acc[val] += 1 return acc def comboCounter(acc1, acc2): print 'combo acc1:', acc1 print 'combo acc2:', acc2 acc1.update(acc2) return acc1 def main(): conf = SparkConf().setMaster("local").setAppName("Aggregate with Counter") sc = SparkContext(conf = conf) print '======= AGGREGATING with ONE PARTITION =======' print sc.parallelize(range(1,10), 1).aggregate(collections.Counter(), updateCounter, comboCounter) print '======= AGGREGATING with TWO PARTITIONS =======' print sc.parallelize(range(1,10), 2).aggregate(collections.Counter(), updateCounter, comboCounter) if __name__ == "__main__": main() ``` One probably expects this to output the following: ``` Counter({1: 1, 2: 1, 3: 1, 4: 1, 5: 1, 6: 1, 7: 1, 8: 1, 9: 1}) ``` But it instead outputs this (regardless of the number of partitions): ``` Counter({1: 2, 2: 2, 3: 2, 4: 2, 5: 2, 6: 2, 7: 2, 8: 2, 9: 2}) ``` This is because (I believe) `zeroValue` gets passed correctly to each partition, but after `RDD.mapPartitions()` completes, the `zeroValue` object has been updated and is then passed to `RDD.fold()`, which results in all items being double-counted within each partition before being finally reduced at the calling node. I realize that this type of calculation is typically done by `RDD.mapPartitions(...).reduceByKey(...)`, but hopefully this illustrates some potentially confusing behavior. I also noticed that other `RDD` methods use this `deepcopy` approach to creating unique copies of `zeroValue` (i.e., `RDD.aggregateByKey()` and `RDD.foldByKey()`), and that the Scala implementations do seem to serialize the `zeroValue` object appropriately to prevent this type of behavior. Author: Nicholas Hwang Closes #7378 from njhwang/master and squashes the following commits: 659bb27 [Nicholas Hwang] Fixed RDD.aggregate() to perform a reduce operation on collected mapPartitions results, similar to how fold currently is implemented. This prevents an initial combOp being performed on each partition with zeroValue (which leads to unexpected behavior if zeroValue is a mutable object) before being combOp'ed with other partition results. 8d8d694 [Nicholas Hwang] Changed dict construction to be compatible with Python 2.6 (cannot use list comprehensions to make dicts) 56eb2ab [Nicholas Hwang] Fixed whitespace after colon to conform with PEP8 391de4a [Nicholas Hwang] Removed used of collections.Counter from RDD tests for Python 2.6 compatibility; used defaultdict(int) instead. Merged treeAggregate test with mutable zero value into aggregate test to reduce code duplication. 2fa4e4b [Nicholas Hwang] Merge branch 'master' of https://github.com/njhwang/spark ba528bd [Nicholas Hwang] Updated comments regarding protection of zeroValue from mutation in RDD.aggregate(). Added regression tests for aggregate(), fold(), aggregateByKey(), foldByKey(), and treeAggregate(), all with both 1 and 2 partition RDDs. Confirmed that aggregate() is the only problematic implementation as of commit 257236c3e17906098f801cbc2059e7a9054e8cab. Also replaced some parallelizations of ranges with xranges, per the documentation's recommendations of preferring xrange over range. 7820391 [Nicholas Hwang] Updated comments regarding protection of zeroValue from mutation in RDD.aggregate(). Added regression tests for aggregate(), fold(), aggregateByKey(), foldByKey(), and treeAggregate(), all with both 1 and 2 partition RDDs. Confirmed that aggregate() is the only problematic implementation as of commit 257236c3e17906098f801cbc2059e7a9054e8cab. 90d1544 [Nicholas Hwang] Made sure RDD.aggregate() makes a deepcopy of zeroValue for all partitions; this ensures that the mapPartitions call works with unique copies of zeroValue in each partition, and prevents a single reference to zeroValue being used for both map and fold calls on each partition (resulting in possibly unexpected behavior). --- python/pyspark/rdd.py | 10 ++- python/pyspark/tests.py | 141 ++++++++++++++++++++++++++++++++++++---- 2 files changed, 137 insertions(+), 14 deletions(-) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 3218bed5c74fc..7e788148d981c 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -862,6 +862,9 @@ def func(iterator): for obj in iterator: acc = op(obj, acc) yield acc + # collecting result of mapPartitions here ensures that the copy of + # zeroValue provided to each partition is unique from the one provided + # to the final reduce call vals = self.mapPartitions(func).collect() return reduce(op, vals, zeroValue) @@ -891,8 +894,11 @@ def func(iterator): for obj in iterator: acc = seqOp(acc, obj) yield acc - - return self.mapPartitions(func).fold(zeroValue, combOp) + # collecting result of mapPartitions here ensures that the copy of + # zeroValue provided to each partition is unique from the one provided + # to the final reduce call + vals = self.mapPartitions(func).collect() + return reduce(combOp, vals, zeroValue) def treeAggregate(self, zeroValue, seqOp, combOp, depth=2): """ diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 21225016805bc..5be9937cb04b2 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -529,10 +529,127 @@ def test_deleting_input_files(self): def test_sampling_default_seed(self): # Test for SPARK-3995 (default seed setting) - data = self.sc.parallelize(range(1000), 1) + data = self.sc.parallelize(xrange(1000), 1) subset = data.takeSample(False, 10) self.assertEqual(len(subset), 10) + def test_aggregate_mutable_zero_value(self): + # Test for SPARK-9021; uses aggregate and treeAggregate to build dict + # representing a counter of ints + # NOTE: dict is used instead of collections.Counter for Python 2.6 + # compatibility + from collections import defaultdict + + # Show that single or multiple partitions work + data1 = self.sc.range(10, numSlices=1) + data2 = self.sc.range(10, numSlices=2) + + def seqOp(x, y): + x[y] += 1 + return x + + def comboOp(x, y): + for key, val in y.items(): + x[key] += val + return x + + counts1 = data1.aggregate(defaultdict(int), seqOp, comboOp) + counts2 = data2.aggregate(defaultdict(int), seqOp, comboOp) + counts3 = data1.treeAggregate(defaultdict(int), seqOp, comboOp, 2) + counts4 = data2.treeAggregate(defaultdict(int), seqOp, comboOp, 2) + + ground_truth = defaultdict(int, dict((i, 1) for i in range(10))) + self.assertEqual(counts1, ground_truth) + self.assertEqual(counts2, ground_truth) + self.assertEqual(counts3, ground_truth) + self.assertEqual(counts4, ground_truth) + + def test_aggregate_by_key_mutable_zero_value(self): + # Test for SPARK-9021; uses aggregateByKey to make a pair RDD that + # contains lists of all values for each key in the original RDD + + # list(range(...)) for Python 3.x compatibility (can't use * operator + # on a range object) + # list(zip(...)) for Python 3.x compatibility (want to parallelize a + # collection, not a zip object) + tuples = list(zip(list(range(10))*2, [1]*20)) + # Show that single or multiple partitions work + data1 = self.sc.parallelize(tuples, 1) + data2 = self.sc.parallelize(tuples, 2) + + def seqOp(x, y): + x.append(y) + return x + + def comboOp(x, y): + x.extend(y) + return x + + values1 = data1.aggregateByKey([], seqOp, comboOp).collect() + values2 = data2.aggregateByKey([], seqOp, comboOp).collect() + # Sort lists to ensure clean comparison with ground_truth + values1.sort() + values2.sort() + + ground_truth = [(i, [1]*2) for i in range(10)] + self.assertEqual(values1, ground_truth) + self.assertEqual(values2, ground_truth) + + def test_fold_mutable_zero_value(self): + # Test for SPARK-9021; uses fold to merge an RDD of dict counters into + # a single dict + # NOTE: dict is used instead of collections.Counter for Python 2.6 + # compatibility + from collections import defaultdict + + counts1 = defaultdict(int, dict((i, 1) for i in range(10))) + counts2 = defaultdict(int, dict((i, 1) for i in range(3, 8))) + counts3 = defaultdict(int, dict((i, 1) for i in range(4, 7))) + counts4 = defaultdict(int, dict((i, 1) for i in range(5, 6))) + all_counts = [counts1, counts2, counts3, counts4] + # Show that single or multiple partitions work + data1 = self.sc.parallelize(all_counts, 1) + data2 = self.sc.parallelize(all_counts, 2) + + def comboOp(x, y): + for key, val in y.items(): + x[key] += val + return x + + fold1 = data1.fold(defaultdict(int), comboOp) + fold2 = data2.fold(defaultdict(int), comboOp) + + ground_truth = defaultdict(int) + for counts in all_counts: + for key, val in counts.items(): + ground_truth[key] += val + self.assertEqual(fold1, ground_truth) + self.assertEqual(fold2, ground_truth) + + def test_fold_by_key_mutable_zero_value(self): + # Test for SPARK-9021; uses foldByKey to make a pair RDD that contains + # lists of all values for each key in the original RDD + + tuples = [(i, range(i)) for i in range(10)]*2 + # Show that single or multiple partitions work + data1 = self.sc.parallelize(tuples, 1) + data2 = self.sc.parallelize(tuples, 2) + + def comboOp(x, y): + x.extend(y) + return x + + values1 = data1.foldByKey([], comboOp).collect() + values2 = data2.foldByKey([], comboOp).collect() + # Sort lists to ensure clean comparison with ground_truth + values1.sort() + values2.sort() + + # list(range(...)) for Python 3.x compatibility + ground_truth = [(i, list(range(i))*2) for i in range(10)] + self.assertEqual(values1, ground_truth) + self.assertEqual(values2, ground_truth) + def test_aggregate_by_key(self): data = self.sc.parallelize([(1, 1), (1, 1), (3, 2), (5, 1), (5, 3)], 2) @@ -624,8 +741,8 @@ def test_zip_with_different_serializers(self): def test_zip_with_different_object_sizes(self): # regress test for SPARK-5973 - a = self.sc.parallelize(range(10000)).map(lambda i: '*' * i) - b = self.sc.parallelize(range(10000, 20000)).map(lambda i: '*' * i) + a = self.sc.parallelize(xrange(10000)).map(lambda i: '*' * i) + b = self.sc.parallelize(xrange(10000, 20000)).map(lambda i: '*' * i) self.assertEqual(10000, a.zip(b).count()) def test_zip_with_different_number_of_items(self): @@ -647,7 +764,7 @@ def test_zip_with_different_number_of_items(self): self.assertRaises(Exception, lambda: a.zip(b).count()) def test_count_approx_distinct(self): - rdd = self.sc.parallelize(range(1000)) + rdd = self.sc.parallelize(xrange(1000)) self.assertTrue(950 < rdd.countApproxDistinct(0.03) < 1050) self.assertTrue(950 < rdd.map(float).countApproxDistinct(0.03) < 1050) self.assertTrue(950 < rdd.map(str).countApproxDistinct(0.03) < 1050) @@ -777,7 +894,7 @@ def test_distinct(self): def test_external_group_by_key(self): self.sc._conf.set("spark.python.worker.memory", "1m") N = 200001 - kv = self.sc.parallelize(range(N)).map(lambda x: (x % 3, x)) + kv = self.sc.parallelize(xrange(N)).map(lambda x: (x % 3, x)) gkv = kv.groupByKey().cache() self.assertEqual(3, gkv.count()) filtered = gkv.filter(lambda kv: kv[0] == 1) @@ -871,7 +988,7 @@ def test_narrow_dependency_in_join(self): # Regression test for SPARK-6294 def test_take_on_jrdd(self): - rdd = self.sc.parallelize(range(1 << 20)).map(lambda x: str(x)) + rdd = self.sc.parallelize(xrange(1 << 20)).map(lambda x: str(x)) rdd._jrdd.first() def test_sortByKey_uses_all_partitions_not_only_first_and_last(self): @@ -1517,13 +1634,13 @@ def run(): self.fail("daemon had been killed") # run a normal job - rdd = self.sc.parallelize(range(100), 1) + rdd = self.sc.parallelize(xrange(100), 1) self.assertEqual(100, rdd.map(str).count()) def test_after_exception(self): def raise_exception(_): raise Exception() - rdd = self.sc.parallelize(range(100), 1) + rdd = self.sc.parallelize(xrange(100), 1) with QuietTest(self.sc): self.assertRaises(Exception, lambda: rdd.foreach(raise_exception)) self.assertEqual(100, rdd.map(str).count()) @@ -1539,22 +1656,22 @@ def test_after_jvm_exception(self): with QuietTest(self.sc): self.assertRaises(Exception, lambda: filtered_data.count()) - rdd = self.sc.parallelize(range(100), 1) + rdd = self.sc.parallelize(xrange(100), 1) self.assertEqual(100, rdd.map(str).count()) def test_accumulator_when_reuse_worker(self): from pyspark.accumulators import INT_ACCUMULATOR_PARAM acc1 = self.sc.accumulator(0, INT_ACCUMULATOR_PARAM) - self.sc.parallelize(range(100), 20).foreach(lambda x: acc1.add(x)) + self.sc.parallelize(xrange(100), 20).foreach(lambda x: acc1.add(x)) self.assertEqual(sum(range(100)), acc1.value) acc2 = self.sc.accumulator(0, INT_ACCUMULATOR_PARAM) - self.sc.parallelize(range(100), 20).foreach(lambda x: acc2.add(x)) + self.sc.parallelize(xrange(100), 20).foreach(lambda x: acc2.add(x)) self.assertEqual(sum(range(100)), acc2.value) self.assertEqual(sum(range(100)), acc1.value) def test_reuse_worker_after_take(self): - rdd = self.sc.parallelize(range(100000), 1) + rdd = self.sc.parallelize(xrange(100000), 1) self.assertEqual(0, rdd.first()) def count(): From 7a81245345f2d6124423161786bb0d9f1c278ab8 Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Sun, 19 Jul 2015 16:29:50 -0700 Subject: [PATCH 57/58] [SPARK-8638] [SQL] Window Function Performance Improvements - Cleanup This PR contains a few clean-ups that are a part of SPARK-8638: a few style issues got fixed, and a few tests were moved. Git commit message is wrong BTW :(... Author: Herman van Hovell Closes #7513 from hvanhovell/SPARK-8638-cleanup and squashes the following commits: 4e69d08 [Herman van Hovell] Fixed Perfomance Regression for Shrinking Window Frames (+Rebase) --- .../apache/spark/sql/execution/Window.scala | 14 ++-- .../sql/hive/HiveDataFrameWindowSuite.scala | 43 ++++++++++ .../sql/hive/execution/WindowSuite.scala | 79 ------------------- 3 files changed, 51 insertions(+), 85 deletions(-) delete mode 100644 sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/WindowSuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala index a054f52b8b489..de04132eb1104 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala @@ -118,22 +118,24 @@ case class Window( val exprs = windowSpec.orderSpec.map(_.child) val projection = newMutableProjection(exprs, child.output) (windowSpec.orderSpec, projection(), projection()) - } - else if (windowSpec.orderSpec.size == 1) { + } else if (windowSpec.orderSpec.size == 1) { // Use only the first order expression when the offset is non-null. val sortExpr = windowSpec.orderSpec.head val expr = sortExpr.child // Create the projection which returns the current 'value'. val current = newMutableProjection(expr :: Nil, child.output)() // Flip the sign of the offset when processing the order is descending - val boundOffset = if (sortExpr.direction == Descending) -offset - else offset + val boundOffset = + if (sortExpr.direction == Descending) { + -offset + } else { + offset + } // Create the projection which returns the current 'value' modified by adding the offset. val boundExpr = Add(expr, Cast(Literal.create(boundOffset, IntegerType), expr.dataType)) val bound = newMutableProjection(boundExpr :: Nil, child.output)() (sortExpr :: Nil, current, bound) - } - else { + } else { sys.error("Non-Zero range offsets are not supported for windows " + "with multiple order expressions.") } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameWindowSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameWindowSuite.scala index 15b5f418f0a8c..c177cbdd991cf 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameWindowSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameWindowSuite.scala @@ -212,4 +212,47 @@ class HiveDataFrameWindowSuite extends QueryTest { | (PARTITION BY value ORDER BY key RANGE BETWEEN 1 preceding and current row) | FROM window_table""".stripMargin).collect()) } + + test("reverse sliding range frame") { + val df = Seq( + (1, "Thin", "Cell Phone", 6000), + (2, "Normal", "Tablet", 1500), + (3, "Mini", "Tablet", 5500), + (4, "Ultra thin", "Cell Phone", 5500), + (5, "Very thin", "Cell Phone", 6000), + (6, "Big", "Tablet", 2500), + (7, "Bendable", "Cell Phone", 3000), + (8, "Foldable", "Cell Phone", 3000), + (9, "Pro", "Tablet", 4500), + (10, "Pro2", "Tablet", 6500)). + toDF("id", "product", "category", "revenue") + val window = Window. + partitionBy($"category"). + orderBy($"revenue".desc). + rangeBetween(-2000L, 1000L) + checkAnswer( + df.select( + $"id", + avg($"revenue").over(window).cast("int")), + Row(1, 5833) :: Row(2, 2000) :: Row(3, 5500) :: + Row(4, 5833) :: Row(5, 5833) :: Row(6, 2833) :: + Row(7, 3000) :: Row(8, 3000) :: Row(9, 5500) :: + Row(10, 6000) :: Nil) + } + + // This is here to illustrate the fact that reverse order also reverses offsets. + test("reverse unbounded range frame") { + val df = Seq(1, 2, 4, 3, 2, 1). + map(Tuple1.apply). + toDF("value") + val window = Window.orderBy($"value".desc) + checkAnswer( + df.select( + $"value", + sum($"value").over(window.rangeBetween(Long.MinValue, 1)), + sum($"value").over(window.rangeBetween(1, Long.MaxValue))), + Row(1, 13, null) :: Row(2, 13, 2) :: Row(4, 7, 9) :: + Row(3, 11, 6) :: Row(2, 13, 2) :: Row(1, 13, null) :: Nil) + + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/WindowSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/WindowSuite.scala deleted file mode 100644 index a089d0d165195..0000000000000 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/WindowSuite.scala +++ /dev/null @@ -1,79 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.hive.execution - -import org.apache.spark.sql.{Row, QueryTest} -import org.apache.spark.sql.expressions.Window -import org.apache.spark.sql.functions._ -import org.apache.spark.sql.hive.test.TestHive.implicits._ - -/** - * Window expressions are tested extensively by the following test suites: - * [[org.apache.spark.sql.hive.HiveDataFrameWindowSuite]] - * [[org.apache.spark.sql.hive.execution.HiveWindowFunctionQueryWithoutCodeGenSuite]] - * [[org.apache.spark.sql.hive.execution.HiveWindowFunctionQueryFileWithoutCodeGenSuite]] - * However these suites do not cover all possible (i.e. more exotic) settings. This suite fill - * this gap. - * - * TODO Move this class to the sql/core project when we move to Native Spark UDAFs. - */ -class WindowSuite extends QueryTest { - - test("reverse sliding range frame") { - val df = Seq( - (1, "Thin", "Cell Phone", 6000), - (2, "Normal", "Tablet", 1500), - (3, "Mini", "Tablet", 5500), - (4, "Ultra thin", "Cell Phone", 5500), - (5, "Very thin", "Cell Phone", 6000), - (6, "Big", "Tablet", 2500), - (7, "Bendable", "Cell Phone", 3000), - (8, "Foldable", "Cell Phone", 3000), - (9, "Pro", "Tablet", 4500), - (10, "Pro2", "Tablet", 6500)). - toDF("id", "product", "category", "revenue") - val window = Window. - partitionBy($"category"). - orderBy($"revenue".desc). - rangeBetween(-2000L, 1000L) - checkAnswer( - df.select( - $"id", - avg($"revenue").over(window).cast("int")), - Row(1, 5833) :: Row(2, 2000) :: Row(3, 5500) :: - Row(4, 5833) :: Row(5, 5833) :: Row(6, 2833) :: - Row(7, 3000) :: Row(8, 3000) :: Row(9, 5500) :: - Row(10, 6000) :: Nil) - } - - // This is here to illustrate the fact that reverse order also reverses offsets. - test("reverse unbounded range frame") { - val df = Seq(1, 2, 4, 3, 2, 1). - map(Tuple1.apply). - toDF("value") - val window = Window.orderBy($"value".desc) - checkAnswer( - df.select( - $"value", - sum($"value").over(window.rangeBetween(Long.MinValue, 1)), - sum($"value").over(window.rangeBetween(1, Long.MaxValue))), - Row(1, 13, null) :: Row(2, 13, 2) :: Row(4, 7, 9) :: - Row(3, 11, 6) :: Row(2, 13, 2) :: Row(1, 13, null) :: Nil) - - } -} From 163e3f1df94f6b7d3dadb46a87dbb3a2bade3f95 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Sun, 19 Jul 2015 16:48:47 -0700 Subject: [PATCH 58/58] [SPARK-8241][SQL] string function: concat_ws. I also changed the semantics of concat w.r.t. null back to the same behavior as Hive. That is to say, concat now returns null if any input is null. Author: Reynold Xin Closes #7504 from rxin/concat_ws and squashes the following commits: 83fd950 [Reynold Xin] Fixed type casting. 3ae85f7 [Reynold Xin] Write null better. cdc7be6 [Reynold Xin] Added code generation for pure string mode. a61c4e4 [Reynold Xin] Updated comments. 2d51406 [Reynold Xin] [SPARK-8241][SQL] string function: concat_ws. --- .../catalyst/analysis/FunctionRegistry.scala | 11 ++- .../expressions/stringOperations.scala | 72 ++++++++++++++++--- .../org/apache/spark/sql/types/DataType.scala | 2 +- .../analysis/HiveTypeCoercionSuite.scala | 11 ++- .../expressions/StringExpressionsSuite.scala | 31 +++++++- .../org/apache/spark/sql/functions.scala | 24 +++++++ .../spark/sql/StringFunctionsSuite.scala | 19 +++-- .../execution/HiveCompatibilitySuite.scala | 4 +- .../apache/spark/unsafe/types/UTF8String.java | 58 +++++++++++++-- .../spark/unsafe/types/UTF8StringSuite.java | 62 ++++++++++++---- 10 files changed, 256 insertions(+), 38 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 4b256adcc60c6..71e87b98d86fc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -153,6 +153,7 @@ object FunctionRegistry { expression[Ascii]("ascii"), expression[Base64]("base64"), expression[Concat]("concat"), + expression[ConcatWs]("concat_ws"), expression[Encode]("encode"), expression[Decode]("decode"), expression[FormatNumber]("format_number"), @@ -211,7 +212,10 @@ object FunctionRegistry { val builder = (expressions: Seq[Expression]) => { if (varargCtor.isDefined) { // If there is an apply method that accepts Seq[Expression], use that one. - varargCtor.get.newInstance(expressions).asInstanceOf[Expression] + Try(varargCtor.get.newInstance(expressions).asInstanceOf[Expression]) match { + case Success(e) => e + case Failure(e) => throw new AnalysisException(e.getMessage) + } } else { // Otherwise, find an ctor method that matches the number of arguments, and use that. val params = Seq.fill(expressions.size)(classOf[Expression]) @@ -221,7 +225,10 @@ object FunctionRegistry { case Failure(e) => throw new AnalysisException(s"Invalid number of arguments for function $name") } - f.newInstance(expressions : _*).asInstanceOf[Expression] + Try(f.newInstance(expressions : _*).asInstanceOf[Expression]) match { + case Success(e) => e + case Failure(e) => throw new AnalysisException(e.getMessage) + } } } (name, builder) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index 560b1bc2d889f..5f8ac716f79a1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -34,19 +34,14 @@ import org.apache.spark.unsafe.types.UTF8String /** * An expression that concatenates multiple input strings into a single string. - * Input expressions that are evaluated to nulls are skipped. - * - * For example, `concat("a", null, "b")` is evaluated to `"ab"`. - * - * Note that this is different from Hive since Hive outputs null if any input is null. - * We never output null. + * If any input is null, concat returns null. */ case class Concat(children: Seq[Expression]) extends Expression with ImplicitCastInputTypes { override def inputTypes: Seq[AbstractDataType] = Seq.fill(children.size)(StringType) override def dataType: DataType = StringType - override def nullable: Boolean = false + override def nullable: Boolean = children.exists(_.nullable) override def foldable: Boolean = children.forall(_.foldable) override def eval(input: InternalRow): Any = { @@ -56,15 +51,76 @@ case class Concat(children: Seq[Expression]) extends Expression with ImplicitCas override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { val evals = children.map(_.gen(ctx)) - val inputs = evals.map { eval => s"${eval.isNull} ? null : ${eval.primitive}" }.mkString(", ") + val inputs = evals.map { eval => + s"${eval.isNull} ? (UTF8String)null : ${eval.primitive}" + }.mkString(", ") evals.map(_.code).mkString("\n") + s""" boolean ${ev.isNull} = false; UTF8String ${ev.primitive} = UTF8String.concat($inputs); + if (${ev.primitive} == null) { + ${ev.isNull} = true; + } """ } } +/** + * An expression that concatenates multiple input strings or array of strings into a single string, + * using a given separator (the first child). + * + * Returns null if the separator is null. Otherwise, concat_ws skips all null values. + */ +case class ConcatWs(children: Seq[Expression]) + extends Expression with ImplicitCastInputTypes with CodegenFallback { + + require(children.nonEmpty, s"$prettyName requires at least one argument.") + + override def prettyName: String = "concat_ws" + + /** The 1st child (separator) is str, and rest are either str or array of str. */ + override def inputTypes: Seq[AbstractDataType] = { + val arrayOrStr = TypeCollection(ArrayType(StringType), StringType) + StringType +: Seq.fill(children.size - 1)(arrayOrStr) + } + + override def dataType: DataType = StringType + + override def nullable: Boolean = children.head.nullable + override def foldable: Boolean = children.forall(_.foldable) + + override def eval(input: InternalRow): Any = { + val flatInputs = children.flatMap { child => + child.eval(input) match { + case s: UTF8String => Iterator(s) + case arr: Seq[_] => arr.asInstanceOf[Seq[UTF8String]] + case null => Iterator(null.asInstanceOf[UTF8String]) + } + } + UTF8String.concatWs(flatInputs.head, flatInputs.tail : _*) + } + + override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + if (children.forall(_.dataType == StringType)) { + // All children are strings. In that case we can construct a fixed size array. + val evals = children.map(_.gen(ctx)) + + val inputs = evals.map { eval => + s"${eval.isNull} ? (UTF8String)null : ${eval.primitive}" + }.mkString(", ") + + evals.map(_.code).mkString("\n") + s""" + UTF8String ${ev.primitive} = UTF8String.concatWs($inputs); + boolean ${ev.isNull} = ${ev.primitive} == null; + """ + } else { + // Contains a mix of strings and arrays. Fall back to interpreted mode for now. + super.genCode(ctx, ev) + } + } +} + + trait StringRegexExpression extends ImplicitCastInputTypes { self: BinaryExpression => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala index 2d133eea19fe0..e98fd2583b931 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala @@ -79,7 +79,7 @@ abstract class DataType extends AbstractDataType { override private[sql] def defaultConcreteType: DataType = this - override private[sql] def acceptsType(other: DataType): Boolean = this == other + override private[sql] def acceptsType(other: DataType): Boolean = sameType(other) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala index f9442bccc4a7a..7ee2333a81dfe 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala @@ -37,7 +37,6 @@ class HiveTypeCoercionSuite extends PlanTest { shouldCast(NullType, IntegerType, IntegerType) shouldCast(NullType, DecimalType, DecimalType.Unlimited) - // TODO: write the entire implicit cast table out for test cases. shouldCast(ByteType, IntegerType, IntegerType) shouldCast(IntegerType, IntegerType, IntegerType) shouldCast(IntegerType, LongType, LongType) @@ -86,6 +85,16 @@ class HiveTypeCoercionSuite extends PlanTest { DecimalType.Unlimited, DecimalType(10, 2)).foreach { tpe => shouldCast(tpe, NumericType, tpe) } + + shouldCast( + ArrayType(StringType, false), + TypeCollection(ArrayType(StringType), StringType), + ArrayType(StringType, false)) + + shouldCast( + ArrayType(StringType, true), + TypeCollection(ArrayType(StringType), StringType), + ArrayType(StringType, true)) } test("ineligible implicit type cast") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala index 0ed567a90dd1f..96f433be8b065 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala @@ -26,7 +26,7 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { test("concat") { def testConcat(inputs: String*): Unit = { - val expected = inputs.filter(_ != null).mkString + val expected = if (inputs.contains(null)) null else inputs.mkString checkEvaluation(Concat(inputs.map(Literal.create(_, StringType))), expected, EmptyRow) } @@ -46,6 +46,35 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { // scalastyle:on } + test("concat_ws") { + def testConcatWs(expected: String, sep: String, inputs: Any*): Unit = { + val inputExprs = inputs.map { + case s: Seq[_] => Literal.create(s, ArrayType(StringType)) + case null => Literal.create(null, StringType) + case s: String => Literal.create(s, StringType) + } + val sepExpr = Literal.create(sep, StringType) + checkEvaluation(ConcatWs(sepExpr +: inputExprs), expected, EmptyRow) + } + + // scalastyle:off + // non ascii characters are not allowed in the code, so we disable the scalastyle here. + testConcatWs(null, null) + testConcatWs(null, null, "a", "b") + testConcatWs("", "") + testConcatWs("ab", "哈哈", "ab") + testConcatWs("a哈哈b", "哈哈", "a", "b") + testConcatWs("a哈哈b", "哈哈", "a", null, "b") + testConcatWs("a哈哈b哈哈c", "哈哈", null, "a", null, "b", "c") + + testConcatWs("ab", "哈哈", Seq("ab")) + testConcatWs("a哈哈b", "哈哈", Seq("a", "b")) + testConcatWs("a哈哈b哈哈c哈哈d", "哈哈", Seq("a", null, "b"), null, "c", Seq(null, "d")) + testConcatWs("a哈哈b哈哈c", "哈哈", Seq("a", null, "b"), null, "c", Seq.empty[String]) + testConcatWs("a哈哈b哈哈c", "哈哈", Seq("a", null, "b"), null, "c", Seq[String](null)) + // scalastyle:on + } + test("StringComparison") { val row = create_row("abc", null) val c1 = 'a.string.at(0) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index f67c89437bb4a..b5140dca0487f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1732,6 +1732,30 @@ object functions { concat((columnName +: columnNames).map(Column.apply): _*) } + /** + * Concatenates input strings together into a single string, using the given separator. + * + * @group string_funcs + * @since 1.5.0 + */ + @scala.annotation.varargs + def concat_ws(sep: String, exprs: Column*): Column = { + ConcatWs(Literal.create(sep, StringType) +: exprs.map(_.expr)) + } + + /** + * Concatenates input strings together into a single string, using the given separator. + * + * This is the variant of concat_ws that takes in the column names. + * + * @group string_funcs + * @since 1.5.0 + */ + @scala.annotation.varargs + def concat_ws(sep: String, columnName: String, columnNames: String*): Column = { + concat_ws(sep, (columnName +: columnNames).map(Column.apply) : _*) + } + /** * Computes the length of a given string / binary value. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala index 4eff33ed45042..fe4de8d8b855f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala @@ -30,14 +30,25 @@ class StringFunctionsSuite extends QueryTest { val df = Seq[(String, String, String)](("a", "b", null)).toDF("a", "b", "c") checkAnswer( - df.select(concat($"a", $"b", $"c")), - Row("ab")) + df.select(concat($"a", $"b"), concat($"a", $"b", $"c")), + Row("ab", null)) checkAnswer( - df.selectExpr("concat(a, b, c)"), - Row("ab")) + df.selectExpr("concat(a, b)", "concat(a, b, c)"), + Row("ab", null)) } + test("string concat_ws") { + val df = Seq[(String, String, String)](("a", "b", null)).toDF("a", "b", "c") + + checkAnswer( + df.select(concat_ws("||", $"a", $"b", $"c")), + Row("a||b")) + + checkAnswer( + df.selectExpr("concat_ws('||', a, b, c)"), + Row("a||b")) + } test("string Levenshtein distance") { val df = Seq(("kitten", "sitting"), ("frog", "fog")).toDF("l", "r") diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index 2689d904d6541..b12b3838e615c 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -263,9 +263,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "timestamp_2", "timestamp_udf", - // Hive outputs NULL if any concat input has null. We never output null for concat. - "udf_concat", - // Unlike Hive, we do support log base in (0, 1.0], therefore disable this "udf7" ) @@ -856,6 +853,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "udf_case", "udf_ceil", "udf_ceiling", + "udf_concat", "udf_concat_insert1", "udf_concat_insert2", "udf_concat_ws", diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index 9723b6e0834b2..3eecd657e6ef9 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -397,26 +397,62 @@ public UTF8String lpad(int len, UTF8String pad) { } /** - * Concatenates input strings together into a single string. A null input is skipped. - * For example, concat("a", null, "c") would yield "ac". + * Concatenates input strings together into a single string. Returns null if any input is null. */ public static UTF8String concat(UTF8String... inputs) { - if (inputs == null) { - return fromBytes(new byte[0]); - } - // Compute the total length of the result. int totalLength = 0; for (int i = 0; i < inputs.length; i++) { if (inputs[i] != null) { totalLength += inputs[i].numBytes; + } else { + return null; } } // Allocate a new byte array, and copy the inputs one by one into it. final byte[] result = new byte[totalLength]; int offset = 0; + for (int i = 0; i < inputs.length; i++) { + int len = inputs[i].numBytes; + PlatformDependent.copyMemory( + inputs[i].base, inputs[i].offset, + result, PlatformDependent.BYTE_ARRAY_OFFSET + offset, + len); + offset += len; + } + return fromBytes(result); + } + + /** + * Concatenates input strings together into a single string using the separator. + * A null input is skipped. For example, concat(",", "a", null, "c") would yield "a,c". + */ + public static UTF8String concatWs(UTF8String separator, UTF8String... inputs) { + if (separator == null) { + return null; + } + + int numInputBytes = 0; // total number of bytes from the inputs + int numInputs = 0; // number of non-null inputs for (int i = 0; i < inputs.length; i++) { + if (inputs[i] != null) { + numInputBytes += inputs[i].numBytes; + numInputs++; + } + } + + if (numInputs == 0) { + // Return an empty string if there is no input, or all the inputs are null. + return fromBytes(new byte[0]); + } + + // Allocate a new byte array, and copy the inputs one by one into it. + // The size of the new array is the size of all inputs, plus the separators. + final byte[] result = new byte[numInputBytes + (numInputs - 1) * separator.numBytes]; + int offset = 0; + + for (int i = 0, j = 0; i < inputs.length; i++) { if (inputs[i] != null) { int len = inputs[i].numBytes; PlatformDependent.copyMemory( @@ -424,6 +460,16 @@ public static UTF8String concat(UTF8String... inputs) { result, PlatformDependent.BYTE_ARRAY_OFFSET + offset, len); offset += len; + + j++; + // Add separator if this is not the last input. + if (j < numInputs) { + PlatformDependent.copyMemory( + separator.base, separator.offset, + result, PlatformDependent.BYTE_ARRAY_OFFSET + offset, + separator.numBytes); + offset += separator.numBytes; + } } } return fromBytes(result); diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java index 0db7522b50c1a..7d0c49e2fb84c 100644 --- a/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java +++ b/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java @@ -88,16 +88,50 @@ public void upperAndLower() { @Test public void concatTest() { - assertEquals(concat(), fromString("")); - assertEquals(concat(null), fromString("")); - assertEquals(concat(fromString("")), fromString("")); - assertEquals(concat(fromString("ab")), fromString("ab")); - assertEquals(concat(fromString("a"), fromString("b")), fromString("ab")); - assertEquals(concat(fromString("a"), fromString("b"), fromString("c")), fromString("abc")); - assertEquals(concat(fromString("a"), null, fromString("c")), fromString("ac")); - assertEquals(concat(fromString("a"), null, null), fromString("a")); - assertEquals(concat(null, null, null), fromString("")); - assertEquals(concat(fromString("数据"), fromString("砖头")), fromString("数据砖头")); + assertEquals(fromString(""), concat()); + assertEquals(null, concat((UTF8String) null)); + assertEquals(fromString(""), concat(fromString(""))); + assertEquals(fromString("ab"), concat(fromString("ab"))); + assertEquals(fromString("ab"), concat(fromString("a"), fromString("b"))); + assertEquals(fromString("abc"), concat(fromString("a"), fromString("b"), fromString("c"))); + assertEquals(null, concat(fromString("a"), null, fromString("c"))); + assertEquals(null, concat(fromString("a"), null, null)); + assertEquals(null, concat(null, null, null)); + assertEquals(fromString("数据砖头"), concat(fromString("数据"), fromString("砖头"))); + } + + @Test + public void concatWsTest() { + // Returns null if the separator is null + assertEquals(null, concatWs(null, (UTF8String)null)); + assertEquals(null, concatWs(null, fromString("a"))); + + // If separator is null, concatWs should skip all null inputs and never return null. + UTF8String sep = fromString("哈哈"); + assertEquals( + fromString(""), + concatWs(sep, fromString(""))); + assertEquals( + fromString("ab"), + concatWs(sep, fromString("ab"))); + assertEquals( + fromString("a哈哈b"), + concatWs(sep, fromString("a"), fromString("b"))); + assertEquals( + fromString("a哈哈b哈哈c"), + concatWs(sep, fromString("a"), fromString("b"), fromString("c"))); + assertEquals( + fromString("a哈哈c"), + concatWs(sep, fromString("a"), null, fromString("c"))); + assertEquals( + fromString("a"), + concatWs(sep, fromString("a"), null, null)); + assertEquals( + fromString(""), + concatWs(sep, null, null, null)); + assertEquals( + fromString("数据哈哈砖头"), + concatWs(sep, fromString("数据"), fromString("砖头"))); } @Test @@ -215,14 +249,18 @@ public void pad() { assertEquals(fromString("??数据砖头"), fromString("数据砖头").lpad(6, fromString("????"))); assertEquals(fromString("孙行数据砖头"), fromString("数据砖头").lpad(6, fromString("孙行者"))); assertEquals(fromString("孙行者数据砖头"), fromString("数据砖头").lpad(7, fromString("孙行者"))); - assertEquals(fromString("孙行者孙行者孙行数据砖头"), fromString("数据砖头").lpad(12, fromString("孙行者"))); + assertEquals( + fromString("孙行者孙行者孙行数据砖头"), + fromString("数据砖头").lpad(12, fromString("孙行者"))); assertEquals(fromString("数据砖"), fromString("数据砖头").rpad(3, fromString("????"))); assertEquals(fromString("数据砖头?"), fromString("数据砖头").rpad(5, fromString("????"))); assertEquals(fromString("数据砖头??"), fromString("数据砖头").rpad(6, fromString("????"))); assertEquals(fromString("数据砖头孙行"), fromString("数据砖头").rpad(6, fromString("孙行者"))); assertEquals(fromString("数据砖头孙行者"), fromString("数据砖头").rpad(7, fromString("孙行者"))); - assertEquals(fromString("数据砖头孙行者孙行者孙行"), fromString("数据砖头").rpad(12, fromString("孙行者"))); + assertEquals( + fromString("数据砖头孙行者孙行者孙行"), + fromString("数据砖头").rpad(12, fromString("孙行者"))); } @Test