From 5572fd0c518acd2e4483ff41bea1eb1cffd543ce Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Tue, 14 Jul 2015 21:44:47 -0700 Subject: [PATCH 01/27] [HOTFIX] Adding new names to known contributors --- dev/create-release/known_translations | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/dev/create-release/known_translations b/dev/create-release/known_translations index 5f2671a6e5053..e462302f28423 100644 --- a/dev/create-release/known_translations +++ b/dev/create-release/known_translations @@ -129,3 +129,12 @@ yongtang - Yong Tang ypcat - Pei-Lun Lee zhichao-li - Zhichao Li zzcclp - Zhichao Zhang +979969786 - Yuming Wang +Rosstin - Rosstin Murphy +ameyc - Amey Chaugule +animeshbaranawal - Animesh Baranawal +cafreeman - Chris Freeman +lee19 - Lee +lockwobr - Brian Lockwood +navis - Navis Ryu +pparkkin - Paavo Parkkinen From f650a005e03ecd800c9005a496cc6a0d8eb68c93 Mon Sep 17 00:00:00 2001 From: Sun Rui Date: Tue, 14 Jul 2015 22:21:01 -0700 Subject: [PATCH 02/27] [SPARK-8808] [SPARKR] Fix assignments in SparkR. Author: Sun Rui Closes #7395 from sun-rui/SPARK-8808 and squashes the following commits: ce603bc [Sun Rui] Use '<-' instead of '='. 88590b1 [Sun Rui] Use '<-' instead of '='. --- R/pkg/R/DataFrame.R | 2 +- R/pkg/R/client.R | 4 ++-- R/pkg/R/group.R | 4 ++-- R/pkg/R/utils.R | 4 ++-- R/pkg/inst/tests/test_binaryFile.R | 2 +- R/pkg/inst/tests/test_binary_function.R | 2 +- R/pkg/inst/tests/test_rdd.R | 4 ++-- R/pkg/inst/tests/test_textFile.R | 2 +- R/pkg/inst/tests/test_utils.R | 2 +- 9 files changed, 13 insertions(+), 13 deletions(-) diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 60702824acb46..208813768e264 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -1328,7 +1328,7 @@ setMethod("write.df", jmode <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "saveMode", mode) options <- varargsToEnv(...) if (!is.null(path)) { - options[['path']] = path + options[['path']] <- path } callJMethod(df@sdf, "save", source, jmode, options) }) diff --git a/R/pkg/R/client.R b/R/pkg/R/client.R index 78c7a3037ffac..6f772158ddfe8 100644 --- a/R/pkg/R/client.R +++ b/R/pkg/R/client.R @@ -36,9 +36,9 @@ connectBackend <- function(hostname, port, timeout = 6000) { determineSparkSubmitBin <- function() { if (.Platform$OS.type == "unix") { - sparkSubmitBinName = "spark-submit" + sparkSubmitBinName <- "spark-submit" } else { - sparkSubmitBinName = "spark-submit.cmd" + sparkSubmitBinName <- "spark-submit.cmd" } sparkSubmitBinName } diff --git a/R/pkg/R/group.R b/R/pkg/R/group.R index 8f1c68f7c4d28..576ac72f40fc0 100644 --- a/R/pkg/R/group.R +++ b/R/pkg/R/group.R @@ -87,7 +87,7 @@ setMethod("count", setMethod("agg", signature(x = "GroupedData"), function(x, ...) { - cols = list(...) + cols <- list(...) stopifnot(length(cols) > 0) if (is.character(cols[[1]])) { cols <- varargsToEnv(...) @@ -97,7 +97,7 @@ setMethod("agg", if (!is.null(ns)) { for (n in ns) { if (n != "") { - cols[[n]] = alias(cols[[n]], n) + cols[[n]] <- alias(cols[[n]], n) } } } diff --git a/R/pkg/R/utils.R b/R/pkg/R/utils.R index ea629a64f7158..950ba74dbe017 100644 --- a/R/pkg/R/utils.R +++ b/R/pkg/R/utils.R @@ -41,8 +41,8 @@ convertJListToRList <- function(jList, flatten, logicalUpperBound = NULL, if (isInstanceOf(obj, "scala.Tuple2")) { # JavaPairRDD[Array[Byte], Array[Byte]]. - keyBytes = callJMethod(obj, "_1") - valBytes = callJMethod(obj, "_2") + keyBytes <- callJMethod(obj, "_1") + valBytes <- callJMethod(obj, "_2") res <- list(unserialize(keyBytes), unserialize(valBytes)) } else { diff --git a/R/pkg/inst/tests/test_binaryFile.R b/R/pkg/inst/tests/test_binaryFile.R index ccaea18ecab2a..f2452ed97d2ea 100644 --- a/R/pkg/inst/tests/test_binaryFile.R +++ b/R/pkg/inst/tests/test_binaryFile.R @@ -20,7 +20,7 @@ context("functions on binary files") # JavaSparkContext handle sc <- sparkR.init() -mockFile = c("Spark is pretty.", "Spark is awesome.") +mockFile <- c("Spark is pretty.", "Spark is awesome.") test_that("saveAsObjectFile()/objectFile() following textFile() works", { fileName1 <- tempfile(pattern="spark-test", fileext=".tmp") diff --git a/R/pkg/inst/tests/test_binary_function.R b/R/pkg/inst/tests/test_binary_function.R index 3be8c65a6c1a0..dca0657c57e0d 100644 --- a/R/pkg/inst/tests/test_binary_function.R +++ b/R/pkg/inst/tests/test_binary_function.R @@ -76,7 +76,7 @@ test_that("zipPartitions() on RDDs", { expect_equal(actual, list(list(1, c(1,2), c(1,2,3)), list(2, c(3,4), c(4,5,6)))) - mockFile = c("Spark is pretty.", "Spark is awesome.") + mockFile <- c("Spark is pretty.", "Spark is awesome.") fileName <- tempfile(pattern="spark-test", fileext=".tmp") writeLines(mockFile, fileName) diff --git a/R/pkg/inst/tests/test_rdd.R b/R/pkg/inst/tests/test_rdd.R index b79692873cec3..6c3aaab8c711e 100644 --- a/R/pkg/inst/tests/test_rdd.R +++ b/R/pkg/inst/tests/test_rdd.R @@ -447,7 +447,7 @@ test_that("zipRDD() on RDDs", { expect_equal(actual, list(list(0, 1000), list(1, 1001), list(2, 1002), list(3, 1003), list(4, 1004))) - mockFile = c("Spark is pretty.", "Spark is awesome.") + mockFile <- c("Spark is pretty.", "Spark is awesome.") fileName <- tempfile(pattern="spark-test", fileext=".tmp") writeLines(mockFile, fileName) @@ -483,7 +483,7 @@ test_that("cartesian() on RDDs", { actual <- collect(cartesian(rdd, emptyRdd)) expect_equal(actual, list()) - mockFile = c("Spark is pretty.", "Spark is awesome.") + mockFile <- c("Spark is pretty.", "Spark is awesome.") fileName <- tempfile(pattern="spark-test", fileext=".tmp") writeLines(mockFile, fileName) diff --git a/R/pkg/inst/tests/test_textFile.R b/R/pkg/inst/tests/test_textFile.R index 58318dfef71ab..a9cf83dbdbdb1 100644 --- a/R/pkg/inst/tests/test_textFile.R +++ b/R/pkg/inst/tests/test_textFile.R @@ -20,7 +20,7 @@ context("the textFile() function") # JavaSparkContext handle sc <- sparkR.init() -mockFile = c("Spark is pretty.", "Spark is awesome.") +mockFile <- c("Spark is pretty.", "Spark is awesome.") test_that("textFile() on a local file returns an RDD", { fileName <- tempfile(pattern="spark-test", fileext=".tmp") diff --git a/R/pkg/inst/tests/test_utils.R b/R/pkg/inst/tests/test_utils.R index aa0d2a66b9082..12df4cf4f65b7 100644 --- a/R/pkg/inst/tests/test_utils.R +++ b/R/pkg/inst/tests/test_utils.R @@ -119,7 +119,7 @@ test_that("cleanClosure on R functions", { # Test for overriding variables in base namespace (Issue: SparkR-196). nums <- as.list(1:10) rdd <- parallelize(sc, nums, 2L) - t = 4 # Override base::t in .GlobalEnv. + t <- 4 # Override base::t in .GlobalEnv. f <- function(x) { x > t } newF <- cleanClosure(f) env <- environment(newF) From f23a721c10b64ec5c6768634fc5e9e7b60ee7ca8 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 14 Jul 2015 22:52:53 -0700 Subject: [PATCH 03/27] [SPARK-8993][SQL] More comprehensive type checking in expressions. This patch makes the following changes: 1. ExpectsInputTypes only defines expected input types, but does not perform any implicit type casting. 2. ImplicitCastInputTypes is a new trait that defines both expected input types, as well as performs implicit type casting. 3. BinaryOperator has a new abstract function "inputType", which defines the expected input type for both left/right. Concrete BinaryOperator expressions no longer perform any implicit type casting. 4. For BinaryOperators, convert NullType (i.e. null literals) into some accepted type so BinaryOperators don't need to handle NullTypes. TODOs needed: fix unit tests for error reporting. I'm intentionally not changing anything in aggregate expressions because yhuai is doing a big refactoring on that right now. Author: Reynold Xin Closes #7348 from rxin/typecheck and squashes the following commits: 8fcf814 [Reynold Xin] Fixed ordering of cases. 3bb63e7 [Reynold Xin] Style fix. f45408f [Reynold Xin] Comment update. aa7790e [Reynold Xin] Moved RemoveNullTypes into ImplicitTypeCasts. 438ea07 [Reynold Xin] space d55c9e5 [Reynold Xin] Removes NullTypes. 360d124 [Reynold Xin] Fixed the rule. fb66657 [Reynold Xin] Convert NullType into some accepted type for BinaryOperators. 2e22330 [Reynold Xin] Fixed unit tests. 4932d57 [Reynold Xin] Style fix. d061691 [Reynold Xin] Rename existing ExpectsInputTypes -> ImplicitCastInputTypes. e4727cc [Reynold Xin] BinaryOperator should not be doing implicit cast. d017861 [Reynold Xin] Improve expression type checking. --- .../catalyst/analysis/FunctionRegistry.scala | 1 + .../catalyst/analysis/HiveTypeCoercion.scala | 43 ++++++---- .../expressions/ExpectsInputTypes.scala | 17 +++- .../sql/catalyst/expressions/Expression.scala | 44 +++++++++- .../sql/catalyst/expressions/ScalaUDF.scala | 2 +- .../sql/catalyst/expressions/arithmetic.scala | 84 ++++++++----------- .../sql/catalyst/expressions/bitwise.scala | 30 +++---- .../spark/sql/catalyst/expressions/math.scala | 18 ++-- .../spark/sql/catalyst/expressions/misc.scala | 8 +- .../sql/catalyst/expressions/predicates.scala | 83 ++++++++++-------- .../expressions/stringOperations.scala | 36 ++++---- .../spark/sql/catalyst/util/TypeUtils.scala | 8 -- .../spark/sql/types/AbstractDataType.scala | 35 ++++++++ .../analysis/AnalysisErrorSuite.scala | 2 +- .../ExpressionTypeCheckingSuite.scala | 6 +- .../analysis/HiveTypeCoercionSuite.scala | 56 +++++++++++++ .../spark/sql/MathExpressionsSuite.scala | 1 - 17 files changed, 309 insertions(+), 165 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index ed69c42dcb825..6b1a94e4b2ad4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.analysis +import scala.language.existentials import scala.reflect.ClassTag import scala.util.{Failure, Success, Try} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 8cb71995eb818..15da5eecc8d3c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -214,19 +214,6 @@ object HiveTypeCoercion { } Union(newLeft, newRight) - - // Also widen types for BinaryOperator. - case q: LogicalPlan => q transformExpressions { - // Skip nodes who's children have not been resolved yet. - case e if !e.childrenResolved => e - - case b @ BinaryOperator(left, right) if left.dataType != right.dataType => - findTightestCommonTypeOfTwo(left.dataType, right.dataType).map { widestType => - val newLeft = if (left.dataType == widestType) left else Cast(left, widestType) - val newRight = if (right.dataType == widestType) right else Cast(right, widestType) - b.makeCopy(Array(newLeft, newRight)) - }.getOrElse(b) // If there is no applicable conversion, leave expression unchanged. - } } } @@ -672,20 +659,44 @@ object HiveTypeCoercion { } /** - * Casts types according to the expected input types for Expressions that have the trait - * [[ExpectsInputTypes]]. + * Casts types according to the expected input types for [[Expression]]s. */ object ImplicitTypeCasts extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e - case e: ExpectsInputTypes if (e.inputTypes.nonEmpty) => + case b @ BinaryOperator(left, right) if left.dataType != right.dataType => + findTightestCommonTypeOfTwo(left.dataType, right.dataType).map { commonType => + if (b.inputType.acceptsType(commonType)) { + // If the expression accepts the tighest common type, cast to that. + val newLeft = if (left.dataType == commonType) left else Cast(left, commonType) + val newRight = if (right.dataType == commonType) right else Cast(right, commonType) + b.makeCopy(Array(newLeft, newRight)) + } else { + // Otherwise, don't do anything with the expression. + b + } + }.getOrElse(b) // If there is no applicable conversion, leave expression unchanged. + + case e: ImplicitCastInputTypes if e.inputTypes.nonEmpty => val children: Seq[Expression] = e.children.zip(e.inputTypes).map { case (in, expected) => // If we cannot do the implicit cast, just use the original input. implicitCast(in, expected).getOrElse(in) } e.withNewChildren(children) + + case e: ExpectsInputTypes if e.inputTypes.nonEmpty => + // Convert NullType into some specific target type for ExpectsInputTypes that don't do + // general implicit casting. + val children: Seq[Expression] = e.children.zip(e.inputTypes).map { case (in, expected) => + if (in.dataType == NullType && !expected.acceptsType(NullType)) { + Cast(in, expected.defaultConcreteType) + } else { + in + } + } + e.withNewChildren(children) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala index 3eb0eb195c80d..ded89e85dea79 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala @@ -19,10 +19,15 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.types.AbstractDataType - +import org.apache.spark.sql.catalyst.analysis.HiveTypeCoercion.ImplicitTypeCasts /** * An trait that gets mixin to define the expected input types of an expression. + * + * This trait is typically used by operator expressions (e.g. [[Add]], [[Subtract]]) to define + * expected input types without any implicit casting. + * + * Most function expressions (e.g. [[Substring]] should extends [[ImplicitCastInputTypes]]) instead. */ trait ExpectsInputTypes { self: Expression => @@ -40,7 +45,7 @@ trait ExpectsInputTypes { self: Expression => val mismatches = children.zip(inputTypes).zipWithIndex.collect { case ((child, expected), idx) if !expected.acceptsType(child.dataType) => s"argument ${idx + 1} is expected to be of type ${expected.simpleString}, " + - s"however, '${child.prettyString}' is of type ${child.dataType.simpleString}." + s"however, '${child.prettyString}' is of type ${child.dataType.simpleString}." } if (mismatches.isEmpty) { @@ -50,3 +55,11 @@ trait ExpectsInputTypes { self: Expression => } } } + + +/** + * A mixin for the analyzer to perform implicit type casting using [[ImplicitTypeCasts]]. + */ +trait ImplicitCastInputTypes extends ExpectsInputTypes { self: Expression => + // No other methods +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 54ec10444c4f3..3f19ac2b592b5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -24,8 +24,20 @@ import org.apache.spark.sql.catalyst.trees import org.apache.spark.sql.catalyst.trees.TreeNode import org.apache.spark.sql.types._ +//////////////////////////////////////////////////////////////////////////////////////////////////// +// This file defines the basic expression abstract classes in Catalyst, including: +// Expression: the base expression abstract class +// LeafExpression +// UnaryExpression +// BinaryExpression +// BinaryOperator +// +// For details, see their classdocs. +//////////////////////////////////////////////////////////////////////////////////////////////////// /** + * An expression in Catalyst. + * * If an expression wants to be exposed in the function registry (so users can call it with * "name(arguments...)", the concrete implementation must be a case class whose constructor * arguments are all Expressions types. @@ -335,15 +347,41 @@ abstract class BinaryExpression extends Expression with trees.BinaryNode[Express /** - * An expression that has two inputs that are expected to the be same type. If the two inputs have - * different types, the analyzer will find the tightest common type and do the proper type casting. + * A [[BinaryExpression]] that is an operator, with two properties: + * + * 1. The string representation is "x symbol y", rather than "funcName(x, y)". + * 2. Two inputs are expected to the be same type. If the two inputs have different types, + * the analyzer will find the tightest common type and do the proper type casting. */ -abstract class BinaryOperator extends BinaryExpression { +abstract class BinaryOperator extends BinaryExpression with ExpectsInputTypes { self: Product => + /** + * Expected input type from both left/right child expressions, similar to the + * [[ImplicitCastInputTypes]] trait. + */ + def inputType: AbstractDataType + def symbol: String override def toString: String = s"($left $symbol $right)" + + override def inputTypes: Seq[AbstractDataType] = Seq(inputType, inputType) + + override def checkInputDataTypes(): TypeCheckResult = { + // First call the checker for ExpectsInputTypes, and then check whether left and right have + // the same type. + super.checkInputDataTypes() match { + case TypeCheckResult.TypeCheckSuccess => + if (left.dataType != right.dataType) { + TypeCheckResult.TypeCheckFailure(s"differing types in '$prettyString' " + + s"(${left.dataType.simpleString} and ${right.dataType.simpleString}).") + } else { + TypeCheckResult.TypeCheckSuccess + } + case TypeCheckResult.TypeCheckFailure(msg) => TypeCheckResult.TypeCheckFailure(msg) + } + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala index 6fb3343bb63f2..22687acd68a97 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala @@ -29,7 +29,7 @@ case class ScalaUDF( function: AnyRef, dataType: DataType, children: Seq[Expression], - inputTypes: Seq[DataType] = Nil) extends Expression with ExpectsInputTypes { + inputTypes: Seq[DataType] = Nil) extends Expression with ImplicitCastInputTypes { override def nullable: Boolean = true diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 8476af4a5d8d6..1a55a0876f303 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -18,23 +18,19 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode} import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ -abstract class UnaryArithmetic extends UnaryExpression { - self: Product => + +case class UnaryMinus(child: Expression) extends UnaryExpression with ExpectsInputTypes { + + override def inputTypes: Seq[AbstractDataType] = Seq(NumericType) override def dataType: DataType = child.dataType -} -case class UnaryMinus(child: Expression) extends UnaryArithmetic { override def toString: String = s"-$child" - override def checkInputDataTypes(): TypeCheckResult = - TypeUtils.checkForNumericExpr(child.dataType, "operator -") - private lazy val numeric = TypeUtils.getNumeric(dataType) override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = dataType match { @@ -45,9 +41,13 @@ case class UnaryMinus(child: Expression) extends UnaryArithmetic { protected override def nullSafeEval(input: Any): Any = numeric.negate(input) } -case class UnaryPositive(child: Expression) extends UnaryArithmetic { +case class UnaryPositive(child: Expression) extends UnaryExpression with ExpectsInputTypes { override def prettyName: String = "positive" + override def inputTypes: Seq[AbstractDataType] = Seq(NumericType) + + override def dataType: DataType = child.dataType + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = defineCodeGen(ctx, ev, c => c) @@ -57,9 +57,11 @@ case class UnaryPositive(child: Expression) extends UnaryArithmetic { /** * A function that get the absolute value of the numeric value. */ -case class Abs(child: Expression) extends UnaryArithmetic { - override def checkInputDataTypes(): TypeCheckResult = - TypeUtils.checkForNumericExpr(child.dataType, "function abs") +case class Abs(child: Expression) extends UnaryExpression with ExpectsInputTypes { + + override def inputTypes: Seq[AbstractDataType] = Seq(NumericType) + + override def dataType: DataType = child.dataType private lazy val numeric = TypeUtils.getNumeric(dataType) @@ -71,18 +73,6 @@ abstract class BinaryArithmetic extends BinaryOperator { override def dataType: DataType = left.dataType - override def checkInputDataTypes(): TypeCheckResult = { - if (left.dataType != right.dataType) { - TypeCheckResult.TypeCheckFailure( - s"differing types in ${this.getClass.getSimpleName} " + - s"(${left.dataType} and ${right.dataType}).") - } else { - checkTypesInternal(dataType) - } - } - - protected def checkTypesInternal(t: DataType): TypeCheckResult - /** Name of the function for this expression on a [[Decimal]] type. */ def decimalMethod: String = sys.error("BinaryArithmetics must override either decimalMethod or genCode") @@ -104,62 +94,61 @@ private[sql] object BinaryArithmetic { } case class Add(left: Expression, right: Expression) extends BinaryArithmetic { + + override def inputType: AbstractDataType = NumericType + override def symbol: String = "+" override def decimalMethod: String = "$plus" override lazy val resolved = childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType) - protected def checkTypesInternal(t: DataType) = - TypeUtils.checkForNumericExpr(t, "operator " + symbol) - private lazy val numeric = TypeUtils.getNumeric(dataType) protected override def nullSafeEval(input1: Any, input2: Any): Any = numeric.plus(input1, input2) } case class Subtract(left: Expression, right: Expression) extends BinaryArithmetic { + + override def inputType: AbstractDataType = NumericType + override def symbol: String = "-" override def decimalMethod: String = "$minus" override lazy val resolved = childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType) - protected def checkTypesInternal(t: DataType) = - TypeUtils.checkForNumericExpr(t, "operator " + symbol) - private lazy val numeric = TypeUtils.getNumeric(dataType) protected override def nullSafeEval(input1: Any, input2: Any): Any = numeric.minus(input1, input2) } case class Multiply(left: Expression, right: Expression) extends BinaryArithmetic { + + override def inputType: AbstractDataType = NumericType + override def symbol: String = "*" override def decimalMethod: String = "$times" override lazy val resolved = childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType) - protected def checkTypesInternal(t: DataType) = - TypeUtils.checkForNumericExpr(t, "operator " + symbol) - private lazy val numeric = TypeUtils.getNumeric(dataType) protected override def nullSafeEval(input1: Any, input2: Any): Any = numeric.times(input1, input2) } case class Divide(left: Expression, right: Expression) extends BinaryArithmetic { + + override def inputType: AbstractDataType = NumericType + override def symbol: String = "/" override def decimalMethod: String = "$div" - override def nullable: Boolean = true override lazy val resolved = childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType) - protected def checkTypesInternal(t: DataType) = - TypeUtils.checkForNumericExpr(t, "operator " + symbol) - private lazy val div: (Any, Any) => Any = dataType match { case ft: FractionalType => ft.fractional.asInstanceOf[Fractional[Any]].div case it: IntegralType => it.integral.asInstanceOf[Integral[Any]].quot @@ -215,17 +204,16 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic } case class Remainder(left: Expression, right: Expression) extends BinaryArithmetic { + + override def inputType: AbstractDataType = NumericType + override def symbol: String = "%" override def decimalMethod: String = "remainder" - override def nullable: Boolean = true override lazy val resolved = childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType) - protected def checkTypesInternal(t: DataType) = - TypeUtils.checkForNumericExpr(t, "operator " + symbol) - private lazy val integral = dataType match { case i: IntegralType => i.integral.asInstanceOf[Integral[Any]] case i: FractionalType => i.asIntegral.asInstanceOf[Integral[Any]] @@ -281,10 +269,11 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet } case class MaxOf(left: Expression, right: Expression) extends BinaryArithmetic { - override def nullable: Boolean = left.nullable && right.nullable + // TODO: Remove MaxOf and MinOf, and replace its usage with Greatest and Least. - protected def checkTypesInternal(t: DataType) = - TypeUtils.checkForOrderingExpr(t, "function maxOf") + override def inputType: AbstractDataType = TypeCollection.Ordered + + override def nullable: Boolean = left.nullable && right.nullable private lazy val ordering = TypeUtils.getOrdering(dataType) @@ -335,10 +324,11 @@ case class MaxOf(left: Expression, right: Expression) extends BinaryArithmetic { } case class MinOf(left: Expression, right: Expression) extends BinaryArithmetic { - override def nullable: Boolean = left.nullable && right.nullable + // TODO: Remove MaxOf and MinOf, and replace its usage with Greatest and Least. - protected def checkTypesInternal(t: DataType) = - TypeUtils.checkForOrderingExpr(t, "function minOf") + override def inputType: AbstractDataType = TypeCollection.Ordered + + override def nullable: Boolean = left.nullable && right.nullable private lazy val ordering = TypeUtils.getOrdering(dataType) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwise.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwise.scala index 2d47124d247e7..af1abbcd2239b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwise.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwise.scala @@ -17,9 +17,7 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ -import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ @@ -29,10 +27,10 @@ import org.apache.spark.sql.types._ * Code generation inherited from BinaryArithmetic. */ case class BitwiseAnd(left: Expression, right: Expression) extends BinaryArithmetic { - override def symbol: String = "&" - protected def checkTypesInternal(t: DataType) = - TypeUtils.checkForBitwiseExpr(t, "operator " + symbol) + override def inputType: AbstractDataType = TypeCollection.Bitwise + + override def symbol: String = "&" private lazy val and: (Any, Any) => Any = dataType match { case ByteType => @@ -54,10 +52,10 @@ case class BitwiseAnd(left: Expression, right: Expression) extends BinaryArithme * Code generation inherited from BinaryArithmetic. */ case class BitwiseOr(left: Expression, right: Expression) extends BinaryArithmetic { - override def symbol: String = "|" - protected def checkTypesInternal(t: DataType) = - TypeUtils.checkForBitwiseExpr(t, "operator " + symbol) + override def inputType: AbstractDataType = TypeCollection.Bitwise + + override def symbol: String = "|" private lazy val or: (Any, Any) => Any = dataType match { case ByteType => @@ -79,10 +77,10 @@ case class BitwiseOr(left: Expression, right: Expression) extends BinaryArithmet * Code generation inherited from BinaryArithmetic. */ case class BitwiseXor(left: Expression, right: Expression) extends BinaryArithmetic { - override def symbol: String = "^" - protected def checkTypesInternal(t: DataType) = - TypeUtils.checkForBitwiseExpr(t, "operator " + symbol) + override def inputType: AbstractDataType = TypeCollection.Bitwise + + override def symbol: String = "^" private lazy val xor: (Any, Any) => Any = dataType match { case ByteType => @@ -101,11 +99,13 @@ case class BitwiseXor(left: Expression, right: Expression) extends BinaryArithme /** * A function that calculates bitwise not(~) of a number. */ -case class BitwiseNot(child: Expression) extends UnaryArithmetic { - override def toString: String = s"~$child" +case class BitwiseNot(child: Expression) extends UnaryExpression with ExpectsInputTypes { - override def checkInputDataTypes(): TypeCheckResult = - TypeUtils.checkForBitwiseExpr(child.dataType, "operator ~") + override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection.Bitwise) + + override def dataType: DataType = child.dataType + + override def toString: String = s"~$child" private lazy val not: (Any) => Any = dataType match { case ByteType => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala index c31890e27fb54..4b7fe05dd4980 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala @@ -55,7 +55,7 @@ abstract class LeafMathExpression(c: Double, name: String) * @param name The short name of the function */ abstract class UnaryMathExpression(f: Double => Double, name: String) - extends UnaryExpression with Serializable with ExpectsInputTypes { self: Product => + extends UnaryExpression with Serializable with ImplicitCastInputTypes { self: Product => override def inputTypes: Seq[DataType] = Seq(DoubleType) override def dataType: DataType = DoubleType @@ -89,7 +89,7 @@ abstract class UnaryMathExpression(f: Double => Double, name: String) * @param name The short name of the function */ abstract class BinaryMathExpression(f: (Double, Double) => Double, name: String) - extends BinaryExpression with Serializable with ExpectsInputTypes { self: Product => + extends BinaryExpression with Serializable with ImplicitCastInputTypes { self: Product => override def inputTypes: Seq[DataType] = Seq(DoubleType, DoubleType) @@ -174,7 +174,7 @@ object Factorial { ) } -case class Factorial(child: Expression) extends UnaryExpression with ExpectsInputTypes { +case class Factorial(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { override def inputTypes: Seq[DataType] = Seq(IntegerType) @@ -251,7 +251,7 @@ case class ToRadians(child: Expression) extends UnaryMathExpression(math.toRadia } case class Bin(child: Expression) - extends UnaryExpression with Serializable with ExpectsInputTypes { + extends UnaryExpression with Serializable with ImplicitCastInputTypes { override def inputTypes: Seq[DataType] = Seq(LongType) override def dataType: DataType = StringType @@ -285,7 +285,7 @@ object Hex { * Otherwise if the number is a STRING, it converts each character into its hex representation * and returns the resulting STRING. Negative numbers would be treated as two's complement. */ -case class Hex(child: Expression) extends UnaryExpression with ExpectsInputTypes { +case class Hex(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { // TODO: Create code-gen version. override def inputTypes: Seq[AbstractDataType] = @@ -329,7 +329,7 @@ case class Hex(child: Expression) extends UnaryExpression with ExpectsInputTypes * Performs the inverse operation of HEX. * Resulting characters are returned as a byte array. */ -case class Unhex(child: Expression) extends UnaryExpression with ExpectsInputTypes { +case class Unhex(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { // TODO: Create code-gen version. override def inputTypes: Seq[AbstractDataType] = Seq(StringType) @@ -416,7 +416,7 @@ case class Pow(left: Expression, right: Expression) * @param right number of bits to left shift. */ case class ShiftLeft(left: Expression, right: Expression) - extends BinaryExpression with ExpectsInputTypes { + extends BinaryExpression with ImplicitCastInputTypes { override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(IntegerType, LongType), IntegerType) @@ -442,7 +442,7 @@ case class ShiftLeft(left: Expression, right: Expression) * @param right number of bits to left shift. */ case class ShiftRight(left: Expression, right: Expression) - extends BinaryExpression with ExpectsInputTypes { + extends BinaryExpression with ImplicitCastInputTypes { override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(IntegerType, LongType), IntegerType) @@ -468,7 +468,7 @@ case class ShiftRight(left: Expression, right: Expression) * @param right the number of bits to right shift. */ case class ShiftRightUnsigned(left: Expression, right: Expression) - extends BinaryExpression with ExpectsInputTypes { + extends BinaryExpression with ImplicitCastInputTypes { override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(IntegerType, LongType), IntegerType) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index 3b59cd431b871..a269ec4a1e6dc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -31,7 +31,7 @@ import org.apache.spark.unsafe.types.UTF8String * A function that calculates an MD5 128-bit checksum and returns it as a hex string * For input of type [[BinaryType]] */ -case class Md5(child: Expression) extends UnaryExpression with ExpectsInputTypes { +case class Md5(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { override def dataType: DataType = StringType @@ -55,7 +55,7 @@ case class Md5(child: Expression) extends UnaryExpression with ExpectsInputTypes * the hash length is not one of the permitted values, the return value is NULL. */ case class Sha2(left: Expression, right: Expression) - extends BinaryExpression with Serializable with ExpectsInputTypes { + extends BinaryExpression with Serializable with ImplicitCastInputTypes { override def dataType: DataType = StringType @@ -118,7 +118,7 @@ case class Sha2(left: Expression, right: Expression) * A function that calculates a sha1 hash value and returns it as a hex string * For input of type [[BinaryType]] or [[StringType]] */ -case class Sha1(child: Expression) extends UnaryExpression with ExpectsInputTypes { +case class Sha1(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { override def dataType: DataType = StringType @@ -138,7 +138,7 @@ case class Sha1(child: Expression) extends UnaryExpression with ExpectsInputType * A function that computes a cyclic redundancy check value and returns it as a bigint * For input of type [[BinaryType]] */ -case class Crc32(child: Expression) extends UnaryExpression with ExpectsInputTypes { +case class Crc32(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { override def dataType: DataType = LongType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index f74fd04619714..aa6c30e2f79f2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -33,12 +33,17 @@ object InterpretedPredicate { } } + +/** + * An [[Expression]] that returns a boolean value. + */ trait Predicate extends Expression { self: Product => override def dataType: DataType = BooleanType } + trait PredicateHelper { protected def splitConjunctivePredicates(condition: Expression): Seq[Expression] = { condition match { @@ -70,7 +75,10 @@ trait PredicateHelper { expr.references.subsetOf(plan.outputSet) } -case class Not(child: Expression) extends UnaryExpression with Predicate with ExpectsInputTypes { + +case class Not(child: Expression) + extends UnaryExpression with Predicate with ImplicitCastInputTypes { + override def toString: String = s"NOT $child" override def inputTypes: Seq[DataType] = Seq(BooleanType) @@ -82,6 +90,7 @@ case class Not(child: Expression) extends UnaryExpression with Predicate with Ex } } + /** * Evaluates to `true` if `list` contains `value`. */ @@ -97,6 +106,7 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate { } } + /** * Optimized version of In clause, when all filter values of In clause are * static. @@ -112,12 +122,12 @@ case class InSet(child: Expression, hset: Set[Any]) } } -case class And(left: Expression, right: Expression) - extends BinaryExpression with Predicate with ExpectsInputTypes { - override def toString: String = s"($left && $right)" +case class And(left: Expression, right: Expression) extends BinaryOperator with Predicate { + + override def inputType: AbstractDataType = BooleanType - override def inputTypes: Seq[DataType] = Seq(BooleanType, BooleanType) + override def symbol: String = "&&" override def eval(input: InternalRow): Any = { val input1 = left.eval(input) @@ -161,12 +171,12 @@ case class And(left: Expression, right: Expression) } } -case class Or(left: Expression, right: Expression) - extends BinaryExpression with Predicate with ExpectsInputTypes { - override def toString: String = s"($left || $right)" +case class Or(left: Expression, right: Expression) extends BinaryOperator with Predicate { - override def inputTypes: Seq[DataType] = Seq(BooleanType, BooleanType) + override def inputType: AbstractDataType = BooleanType + + override def symbol: String = "||" override def eval(input: InternalRow): Any = { val input1 = left.eval(input) @@ -210,21 +220,10 @@ case class Or(left: Expression, right: Expression) } } + abstract class BinaryComparison extends BinaryOperator with Predicate { self: Product => - override def checkInputDataTypes(): TypeCheckResult = { - if (left.dataType != right.dataType) { - TypeCheckResult.TypeCheckFailure( - s"differing types in ${this.getClass.getSimpleName} " + - s"(${left.dataType} and ${right.dataType}).") - } else { - checkTypesInternal(dataType) - } - } - - protected def checkTypesInternal(t: DataType): TypeCheckResult - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { if (ctx.isPrimitiveType(left.dataType)) { // faster version @@ -235,10 +234,12 @@ abstract class BinaryComparison extends BinaryOperator with Predicate { } } + private[sql] object BinaryComparison { def unapply(e: BinaryComparison): Option[(Expression, Expression)] = Some((e.left, e.right)) } + /** An extractor that matches both standard 3VL equality and null-safe equality. */ private[sql] object Equality { def unapply(e: BinaryComparison): Option[(Expression, Expression)] = e match { @@ -248,10 +249,12 @@ private[sql] object Equality { } } + case class EqualTo(left: Expression, right: Expression) extends BinaryComparison { - override def symbol: String = "=" - override protected def checkTypesInternal(t: DataType) = TypeCheckResult.TypeCheckSuccess + override def inputType: AbstractDataType = AnyDataType + + override def symbol: String = "=" protected override def nullSafeEval(input1: Any, input2: Any): Any = { if (left.dataType != BinaryType) input1 == input2 @@ -263,13 +266,15 @@ case class EqualTo(left: Expression, right: Expression) extends BinaryComparison } } + case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComparison { + + override def inputType: AbstractDataType = AnyDataType + override def symbol: String = "<=>" override def nullable: Boolean = false - override protected def checkTypesInternal(t: DataType) = TypeCheckResult.TypeCheckSuccess - override def eval(input: InternalRow): Any = { val input1 = left.eval(input) val input2 = right.eval(input) @@ -298,44 +303,48 @@ case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComp } } + case class LessThan(left: Expression, right: Expression) extends BinaryComparison { - override def symbol: String = "<" - override protected def checkTypesInternal(t: DataType) = - TypeUtils.checkForOrderingExpr(left.dataType, "operator " + symbol) + override def inputType: AbstractDataType = TypeCollection.Ordered + + override def symbol: String = "<" private lazy val ordering = TypeUtils.getOrdering(left.dataType) protected override def nullSafeEval(input1: Any, input2: Any): Any = ordering.lt(input1, input2) } + case class LessThanOrEqual(left: Expression, right: Expression) extends BinaryComparison { - override def symbol: String = "<=" - override protected def checkTypesInternal(t: DataType) = - TypeUtils.checkForOrderingExpr(left.dataType, "operator " + symbol) + override def inputType: AbstractDataType = TypeCollection.Ordered + + override def symbol: String = "<=" private lazy val ordering = TypeUtils.getOrdering(left.dataType) protected override def nullSafeEval(input1: Any, input2: Any): Any = ordering.lteq(input1, input2) } + case class GreaterThan(left: Expression, right: Expression) extends BinaryComparison { - override def symbol: String = ">" - override protected def checkTypesInternal(t: DataType) = - TypeUtils.checkForOrderingExpr(left.dataType, "operator " + symbol) + override def inputType: AbstractDataType = TypeCollection.Ordered + + override def symbol: String = ">" private lazy val ordering = TypeUtils.getOrdering(left.dataType) protected override def nullSafeEval(input1: Any, input2: Any): Any = ordering.gt(input1, input2) } + case class GreaterThanOrEqual(left: Expression, right: Expression) extends BinaryComparison { - override def symbol: String = ">=" - override protected def checkTypesInternal(t: DataType) = - TypeUtils.checkForOrderingExpr(left.dataType, "operator " + symbol) + override def inputType: AbstractDataType = TypeCollection.Ordered + + override def symbol: String = ">=" private lazy val ordering = TypeUtils.getOrdering(left.dataType) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index f64899c1ed84c..03b55ce5fe7cc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String -trait StringRegexExpression extends ExpectsInputTypes { +trait StringRegexExpression extends ImplicitCastInputTypes { self: BinaryExpression => def escape(v: String): String @@ -105,7 +105,7 @@ case class RLike(left: Expression, right: Expression) override def toString: String = s"$left RLIKE $right" } -trait String2StringExpression extends ExpectsInputTypes { +trait String2StringExpression extends ImplicitCastInputTypes { self: UnaryExpression => def convert(v: UTF8String): UTF8String @@ -142,7 +142,7 @@ case class Lower(child: Expression) extends UnaryExpression with String2StringEx } /** A base trait for functions that compare two strings, returning a boolean. */ -trait StringComparison extends ExpectsInputTypes { +trait StringComparison extends ImplicitCastInputTypes { self: BinaryExpression => def compare(l: UTF8String, r: UTF8String): Boolean @@ -241,7 +241,7 @@ case class StringTrimRight(child: Expression) * NOTE: that this is not zero based, but 1-based index. The first character in str has index 1. */ case class StringInstr(str: Expression, substr: Expression) - extends BinaryExpression with ExpectsInputTypes { + extends BinaryExpression with ImplicitCastInputTypes { override def left: Expression = str override def right: Expression = substr @@ -265,7 +265,7 @@ case class StringInstr(str: Expression, substr: Expression) * in given string after position pos. */ case class StringLocate(substr: Expression, str: Expression, start: Expression) - extends Expression with ExpectsInputTypes { + extends Expression with ImplicitCastInputTypes { def this(substr: Expression, str: Expression) = { this(substr, str, Literal(0)) @@ -306,7 +306,7 @@ case class StringLocate(substr: Expression, str: Expression, start: Expression) * Returns str, left-padded with pad to a length of len. */ case class StringLPad(str: Expression, len: Expression, pad: Expression) - extends Expression with ExpectsInputTypes { + extends Expression with ImplicitCastInputTypes { override def children: Seq[Expression] = str :: len :: pad :: Nil override def foldable: Boolean = children.forall(_.foldable) @@ -344,7 +344,7 @@ case class StringLPad(str: Expression, len: Expression, pad: Expression) * Returns str, right-padded with pad to a length of len. */ case class StringRPad(str: Expression, len: Expression, pad: Expression) - extends Expression with ExpectsInputTypes { + extends Expression with ImplicitCastInputTypes { override def children: Seq[Expression] = str :: len :: pad :: Nil override def foldable: Boolean = children.forall(_.foldable) @@ -413,7 +413,7 @@ case class StringFormat(children: Expression*) extends Expression { * Returns the string which repeat the given string value n times. */ case class StringRepeat(str: Expression, times: Expression) - extends BinaryExpression with ExpectsInputTypes { + extends BinaryExpression with ImplicitCastInputTypes { override def left: Expression = str override def right: Expression = times @@ -447,7 +447,7 @@ case class StringReverse(child: Expression) extends UnaryExpression with String2 /** * Returns a n spaces string. */ -case class StringSpace(child: Expression) extends UnaryExpression with ExpectsInputTypes { +case class StringSpace(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { override def dataType: DataType = StringType override def inputTypes: Seq[DataType] = Seq(IntegerType) @@ -467,7 +467,7 @@ case class StringSpace(child: Expression) extends UnaryExpression with ExpectsIn * Splits str around pat (pattern is a regular expression). */ case class StringSplit(str: Expression, pattern: Expression) - extends BinaryExpression with ExpectsInputTypes { + extends BinaryExpression with ImplicitCastInputTypes { override def left: Expression = str override def right: Expression = pattern @@ -488,7 +488,7 @@ case class StringSplit(str: Expression, pattern: Expression) * Defined for String and Binary types. */ case class Substring(str: Expression, pos: Expression, len: Expression) - extends Expression with ExpectsInputTypes { + extends Expression with ImplicitCastInputTypes { def this(str: Expression, pos: Expression) = { this(str, pos, Literal(Integer.MAX_VALUE)) @@ -555,7 +555,7 @@ case class Substring(str: Expression, pos: Expression, len: Expression) /** * A function that return the length of the given string expression. */ -case class StringLength(child: Expression) extends UnaryExpression with ExpectsInputTypes { +case class StringLength(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { override def dataType: DataType = IntegerType override def inputTypes: Seq[DataType] = Seq(StringType) @@ -573,7 +573,7 @@ case class StringLength(child: Expression) extends UnaryExpression with ExpectsI * A function that return the Levenshtein distance between the two given strings. */ case class Levenshtein(left: Expression, right: Expression) extends BinaryExpression - with ExpectsInputTypes { + with ImplicitCastInputTypes { override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType) @@ -591,7 +591,7 @@ case class Levenshtein(left: Expression, right: Expression) extends BinaryExpres /** * Returns the numeric value of the first character of str. */ -case class Ascii(child: Expression) extends UnaryExpression with ExpectsInputTypes { +case class Ascii(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { override def dataType: DataType = IntegerType override def inputTypes: Seq[DataType] = Seq(StringType) @@ -608,7 +608,7 @@ case class Ascii(child: Expression) extends UnaryExpression with ExpectsInputTyp /** * Converts the argument from binary to a base 64 string. */ -case class Base64(child: Expression) extends UnaryExpression with ExpectsInputTypes { +case class Base64(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { override def dataType: DataType = StringType override def inputTypes: Seq[DataType] = Seq(BinaryType) @@ -622,7 +622,7 @@ case class Base64(child: Expression) extends UnaryExpression with ExpectsInputTy /** * Converts the argument from a base 64 string to BINARY. */ -case class UnBase64(child: Expression) extends UnaryExpression with ExpectsInputTypes { +case class UnBase64(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { override def dataType: DataType = BinaryType override def inputTypes: Seq[DataType] = Seq(StringType) @@ -636,7 +636,7 @@ case class UnBase64(child: Expression) extends UnaryExpression with ExpectsInput * If either argument is null, the result will also be null. */ case class Decode(bin: Expression, charset: Expression) - extends BinaryExpression with ExpectsInputTypes { + extends BinaryExpression with ImplicitCastInputTypes { override def left: Expression = bin override def right: Expression = charset @@ -655,7 +655,7 @@ case class Decode(bin: Expression, charset: Expression) * If either argument is null, the result will also be null. */ case class Encode(value: Expression, charset: Expression) - extends BinaryExpression with ExpectsInputTypes { + extends BinaryExpression with ImplicitCastInputTypes { override def left: Expression = value override def right: Expression = charset diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala index 3148309a2166f..0103ddcf9cfb7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala @@ -32,14 +32,6 @@ object TypeUtils { } } - def checkForBitwiseExpr(t: DataType, caller: String): TypeCheckResult = { - if (t.isInstanceOf[IntegralType] || t == NullType) { - TypeCheckResult.TypeCheckSuccess - } else { - TypeCheckResult.TypeCheckFailure(s"$caller accepts integral types, not $t") - } - } - def checkForOrderingExpr(t: DataType, caller: String): TypeCheckResult = { if (t.isInstanceOf[AtomicType] || t == NullType) { TypeCheckResult.TypeCheckSuccess diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala index 32f87440b4e37..f5715f7a829ff 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala @@ -96,6 +96,24 @@ private[sql] class TypeCollection(private val types: Seq[AbstractDataType]) private[sql] object TypeCollection { + /** + * Types that can be ordered/compared. In the long run we should probably make this a trait + * that can be mixed into each data type, and perhaps create an [[AbstractDataType]]. + */ + val Ordered = TypeCollection( + BooleanType, + ByteType, ShortType, IntegerType, LongType, + FloatType, DoubleType, DecimalType, + TimestampType, DateType, + StringType, BinaryType) + + /** + * Types that can be used in bitwise operations. + */ + val Bitwise = TypeCollection( + BooleanType, + ByteType, ShortType, IntegerType, LongType) + def apply(types: AbstractDataType*): TypeCollection = new TypeCollection(types) def unapply(typ: AbstractDataType): Option[Seq[AbstractDataType]] = typ match { @@ -105,6 +123,23 @@ private[sql] object TypeCollection { } +/** + * An [[AbstractDataType]] that matches any concrete data types. + */ +protected[sql] object AnyDataType extends AbstractDataType { + + // Note that since AnyDataType matches any concrete types, defaultConcreteType should never + // be invoked. + override private[sql] def defaultConcreteType: DataType = throw new UnsupportedOperationException + + override private[sql] def simpleString: String = "any" + + override private[sql] def isSameType(other: DataType): Boolean = false + + override private[sql] def acceptsType(other: DataType): Boolean = true +} + + /** * An internal type used to represent everything that is not null, UDTs, arrays, structs, and maps. */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index 9d0c69a2451d1..f0f17103991ef 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.dsl.plans._ case class TestFunction( children: Seq[Expression], - inputTypes: Seq[AbstractDataType]) extends Expression with ExpectsInputTypes { + inputTypes: Seq[AbstractDataType]) extends Expression with ImplicitCastInputTypes { override def nullable: Boolean = true override def eval(input: InternalRow): Any = throw new UnsupportedOperationException override def dataType: DataType = StringType diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala index 8e0551b23eea6..5958acbe009ca 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala @@ -49,7 +49,7 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { def assertErrorForDifferingTypes(expr: Expression): Unit = { assertError(expr, - s"differing types in ${expr.getClass.getSimpleName} (IntegerType and BooleanType).") + s"differing types in '${expr.prettyString}' (int and boolean)") } test("check types for unary arithmetic") { @@ -58,7 +58,7 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { assertError(BitwiseNot('stringField), "operator ~ accepts integral type") } - test("check types for binary arithmetic") { + ignore("check types for binary arithmetic") { // We will cast String to Double for binary arithmetic assertSuccess(Add('intField, 'stringField)) assertSuccess(Subtract('intField, 'stringField)) @@ -92,7 +92,7 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { assertError(MinOf('complexField, 'complexField), "function minOf accepts non-complex type") } - test("check types for predicates") { + ignore("check types for predicates") { // We will cast String to Double for binary comparison assertSuccess(EqualTo('intField, 'stringField)) assertSuccess(EqualNullSafe('intField, 'stringField)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala index acb9a433de903..8e9b20a3ebe42 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala @@ -194,6 +194,32 @@ class HiveTypeCoercionSuite extends PlanTest { Project(Seq(Alias(transformed, "a")()), testRelation)) } + test("cast NullType for expresions that implement ExpectsInputTypes") { + import HiveTypeCoercionSuite._ + + ruleTest(HiveTypeCoercion.ImplicitTypeCasts, + AnyTypeUnaryExpression(Literal.create(null, NullType)), + AnyTypeUnaryExpression(Literal.create(null, NullType))) + + ruleTest(HiveTypeCoercion.ImplicitTypeCasts, + NumericTypeUnaryExpression(Literal.create(null, NullType)), + NumericTypeUnaryExpression(Cast(Literal.create(null, NullType), DoubleType))) + } + + test("cast NullType for binary operators") { + import HiveTypeCoercionSuite._ + + ruleTest(HiveTypeCoercion.ImplicitTypeCasts, + AnyTypeBinaryOperator(Literal.create(null, NullType), Literal.create(null, NullType)), + AnyTypeBinaryOperator(Literal.create(null, NullType), Literal.create(null, NullType))) + + ruleTest(HiveTypeCoercion.ImplicitTypeCasts, + NumericTypeBinaryOperator(Literal.create(null, NullType), Literal.create(null, NullType)), + NumericTypeBinaryOperator( + Cast(Literal.create(null, NullType), DoubleType), + Cast(Literal.create(null, NullType), DoubleType))) + } + test("coalesce casts") { ruleTest(HiveTypeCoercion.FunctionArgumentConversion, Coalesce(Literal(1.0) @@ -302,3 +328,33 @@ class HiveTypeCoercionSuite extends PlanTest { ) } } + + +object HiveTypeCoercionSuite { + + case class AnyTypeUnaryExpression(child: Expression) + extends UnaryExpression with ExpectsInputTypes { + override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType) + override def dataType: DataType = NullType + } + + case class NumericTypeUnaryExpression(child: Expression) + extends UnaryExpression with ExpectsInputTypes { + override def inputTypes: Seq[AbstractDataType] = Seq(NumericType) + override def dataType: DataType = NullType + } + + case class AnyTypeBinaryOperator(left: Expression, right: Expression) + extends BinaryOperator with ExpectsInputTypes { + override def dataType: DataType = NullType + override def inputType: AbstractDataType = AnyDataType + override def symbol: String = "anytype" + } + + case class NumericTypeBinaryOperator(left: Expression, right: Expression) + extends BinaryOperator with ExpectsInputTypes { + override def dataType: DataType = NullType + override def inputType: AbstractDataType = NumericType + override def symbol: String = "numerictype" + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala index 24bef21b999ea..b30b9f12258b9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala @@ -375,6 +375,5 @@ class MathExpressionsSuite extends QueryTest { val df = Seq((1, -1, "abc")).toDF("a", "b", "c") checkAnswer(df.selectExpr("positive(a)"), Row(1)) checkAnswer(df.selectExpr("positive(b)"), Row(-1)) - checkAnswer(df.selectExpr("positive(c)"), Row("abc")) } } From c6b1a9e74e34267dc198e57a184c41498ca9d6a3 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Tue, 14 Jul 2015 22:57:39 -0700 Subject: [PATCH 04/27] Revert SPARK-6910 and SPARK-9027 Revert #7216 and #7386. These patch seems to be causing quite a few test failures: ``` Caused by: java.lang.reflect.InvocationTargetException at sun.reflect.GeneratedMethodAccessor322.invoke(Unknown Source) at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43) at java.lang.reflect.Method.invoke(Method.java:606) at org.apache.spark.sql.hive.client.Shim_v0_13.getPartitionsByFilter(HiveShim.scala:351) at org.apache.spark.sql.hive.client.ClientWrapper$$anonfun$getPartitionsByFilter$1.apply(ClientWrapper.scala:320) at org.apache.spark.sql.hive.client.ClientWrapper$$anonfun$getPartitionsByFilter$1.apply(ClientWrapper.scala:318) at org.apache.spark.sql.hive.client.ClientWrapper$$anonfun$withHiveState$1.apply(ClientWrapper.scala:180) at org.apache.spark.sql.hive.client.ClientWrapper.retryLocked(ClientWrapper.scala:135) at org.apache.spark.sql.hive.client.ClientWrapper.withHiveState(ClientWrapper.scala:172) at org.apache.spark.sql.hive.client.ClientWrapper.getPartitionsByFilter(ClientWrapper.scala:318) at org.apache.spark.sql.hive.client.HiveTable.getPartitions(ClientInterface.scala:78) at org.apache.spark.sql.hive.MetastoreRelation.getHiveQlPartitions(HiveMetastoreCatalog.scala:670) at org.apache.spark.sql.hive.execution.HiveTableScan.doExecute(HiveTableScan.scala:137) at org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$1.apply(SparkPlan.scala:90) at org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$1.apply(SparkPlan.scala:90) at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:147) at org.apache.spark.sql.execution.SparkPlan.execute(SparkPlan.scala:89) at org.apache.spark.sql.execution.Exchange$$anonfun$doExecute$1.apply(Exchange.scala:164) at org.apache.spark.sql.execution.Exchange$$anonfun$doExecute$1.apply(Exchange.scala:151) at org.apache.spark.sql.catalyst.errors.package$.attachTree(package.scala:48) ... 85 more Caused by: MetaException(message:Filtering is supported only on partition keys of type string) at org.apache.hadoop.hive.metastore.parser.ExpressionTree$FilterBuilder.setError(ExpressionTree.java:185) at org.apache.hadoop.hive.metastore.parser.ExpressionTree$LeafNode.getJdoFilterPushdownParam(ExpressionTree.java:452) at org.apache.hadoop.hive.metastore.parser.ExpressionTree$LeafNode.generateJDOFilterOverPartitions(ExpressionTree.java:357) at org.apache.hadoop.hive.metastore.parser.ExpressionTree$LeafNode.generateJDOFilter(ExpressionTree.java:279) at org.apache.hadoop.hive.metastore.parser.ExpressionTree$TreeNode.generateJDOFilter(ExpressionTree.java:243) at org.apache.hadoop.hive.metastore.parser.ExpressionTree.generateJDOFilterFragment(ExpressionTree.java:590) at org.apache.hadoop.hive.metastore.ObjectStore.makeQueryFilterString(ObjectStore.java:2417) at org.apache.hadoop.hive.metastore.ObjectStore.getPartitionsViaOrmFilter(ObjectStore.java:2029) at org.apache.hadoop.hive.metastore.ObjectStore.access$500(ObjectStore.java:146) at org.apache.hadoop.hive.metastore.ObjectStore$4.getJdoResult(ObjectStore.java:2332) ``` https://amplab.cs.berkeley.edu/jenkins/view/Spark-QA-Test/job/Spark-Master-Maven-with-YARN/2945/HADOOP_PROFILE=hadoop-2.4,label=centos/testReport/junit/org.apache.spark.sql.hive.execution/SortMergeCompatibilitySuite/auto_sortmerge_join_16/ Author: Michael Armbrust Closes #7409 from marmbrus/revertMetastorePushdown and squashes the following commits: 92fabd3 [Michael Armbrust] Revert SPARK-6910 and SPARK-9027 5d3bdf2 [Michael Armbrust] Revert "[SPARK-9027] [SQL] Generalize metastore predicate pushdown" --- .../spark/sql/hive/HiveMetastoreCatalog.scala | 58 +++++++------- .../org/apache/spark/sql/hive/HiveShim.scala | 1 - .../spark/sql/hive/HiveStrategies.scala | 4 +- .../sql/hive/client/ClientInterface.scala | 11 +-- .../spark/sql/hive/client/ClientWrapper.scala | 21 +++-- .../spark/sql/hive/client/HiveShim.scala | 72 +---------------- .../sql/hive/execution/HiveTableScan.scala | 7 +- .../spark/sql/hive/client/FiltersSuite.scala | 78 ------------------- .../spark/sql/hive/client/VersionsSuite.scala | 8 -- .../sql/hive/execution/PruningSuite.scala | 2 +- 10 files changed, 44 insertions(+), 218 deletions(-) delete mode 100644 sql/hive/src/test/scala/org/apache/spark/sql/hive/client/FiltersSuite.scala diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 5bdf68c83fca7..4b7a782c805a0 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -301,9 +301,7 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive val result = if (metastoreRelation.hiveQlTable.isPartitioned) { val partitionSchema = StructType.fromAttributes(metastoreRelation.partitionKeys) val partitionColumnDataTypes = partitionSchema.map(_.dataType) - // We're converting the entire table into ParquetRelation, so predicates to Hive metastore - // are empty. - val partitions = metastoreRelation.getHiveQlPartitions().map { p => + val partitions = metastoreRelation.hiveQlPartitions.map { p => val location = p.getLocation val values = InternalRow.fromSeq(p.getValues.zip(partitionColumnDataTypes).map { case (rawValue, dataType) => Cast(Literal(rawValue), dataType).eval(null) @@ -646,6 +644,32 @@ private[hive] case class MetastoreRelation new Table(tTable) } + @transient val hiveQlPartitions: Seq[Partition] = table.getAllPartitions.map { p => + val tPartition = new org.apache.hadoop.hive.metastore.api.Partition + tPartition.setDbName(databaseName) + tPartition.setTableName(tableName) + tPartition.setValues(p.values) + + val sd = new org.apache.hadoop.hive.metastore.api.StorageDescriptor() + tPartition.setSd(sd) + sd.setCols(table.schema.map(c => new FieldSchema(c.name, c.hiveType, c.comment))) + + sd.setLocation(p.storage.location) + sd.setInputFormat(p.storage.inputFormat) + sd.setOutputFormat(p.storage.outputFormat) + + val serdeInfo = new org.apache.hadoop.hive.metastore.api.SerDeInfo + sd.setSerdeInfo(serdeInfo) + serdeInfo.setSerializationLib(p.storage.serde) + + val serdeParameters = new java.util.HashMap[String, String]() + serdeInfo.setParameters(serdeParameters) + table.serdeProperties.foreach { case (k, v) => serdeParameters.put(k, v) } + p.storage.serdeProperties.foreach { case (k, v) => serdeParameters.put(k, v) } + + new Partition(hiveQlTable, tPartition) + } + @transient override lazy val statistics: Statistics = Statistics( sizeInBytes = { val totalSize = hiveQlTable.getParameters.get(StatsSetupConst.TOTAL_SIZE) @@ -666,34 +690,6 @@ private[hive] case class MetastoreRelation } ) - def getHiveQlPartitions(predicates: Seq[Expression] = Nil): Seq[Partition] = { - table.getPartitions(predicates).map { p => - val tPartition = new org.apache.hadoop.hive.metastore.api.Partition - tPartition.setDbName(databaseName) - tPartition.setTableName(tableName) - tPartition.setValues(p.values) - - val sd = new org.apache.hadoop.hive.metastore.api.StorageDescriptor() - tPartition.setSd(sd) - sd.setCols(table.schema.map(c => new FieldSchema(c.name, c.hiveType, c.comment))) - - sd.setLocation(p.storage.location) - sd.setInputFormat(p.storage.inputFormat) - sd.setOutputFormat(p.storage.outputFormat) - - val serdeInfo = new org.apache.hadoop.hive.metastore.api.SerDeInfo - sd.setSerdeInfo(serdeInfo) - serdeInfo.setSerializationLib(p.storage.serde) - - val serdeParameters = new java.util.HashMap[String, String]() - serdeInfo.setParameters(serdeParameters) - table.serdeProperties.foreach { case (k, v) => serdeParameters.put(k, v) } - p.storage.serdeProperties.foreach { case (k, v) => serdeParameters.put(k, v) } - - new Partition(hiveQlTable, tPartition) - } - } - /** Only compare database and tablename, not alias. */ override def sameResult(plan: LogicalPlan): Boolean = { plan match { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala index a357bb39ca7fd..d08c594151654 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala @@ -27,7 +27,6 @@ import scala.reflect.ClassTag import com.esotericsoftware.kryo.Kryo import com.esotericsoftware.kryo.io.{Input, Output} - import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.hadoop.hive.ql.exec.{UDF, Utilities} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index 9638a8201e190..ed359620a5f7f 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -125,7 +125,7 @@ private[hive] trait HiveStrategies { InterpretedPredicate.create(castedPredicate) } - val partitions = relation.getHiveQlPartitions(pruningPredicates).filter { part => + val partitions = relation.hiveQlPartitions.filter { part => val partitionValues = part.getValues var i = 0 while (i < partitionValues.size()) { @@ -213,7 +213,7 @@ private[hive] trait HiveStrategies { projectList, otherPredicates, identity[Seq[Expression]], - HiveTableScan(_, relation, pruningPredicates)(hiveContext)) :: Nil + HiveTableScan(_, relation, pruningPredicates.reduceLeftOption(And))(hiveContext)) :: Nil case _ => Nil } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientInterface.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientInterface.scala index 1656587d14835..0a1d761a52f88 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientInterface.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientInterface.scala @@ -21,7 +21,6 @@ import java.io.PrintStream import java.util.{Map => JMap} import org.apache.spark.sql.catalyst.analysis.{NoSuchDatabaseException, NoSuchTableException} -import org.apache.spark.sql.catalyst.expressions.Expression private[hive] case class HiveDatabase( name: String, @@ -72,12 +71,7 @@ private[hive] case class HiveTable( def isPartitioned: Boolean = partitionColumns.nonEmpty - def getPartitions(predicates: Seq[Expression]): Seq[HivePartition] = { - predicates match { - case Nil => client.getAllPartitions(this) - case _ => client.getPartitionsByFilter(this, predicates) - } - } + def getAllPartitions: Seq[HivePartition] = client.getAllPartitions(this) // Hive does not support backticks when passing names to the client. def qualifiedName: String = s"$database.$name" @@ -138,9 +132,6 @@ private[hive] trait ClientInterface { /** Returns all partitions for the given table. */ def getAllPartitions(hTable: HiveTable): Seq[HivePartition] - /** Returns partitions filtered by predicates for the given table. */ - def getPartitionsByFilter(hTable: HiveTable, predicates: Seq[Expression]): Seq[HivePartition] - /** Loads a static partition into an existing table. */ def loadPartition( loadPath: String, diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala index 8adda54754230..53f457ad4f3cc 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala @@ -17,21 +17,25 @@ package org.apache.spark.sql.hive.client -import java.io.{File, PrintStream} -import java.util.{Map => JMap} +import java.io.{BufferedReader, InputStreamReader, File, PrintStream} +import java.net.URI +import java.util.{ArrayList => JArrayList, Map => JMap, List => JList, Set => JSet} import javax.annotation.concurrent.GuardedBy import scala.collection.JavaConversions._ import scala.language.reflectiveCalls import org.apache.hadoop.fs.Path +import org.apache.hadoop.hive.metastore.api.Database import org.apache.hadoop.hive.conf.HiveConf -import org.apache.hadoop.hive.metastore.api.{Database, FieldSchema} import org.apache.hadoop.hive.metastore.{TableType => HTableType} +import org.apache.hadoop.hive.metastore.api +import org.apache.hadoop.hive.metastore.api.FieldSchema +import org.apache.hadoop.hive.ql.metadata import org.apache.hadoop.hive.ql.metadata.Hive -import org.apache.hadoop.hive.ql.processors._ import org.apache.hadoop.hive.ql.session.SessionState -import org.apache.hadoop.hive.ql.{Driver, metadata} +import org.apache.hadoop.hive.ql.processors._ +import org.apache.hadoop.hive.ql.Driver import org.apache.spark.Logging import org.apache.spark.sql.catalyst.expressions.Expression @@ -312,13 +316,6 @@ private[hive] class ClientWrapper( shim.getAllPartitions(client, qlTable).map(toHivePartition) } - override def getPartitionsByFilter( - hTable: HiveTable, - predicates: Seq[Expression]): Seq[HivePartition] = withHiveState { - val qlTable = toQlTable(hTable) - shim.getPartitionsByFilter(client, qlTable, predicates).map(toHivePartition) - } - override def listTables(dbName: String): Seq[String] = withHiveState { client.getAllTables(dbName) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala index d12778c7583df..1fa9d278e2a57 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala @@ -31,11 +31,6 @@ import org.apache.hadoop.hive.ql.Driver import org.apache.hadoop.hive.ql.metadata.{Hive, Partition, Table} import org.apache.hadoop.hive.ql.processors.{CommandProcessor, CommandProcessorFactory} import org.apache.hadoop.hive.ql.session.SessionState -import org.apache.hadoop.hive.serde.serdeConstants - -import org.apache.spark.Logging -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.types.{StringType, IntegralType} /** * A shim that defines the interface between ClientWrapper and the underlying Hive library used to @@ -66,8 +61,6 @@ private[client] sealed abstract class Shim { def getAllPartitions(hive: Hive, table: Table): Seq[Partition] - def getPartitionsByFilter(hive: Hive, table: Table, predicates: Seq[Expression]): Seq[Partition] - def getCommandProcessor(token: String, conf: HiveConf): CommandProcessor def getDriverResults(driver: Driver): Seq[String] @@ -116,7 +109,7 @@ private[client] sealed abstract class Shim { } -private[client] class Shim_v0_12 extends Shim with Logging { +private[client] class Shim_v0_12 extends Shim { private lazy val startMethod = findStaticMethod( @@ -203,17 +196,6 @@ private[client] class Shim_v0_12 extends Shim with Logging { override def getAllPartitions(hive: Hive, table: Table): Seq[Partition] = getAllPartitionsMethod.invoke(hive, table).asInstanceOf[JSet[Partition]].toSeq - override def getPartitionsByFilter( - hive: Hive, - table: Table, - predicates: Seq[Expression]): Seq[Partition] = { - // getPartitionsByFilter() doesn't support binary comparison ops in Hive 0.12. - // See HIVE-4888. - logDebug("Hive 0.12 doesn't support predicate pushdown to metastore. " + - "Please use Hive 0.13 or higher.") - getAllPartitions(hive, table) - } - override def getCommandProcessor(token: String, conf: HiveConf): CommandProcessor = getCommandProcessorMethod.invoke(null, token, conf).asInstanceOf[CommandProcessor] @@ -285,12 +267,6 @@ private[client] class Shim_v0_13 extends Shim_v0_12 { classOf[Hive], "getAllPartitionsOf", classOf[Table]) - private lazy val getPartitionsByFilterMethod = - findMethod( - classOf[Hive], - "getPartitionsByFilter", - classOf[Table], - classOf[String]) private lazy val getCommandProcessorMethod = findStaticMethod( classOf[CommandProcessorFactory], @@ -312,52 +288,6 @@ private[client] class Shim_v0_13 extends Shim_v0_12 { override def getAllPartitions(hive: Hive, table: Table): Seq[Partition] = getAllPartitionsMethod.invoke(hive, table).asInstanceOf[JSet[Partition]].toSeq - /** - * Converts catalyst expression to the format that Hive's getPartitionsByFilter() expects, i.e. - * a string that represents partition predicates like "str_key=\"value\" and int_key=1 ...". - * - * Unsupported predicates are skipped. - */ - def convertFilters(table: Table, filters: Seq[Expression]): String = { - // hive varchar is treated as catalyst string, but hive varchar can't be pushed down. - val varcharKeys = table.getPartitionKeys - .filter(col => col.getType.startsWith(serdeConstants.VARCHAR_TYPE_NAME)) - .map(col => col.getName).toSet - - filters.collect { - case op @ BinaryComparison(a: Attribute, Literal(v, _: IntegralType)) => - s"${a.name} ${op.symbol} $v" - case op @ BinaryComparison(Literal(v, _: IntegralType), a: Attribute) => - s"$v ${op.symbol} ${a.name}" - - case op @ BinaryComparison(a: Attribute, Literal(v, _: StringType)) - if !varcharKeys.contains(a.name) => - s"""${a.name} ${op.symbol} "$v"""" - case op @ BinaryComparison(Literal(v, _: StringType), a: Attribute) - if !varcharKeys.contains(a.name) => - s""""$v" ${op.symbol} ${a.name}""" - }.mkString(" and ") - } - - override def getPartitionsByFilter( - hive: Hive, - table: Table, - predicates: Seq[Expression]): Seq[Partition] = { - - // Hive getPartitionsByFilter() takes a string that represents partition - // predicates like "str_key=\"value\" and int_key=1 ..." - val filter = convertFilters(table, predicates) - val partitions = - if (filter.isEmpty) { - getAllPartitionsMethod.invoke(hive, table).asInstanceOf[JSet[Partition]] - } else { - logDebug(s"Hive metastore filter is '$filter'.") - getPartitionsByFilterMethod.invoke(hive, table, filter).asInstanceOf[JArrayList[Partition]] - } - - partitions.toSeq - } - override def getCommandProcessor(token: String, conf: HiveConf): CommandProcessor = getCommandProcessorMethod.invoke(null, Array(token), conf).asInstanceOf[CommandProcessor] diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala index ba7eb15a1c0c6..d33da8242cc1d 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala @@ -44,7 +44,7 @@ private[hive] case class HiveTableScan( requestedAttributes: Seq[Attribute], relation: MetastoreRelation, - partitionPruningPred: Seq[Expression])( + partitionPruningPred: Option[Expression])( @transient val context: HiveContext) extends LeafNode { @@ -56,7 +56,7 @@ case class HiveTableScan( // Bind all partition key attribute references in the partition pruning predicate for later // evaluation. - private[this] val boundPruningPred = partitionPruningPred.reduceLeftOption(And).map { pred => + private[this] val boundPruningPred = partitionPruningPred.map { pred => require( pred.dataType == BooleanType, s"Data type of predicate $pred must be BooleanType rather than ${pred.dataType}.") @@ -133,8 +133,7 @@ case class HiveTableScan( protected override def doExecute(): RDD[InternalRow] = if (!relation.hiveQlTable.isPartitioned) { hadoopReader.makeRDDForTable(relation.hiveQlTable) } else { - hadoopReader.makeRDDForPartitionedTable( - prunePartitions(relation.getHiveQlPartitions(partitionPruningPred))) + hadoopReader.makeRDDForPartitionedTable(prunePartitions(relation.hiveQlPartitions)) } override def output: Seq[Attribute] = attributes diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/FiltersSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/FiltersSuite.scala deleted file mode 100644 index 0efcf80bd4ea7..0000000000000 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/FiltersSuite.scala +++ /dev/null @@ -1,78 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.hive.client - -import scala.collection.JavaConversions._ - -import org.apache.hadoop.hive.metastore.api.FieldSchema -import org.apache.hadoop.hive.serde.serdeConstants - -import org.apache.spark.{Logging, SparkFunSuite} -import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.types._ - -/** - * A set of tests for the filter conversion logic used when pushing partition pruning into the - * metastore - */ -class FiltersSuite extends SparkFunSuite with Logging { - private val shim = new Shim_v0_13 - - private val testTable = new org.apache.hadoop.hive.ql.metadata.Table("default", "test") - private val varCharCol = new FieldSchema() - varCharCol.setName("varchar") - varCharCol.setType(serdeConstants.VARCHAR_TYPE_NAME) - testTable.setPartCols(varCharCol :: Nil) - - filterTest("string filter", - (a("stringcol", StringType) > Literal("test")) :: Nil, - "stringcol > \"test\"") - - filterTest("string filter backwards", - (Literal("test") > a("stringcol", StringType)) :: Nil, - "\"test\" > stringcol") - - filterTest("int filter", - (a("intcol", IntegerType) === Literal(1)) :: Nil, - "intcol = 1") - - filterTest("int filter backwards", - (Literal(1) === a("intcol", IntegerType)) :: Nil, - "1 = intcol") - - filterTest("int and string filter", - (Literal(1) === a("intcol", IntegerType)) :: (Literal("a") === a("strcol", IntegerType)) :: Nil, - "1 = intcol and \"a\" = strcol") - - filterTest("skip varchar", - (Literal("") === a("varchar", StringType)) :: Nil, - "") - - private def filterTest(name: String, filters: Seq[Expression], result: String) = { - test(name){ - val converted = shim.convertFilters(testTable, filters) - if (converted != result) { - fail( - s"Expected filters ${filters.mkString(",")} to convert to '$result' but got '$converted'") - } - } - } - - private def a(name: String, dataType: DataType) = AttributeReference(name, dataType)() -} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala index 3eb127e23d486..d52e162acbd04 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala @@ -20,9 +20,7 @@ package org.apache.spark.sql.hive.client import java.io.File import org.apache.spark.{Logging, SparkFunSuite} -import org.apache.spark.sql.catalyst.expressions.{NamedExpression, Literal, AttributeReference, EqualTo} import org.apache.spark.sql.catalyst.util.quietly -import org.apache.spark.sql.types.IntegerType import org.apache.spark.util.Utils /** @@ -153,12 +151,6 @@ class VersionsSuite extends SparkFunSuite with Logging { client.getAllPartitions(client.getTable("default", "src_part")) } - test(s"$version: getPartitionsByFilter") { - client.getPartitionsByFilter(client.getTable("default", "src_part"), Seq(EqualTo( - AttributeReference("key", IntegerType, false)(NamedExpression.newExprId), - Literal(1)))) - } - test(s"$version: loadPartition") { client.loadPartition( emptyDir, diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala index e83a7dc77e329..de6a41ce5bfcb 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala @@ -151,7 +151,7 @@ class PruningSuite extends HiveComparisonTest with BeforeAndAfter { case p @ HiveTableScan(columns, relation, _) => val columnNames = columns.map(_.name) val partValues = if (relation.table.isPartitioned) { - p.prunePartitions(relation.getHiveQlPartitions()).map(_.getValues) + p.prunePartitions(relation.hiveQlPartitions).map(_.getValues) } else { Seq.empty } From 4692769655e09d129a62a89a8ffb5d635675aa4d Mon Sep 17 00:00:00 2001 From: Yu ISHIKAWA Date: Tue, 14 Jul 2015 23:27:42 -0700 Subject: [PATCH 05/27] [SPARK-6259] [MLLIB] Python API for LDA I implemented the Python API for LDA. But I didn't implemented a method for `LDAModel.describeTopics()`, beause it's a little hard to implement it now. And adding document about that and an example code would fit for another issue. TODO: LDAModel.describeTopics() in Python must be also implemented. But it would be nice to fit for another issue. Implementing it is a little hard, since the return value of `describeTopics` in Scala consists of Tuple classes. Author: Yu ISHIKAWA Closes #6791 from yu-iskw/SPARK-6259 and squashes the following commits: 6855f59 [Yu ISHIKAWA] LDA inherits object 28bd165 [Yu ISHIKAWA] Change the place of testing code d7a332a [Yu ISHIKAWA] Remove the doc comment about the optimizer's default value 083e226 [Yu ISHIKAWA] Add the comment about the supported values and the default value of `optimizer` 9f8bed8 [Yu ISHIKAWA] Simplify casting faa9764 [Yu ISHIKAWA] Add some comments for the LDA paramters 98f645a [Yu ISHIKAWA] Remove the interface for `describeTopics`. Because it is not implemented. 57ac03d [Yu ISHIKAWA] Remove the unnecessary import in Python unit testing 73412c3 [Yu ISHIKAWA] Fix the typo 2278829 [Yu ISHIKAWA] Fix the indentation 39514ec [Yu ISHIKAWA] Modify how to cast the input data 8117e18 [Yu ISHIKAWA] Fix the validation problems by `lint-scala` 77fd1b7 [Yu ISHIKAWA] Not use LabeledPoint 68f0653 [Yu ISHIKAWA] Support some parameters for `ALS.train()` in Python 25ef2ac [Yu ISHIKAWA] Resolve conflicts with rebasing --- .../mllib/api/python/PythonMLLibAPI.scala | 33 ++++++++++ python/pyspark/mllib/clustering.py | 66 ++++++++++++++++++- 2 files changed, 98 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index e628059c4af8e..c58a64001d9a0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -502,6 +502,39 @@ private[python] class PythonMLLibAPI extends Serializable { new MatrixFactorizationModelWrapper(model) } + /** + * Java stub for Python mllib LDA.run() + */ + def trainLDAModel( + data: JavaRDD[java.util.List[Any]], + k: Int, + maxIterations: Int, + docConcentration: Double, + topicConcentration: Double, + seed: java.lang.Long, + checkpointInterval: Int, + optimizer: String): LDAModel = { + val algo = new LDA() + .setK(k) + .setMaxIterations(maxIterations) + .setDocConcentration(docConcentration) + .setTopicConcentration(topicConcentration) + .setCheckpointInterval(checkpointInterval) + .setOptimizer(optimizer) + + if (seed != null) algo.setSeed(seed) + + val documents = data.rdd.map(_.asScala.toArray).map { r => + r(0) match { + case i: java.lang.Integer => (i.toLong, r(1).asInstanceOf[Vector]) + case i: java.lang.Long => (i.toLong, r(1).asInstanceOf[Vector]) + case _ => throw new IllegalArgumentException("input values contains invalid type value.") + } + } + algo.run(documents) + } + + /** * Java stub for Python mllib FPGrowth.train(). This stub returns a handle * to the Java object instead of the content of the Java object. Extra care diff --git a/python/pyspark/mllib/clustering.py b/python/pyspark/mllib/clustering.py index ed4d78a2c6788..8a92f6911c24b 100644 --- a/python/pyspark/mllib/clustering.py +++ b/python/pyspark/mllib/clustering.py @@ -31,13 +31,15 @@ from pyspark.rdd import RDD, ignore_unicode_prefix from pyspark.mllib.common import JavaModelWrapper, callMLlibFunc, callJavaFunc, _py2java, _java2py from pyspark.mllib.linalg import SparseVector, _convert_to_vector, DenseVector +from pyspark.mllib.regression import LabeledPoint from pyspark.mllib.stat.distribution import MultivariateGaussian from pyspark.mllib.util import Saveable, Loader, inherit_doc, JavaLoader, JavaSaveable from pyspark.streaming import DStream __all__ = ['KMeansModel', 'KMeans', 'GaussianMixtureModel', 'GaussianMixture', 'PowerIterationClusteringModel', 'PowerIterationClustering', - 'StreamingKMeans', 'StreamingKMeansModel'] + 'StreamingKMeans', 'StreamingKMeansModel', + 'LDA', 'LDAModel'] @inherit_doc @@ -563,6 +565,68 @@ def predictOnValues(self, dstream): return dstream.mapValues(lambda x: self._model.predict(x)) +class LDAModel(JavaModelWrapper): + + """ A clustering model derived from the LDA method. + + Latent Dirichlet Allocation (LDA), a topic model designed for text documents. + Terminology + - "word" = "term": an element of the vocabulary + - "token": instance of a term appearing in a document + - "topic": multinomial distribution over words representing some concept + References: + - Original LDA paper (journal version): + Blei, Ng, and Jordan. "Latent Dirichlet Allocation." JMLR, 2003. + + >>> from pyspark.mllib.linalg import Vectors + >>> from numpy.testing import assert_almost_equal + >>> data = [ + ... [1, Vectors.dense([0.0, 1.0])], + ... [2, SparseVector(2, {0: 1.0})], + ... ] + >>> rdd = sc.parallelize(data) + >>> model = LDA.train(rdd, k=2) + >>> model.vocabSize() + 2 + >>> topics = model.topicsMatrix() + >>> topics_expect = array([[0.5, 0.5], [0.5, 0.5]]) + >>> assert_almost_equal(topics, topics_expect, 1) + """ + + def topicsMatrix(self): + """Inferred topics, where each topic is represented by a distribution over terms.""" + return self.call("topicsMatrix").toArray() + + def vocabSize(self): + """Vocabulary size (number of terms or terms in the vocabulary)""" + return self.call("vocabSize") + + +class LDA(object): + + @classmethod + def train(cls, rdd, k=10, maxIterations=20, docConcentration=-1.0, + topicConcentration=-1.0, seed=None, checkpointInterval=10, optimizer="em"): + """Train a LDA model. + + :param rdd: RDD of data points + :param k: Number of clusters you want + :param maxIterations: Number of iterations. Default to 20 + :param docConcentration: Concentration parameter (commonly named "alpha") + for the prior placed on documents' distributions over topics ("theta"). + :param topicConcentration: Concentration parameter (commonly named "beta" or "eta") + for the prior placed on topics' distributions over terms. + :param seed: Random Seed + :param checkpointInterval: Period (in iterations) between checkpoints. + :param optimizer: LDAOptimizer used to perform the actual calculation. + Currently "em", "online" are supported. Default to "em". + """ + model = callMLlibFunc("trainLDAModel", rdd, k, maxIterations, + docConcentration, topicConcentration, seed, + checkpointInterval, optimizer) + return LDAModel(model) + + def _test(): import doctest import pyspark.mllib.clustering From 3f6296fed4ee10f53e728eb1e02f13338839b94d Mon Sep 17 00:00:00 2001 From: FlytxtRnD Date: Tue, 14 Jul 2015 23:29:02 -0700 Subject: [PATCH 06/27] [SPARK-8018] [MLLIB] KMeans should accept initial cluster centers as param This allows Kmeans to be initialized using an existing set of cluster centers provided as a KMeansModel object. This mode of initialization performs a single run. Author: FlytxtRnD Closes #6737 from FlytxtRnD/Kmeans-8018 and squashes the following commits: 94b56df [FlytxtRnD] style correction ef95ee2 [FlytxtRnD] style correction c446c58 [FlytxtRnD] documentation and numRuns warning change 06d13ef [FlytxtRnD] numRuns corrected d12336e [FlytxtRnD] numRuns variable modifications 07f8554 [FlytxtRnD] remove setRuns from setIntialModel e721dfe [FlytxtRnD] Merge remote-tracking branch 'upstream/master' into Kmeans-8018 242ead1 [FlytxtRnD] corrected == to === in assert 714acb5 [FlytxtRnD] added numRuns 60c8ce2 [FlytxtRnD] ignore runs parameter and initialModel test suite changed 582e6d9 [FlytxtRnD] Merge remote-tracking branch 'upstream/master' into Kmeans-8018 3f5fc8e [FlytxtRnD] test case modified and one runs condition added cd5dc5c [FlytxtRnD] Merge remote-tracking branch 'upstream/master' into Kmeans-8018 16f1b53 [FlytxtRnD] Merge branch 'Kmeans-8018', remote-tracking branch 'upstream/master' into Kmeans-8018 e9c35d7 [FlytxtRnD] Remove getInitialModel and match cluster count criteria 6959861 [FlytxtRnD] Accept initial cluster centers in KMeans --- docs/mllib-clustering.md | 1 + .../spark/mllib/clustering/KMeans.scala | 41 ++++++++++++++++--- .../spark/mllib/clustering/KMeansSuite.scala | 22 ++++++++++ 3 files changed, 58 insertions(+), 6 deletions(-) diff --git a/docs/mllib-clustering.md b/docs/mllib-clustering.md index d72dc20a5ad6e..0fc7036bffeaf 100644 --- a/docs/mllib-clustering.md +++ b/docs/mllib-clustering.md @@ -33,6 +33,7 @@ guaranteed to find a globally optimal solution, and when run multiple times on a given dataset, the algorithm returns the best clustering result). * *initializationSteps* determines the number of steps in the k-means\|\| algorithm. * *epsilon* determines the distance threshold within which we consider k-means to have converged. +* *initialModel* is an optional set of cluster centers used for initialization. If this parameter is supplied, only one run is performed. **Examples** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala index 0f8d6a399682d..68297130a7b03 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala @@ -156,6 +156,21 @@ class KMeans private ( this } + // Initial cluster centers can be provided as a KMeansModel object rather than using the + // random or k-means|| initializationMode + private var initialModel: Option[KMeansModel] = None + + /** + * Set the initial starting point, bypassing the random initialization or k-means|| + * The condition model.k == this.k must be met, failure results + * in an IllegalArgumentException. + */ + def setInitialModel(model: KMeansModel): this.type = { + require(model.k == k, "mismatched cluster count") + initialModel = Some(model) + this + } + /** * Train a K-means model on the given set of points; `data` should be cached for high * performance, because this is an iterative algorithm. @@ -193,20 +208,34 @@ class KMeans private ( val initStartTime = System.nanoTime() - val centers = if (initializationMode == KMeans.RANDOM) { - initRandom(data) + // Only one run is allowed when initialModel is given + val numRuns = if (initialModel.nonEmpty) { + if (runs > 1) logWarning("Ignoring runs; one run is allowed when initialModel is given.") + 1 } else { - initKMeansParallel(data) + runs } + val centers = initialModel match { + case Some(kMeansCenters) => { + Array(kMeansCenters.clusterCenters.map(s => new VectorWithNorm(s))) + } + case None => { + if (initializationMode == KMeans.RANDOM) { + initRandom(data) + } else { + initKMeansParallel(data) + } + } + } val initTimeInSeconds = (System.nanoTime() - initStartTime) / 1e9 logInfo(s"Initialization with $initializationMode took " + "%.3f".format(initTimeInSeconds) + " seconds.") - val active = Array.fill(runs)(true) - val costs = Array.fill(runs)(0.0) + val active = Array.fill(numRuns)(true) + val costs = Array.fill(numRuns)(0.0) - var activeRuns = new ArrayBuffer[Int] ++ (0 until runs) + var activeRuns = new ArrayBuffer[Int] ++ (0 until numRuns) var iteration = 0 val iterationStartTime = System.nanoTime() diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala index 0dbbd7127444f..3003c62d9876c 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala @@ -278,6 +278,28 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext { } } } + + test("Initialize using given cluster centers") { + val points = Seq( + Vectors.dense(0.0, 0.0), + Vectors.dense(1.0, 0.0), + Vectors.dense(0.0, 1.0), + Vectors.dense(1.0, 1.0) + ) + val rdd = sc.parallelize(points, 3) + // creating an initial model + val initialModel = new KMeansModel(Array(points(0), points(2))) + + val returnModel = new KMeans() + .setK(2) + .setMaxIterations(0) + .setInitialModel(initialModel) + .run(rdd) + // comparing the returned model and the initial model + assert(returnModel.clusterCenters(0) === initialModel.clusterCenters(0)) + assert(returnModel.clusterCenters(1) === initialModel.clusterCenters(1)) + } + } object KMeansSuite extends SparkFunSuite { From f0e129740dc2442a21dfa7fbd97360df87291095 Mon Sep 17 00:00:00 2001 From: Yijie Shen Date: Tue, 14 Jul 2015 23:30:41 -0700 Subject: [PATCH 07/27] [SPARK-8279][SQL]Add math function round JIRA: https://issues.apache.org/jira/browse/SPARK-8279 Author: Yijie Shen Closes #6938 from yijieshen/udf_round_3 and squashes the following commits: 07a124c [Yijie Shen] remove useless def children 392b65b [Yijie Shen] add negative scale test in DecimalSuite 61760ee [Yijie Shen] address reviews 302a78a [Yijie Shen] Add dataframe function test 31dfe7c [Yijie Shen] refactor round to make it readable 8c7a949 [Yijie Shen] rebase & inputTypes update 9555e35 [Yijie Shen] tiny style fix d10be4a [Yijie Shen] use TypeCollection to specify wanted input and implicit cast c3b9839 [Yijie Shen] rely on implict cast to handle string input b0bff79 [Yijie Shen] make round's inner method's name more meaningful 9bd6930 [Yijie Shen] revert accidental change e6f44c4 [Yijie Shen] refactor eval and genCode 1b87540 [Yijie Shen] modify checkInputDataTypes using foldable 5486b2d [Yijie Shen] DataFrame API modification 2077888 [Yijie Shen] codegen versioned eval 6cd9a64 [Yijie Shen] refactor Round's constructor 9be894e [Yijie Shen] add round functions in o.a.s.sql.functions 7c83e13 [Yijie Shen] more tests on round 56db4bb [Yijie Shen] Add decimal support to Round 7e163ae [Yijie Shen] style fix 653d047 [Yijie Shen] Add math function round --- .../catalyst/analysis/FunctionRegistry.scala | 1 + .../spark/sql/catalyst/expressions/math.scala | 203 +++++++++++++++++- .../ExpressionTypeCheckingSuite.scala | 17 ++ .../expressions/MathFunctionsSuite.scala | 44 ++++ .../sql/types/decimal/DecimalSuite.scala | 23 +- .../org/apache/spark/sql/functions.scala | 32 +++ .../spark/sql/MathExpressionsSuite.scala | 15 ++ .../execution/HiveCompatibilitySuite.scala | 7 +- 8 files changed, 329 insertions(+), 13 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 6b1a94e4b2ad4..ec75f51d5e4ff 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -117,6 +117,7 @@ object FunctionRegistry { expression[Pow]("power"), expression[UnaryPositive]("positive"), expression[Rint]("rint"), + expression[Round]("round"), expression[ShiftLeft]("shiftleft"), expression[ShiftRight]("shiftright"), expression[ShiftRightUnsigned]("shiftrightunsigned"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala index 4b7fe05dd4980..a7ad452ef4943 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala @@ -19,8 +19,10 @@ package org.apache.spark.sql.catalyst.expressions import java.{lang => jl} -import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckSuccess, TypeCheckFailure} import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -520,3 +522,202 @@ case class Logarithm(left: Expression, right: Expression) """ } } + +/** + * Round the `child`'s result to `scale` decimal place when `scale` >= 0 + * or round at integral part when `scale` < 0. + * For example, round(31.415, 2) would eval to 31.42 and round(31.415, -1) would eval to 30. + * + * Child of IntegralType would eval to itself when `scale` >= 0. + * Child of FractionalType whose value is NaN or Infinite would always eval to itself. + * + * Round's dataType would always equal to `child`'s dataType except for [[DecimalType.Fixed]], + * which leads to scale update in DecimalType's [[PrecisionInfo]] + * + * @param child expr to be round, all [[NumericType]] is allowed as Input + * @param scale new scale to be round to, this should be a constant int at runtime + */ +case class Round(child: Expression, scale: Expression) + extends BinaryExpression with ExpectsInputTypes { + + import BigDecimal.RoundingMode.HALF_UP + + def this(child: Expression) = this(child, Literal(0)) + + override def left: Expression = child + override def right: Expression = scale + + // round of Decimal would eval to null if it fails to `changePrecision` + override def nullable: Boolean = true + + override def foldable: Boolean = child.foldable + + override lazy val dataType: DataType = child.dataType match { + // if the new scale is bigger which means we are scaling up, + // keep the original scale as `Decimal` does + case DecimalType.Fixed(p, s) => DecimalType(p, if (_scale > s) s else _scale) + case t => t + } + + override def inputTypes: Seq[AbstractDataType] = Seq(NumericType, IntegerType) + + override def checkInputDataTypes(): TypeCheckResult = { + super.checkInputDataTypes() match { + case TypeCheckSuccess => + if (scale.foldable) { + TypeCheckSuccess + } else { + TypeCheckFailure("Only foldable Expression is allowed for scale arguments") + } + case f => f + } + } + + // Avoid repeated evaluation since `scale` is a constant int, + // avoid unnecessary `child` evaluation in both codegen and non-codegen eval + // by checking if scaleV == null as well. + private lazy val scaleV: Any = scale.eval(EmptyRow) + private lazy val _scale: Int = scaleV.asInstanceOf[Int] + + override def eval(input: InternalRow): Any = { + if (scaleV == null) { // if scale is null, no need to eval its child at all + null + } else { + val evalE = child.eval(input) + if (evalE == null) { + null + } else { + nullSafeEval(evalE) + } + } + } + + // not overriding since _scale is a constant int at runtime + def nullSafeEval(input1: Any): Any = { + child.dataType match { + case _: DecimalType => + val decimal = input1.asInstanceOf[Decimal] + if (decimal.changePrecision(decimal.precision, _scale)) decimal else null + case ByteType => + BigDecimal(input1.asInstanceOf[Byte]).setScale(_scale, HALF_UP).toByte + case ShortType => + BigDecimal(input1.asInstanceOf[Short]).setScale(_scale, HALF_UP).toShort + case IntegerType => + BigDecimal(input1.asInstanceOf[Int]).setScale(_scale, HALF_UP).toInt + case LongType => + BigDecimal(input1.asInstanceOf[Long]).setScale(_scale, HALF_UP).toLong + case FloatType => + val f = input1.asInstanceOf[Float] + if (f.isNaN || f.isInfinite) { + f + } else { + BigDecimal(f).setScale(_scale, HALF_UP).toFloat + } + case DoubleType => + val d = input1.asInstanceOf[Double] + if (d.isNaN || d.isInfinite) { + d + } else { + BigDecimal(d).setScale(_scale, HALF_UP).toDouble + } + } + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val ce = child.gen(ctx) + + val evaluationCode = child.dataType match { + case _: DecimalType => + s""" + if (${ce.primitive}.changePrecision(${ce.primitive}.precision(), ${_scale})) { + ${ev.primitive} = ${ce.primitive}; + } else { + ${ev.isNull} = true; + }""" + case ByteType => + if (_scale < 0) { + s""" + ${ev.primitive} = new java.math.BigDecimal(${ce.primitive}). + setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP).byteValue();""" + } else { + s"${ev.primitive} = ${ce.primitive};" + } + case ShortType => + if (_scale < 0) { + s""" + ${ev.primitive} = new java.math.BigDecimal(${ce.primitive}). + setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP).shortValue();""" + } else { + s"${ev.primitive} = ${ce.primitive};" + } + case IntegerType => + if (_scale < 0) { + s""" + ${ev.primitive} = new java.math.BigDecimal(${ce.primitive}). + setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP).intValue();""" + } else { + s"${ev.primitive} = ${ce.primitive};" + } + case LongType => + if (_scale < 0) { + s""" + ${ev.primitive} = new java.math.BigDecimal(${ce.primitive}). + setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP).longValue();""" + } else { + s"${ev.primitive} = ${ce.primitive};" + } + case FloatType => // if child eval to NaN or Infinity, just return it. + if (_scale == 0) { + s""" + if (Float.isNaN(${ce.primitive}) || Float.isInfinite(${ce.primitive})){ + ${ev.primitive} = ${ce.primitive}; + } else { + ${ev.primitive} = Math.round(${ce.primitive}); + }""" + } else { + s""" + if (Float.isNaN(${ce.primitive}) || Float.isInfinite(${ce.primitive})){ + ${ev.primitive} = ${ce.primitive}; + } else { + ${ev.primitive} = java.math.BigDecimal.valueOf(${ce.primitive}). + setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP).floatValue(); + }""" + } + case DoubleType => // if child eval to NaN or Infinity, just return it. + if (_scale == 0) { + s""" + if (Double.isNaN(${ce.primitive}) || Double.isInfinite(${ce.primitive})){ + ${ev.primitive} = ${ce.primitive}; + } else { + ${ev.primitive} = Math.round(${ce.primitive}); + }""" + } else { + s""" + if (Double.isNaN(${ce.primitive}) || Double.isInfinite(${ce.primitive})){ + ${ev.primitive} = ${ce.primitive}; + } else { + ${ev.primitive} = java.math.BigDecimal.valueOf(${ce.primitive}). + setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP).doubleValue(); + }""" + } + } + + if (scaleV == null) { // if scale is null, no need to eval its child at all + s""" + boolean ${ev.isNull} = true; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + """ + } else { + s""" + ${ce.code} + boolean ${ev.isNull} = ${ce.isNull}; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + if (!${ev.isNull}) { + $evaluationCode + } + """ + } + } + + override def prettyName: String = "round" +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala index 5958acbe009ca..e885a18254ea0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala @@ -52,6 +52,13 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { s"differing types in '${expr.prettyString}' (int and boolean)") } + def assertErrorWithImplicitCast(expr: Expression, errorMessage: String): Unit = { + val e = intercept[AnalysisException] { + assertSuccess(expr) + } + assert(e.getMessage.contains(errorMessage)) + } + test("check types for unary arithmetic") { assertError(UnaryMinus('stringField), "operator - accepts numeric type") assertError(Abs('stringField), "function abs accepts numeric type") @@ -171,4 +178,14 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { CreateNamedStruct(Seq('a.string.at(0), "a", "b", 2.0)), "Odd position only allow foldable and not-null StringType expressions") } + + test("check types for ROUND") { + assertErrorWithImplicitCast(Round(Literal(null), 'booleanField), + "data type mismatch: argument 2 is expected to be of type int") + assertErrorWithImplicitCast(Round(Literal(null), 'complexField), + "data type mismatch: argument 2 is expected to be of type int") + assertSuccess(Round(Literal(null), Literal(null))) + assertError(Round('booleanField, 'intField), + "data type mismatch: argument 1 is expected to be of type numeric") + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala index 7ca9e30b2bcd5..52a874a9d89ef 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.expressions +import scala.math.BigDecimal.RoundingMode + import com.google.common.math.LongMath import org.apache.spark.SparkFunSuite @@ -336,4 +338,46 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { null, create_row(null)) } + + test("round") { + val domain = -6 to 6 + val doublePi: Double = math.Pi + val shortPi: Short = 31415 + val intPi: Int = 314159265 + val longPi: Long = 31415926535897932L + val bdPi: BigDecimal = BigDecimal(31415927L, 7) + + val doubleResults: Seq[Double] = Seq(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 3.0, 3.1, 3.14, 3.142, + 3.1416, 3.14159, 3.141593) + + val shortResults: Seq[Short] = Seq[Short](0, 0, 30000, 31000, 31400, 31420) ++ + Seq.fill[Short](7)(31415) + + val intResults: Seq[Int] = Seq(314000000, 314200000, 314160000, 314159000, 314159300, + 314159270) ++ Seq.fill(7)(314159265) + + val longResults: Seq[Long] = Seq(31415926536000000L, 31415926535900000L, + 31415926535900000L, 31415926535898000L, 31415926535897900L, 31415926535897930L) ++ + Seq.fill(7)(31415926535897932L) + + val bdResults: Seq[BigDecimal] = Seq(BigDecimal(3.0), BigDecimal(3.1), BigDecimal(3.14), + BigDecimal(3.142), BigDecimal(3.1416), BigDecimal(3.14159), + BigDecimal(3.141593), BigDecimal(3.1415927)) + + domain.zipWithIndex.foreach { case (scale, i) => + checkEvaluation(Round(doublePi, scale), doubleResults(i), EmptyRow) + checkEvaluation(Round(shortPi, scale), shortResults(i), EmptyRow) + checkEvaluation(Round(intPi, scale), intResults(i), EmptyRow) + checkEvaluation(Round(longPi, scale), longResults(i), EmptyRow) + } + + // round_scale > current_scale would result in precision increase + // and not allowed by o.a.s.s.types.Decimal.changePrecision, therefore null + (0 to 7).foreach { i => + checkEvaluation(Round(bdPi, i), bdResults(i), EmptyRow) + } + (8 to 10).foreach { scale => + checkEvaluation(Round(bdPi, scale), null, EmptyRow) + } + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/decimal/DecimalSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/decimal/DecimalSuite.scala index 030bb6d21b18b..f0c849d1a1564 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/decimal/DecimalSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/decimal/DecimalSuite.scala @@ -24,14 +24,14 @@ import org.scalatest.PrivateMethodTester import scala.language.postfixOps class DecimalSuite extends SparkFunSuite with PrivateMethodTester { - test("creating decimals") { - /** Check that a Decimal has the given string representation, precision and scale */ - def checkDecimal(d: Decimal, string: String, precision: Int, scale: Int): Unit = { - assert(d.toString === string) - assert(d.precision === precision) - assert(d.scale === scale) - } + /** Check that a Decimal has the given string representation, precision and scale */ + private def checkDecimal(d: Decimal, string: String, precision: Int, scale: Int): Unit = { + assert(d.toString === string) + assert(d.precision === precision) + assert(d.scale === scale) + } + test("creating decimals") { checkDecimal(new Decimal(), "0", 1, 0) checkDecimal(Decimal(BigDecimal("10.030")), "10.030", 5, 3) checkDecimal(Decimal(BigDecimal("10.030"), 4, 1), "10.0", 4, 1) @@ -53,6 +53,15 @@ class DecimalSuite extends SparkFunSuite with PrivateMethodTester { intercept[IllegalArgumentException](Decimal(1e17.toLong, 17, 0)) } + test("creating decimals with negative scale") { + checkDecimal(Decimal(BigDecimal("98765"), 5, -3), "9.9E+4", 5, -3) + checkDecimal(Decimal(BigDecimal("314.159"), 6, -2), "3E+2", 6, -2) + checkDecimal(Decimal(BigDecimal(1.579e12), 4, -9), "1.579E+12", 4, -9) + checkDecimal(Decimal(BigDecimal(1.579e12), 4, -10), "1.58E+12", 4, -10) + checkDecimal(Decimal(103050709L, 9, -10), "1.03050709E+18", 9, -10) + checkDecimal(Decimal(1e8.toLong, 10, -10), "1.00000000E+18", 10, -10) + } + test("double and long values") { /** Check that a Decimal converts to the given double and long values */ def checkValues(d: Decimal, doubleValue: Double, longValue: Long): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 0d4e160ed8057..5119ee31d852d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1389,6 +1389,38 @@ object functions { */ def rint(columnName: String): Column = rint(Column(columnName)) + /** + * Returns the value of the column `e` rounded to 0 decimal places. + * + * @group math_funcs + * @since 1.5.0 + */ + def round(e: Column): Column = round(e.expr, 0) + + /** + * Returns the value of the given column rounded to 0 decimal places. + * + * @group math_funcs + * @since 1.5.0 + */ + def round(columnName: String): Column = round(Column(columnName), 0) + + /** + * Returns the value of `e` rounded to `scale` decimal places. + * + * @group math_funcs + * @since 1.5.0 + */ + def round(e: Column, scale: Int): Column = Round(e.expr, Literal(scale)) + + /** + * Returns the value of the given column rounded to `scale` decimal places. + * + * @group math_funcs + * @since 1.5.0 + */ + def round(columnName: String, scale: Int): Column = round(Column(columnName), scale) + /** * Shift the the given value numBits left. If the given value is a long value, this function * will return a long value else it will return an integer value. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala index b30b9f12258b9..087126bb2e513 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala @@ -198,6 +198,21 @@ class MathExpressionsSuite extends QueryTest { testOneToOneMathFunction(rint, math.rint) } + test("round") { + val df = Seq(5, 55, 555).map(Tuple1(_)).toDF("a") + checkAnswer( + df.select(round('a), round('a, -1), round('a, -2)), + Seq(Row(5, 10, 0), Row(55, 60, 100), Row(555, 560, 600)) + ) + + val pi = 3.1415 + checkAnswer( + ctx.sql(s"SELECT round($pi, -3), round($pi, -2), round($pi, -1), " + + s"round($pi, 0), round($pi, 1), round($pi, 2), round($pi, 3)"), + Seq(Row(0.0, 0.0, 0.0, 3.0, 3.1, 3.14, 3.142)) + ) + } + test("exp") { testOneToOneMathFunction(exp, math.exp) } diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index c884c399281a8..4ada64bc21966 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -221,9 +221,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "udf_when", "udf_case", - // Needs constant object inspectors - "udf_round", - // the table src(key INT, value STRING) is not the same as HIVE unittest. In Hive // is src(key STRING, value STRING), and in the reflect.q, it failed in // Integer.valueOf, which expect the first argument passed as STRING type not INT. @@ -918,8 +915,8 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "udf_regexp_replace", "udf_repeat", "udf_rlike", - "udf_round", - // "udf_round_3", TODO: FIX THIS failed due to cast exception + // "udf_round", turn this on after we figure out null vs nan vs infinity + "udf_round_3", "udf_rpad", "udf_rtrim", "udf_second", From 1bb8accbc95a0f0856a8bb715f1e94c3ff96a8c7 Mon Sep 17 00:00:00 2001 From: Feynman Liang Date: Tue, 14 Jul 2015 23:50:57 -0700 Subject: [PATCH 08/27] [SPARK-8997] [MLLIB] Performance improvements in LocalPrefixSpan Improves the performance of LocalPrefixSpan by implementing optimizations proposed in [SPARK-8997](https://issues.apache.org/jira/browse/SPARK-8997) Author: Feynman Liang Author: Feynman Liang Author: Xiangrui Meng Closes #7360 from feynmanliang/SPARK-8997-improve-prefixspan and squashes the following commits: 59db2f5 [Feynman Liang] Merge pull request #1 from mengxr/SPARK-8997 91e4357 [Xiangrui Meng] update LocalPrefixSpan impl 9212256 [Feynman Liang] MengXR code review comments f055d82 [Feynman Liang] Fix failing scalatest 2e00cba [Feynman Liang] Depth first projections 70b93e3 [Feynman Liang] Performance improvements in LocalPrefixSpan, fix tests --- .../spark/mllib/fpm/LocalPrefixSpan.scala | 95 ++++++++----------- .../apache/spark/mllib/fpm/PrefixSpan.scala | 5 +- .../spark/mllib/fpm/PrefixSpanSuite.scala | 14 +-- 3 files changed, 44 insertions(+), 70 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala index 39c48b084e550..7ead6327486cc 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala @@ -17,58 +17,49 @@ package org.apache.spark.mllib.fpm +import scala.collection.mutable + import org.apache.spark.Logging -import org.apache.spark.annotation.Experimental /** - * - * :: Experimental :: - * * Calculate all patterns of a projected database in local. */ -@Experimental private[fpm] object LocalPrefixSpan extends Logging with Serializable { /** * Calculate all patterns of a projected database. * @param minCount minimum count * @param maxPatternLength maximum pattern length - * @param prefix prefix - * @param projectedDatabase the projected dabase + * @param prefixes prefixes in reversed order + * @param database the projected database * @return a set of sequential pattern pairs, - * the key of pair is sequential pattern (a list of items), + * the key of pair is sequential pattern (a list of items in reversed order), * the value of pair is the pattern's count. */ def run( minCount: Long, maxPatternLength: Int, - prefix: Array[Int], - projectedDatabase: Array[Array[Int]]): Array[(Array[Int], Long)] = { - val frequentPrefixAndCounts = getFreqItemAndCounts(minCount, projectedDatabase) - val frequentPatternAndCounts = frequentPrefixAndCounts - .map(x => (prefix ++ Array(x._1), x._2)) - val prefixProjectedDatabases = getPatternAndProjectedDatabase( - prefix, frequentPrefixAndCounts.map(_._1), projectedDatabase) - - val continueProcess = prefixProjectedDatabases.nonEmpty && prefix.length + 1 < maxPatternLength - if (continueProcess) { - val nextPatterns = prefixProjectedDatabases - .map(x => run(minCount, maxPatternLength, x._1, x._2)) - .reduce(_ ++ _) - frequentPatternAndCounts ++ nextPatterns - } else { - frequentPatternAndCounts + prefixes: List[Int], + database: Array[Array[Int]]): Iterator[(List[Int], Long)] = { + if (prefixes.length == maxPatternLength || database.isEmpty) return Iterator.empty + val frequentItemAndCounts = getFreqItemAndCounts(minCount, database) + val filteredDatabase = database.map(x => x.filter(frequentItemAndCounts.contains)) + frequentItemAndCounts.iterator.flatMap { case (item, count) => + val newPrefixes = item :: prefixes + val newProjected = project(filteredDatabase, item) + Iterator.single((newPrefixes, count)) ++ + run(minCount, maxPatternLength, newPrefixes, newProjected) } } /** - * calculate suffix sequence following a prefix in a sequence - * @param prefix prefix - * @param sequence sequence + * Calculate suffix sequence immediately after the first occurrence of an item. + * @param item item to get suffix after + * @param sequence sequence to extract suffix from * @return suffix sequence */ - def getSuffix(prefix: Int, sequence: Array[Int]): Array[Int] = { - val index = sequence.indexOf(prefix) + def getSuffix(item: Int, sequence: Array[Int]): Array[Int] = { + val index = sequence.indexOf(item) if (index == -1) { Array() } else { @@ -76,38 +67,28 @@ private[fpm] object LocalPrefixSpan extends Logging with Serializable { } } + def project(database: Array[Array[Int]], prefix: Int): Array[Array[Int]] = { + database + .map(getSuffix(prefix, _)) + .filter(_.nonEmpty) + } + /** * Generates frequent items by filtering the input data using minimal count level. - * @param minCount the absolute minimum count - * @param sequences sequences data - * @return array of item and count pair + * @param minCount the minimum count for an item to be frequent + * @param database database of sequences + * @return freq item to count map */ private def getFreqItemAndCounts( minCount: Long, - sequences: Array[Array[Int]]): Array[(Int, Long)] = { - sequences.flatMap(_.distinct) - .groupBy(x => x) - .mapValues(_.length.toLong) - .filter(_._2 >= minCount) - .toArray - } - - /** - * Get the frequent prefixes' projected database. - * @param prePrefix the frequent prefixes' prefix - * @param frequentPrefixes frequent prefixes - * @param sequences sequences data - * @return prefixes and projected database - */ - private def getPatternAndProjectedDatabase( - prePrefix: Array[Int], - frequentPrefixes: Array[Int], - sequences: Array[Array[Int]]): Array[(Array[Int], Array[Array[Int]])] = { - val filteredProjectedDatabase = sequences - .map(x => x.filter(frequentPrefixes.contains(_))) - frequentPrefixes.map { x => - val sub = filteredProjectedDatabase.map(y => getSuffix(x, y)).filter(_.nonEmpty) - (prePrefix ++ Array(x), sub) - }.filter(x => x._2.nonEmpty) + database: Array[Array[Int]]): mutable.Map[Int, Long] = { + // TODO: use PrimitiveKeyOpenHashMap + val counts = mutable.Map[Int, Long]().withDefaultValue(0L) + database.foreach { sequence => + sequence.distinct.foreach { item => + counts(item) += 1L + } + } + counts.filter(_._2 >= minCount) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala index 9d8c60ef0fc45..6f52db7b073ae 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala @@ -150,8 +150,9 @@ class PrefixSpan private ( private def getPatternsInLocal( minCount: Long, data: RDD[(Array[Int], Array[Array[Int]])]): RDD[(Array[Int], Long)] = { - data.flatMap { x => - LocalPrefixSpan.run(minCount, maxPatternLength, x._1, x._2) + data.flatMap { case (prefix, projDB) => + LocalPrefixSpan.run(minCount, maxPatternLength, prefix.toList, projDB) + .map { case (pattern: List[Int], count: Long) => (pattern.toArray.reverse, count) } } } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala index 413436d3db85f..9f107c89f6d80 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala @@ -18,9 +18,8 @@ package org.apache.spark.mllib.fpm import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.rdd.RDD -class PrefixspanSuite extends SparkFunSuite with MLlibTestSparkContext { +class PrefixSpanSuite extends SparkFunSuite with MLlibTestSparkContext { test("PrefixSpan using Integer type") { @@ -48,15 +47,8 @@ class PrefixspanSuite extends SparkFunSuite with MLlibTestSparkContext { def compareResult( expectedValue: Array[(Array[Int], Long)], actualValue: Array[(Array[Int], Long)]): Boolean = { - val sortedExpectedValue = expectedValue.sortWith{ (x, y) => - x._1.mkString(",") + ":" + x._2 < y._1.mkString(",") + ":" + y._2 - } - val sortedActualValue = actualValue.sortWith{ (x, y) => - x._1.mkString(",") + ":" + x._2 < y._1.mkString(",") + ":" + y._2 - } - sortedExpectedValue.zip(sortedActualValue) - .map(x => x._1._1.mkString(",") == x._2._1.mkString(",") && x._1._2 == x._2._2) - .reduce(_&&_) + expectedValue.map(x => (x._1.toSeq, x._2)).toSet == + actualValue.map(x => (x._1.toSeq, x._2)).toSet } val prefixspan = new PrefixSpan() From 14935d846a4f6bcd4d2a448a8f112fa5dee769ba Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 15 Jul 2015 00:12:21 -0700 Subject: [PATCH 09/27] [HOTFIX][SQL] Unit test breaking. --- .../sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala index e885a18254ea0..a4ce1825cab28 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala @@ -60,9 +60,9 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { } test("check types for unary arithmetic") { - assertError(UnaryMinus('stringField), "operator - accepts numeric type") - assertError(Abs('stringField), "function abs accepts numeric type") - assertError(BitwiseNot('stringField), "operator ~ accepts integral type") + assertError(UnaryMinus('stringField), "expected to be of type numeric") + assertError(Abs('stringField), "expected to be of type numeric") + assertError(BitwiseNot('stringField), "type (boolean or tinyint or smallint or int or bigint)") } ignore("check types for binary arithmetic") { From adb33d3665770daf2ccb8915d19e198be9dc3b47 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Wed, 15 Jul 2015 17:30:57 +0900 Subject: [PATCH 10/27] [SPARK-9012] [WEBUI] Escape Accumulators in the task table If running the following codes, the task table will be broken because accumulators aren't escaped. ``` val a = sc.accumulator(1, "") sc.parallelize(1 to 10).foreach(i => a += i) ``` Before this fix, screen shot 2015-07-13 at 8 02 44 pm After this fix, screen shot 2015-07-13 at 8 14 32 pm Author: zsxwing Closes #7369 from zsxwing/SPARK-9012 and squashes the following commits: a83c9b6 [zsxwing] Escape Accumulators in the task table --- core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index ff0a339a39c65..27b82aaddd2e4 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -692,7 +692,9 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { val gettingResultTime = getGettingResultTime(info, currentTime) val maybeAccumulators = info.accumulables - val accumulatorsReadable = maybeAccumulators.map{acc => s"${acc.name}: ${acc.update.get}"} + val accumulatorsReadable = maybeAccumulators.map { acc => + StringEscapeUtils.escapeHtml4(s"${acc.name}: ${acc.update.get}") + } val maybeInput = metrics.flatMap(_.inputMetrics) val inputSortable = maybeInput.map(_.bytesRead.toString).getOrElse("") From 20bb10f8644a92a57496b5df639008832b30e34d Mon Sep 17 00:00:00 2001 From: MechCoder Date: Wed, 15 Jul 2015 08:25:53 -0700 Subject: [PATCH 11/27] [SPARK-8706] [PYSPARK] [PROJECT INFRA] Add pylint checks to PySpark This adds Pylint checks to PySpark. For now this lazy installs using easy_install to /dev/pylint (similar to the pep8 script). We still need to figure out what rules to be allowed. Author: MechCoder Closes #7241 from MechCoder/pylint and squashes the following commits: 2fc7291 [MechCoder] Remove pylint test fail 6d883a2 [MechCoder] Silence warnings and make pylint tests fail to check if it works in jenkins f3a5e17 [MechCoder] undefined-variable ca8b749 [MechCoder] Minor changes 71629f8 [MechCoder] remove trailing whitespace 8498ff9 [MechCoder] Remove blacklisted arguments and pointless statements check 1dbd094 [MechCoder] Disable all checks for now 8b8aa8a [MechCoder] Add pylint configuration file 7871bb1 [MechCoder] [SPARK-8706] [PySpark] [Project infra] Add pylint checks to PySpark --- dev/lint-python | 57 ++++- pylintrc | 404 ++++++++++++++++++++++++++++++ python/pyspark/ml/param/shared.py | 4 +- python/pyspark/tests.py | 3 +- 4 files changed, 457 insertions(+), 11 deletions(-) create mode 100644 pylintrc diff --git a/dev/lint-python b/dev/lint-python index 0c3586462cb37..e02dff220eb87 100755 --- a/dev/lint-python +++ b/dev/lint-python @@ -21,12 +21,14 @@ SCRIPT_DIR="$( cd "$( dirname "$0" )" && pwd )" SPARK_ROOT_DIR="$(dirname "$SCRIPT_DIR")" PATHS_TO_CHECK="./python/pyspark/ ./ec2/spark_ec2.py ./examples/src/main/python/ ./dev/sparktestsupport" PATHS_TO_CHECK="$PATHS_TO_CHECK ./dev/run-tests.py ./python/run-tests.py" -PYTHON_LINT_REPORT_PATH="$SPARK_ROOT_DIR/dev/python-lint-report.txt" +PEP8_REPORT_PATH="$SPARK_ROOT_DIR/dev/pep8-report.txt" +PYLINT_REPORT_PATH="$SPARK_ROOT_DIR/dev/pylint-report.txt" +PYLINT_INSTALL_INFO="$SPARK_ROOT_DIR/dev/pylint-info.txt" cd "$SPARK_ROOT_DIR" # compileall: https://docs.python.org/2/library/compileall.html -python -B -m compileall -q -l $PATHS_TO_CHECK > "$PYTHON_LINT_REPORT_PATH" +python -B -m compileall -q -l $PATHS_TO_CHECK > "$PEP8_REPORT_PATH" compile_status="${PIPESTATUS[0]}" # Get pep8 at runtime so that we don't rely on it being installed on the build server. @@ -47,11 +49,36 @@ if [ ! -e "$PEP8_SCRIPT_PATH" ]; then fi fi +# Easy install pylint in /dev/pylint. To easy_install into a directory, the PYTHONPATH should +# be set to the directory. +# dev/pylint should be appended to the PATH variable as well. +# Jenkins by default installs the pylint3 version, so for now this just checks the code quality +# of python3. +export "PYTHONPATH=$SPARK_ROOT_DIR/dev/pylint" +export "PYLINT_HOME=$PYTHONPATH" +export "PATH=$PYTHONPATH:$PATH" + +if [ ! -d "$PYLINT_HOME" ]; then + mkdir "$PYLINT_HOME" + # Redirect the annoying pylint installation output. + easy_install -d "$PYLINT_HOME" pylint==1.4.4 &>> "$PYLINT_INSTALL_INFO" + easy_install_status="$?" + + if [ "$easy_install_status" -ne 0 ]; then + echo "Unable to install pylint locally in \"$PYTHONPATH\"." + cat "$PYLINT_INSTALL_INFO" + exit "$easy_install_status" + fi + + rm "$PYLINT_INSTALL_INFO" + +fi + # There is no need to write this output to a file #+ first, but we do so so that the check status can #+ be output before the report, like with the #+ scalastyle and RAT checks. -python "$PEP8_SCRIPT_PATH" --ignore=E402,E731,E241,W503,E226 $PATHS_TO_CHECK >> "$PYTHON_LINT_REPORT_PATH" +python "$PEP8_SCRIPT_PATH" --ignore=E402,E731,E241,W503,E226 $PATHS_TO_CHECK >> "$PEP8_REPORT_PATH" pep8_status="${PIPESTATUS[0]}" if [ "$compile_status" -eq 0 -a "$pep8_status" -eq 0 ]; then @@ -61,13 +88,27 @@ else fi if [ "$lint_status" -ne 0 ]; then - echo "Python lint checks failed." - cat "$PYTHON_LINT_REPORT_PATH" + echo "PEP8 checks failed." + cat "$PEP8_REPORT_PATH" +else + echo "PEP8 checks passed." +fi + +rm "$PEP8_REPORT_PATH" + +for to_be_checked in "$PATHS_TO_CHECK" +do + pylint --rcfile="$SPARK_ROOT_DIR/pylintrc" $to_be_checked >> "$PYLINT_REPORT_PATH" +done + +if [ "${PIPESTATUS[0]}" -ne 0 ]; then + lint_status=1 + echo "Pylint checks failed." + cat "$PYLINT_REPORT_PATH" else - echo "Python lint checks passed." + echo "Pylint checks passed." fi -# rm "$PEP8_SCRIPT_PATH" -rm "$PYTHON_LINT_REPORT_PATH" +rm "$PYLINT_REPORT_PATH" exit "$lint_status" diff --git a/pylintrc b/pylintrc new file mode 100644 index 0000000000000..061775960393b --- /dev/null +++ b/pylintrc @@ -0,0 +1,404 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +[MASTER] + +# Specify a configuration file. +#rcfile= + +# Python code to execute, usually for sys.path manipulation such as +# pygtk.require(). +#init-hook= + +# Profiled execution. +profile=no + +# Add files or directories to the blacklist. They should be base names, not +# paths. +ignore=pyspark.heapq3 + +# Pickle collected data for later comparisons. +persistent=yes + +# List of plugins (as comma separated values of python modules names) to load, +# usually to register additional checkers. +load-plugins= + +# Use multiple processes to speed up Pylint. +jobs=1 + +# Allow loading of arbitrary C extensions. Extensions are imported into the +# active Python interpreter and may run arbitrary code. +unsafe-load-any-extension=no + +# A comma-separated list of package or module names from where C extensions may +# be loaded. Extensions are loading into the active Python interpreter and may +# run arbitrary code +extension-pkg-whitelist= + +# Allow optimization of some AST trees. This will activate a peephole AST +# optimizer, which will apply various small optimizations. For instance, it can +# be used to obtain the result of joining multiple strings with the addition +# operator. Joining a lot of strings can lead to a maximum recursion error in +# Pylint and this flag can prevent that. It has one side effect, the resulting +# AST will be different than the one from reality. +optimize-ast=no + + +[MESSAGES CONTROL] + +# Only show warnings with the listed confidence levels. Leave empty to show +# all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED +confidence= + +# Enable the message, report, category or checker with the given id(s). You can +# either give multiple identifier separated by comma (,) or put this option +# multiple time. See also the "--disable" option for examples. +enable= + +# Disable the message, report, category or checker with the given id(s). You +# can either give multiple identifiers separated by comma (,) or put this +# option multiple times (only on the command line, not in the configuration +# file where it should appear only once).You can also use "--disable=all" to +# disable everything first and then reenable specific checks. For example, if +# you want to run only the similarities checker, you can use "--disable=all +# --enable=similarities". If you want to run only the classes checker, but have +# no Warning level messages displayed, use"--disable=all --enable=classes +# --disable=W" + +# These errors are arranged in order of number of warning given in pylint. +# If you would like to improve the code quality of pyspark, remove any of these disabled errors +# run ./dev/lint-python and see if the errors raised by pylint can be fixed. + +disable=invalid-name,missing-docstring,protected-access,unused-argument,no-member,unused-wildcard-import,redefined-builtin,too-many-arguments,unused-variable,too-few-public-methods,bad-continuation,duplicate-code,redefined-outer-name,too-many-ancestors,import-error,superfluous-parens,unused-import,line-too-long,no-name-in-module,unnecessary-lambda,import-self,no-self-use,unidiomatic-typecheck,fixme,too-many-locals,cyclic-import,too-many-branches,bare-except,wildcard-import,dangerous-default-value,broad-except,too-many-public-methods,deprecated-lambda,anomalous-backslash-in-string,too-many-lines,reimported,too-many-statements,bad-whitespace,unpacking-non-sequence,too-many-instance-attributes,abstract-method,old-style-class,global-statement,attribute-defined-outside-init,arguments-differ,undefined-all-variable,no-init,useless-else-on-loop,super-init-not-called,notimplemented-raised,too-many-return-statements,pointless-string-statement,global-variable-undefined,bad-classmethod-argument,too-many-format-args,parse-error,no-self-argument,pointless-statement,undefined-variable + + +[REPORTS] + +# Set the output format. Available formats are text, parseable, colorized, msvs +# (visual studio) and html. You can also give a reporter class, eg +# mypackage.mymodule.MyReporterClass. +output-format=text + +# Put messages in a separate file for each module / package specified on the +# command line instead of printing them on stdout. Reports (if any) will be +# written in a file name "pylint_global.[txt|html]". +files-output=no + +# Tells whether to display a full report or only the messages +reports=no + +# Python expression which should return a note less than 10 (10 is the highest +# note). You have access to the variables errors warning, statement which +# respectively contain the number of errors / warnings messages and the total +# number of statements analyzed. This is used by the global evaluation report +# (RP0004). +evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10) + +# Add a comment according to your evaluation note. This is used by the global +# evaluation report (RP0004). +comment=no + +# Template used to display messages. This is a python new-style format string +# used to format the message information. See doc for all details +#msg-template= + + +[MISCELLANEOUS] + +# List of note tags to take in consideration, separated by a comma. +notes=FIXME,XXX,TODO + + +[BASIC] + +# Required attributes for module, separated by a comma +required-attributes= + +# List of builtins function names that should not be used, separated by a comma +bad-functions= + +# Good variable names which should always be accepted, separated by a comma +good-names=i,j,k,ex,Run,_ + +# Bad variable names which should always be refused, separated by a comma +bad-names=baz,toto,tutu,tata + +# Colon-delimited sets of names that determine each other's naming style when +# the name regexes allow several styles. +name-group= + +# Include a hint for the correct naming format with invalid-name +include-naming-hint=no + +# Regular expression matching correct function names +function-rgx=[a-z_][a-z0-9_]{2,30}$ + +# Naming hint for function names +function-name-hint=[a-z_][a-z0-9_]{2,30}$ + +# Regular expression matching correct variable names +variable-rgx=[a-z_][a-z0-9_]{2,30}$ + +# Naming hint for variable names +variable-name-hint=[a-z_][a-z0-9_]{2,30}$ + +# Regular expression matching correct constant names +const-rgx=(([A-Z_][A-Z0-9_]*)|(__.*__))$ + +# Naming hint for constant names +const-name-hint=(([A-Z_][A-Z0-9_]*)|(__.*__))$ + +# Regular expression matching correct attribute names +attr-rgx=[a-z_][a-z0-9_]{2,30}$ + +# Naming hint for attribute names +attr-name-hint=[a-z_][a-z0-9_]{2,30}$ + +# Regular expression matching correct argument names +argument-rgx=[a-z_][a-z0-9_]{2,30}$ + +# Naming hint for argument names +argument-name-hint=[a-z_][a-z0-9_]{2,30}$ + +# Regular expression matching correct class attribute names +class-attribute-rgx=([A-Za-z_][A-Za-z0-9_]{2,30}|(__.*__))$ + +# Naming hint for class attribute names +class-attribute-name-hint=([A-Za-z_][A-Za-z0-9_]{2,30}|(__.*__))$ + +# Regular expression matching correct inline iteration names +inlinevar-rgx=[A-Za-z_][A-Za-z0-9_]*$ + +# Naming hint for inline iteration names +inlinevar-name-hint=[A-Za-z_][A-Za-z0-9_]*$ + +# Regular expression matching correct class names +class-rgx=[A-Z_][a-zA-Z0-9]+$ + +# Naming hint for class names +class-name-hint=[A-Z_][a-zA-Z0-9]+$ + +# Regular expression matching correct module names +module-rgx=(([a-z_][a-z0-9_]*)|([A-Z][a-zA-Z0-9]+))$ + +# Naming hint for module names +module-name-hint=(([a-z_][a-z0-9_]*)|([A-Z][a-zA-Z0-9]+))$ + +# Regular expression matching correct method names +method-rgx=[a-z_][a-z0-9_]{2,30}$ + +# Naming hint for method names +method-name-hint=[a-z_][a-z0-9_]{2,30}$ + +# Regular expression which should only match function or class names that do +# not require a docstring. +no-docstring-rgx=__.*__ + +# Minimum line length for functions/classes that require docstrings, shorter +# ones are exempt. +docstring-min-length=-1 + + +[FORMAT] + +# Maximum number of characters on a single line. +max-line-length=100 + +# Regexp for a line that is allowed to be longer than the limit. +ignore-long-lines=^\s*(# )??$ + +# Allow the body of an if to be on the same line as the test if there is no +# else. +single-line-if-stmt=no + +# List of optional constructs for which whitespace checking is disabled +no-space-check=trailing-comma,dict-separator + +# Maximum number of lines in a module +max-module-lines=1000 + +# String used as indentation unit. This is usually " " (4 spaces) or "\t" (1 +# tab). +indent-string=' ' + +# Number of spaces of indent required inside a hanging or continued line. +indent-after-paren=4 + +# Expected format of line ending, e.g. empty (any line ending), LF or CRLF. +expected-line-ending-format= + + +[SIMILARITIES] + +# Minimum lines number of a similarity. +min-similarity-lines=4 + +# Ignore comments when computing similarities. +ignore-comments=yes + +# Ignore docstrings when computing similarities. +ignore-docstrings=yes + +# Ignore imports when computing similarities. +ignore-imports=no + + +[VARIABLES] + +# Tells whether we should check for unused import in __init__ files. +init-import=no + +# A regular expression matching the name of dummy variables (i.e. expectedly +# not used). +dummy-variables-rgx=_$|dummy + +# List of additional names supposed to be defined in builtins. Remember that +# you should avoid to define new builtins when possible. +additional-builtins= + +# List of strings which can identify a callback function by name. A callback +# name must start or end with one of those strings. +callbacks=cb_,_cb + + +[SPELLING] + +# Spelling dictionary name. Available dictionaries: none. To make it working +# install python-enchant package. +spelling-dict= + +# List of comma separated words that should not be checked. +spelling-ignore-words= + +# A path to a file that contains private dictionary; one word per line. +spelling-private-dict-file= + +# Tells whether to store unknown words to indicated private dictionary in +# --spelling-private-dict-file option instead of raising a message. +spelling-store-unknown-words=no + + +[LOGGING] + +# Logging modules to check that the string format arguments are in logging +# function parameter format +logging-modules=logging + + +[TYPECHECK] + +# Tells whether missing members accessed in mixin class should be ignored. A +# mixin class is detected if its name ends with "mixin" (case insensitive). +ignore-mixin-members=yes + +# List of module names for which member attributes should not be checked +# (useful for modules/projects where namespaces are manipulated during runtime +# and thus existing member attributes cannot be deduced by static analysis +ignored-modules= + +# List of classes names for which member attributes should not be checked +# (useful for classes with attributes dynamically set). +ignored-classes=SQLObject + +# When zope mode is activated, add a predefined set of Zope acquired attributes +# to generated-members. +zope=no + +# List of members which are set dynamically and missed by pylint inference +# system, and so shouldn't trigger E0201 when accessed. Python regular +# expressions are accepted. +generated-members=REQUEST,acl_users,aq_parent + + +[CLASSES] + +# List of interface methods to ignore, separated by a comma. This is used for +# instance to not check methods defines in Zope's Interface base class. +ignore-iface-methods=isImplementedBy,deferred,extends,names,namesAndDescriptions,queryDescriptionFor,getBases,getDescriptionFor,getDoc,getName,getTaggedValue,getTaggedValueTags,isEqualOrExtendedBy,setTaggedValue,isImplementedByInstancesOf,adaptWith,is_implemented_by + +# List of method names used to declare (i.e. assign) instance attributes. +defining-attr-methods=__init__,__new__,setUp + +# List of valid names for the first argument in a class method. +valid-classmethod-first-arg=cls + +# List of valid names for the first argument in a metaclass class method. +valid-metaclass-classmethod-first-arg=mcs + +# List of member names, which should be excluded from the protected access +# warning. +exclude-protected=_asdict,_fields,_replace,_source,_make + + +[IMPORTS] + +# Deprecated modules which should not be used, separated by a comma +deprecated-modules=regsub,TERMIOS,Bastion,rexec + +# Create a graph of every (i.e. internal and external) dependencies in the +# given file (report RP0402 must not be disabled) +import-graph= + +# Create a graph of external dependencies in the given file (report RP0402 must +# not be disabled) +ext-import-graph= + +# Create a graph of internal dependencies in the given file (report RP0402 must +# not be disabled) +int-import-graph= + + +[DESIGN] + +# Maximum number of arguments for function / method +max-args=5 + +# Argument names that match this expression will be ignored. Default to name +# with leading underscore +ignored-argument-names=_.* + +# Maximum number of locals for function / method body +max-locals=15 + +# Maximum number of return / yield for function / method body +max-returns=6 + +# Maximum number of branch for function / method body +max-branches=12 + +# Maximum number of statements in function / method body +max-statements=50 + +# Maximum number of parents for a class (see R0901). +max-parents=7 + +# Maximum number of attributes for a class (see R0902). +max-attributes=7 + +# Minimum number of public methods for a class (see R0903). +min-public-methods=2 + +# Maximum number of public methods for a class (see R0904). +max-public-methods=20 + + +[EXCEPTIONS] + +# Exceptions that will emit a warning when being caught. Defaults to +# "Exception" +overgeneral-exceptions=Exception diff --git a/python/pyspark/ml/param/shared.py b/python/pyspark/ml/param/shared.py index bc088e4c29e26..595124726366d 100644 --- a/python/pyspark/ml/param/shared.py +++ b/python/pyspark/ml/param/shared.py @@ -444,7 +444,7 @@ class DecisionTreeParams(Params): minInfoGain = Param(Params._dummy(), "minInfoGain", "Minimum information gain for a split to be considered at a tree node.") maxMemoryInMB = Param(Params._dummy(), "maxMemoryInMB", "Maximum memory in MB allocated to histogram aggregation.") cacheNodeIds = Param(Params._dummy(), "cacheNodeIds", "If false, the algorithm will pass trees to executors to match instances with nodes. If true, the algorithm will cache node IDs for each instance. Caching can speed up training of deeper trees.") - + def __init__(self): super(DecisionTreeParams, self).__init__() @@ -460,7 +460,7 @@ def __init__(self): self.maxMemoryInMB = Param(self, "maxMemoryInMB", "Maximum memory in MB allocated to histogram aggregation.") #: param for If false, the algorithm will pass trees to executors to match instances with nodes. If true, the algorithm will cache node IDs for each instance. Caching can speed up training of deeper trees. self.cacheNodeIds = Param(self, "cacheNodeIds", "If false, the algorithm will pass trees to executors to match instances with nodes. If true, the algorithm will cache node IDs for each instance. Caching can speed up training of deeper trees.") - + def setMaxDepth(self, value): """ Sets the value of :py:attr:`maxDepth`. diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index c5c0add49d02c..21225016805bc 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -893,7 +893,8 @@ def test_pipe_functions(self): self.assertRaises(Py4JJavaError, rdd.pipe('cc', checkCode=True).collect) result = rdd.pipe('cat').collect() result.sort() - [self.assertEqual(x, y) for x, y in zip(data, result)] + for x, y in zip(data, result): + self.assertEqual(x, y) self.assertRaises(Py4JJavaError, rdd.pipe('grep 4', checkCode=True).collect) self.assertEqual([], rdd.pipe('grep 4').collect()) From 6f6902597d5d687049c103bc0cf6da30919b92d8 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 15 Jul 2015 09:48:33 -0700 Subject: [PATCH 12/27] [SPARK-8840] [SPARKR] Add float coercion on SparkR JIRA: https://issues.apache.org/jira/browse/SPARK-8840 Currently the type coercion rules don't include float type. This PR simply adds it. Author: Liang-Chi Hsieh Closes #7280 from viirya/add_r_float_coercion and squashes the following commits: c86dc0e [Liang-Chi Hsieh] For comments. dbf0c1b [Liang-Chi Hsieh] Implicitly convert Double to Float based on provided schema. 733015a [Liang-Chi Hsieh] Add test case for DataFrame with float type. 30c2a40 [Liang-Chi Hsieh] Update test case. 52b5294 [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into add_r_float_coercion 6f9159d [Liang-Chi Hsieh] Add another test case. 8db3244 [Liang-Chi Hsieh] schema also needs to support float. add test case. 0dcc992 [Liang-Chi Hsieh] Add float coercion on SparkR. --- R/pkg/R/deserialize.R | 1 + R/pkg/R/schema.R | 1 + R/pkg/inst/tests/test_sparkSQL.R | 26 +++++++++++++++++++ .../scala/org/apache/spark/api/r/SerDe.scala | 4 +++ .../org/apache/spark/sql/api/r/SQLUtils.scala | 15 ++++++++--- 5 files changed, 44 insertions(+), 3 deletions(-) diff --git a/R/pkg/R/deserialize.R b/R/pkg/R/deserialize.R index d961bbc383688..7d1f6b0819ed0 100644 --- a/R/pkg/R/deserialize.R +++ b/R/pkg/R/deserialize.R @@ -23,6 +23,7 @@ # Int -> integer # String -> character # Boolean -> logical +# Float -> double # Double -> double # Long -> double # Array[Byte] -> raw diff --git a/R/pkg/R/schema.R b/R/pkg/R/schema.R index 15e2bdbd55d79..06df430687682 100644 --- a/R/pkg/R/schema.R +++ b/R/pkg/R/schema.R @@ -123,6 +123,7 @@ structField.character <- function(x, type, nullable = TRUE) { } options <- c("byte", "integer", + "float", "double", "numeric", "character", diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R index b0ea38854304e..76f74f80834a9 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -108,6 +108,32 @@ test_that("create DataFrame from RDD", { expect_equal(count(df), 10) expect_equal(columns(df), c("a", "b")) expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) + + df <- jsonFile(sqlContext, jsonPathNa) + hiveCtx <- tryCatch({ + newJObject("org.apache.spark.sql.hive.test.TestHiveContext", ssc) + }, error = function(err) { + skip("Hive is not build with SparkSQL, skipped") + }) + sql(hiveCtx, "CREATE TABLE people (name string, age double, height float)") + insertInto(df, "people") + expect_equal(sql(hiveCtx, "SELECT age from people WHERE name = 'Bob'"), c(16)) + expect_equal(sql(hiveCtx, "SELECT height from people WHERE name ='Bob'"), c(176.5)) + + schema <- structType(structField("name", "string"), structField("age", "integer"), + structField("height", "float")) + df2 <- createDataFrame(sqlContext, df.toRDD, schema) + expect_equal(columns(df2), c("name", "age", "height")) + expect_equal(dtypes(df2), list(c("name", "string"), c("age", "int"), c("height", "float"))) + expect_equal(collect(where(df2, df2$name == "Bob")), c("Bob", 16, 176.5)) + + localDF <- data.frame(name=c("John", "Smith", "Sarah"), age=c(19, 23, 18), height=c(164.10, 181.4, 173.7)) + df <- createDataFrame(sqlContext, localDF, schema) + expect_is(df, "DataFrame") + expect_equal(count(df), 3) + expect_equal(columns(df), c("name", "age", "height")) + expect_equal(dtypes(df), list(c("name", "string"), c("age", "int"), c("height", "float"))) + expect_equal(collect(where(df, df$name == "John")), c("John", 19, 164.10)) }) test_that("convert NAs to null type in DataFrames", { diff --git a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala index 56adc857d4ce0..d5b4260bf4529 100644 --- a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala +++ b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala @@ -179,6 +179,7 @@ private[spark] object SerDe { // Int -> integer // String -> character // Boolean -> logical + // Float -> double // Double -> double // Long -> double // Array[Byte] -> raw @@ -215,6 +216,9 @@ private[spark] object SerDe { case "long" | "java.lang.Long" => writeType(dos, "double") writeDouble(dos, value.asInstanceOf[Long].toDouble) + case "float" | "java.lang.Float" => + writeType(dos, "double") + writeDouble(dos, value.asInstanceOf[Float].toDouble) case "double" | "java.lang.Double" => writeType(dos, "double") writeDouble(dos, value.asInstanceOf[Double]) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala index 43b62f0e822f8..92861ab038f19 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala @@ -47,6 +47,7 @@ private[r] object SQLUtils { dataType match { case "byte" => org.apache.spark.sql.types.ByteType case "integer" => org.apache.spark.sql.types.IntegerType + case "float" => org.apache.spark.sql.types.FloatType case "double" => org.apache.spark.sql.types.DoubleType case "numeric" => org.apache.spark.sql.types.DoubleType case "character" => org.apache.spark.sql.types.StringType @@ -68,7 +69,7 @@ private[r] object SQLUtils { def createDF(rdd: RDD[Array[Byte]], schema: StructType, sqlContext: SQLContext): DataFrame = { val num = schema.fields.size - val rowRDD = rdd.map(bytesToRow) + val rowRDD = rdd.map(bytesToRow(_, schema)) sqlContext.createDataFrame(rowRDD, schema) } @@ -76,12 +77,20 @@ private[r] object SQLUtils { df.map(r => rowToRBytes(r)) } - private[this] def bytesToRow(bytes: Array[Byte]): Row = { + private[this] def doConversion(data: Object, dataType: DataType): Object = { + data match { + case d: java.lang.Double if dataType == FloatType => + new java.lang.Float(d) + case _ => data + } + } + + private[this] def bytesToRow(bytes: Array[Byte], schema: StructType): Row = { val bis = new ByteArrayInputStream(bytes) val dis = new DataInputStream(bis) val num = SerDe.readInt(dis) Row.fromSeq((0 until num).map { i => - SerDe.readObject(dis) + doConversion(SerDe.readObject(dis), schema.fields(i).dataType) }.toSeq) } From fa4ec3606a965238423f977808163983c9d56e0a Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 15 Jul 2015 10:31:39 -0700 Subject: [PATCH 13/27] [SPARK-9020][SQL] Support mutable state in code gen expressions We can keep expressions' mutable states in generated class(like `SpecificProjection`) as member variables, so that we can read and modify them inside codegened expressions. Author: Wenchen Fan Closes #7392 from cloud-fan/mutable-state and squashes the following commits: eb3a221 [Wenchen Fan] fix order 73144d8 [Wenchen Fan] naming improvement 318f41d [Wenchen Fan] address more comments d43b65d [Wenchen Fan] address comments fd45c7a [Wenchen Fan] Support mutable state in code gen expressions --- .../scala/org/apache/spark/TaskContext.scala | 15 ++- .../expressions/codegen/CodeGenerator.scala | 17 +++- .../codegen/GenerateMutableProjection.scala | 4 + .../codegen/GenerateOrdering.scala | 38 +++++-- .../codegen/GeneratePredicate.scala | 4 + .../codegen/GenerateProjection.scala | 99 ++++++++++--------- .../sql/catalyst/expressions/random.scala | 29 +++++- .../MonotonicallyIncreasingID.scala | 19 +++- .../expressions/SparkPartitionID.scala | 12 ++- 9 files changed, 171 insertions(+), 66 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala b/core/src/main/scala/org/apache/spark/TaskContext.scala index d09e17dea0911..248339148d9b7 100644 --- a/core/src/main/scala/org/apache/spark/TaskContext.scala +++ b/core/src/main/scala/org/apache/spark/TaskContext.scala @@ -32,7 +32,20 @@ object TaskContext { */ def get(): TaskContext = taskContext.get - private val taskContext: ThreadLocal[TaskContext] = new ThreadLocal[TaskContext] + /** + * Returns the partition id of currently active TaskContext. It will return 0 + * if there is no active TaskContext for cases like local execution. + */ + def getPartitionId(): Int = { + val tc = taskContext.get() + if (tc == null) { + 0 + } else { + tc.partitionId() + } + } + + private[this] val taskContext: ThreadLocal[TaskContext] = new ThreadLocal[TaskContext] // Note: protected[spark] instead of private[spark] to prevent the following two from // showing up in JavaDoc. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 9f6329bbda4ec..328d635de8743 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -56,6 +56,18 @@ class CodeGenContext { */ val references: mutable.ArrayBuffer[Expression] = new mutable.ArrayBuffer[Expression]() + /** + * Holding expressions' mutable states like `MonotonicallyIncreasingID.count` as a + * 3-tuple: java type, variable name, code to init it. + * They will be kept as member variables in generated classes like `SpecificProjection`. + */ + val mutableStates: mutable.ArrayBuffer[(String, String, String)] = + mutable.ArrayBuffer.empty[(String, String, String)] + + def addMutableState(javaType: String, variableName: String, initialValue: String): Unit = { + mutableStates += ((javaType, variableName, initialValue)) + } + val stringType: String = classOf[UTF8String].getName val decimalType: String = classOf[Decimal].getName @@ -203,7 +215,10 @@ class CodeGenContext { def isPrimitiveType(dt: DataType): Boolean = isPrimitiveType(javaType(dt)) } - +/** + * A wrapper for generated class, defines a `generate` method so that we can pass extra objects + * into generated class. + */ abstract class GeneratedClass { def generate(expressions: Array[Expression]): Any } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index addb8023d9c0b..71e47d4f9b620 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -46,6 +46,9 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu ${ctx.setColumn("mutableRow", e.dataType, i, evaluationCode.primitive)}; """ }.mkString("\n") + val mutableStates = ctx.mutableStates.map { case (javaType, variableName, initialValue) => + s"private $javaType $variableName = $initialValue;" + }.mkString("\n ") val code = s""" public Object generate($exprType[] expr) { return new SpecificProjection(expr); @@ -55,6 +58,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu private $exprType[] expressions = null; private $mutableRowType mutableRow = null; + $mutableStates public SpecificProjection($exprType[] expr) { expressions = expr; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala index d05dfc108e63a..856ff9f1f96f8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala @@ -46,30 +46,47 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalR protected def create(ordering: Seq[SortOrder]): Ordering[InternalRow] = { val ctx = newCodeGenContext() - val comparisons = ordering.zipWithIndex.map { case (order, i) => - val evalA = order.child.gen(ctx) - val evalB = order.child.gen(ctx) + val comparisons = ordering.map { order => + val eval = order.child.gen(ctx) val asc = order.direction == Ascending + val isNullA = ctx.freshName("isNullA") + val primitiveA = ctx.freshName("primitiveA") + val isNullB = ctx.freshName("isNullB") + val primitiveB = ctx.freshName("primitiveB") s""" i = a; - ${evalA.code} + boolean $isNullA; + ${ctx.javaType(order.child.dataType)} $primitiveA; + { + ${eval.code} + $isNullA = ${eval.isNull}; + $primitiveA = ${eval.primitive}; + } i = b; - ${evalB.code} - if (${evalA.isNull} && ${evalB.isNull}) { + boolean $isNullB; + ${ctx.javaType(order.child.dataType)} $primitiveB; + { + ${eval.code} + $isNullB = ${eval.isNull}; + $primitiveB = ${eval.primitive}; + } + if ($isNullA && $isNullB) { // Nothing - } else if (${evalA.isNull}) { + } else if ($isNullA) { return ${if (order.direction == Ascending) "-1" else "1"}; - } else if (${evalB.isNull}) { + } else if ($isNullB) { return ${if (order.direction == Ascending) "1" else "-1"}; } else { - int comp = ${ctx.genComp(order.child.dataType, evalA.primitive, evalB.primitive)}; + int comp = ${ctx.genComp(order.child.dataType, primitiveA, primitiveB)}; if (comp != 0) { return ${if (asc) "comp" else "-comp"}; } } """ }.mkString("\n") - + val mutableStates = ctx.mutableStates.map { case (javaType, variableName, initialValue) => + s"private $javaType $variableName = $initialValue;" + }.mkString("\n ") val code = s""" public SpecificOrdering generate($exprType[] expr) { return new SpecificOrdering(expr); @@ -78,6 +95,7 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalR class SpecificOrdering extends ${classOf[BaseOrdering].getName} { private $exprType[] expressions = null; + $mutableStates public SpecificOrdering($exprType[] expr) { expressions = expr; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala index 274a42cb69087..9e5a745d512e9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala @@ -40,6 +40,9 @@ object GeneratePredicate extends CodeGenerator[Expression, (InternalRow) => Bool protected def create(predicate: Expression): ((InternalRow) => Boolean) = { val ctx = newCodeGenContext() val eval = predicate.gen(ctx) + val mutableStates = ctx.mutableStates.map { case (javaType, variableName, initialValue) => + s"private $javaType $variableName = $initialValue;" + }.mkString("\n ") val code = s""" public SpecificPredicate generate($exprType[] expr) { return new SpecificPredicate(expr); @@ -47,6 +50,7 @@ object GeneratePredicate extends CodeGenerator[Expression, (InternalRow) => Bool class SpecificPredicate extends ${classOf[Predicate].getName} { private final $exprType[] expressions; + $mutableStates public SpecificPredicate($exprType[] expr) { expressions = expr; } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala index 3c7ee9cc16599..3e5ca308dc31d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala @@ -151,6 +151,10 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { s"""if (!nullBits[$i]) arr[$i] = c$i;""" }.mkString("\n ") + val mutableStates = ctx.mutableStates.map { case (javaType, variableName, initialValue) => + s"private $javaType $variableName = $initialValue;" + }.mkString("\n ") + val code = s""" public SpecificProjection generate($exprType[] expr) { return new SpecificProjection(expr); @@ -158,6 +162,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { class SpecificProjection extends ${classOf[BaseProject].getName} { private $exprType[] expressions = null; + $mutableStates public SpecificProjection($exprType[] expr) { expressions = expr; @@ -165,65 +170,65 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { @Override public Object apply(Object r) { - return new SpecificRow(expressions, (InternalRow) r); + return new SpecificRow((InternalRow) r); } - } - final class SpecificRow extends ${classOf[MutableRow].getName} { + final class SpecificRow extends ${classOf[MutableRow].getName} { - $columns + $columns - public SpecificRow($exprType[] expressions, InternalRow i) { - $initColumns - } + public SpecificRow(InternalRow i) { + $initColumns + } - public int length() { return ${expressions.length};} - protected boolean[] nullBits = new boolean[${expressions.length}]; - public void setNullAt(int i) { nullBits[i] = true; } - public boolean isNullAt(int i) { return nullBits[i]; } + public int length() { return ${expressions.length};} + protected boolean[] nullBits = new boolean[${expressions.length}]; + public void setNullAt(int i) { nullBits[i] = true; } + public boolean isNullAt(int i) { return nullBits[i]; } - public Object get(int i) { - if (isNullAt(i)) return null; - switch (i) { - $getCases + public Object get(int i) { + if (isNullAt(i)) return null; + switch (i) { + $getCases + } + return null; } - return null; - } - public void update(int i, Object value) { - if (value == null) { - setNullAt(i); - return; + public void update(int i, Object value) { + if (value == null) { + setNullAt(i); + return; + } + nullBits[i] = false; + switch (i) { + $updateCases + } } - nullBits[i] = false; - switch (i) { - $updateCases + $specificAccessorFunctions + $specificMutatorFunctions + + @Override + public int hashCode() { + int result = 37; + $hashUpdates + return result; } - } - $specificAccessorFunctions - $specificMutatorFunctions - - @Override - public int hashCode() { - int result = 37; - $hashUpdates - return result; - } - @Override - public boolean equals(Object other) { - if (other instanceof SpecificRow) { - SpecificRow row = (SpecificRow) other; - $columnChecks - return true; + @Override + public boolean equals(Object other) { + if (other instanceof SpecificRow) { + SpecificRow row = (SpecificRow) other; + $columnChecks + return true; + } + return super.equals(other); } - return super.equals(other); - } - @Override - public InternalRow copy() { - Object[] arr = new Object[${expressions.length}]; - ${copyColumns} - return new ${classOf[GenericInternalRow].getName}(arr); + @Override + public InternalRow copy() { + Object[] arr = new Object[${expressions.length}]; + ${copyColumns} + return new ${classOf[GenericInternalRow].getName}(arr); + } } } """ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala index 6cdc3000382e2..e10ba55396664 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.TaskContext import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode} import org.apache.spark.sql.types.{DataType, DoubleType} import org.apache.spark.util.Utils import org.apache.spark.util.random.XORShiftRandom @@ -38,11 +39,7 @@ abstract class RDG(seed: Long) extends LeafExpression with Serializable { * Record ID within each partition. By being transient, the Random Number Generator is * reset every time we serialize and deserialize it. */ - @transient protected lazy val partitionId = TaskContext.get() match { - case null => 0 - case _ => TaskContext.get().partitionId() - } - @transient protected lazy val rng = new XORShiftRandom(seed + partitionId) + @transient protected lazy val rng = new XORShiftRandom(seed + TaskContext.getPartitionId) override def deterministic: Boolean = false @@ -61,6 +58,17 @@ case class Rand(seed: Long) extends RDG(seed) { case IntegerLiteral(s) => s case _ => throw new AnalysisException("Input argument to rand must be an integer literal.") }) + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val rngTerm = ctx.freshName("rng") + val className = classOf[XORShiftRandom].getCanonicalName + ctx.addMutableState(className, rngTerm, + s"new $className($seed + org.apache.spark.TaskContext.getPartitionId())") + ev.isNull = "false" + s""" + final ${ctx.javaType(dataType)} ${ev.primitive} = $rngTerm.nextDouble(); + """ + } } /** Generate a random column with i.i.d. gaussian random distribution. */ @@ -73,4 +81,15 @@ case class Randn(seed: Long) extends RDG(seed) { case IntegerLiteral(s) => s case _ => throw new AnalysisException("Input argument to rand must be an integer literal.") }) + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val rngTerm = ctx.freshName("rng") + val className = classOf[XORShiftRandom].getCanonicalName + ctx.addMutableState(className, rngTerm, + s"new $className($seed + org.apache.spark.TaskContext.getPartitionId())") + ev.isNull = "false" + s""" + final ${ctx.javaType(dataType)} ${ev.primitive} = $rngTerm.nextGaussian(); + """ + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/MonotonicallyIncreasingID.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/MonotonicallyIncreasingID.scala index 437d143e53f3f..69a37750d7525 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/MonotonicallyIncreasingID.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/MonotonicallyIncreasingID.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.expressions import org.apache.spark.TaskContext import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.LeafExpression +import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext} import org.apache.spark.sql.types.{LongType, DataType} /** @@ -40,6 +41,8 @@ private[sql] case class MonotonicallyIncreasingID() extends LeafExpression { */ @transient private[this] var count: Long = 0L + @transient private lazy val partitionMask = TaskContext.getPartitionId.toLong << 33 + override def nullable: Boolean = false override def dataType: DataType = LongType @@ -47,6 +50,20 @@ private[sql] case class MonotonicallyIncreasingID() extends LeafExpression { override def eval(input: InternalRow): Long = { val currentCount = count count += 1 - (TaskContext.get().partitionId().toLong << 33) + currentCount + partitionMask + currentCount + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val countTerm = ctx.freshName("count") + val partitionMaskTerm = ctx.freshName("partitionMask") + ctx.addMutableState(ctx.JAVA_LONG, countTerm, "0L") + ctx.addMutableState(ctx.JAVA_LONG, partitionMaskTerm, + "((long) org.apache.spark.TaskContext.getPartitionId()) << 33") + + ev.isNull = "false" + s""" + final ${ctx.javaType(dataType)} ${ev.primitive} = $partitionMaskTerm + $countTerm; + $countTerm++; + """ } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/SparkPartitionID.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/SparkPartitionID.scala index 822d3d8c9108d..5f1b514f2cff2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/SparkPartitionID.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/SparkPartitionID.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.expressions import org.apache.spark.TaskContext import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.LeafExpression +import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext} import org.apache.spark.sql.types.{IntegerType, DataType} @@ -32,5 +33,14 @@ private[sql] case object SparkPartitionID extends LeafExpression { override def dataType: DataType = IntegerType - override def eval(input: InternalRow): Int = TaskContext.get().partitionId() + @transient private lazy val partitionId = TaskContext.getPartitionId + + override def eval(input: InternalRow): Int = partitionId + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val idTerm = ctx.freshName("partitionId") + ctx.addMutableState(ctx.JAVA_INT, idTerm, "org.apache.spark.TaskContext.getPartitionId()") + ev.isNull = "false" + s"final ${ctx.javaType(dataType)} ${ev.primitive} = $idTerm;" + } } From a9385271a9f6b97ec6aa619cf56ee556ba2fb0de Mon Sep 17 00:00:00 2001 From: "zhichao.li" Date: Wed, 15 Jul 2015 10:43:38 -0700 Subject: [PATCH 14/27] [SPARK-8221][SQL]Add pmod function https://issues.apache.org/jira/browse/SPARK-8221 One concern is the result would be negative if the divisor is not positive( i.e pmod(7, -3) ), but the behavior is the same as hive. Author: zhichao.li Closes #6783 from zhichao-li/pmod2 and squashes the following commits: 7083eb9 [zhichao.li] update to the latest type checking d26dba7 [zhichao.li] add pmod --- .../catalyst/analysis/FunctionRegistry.scala | 1 + .../catalyst/analysis/HiveTypeCoercion.scala | 6 ++ .../sql/catalyst/expressions/arithmetic.scala | 94 +++++++++++++++++++ .../ArithmeticExpressionSuite.scala | 16 +++- .../org/apache/spark/sql/functions.scala | 17 ++++ .../spark/sql/DataFrameFunctionsSuite.scala | 37 ++++++++ 6 files changed, 170 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index ec75f51d5e4ff..d2678ce860701 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -115,6 +115,7 @@ object FunctionRegistry { expression[Log2]("log2"), expression[Pow]("pow"), expression[Pow]("power"), + expression[Pmod]("pmod"), expression[UnaryPositive]("positive"), expression[Rint]("rint"), expression[Round]("round"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 15da5eecc8d3c..25087915b5c35 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -426,6 +426,12 @@ object HiveTypeCoercion { DecimalType(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2)) ) + case Pmod(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => + Cast( + Pmod(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)), + DecimalType(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2)) + ) + // When we compare 2 decimal types with different precisions, cast them to the smallest // common precision. case b @ BinaryComparison(e1 @ DecimalType.Expression(p1, s1), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 1a55a0876f303..394ef556e04a2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -377,3 +377,97 @@ case class MinOf(left: Expression, right: Expression) extends BinaryArithmetic { override def symbol: String = "min" override def prettyName: String = symbol } + +case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic { + + override def toString: String = s"pmod($left, $right)" + + override def symbol: String = "pmod" + + protected def checkTypesInternal(t: DataType) = + TypeUtils.checkForNumericExpr(t, "pmod") + + override def inputType: AbstractDataType = NumericType + + protected override def nullSafeEval(left: Any, right: Any) = + dataType match { + case IntegerType => pmod(left.asInstanceOf[Int], right.asInstanceOf[Int]) + case LongType => pmod(left.asInstanceOf[Long], right.asInstanceOf[Long]) + case ShortType => pmod(left.asInstanceOf[Short], right.asInstanceOf[Short]) + case ByteType => pmod(left.asInstanceOf[Byte], right.asInstanceOf[Byte]) + case FloatType => pmod(left.asInstanceOf[Float], right.asInstanceOf[Float]) + case DoubleType => pmod(left.asInstanceOf[Double], right.asInstanceOf[Double]) + case _: DecimalType => pmod(left.asInstanceOf[Decimal], right.asInstanceOf[Decimal]) + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + nullSafeCodeGen(ctx, ev, (eval1, eval2) => { + dataType match { + case dt: DecimalType => + val decimalAdd = "$plus" + s""" + ${ctx.javaType(dataType)} r = $eval1.remainder($eval2); + if (r.compare(new org.apache.spark.sql.types.Decimal().set(0)) < 0) { + ${ev.primitive} = (r.$decimalAdd($eval2)).remainder($eval2); + } else { + ${ev.primitive} = r; + } + """ + // byte and short are casted into int when add, minus, times or divide + case ByteType | ShortType => + s""" + ${ctx.javaType(dataType)} r = (${ctx.javaType(dataType)})($eval1 % $eval2); + if (r < 0) { + ${ev.primitive} = (${ctx.javaType(dataType)})((r + $eval2) % $eval2); + } else { + ${ev.primitive} = r; + } + """ + case _ => + s""" + ${ctx.javaType(dataType)} r = $eval1 % $eval2; + if (r < 0) { + ${ev.primitive} = (r + $eval2) % $eval2; + } else { + ${ev.primitive} = r; + } + """ + } + }) + } + + private def pmod(a: Int, n: Int): Int = { + val r = a % n + if (r < 0) {(r + n) % n} else r + } + + private def pmod(a: Long, n: Long): Long = { + val r = a % n + if (r < 0) {(r + n) % n} else r + } + + private def pmod(a: Byte, n: Byte): Byte = { + val r = a % n + if (r < 0) {((r + n) % n).toByte} else r.toByte + } + + private def pmod(a: Double, n: Double): Double = { + val r = a % n + if (r < 0) {(r + n) % n} else r + } + + private def pmod(a: Short, n: Short): Short = { + val r = a % n + if (r < 0) {((r + n) % n).toShort} else r.toShort + } + + private def pmod(a: Float, n: Float): Float = { + val r = a % n + if (r < 0) {(r + n) % n} else r + } + + private def pmod(a: Decimal, n: Decimal): Decimal = { + val r = a % n + if (r.compare(Decimal(0)) < 0) {(r + n) % n} else r + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala index 6c93698f8017b..e7e5231d32c9e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala @@ -21,7 +21,6 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.types.Decimal - class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { /** @@ -158,4 +157,19 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(MinOf(Array(1.toByte, 2.toByte), Array(1.toByte, 3.toByte)), Array(1.toByte, 2.toByte)) } + + test("pmod") { + testNumericDataTypes { convert => + val left = Literal(convert(7)) + val right = Literal(convert(3)) + checkEvaluation(Pmod(left, right), convert(1)) + checkEvaluation(Pmod(Literal.create(null, left.dataType), right), null) + checkEvaluation(Pmod(left, Literal.create(null, right.dataType)), null) + checkEvaluation(Remainder(left, Literal(convert(0))), null) // mod by 0 + } + checkEvaluation(Pmod(-7, 3), 2) + checkEvaluation(Pmod(7.2D, 4.1D), 3.1000000000000005) + checkEvaluation(Pmod(Decimal(0.7), Decimal(0.2)), Decimal(0.1)) + checkEvaluation(Pmod(2L, Long.MaxValue), 2) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 5119ee31d852d..c7deaca8437a1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1371,6 +1371,23 @@ object functions { */ def pow(l: Double, rightName: String): Column = pow(l, Column(rightName)) + /** + * Returns the positive value of dividend mod divisor. + * + * @group math_funcs + * @since 1.5.0 + */ + def pmod(dividend: Column, divisor: Column): Column = Pmod(dividend.expr, divisor.expr) + + /** + * Returns the positive value of dividend mod divisor. + * + * @group math_funcs + * @since 1.5.0 + */ + def pmod(dividendColName: String, divisorColName: String): Column = + pmod(Column(dividendColName), Column(divisorColName)) + /** * Returns the double value that is closest in value to the argument and * is equal to a mathematical integer. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 6cebec95d2850..70bd78737f69c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -403,4 +403,41 @@ class DataFrameFunctionsSuite extends QueryTest { Seq(Row(2), Row(2), Row(2), Row(2), Row(3), Row(3)) ) } + + test("pmod") { + val intData = Seq((7, 3), (-7, 3)).toDF("a", "b") + checkAnswer( + intData.select(pmod('a, 'b)), + Seq(Row(1), Row(2)) + ) + checkAnswer( + intData.select(pmod('a, lit(3))), + Seq(Row(1), Row(2)) + ) + checkAnswer( + intData.select(pmod(lit(-7), 'b)), + Seq(Row(2), Row(2)) + ) + checkAnswer( + intData.selectExpr("pmod(a, b)"), + Seq(Row(1), Row(2)) + ) + checkAnswer( + intData.selectExpr("pmod(a, 3)"), + Seq(Row(1), Row(2)) + ) + checkAnswer( + intData.selectExpr("pmod(-7, b)"), + Seq(Row(2), Row(2)) + ) + val doubleData = Seq((7.2, 4.1)).toDF("a", "b") + checkAnswer( + doubleData.select(pmod('a, 'b)), + Seq(Row(3.1000000000000005)) // same as hive + ) + checkAnswer( + doubleData.select(pmod(lit(2), lit(Int.MaxValue))), + Seq(Row(2)) + ) + } } From 9716a727fb2d11380794549039e12e53c771e120 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 15 Jul 2015 10:46:22 -0700 Subject: [PATCH 15/27] [Minor][SQL] Allow spaces in the beginning and ending of string for Interval This is a minor fixing for #7355 to allow spaces in the beginning and ending of string parsed to `Interval`. Author: Liang-Chi Hsieh Closes #7390 from viirya/fix_interval_string and squashes the following commits: 9eb6831 [Liang-Chi Hsieh] Use trim instead of modifying regex. 57861f7 [Liang-Chi Hsieh] Fix scala style. 815a9cb [Liang-Chi Hsieh] Slightly modify regex to allow spaces in the beginning and ending of string. --- .../main/java/org/apache/spark/unsafe/types/Interval.java | 1 + .../java/org/apache/spark/unsafe/types/IntervalSuite.java | 6 ++++++ 2 files changed, 7 insertions(+) diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/Interval.java b/unsafe/src/main/java/org/apache/spark/unsafe/types/Interval.java index eb7475e9df869..905ea0b7b878c 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/types/Interval.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/types/Interval.java @@ -62,6 +62,7 @@ public static Interval fromString(String s) { if (s == null) { return null; } + s = s.trim(); Matcher m = p.matcher(s); if (!m.matches() || s.equals("interval")) { return null; diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/types/IntervalSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/types/IntervalSuite.java index 44a949a371f2b..1832d0bc65551 100644 --- a/unsafe/src/test/java/org/apache/spark/unsafe/types/IntervalSuite.java +++ b/unsafe/src/test/java/org/apache/spark/unsafe/types/IntervalSuite.java @@ -75,6 +75,12 @@ public void fromStringTest() { Interval result = new Interval(-5 * 12 + 23, 0); assertEquals(Interval.fromString(input), result); + input = "interval -5 years 23 month "; + assertEquals(Interval.fromString(input), result); + + input = " interval -5 years 23 month "; + assertEquals(Interval.fromString(input), result); + // Error cases input = "interval 3month 1 hour"; assertEquals(Interval.fromString(input), null); From 303c1201c468d360a5f600ce37b8bee75a77a0e6 Mon Sep 17 00:00:00 2001 From: Shuo Xiang Date: Wed, 15 Jul 2015 12:10:53 -0700 Subject: [PATCH 16/27] [SPARK-7555] [DOCS] Add doc for elastic net in ml-guide and mllib-guide jkbradley I put the elastic net under the **Algorithm guide** section. Also add the formula of elastic net in mllib-linear `mllib-linear-methods#regularizers`. dbtsai I left the code tab for you to add example code. Do you think it is the right place? Author: Shuo Xiang Closes #6504 from coderxiang/elasticnet and squashes the following commits: f6061ee [Shuo Xiang] typo 90a7c88 [Shuo Xiang] Merge remote-tracking branch 'upstream/master' into elasticnet 0610a36 [Shuo Xiang] move out the elastic net to ml-linear-methods 8747190 [Shuo Xiang] merge master 706d3f7 [Shuo Xiang] add python code 9bc2b4c [Shuo Xiang] typo db32a60 [Shuo Xiang] java code sample aab3b3a [Shuo Xiang] Merge remote-tracking branch 'upstream/master' into elasticnet a0dae07 [Shuo Xiang] simplify code d8616fd [Shuo Xiang] Update the definition of elastic net. Add scala code; Mention Lasso and Ridge df5bd14 [Shuo Xiang] use wikipeida page in ml-linear-methods.md 78d9366 [Shuo Xiang] address comments 8ce37c2 [Shuo Xiang] Merge branch 'elasticnet' of github.com:coderxiang/spark into elasticnet 8f24848 [Shuo Xiang] Merge branch 'elastic-net-doc' of github.com:coderxiang/spark into elastic-net-doc 998d766 [Shuo Xiang] Merge branch 'elastic-net-doc' of github.com:coderxiang/spark into elastic-net-doc 89f10e4 [Shuo Xiang] Merge remote-tracking branch 'upstream/master' into elastic-net-doc 9262a72 [Shuo Xiang] update 7e07d12 [Shuo Xiang] update b32f21a [Shuo Xiang] add doc for elastic net in sparkml 937eef1 [Shuo Xiang] Merge remote-tracking branch 'upstream/master' into elastic-net-doc 180b496 [Shuo Xiang] Merge remote-tracking branch 'upstream/master' aa0717d [Shuo Xiang] Merge remote-tracking branch 'upstream/master' 5f109b4 [Shuo Xiang] Merge remote-tracking branch 'upstream/master' c5c5bfe [Shuo Xiang] Merge remote-tracking branch 'upstream/master' 98804c9 [Shuo Xiang] fix bug in topBykey and update test --- docs/ml-guide.md | 31 +++++++++ docs/ml-linear-methods.md | 129 +++++++++++++++++++++++++++++++++++ docs/mllib-linear-methods.md | 53 +++++++------- 3 files changed, 188 insertions(+), 25 deletions(-) create mode 100644 docs/ml-linear-methods.md diff --git a/docs/ml-guide.md b/docs/ml-guide.md index c74cb1f1ef8ea..8c46adf256a9a 100644 --- a/docs/ml-guide.md +++ b/docs/ml-guide.md @@ -3,6 +3,24 @@ layout: global title: Spark ML Programming Guide --- +`\[ +\newcommand{\R}{\mathbb{R}} +\newcommand{\E}{\mathbb{E}} +\newcommand{\x}{\mathbf{x}} +\newcommand{\y}{\mathbf{y}} +\newcommand{\wv}{\mathbf{w}} +\newcommand{\av}{\mathbf{\alpha}} +\newcommand{\bv}{\mathbf{b}} +\newcommand{\N}{\mathbb{N}} +\newcommand{\id}{\mathbf{I}} +\newcommand{\ind}{\mathbf{1}} +\newcommand{\0}{\mathbf{0}} +\newcommand{\unit}{\mathbf{e}} +\newcommand{\one}{\mathbf{1}} +\newcommand{\zero}{\mathbf{0}} +\]` + + Spark 1.2 introduced a new package called `spark.ml`, which aims to provide a uniform set of high-level APIs that help users create and tune practical machine learning pipelines. @@ -154,6 +172,19 @@ Parameters belong to specific instances of `Estimator`s and `Transformer`s. For example, if we have two `LogisticRegression` instances `lr1` and `lr2`, then we can build a `ParamMap` with both `maxIter` parameters specified: `ParamMap(lr1.maxIter -> 10, lr2.maxIter -> 20)`. This is useful if there are two algorithms with the `maxIter` parameter in a `Pipeline`. +# Algorithm Guides + +There are now several algorithms in the Pipelines API which are not in the lower-level MLlib API, so we link to documentation for them here. These algorithms are mostly feature transformers, which fit naturally into the `Transformer` abstraction in Pipelines, and ensembles, which fit naturally into the `Estimator` abstraction in the Pipelines. + +**Pipelines API Algorithm Guides** + +* [Feature Extraction, Transformation, and Selection](ml-features.html) +* [Ensembles](ml-ensembles.html) + +**Algorithms in `spark.ml`** + +* [Linear methods with elastic net regularization](ml-linear-methods.html) + # Code Examples This section gives code examples illustrating the functionality discussed above. diff --git a/docs/ml-linear-methods.md b/docs/ml-linear-methods.md new file mode 100644 index 0000000000000..1ac83d94c9e81 --- /dev/null +++ b/docs/ml-linear-methods.md @@ -0,0 +1,129 @@ +--- +layout: global +title: Linear Methods - ML +displayTitle: ML - Linear Methods +--- + + +`\[ +\newcommand{\R}{\mathbb{R}} +\newcommand{\E}{\mathbb{E}} +\newcommand{\x}{\mathbf{x}} +\newcommand{\y}{\mathbf{y}} +\newcommand{\wv}{\mathbf{w}} +\newcommand{\av}{\mathbf{\alpha}} +\newcommand{\bv}{\mathbf{b}} +\newcommand{\N}{\mathbb{N}} +\newcommand{\id}{\mathbf{I}} +\newcommand{\ind}{\mathbf{1}} +\newcommand{\0}{\mathbf{0}} +\newcommand{\unit}{\mathbf{e}} +\newcommand{\one}{\mathbf{1}} +\newcommand{\zero}{\mathbf{0}} +\]` + + +In MLlib, we implement popular linear methods such as logistic regression and linear least squares with L1 or L2 regularization. Refer to [the linear methods in mllib](mllib-linear-methods.html) for details. In `spark.ml`, we also include Pipelines API for [Elastic net](http://en.wikipedia.org/wiki/Elastic_net_regularization), a hybrid of L1 and L2 regularization proposed in [this paper](http://users.stat.umn.edu/~zouxx019/Papers/elasticnet.pdf). Mathematically it is defined as a linear combination of the L1-norm and the L2-norm: +`\[ +\alpha \|\wv\|_1 + (1-\alpha) \frac{1}{2}\|\wv\|_2^2, \alpha \in [0, 1]. +\]` +By setting $\alpha$ properly, it contains both L1 and L2 regularization as special cases. For example, if a [linear regression](https://en.wikipedia.org/wiki/Linear_regression) model is trained with the elastic net parameter $\alpha$ set to $1$, it is equivalent to a [Lasso](http://en.wikipedia.org/wiki/Least_squares#Lasso_method) model. On the other hand, if $\alpha$ is set to $0$, the trained model reduces to a [ridge regression](http://en.wikipedia.org/wiki/Tikhonov_regularization) model. We implement Pipelines API for both linear regression and logistic regression with elastic net regularization. + +**Examples** + +
+ +
+ +{% highlight scala %} + +import org.apache.spark.ml.classification.LogisticRegression +import org.apache.spark.mllib.util.MLUtils + +// Load training data +val training = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF() + +val lr = new LogisticRegression() + .setMaxIter(10) + .setRegParam(0.3) + .setElasticNetParam(0.8) + +// Fit the model +val lrModel = lr.fit(training) + +// Print the weights and intercept for logistic regression +println(s"Weights: ${lrModel.weights} Intercept: ${lrModel.intercept}") + +{% endhighlight %} + +
+ +
+ +{% highlight java %} + +import org.apache.spark.ml.classification.LogisticRegression; +import org.apache.spark.ml.classification.LogisticRegressionModel; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.mllib.util.MLUtils; +import org.apache.spark.SparkConf; +import org.apache.spark.SparkContext; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.SQLContext; + +public class LogisticRegressionWithElasticNetExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf() + .setAppName("Logistic Regression with Elastic Net Example"); + + SparkContext sc = new SparkContext(conf); + SQLContext sql = new SQLContext(sc); + String path = "sample_libsvm_data.txt"; + + // Load training data + DataFrame training = sql.createDataFrame(MLUtils.loadLibSVMFile(sc, path).toJavaRDD(), LabeledPoint.class); + + LogisticRegression lr = new LogisticRegression() + .setMaxIter(10) + .setRegParam(0.3) + .setElasticNetParam(0.8) + + // Fit the model + LogisticRegressionModel lrModel = lr.fit(training); + + // Print the weights and intercept for logistic regression + System.out.println("Weights: " + lrModel.weights() + " Intercept: " + lrModel.intercept()); + } +} +{% endhighlight %} +
+ +
+ +{% highlight python %} + +from pyspark.ml.classification import LogisticRegression +from pyspark.mllib.regression import LabeledPoint +from pyspark.mllib.util import MLUtils + +# Load training data +training = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF() + +lr = LogisticRegression(maxIter=10, regParam=0.3, elasticNetParam=0.8) + +# Fit the model +lrModel = lr.fit(training) + +# Print the weights and intercept for logistic regression +print("Weights: " + str(lrModel.weights)) +print("Intercept: " + str(lrModel.intercept)) +{% endhighlight %} + +
+ +
+ +### Optimization + +The optimization algorithm underlies the implementation is called [Orthant-Wise Limited-memory QuasiNewton](http://research-srv.microsoft.com/en-us/um/people/jfgao/paper/icml07scalable.pdf) +(OWL-QN). It is an extension of L-BFGS that can effectively handle L1 regularization and elastic net. diff --git a/docs/mllib-linear-methods.md b/docs/mllib-linear-methods.md index 3927d65fbf8fb..07655baa414b5 100644 --- a/docs/mllib-linear-methods.md +++ b/docs/mllib-linear-methods.md @@ -10,7 +10,7 @@ displayTitle: MLlib - Linear Methods `\[ \newcommand{\R}{\mathbb{R}} -\newcommand{\E}{\mathbb{E}} +\newcommand{\E}{\mathbb{E}} \newcommand{\x}{\mathbf{x}} \newcommand{\y}{\mathbf{y}} \newcommand{\wv}{\mathbf{w}} @@ -18,10 +18,10 @@ displayTitle: MLlib - Linear Methods \newcommand{\bv}{\mathbf{b}} \newcommand{\N}{\mathbb{N}} \newcommand{\id}{\mathbf{I}} -\newcommand{\ind}{\mathbf{1}} -\newcommand{\0}{\mathbf{0}} -\newcommand{\unit}{\mathbf{e}} -\newcommand{\one}{\mathbf{1}} +\newcommand{\ind}{\mathbf{1}} +\newcommand{\0}{\mathbf{0}} +\newcommand{\unit}{\mathbf{e}} +\newcommand{\one}{\mathbf{1}} \newcommand{\zero}{\mathbf{0}} \]` @@ -29,7 +29,7 @@ displayTitle: MLlib - Linear Methods Many standard *machine learning* methods can be formulated as a convex optimization problem, i.e. the task of finding a minimizer of a convex function `$f$` that depends on a variable vector -`$\wv$` (called `weights` in the code), which has `$d$` entries. +`$\wv$` (called `weights` in the code), which has `$d$` entries. Formally, we can write this as the optimization problem `$\min_{\wv \in\R^d} \; f(\wv)$`, where the objective function is of the form `\begin{equation} @@ -39,7 +39,7 @@ the objective function is of the form \ . \end{equation}` Here the vectors `$\x_i\in\R^d$` are the training data examples, for `$1\le i\le n$`, and -`$y_i\in\R$` are their corresponding labels, which we want to predict. +`$y_i\in\R$` are their corresponding labels, which we want to predict. We call the method *linear* if $L(\wv; \x, y)$ can be expressed as a function of $\wv^T x$ and $y$. Several of MLlib's classification and regression algorithms fall into this category, and are discussed here. @@ -99,6 +99,9 @@ regularizers in MLlib:
+ + +
L1$\|\wv\|_1$$\mathrm{sign}(\wv)$
elastic net$\alpha \|\wv\|_1 + (1-\alpha)\frac{1}{2}\|\wv\|_2^2$$\alpha \mathrm{sign}(\wv) + (1-\alpha) \wv$
@@ -107,7 +110,7 @@ of `$\wv$`. L2-regularized problems are generally easier to solve than L1-regularized due to smoothness. However, L1 regularization can help promote sparsity in weights leading to smaller and more interpretable models, the latter of which can be useful for feature selection. -It is not recommended to train models without any regularization, +[Elastic net](http://en.wikipedia.org/wiki/Elastic_net_regularization) is a combination of L1 and L2 regularization. It is not recommended to train models without any regularization, especially when the number of training examples is small. ### Optimization @@ -531,7 +534,7 @@ sameModel = LogisticRegressionModel.load(sc, "myModelPath") ### Linear least squares, Lasso, and ridge regression -Linear least squares is the most common formulation for regression problems. +Linear least squares is the most common formulation for regression problems. It is a linear method as described above in equation `$\eqref{eq:regPrimal}$`, with the loss function in the formulation given by the squared loss: `\[ @@ -539,8 +542,8 @@ L(\wv;\x,y) := \frac{1}{2} (\wv^T \x - y)^2. \]` Various related regression methods are derived by using different types of regularization: -[*ordinary least squares*](http://en.wikipedia.org/wiki/Ordinary_least_squares) or -[*linear least squares*](http://en.wikipedia.org/wiki/Linear_least_squares_(mathematics)) uses +[*ordinary least squares*](http://en.wikipedia.org/wiki/Ordinary_least_squares) or +[*linear least squares*](http://en.wikipedia.org/wiki/Linear_least_squares_(mathematics)) uses no regularization; [*ridge regression*](http://en.wikipedia.org/wiki/Ridge_regression) uses L2 regularization; and [*Lasso*](http://en.wikipedia.org/wiki/Lasso_(statistics)) uses L1 regularization. For all of these models, the average loss or training error, $\frac{1}{n} \sum_{i=1}^n (\wv^T x_i - y_i)^2$, is @@ -552,7 +555,7 @@ known as the [mean squared error](http://en.wikipedia.org/wiki/Mean_squared_erro
The following example demonstrate how to load training data, parse it as an RDD of LabeledPoint. -The example then uses LinearRegressionWithSGD to build a simple linear model to predict label +The example then uses LinearRegressionWithSGD to build a simple linear model to predict label values. We compute the mean squared error at the end to evaluate [goodness of fit](http://en.wikipedia.org/wiki/Goodness_of_fit). @@ -614,7 +617,7 @@ public class LinearRegression { public static void main(String[] args) { SparkConf conf = new SparkConf().setAppName("Linear Regression Example"); JavaSparkContext sc = new JavaSparkContext(conf); - + // Load and parse the data String path = "data/mllib/ridge-data/lpsa.data"; JavaRDD data = sc.textFile(path); @@ -634,7 +637,7 @@ public class LinearRegression { // Building the model int numIterations = 100; - final LinearRegressionModel model = + final LinearRegressionModel model = LinearRegressionWithSGD.train(JavaRDD.toRDD(parsedData), numIterations); // Evaluate model on training examples and compute training error @@ -665,7 +668,7 @@ public class LinearRegression {
The following example demonstrate how to load training data, parse it as an RDD of LabeledPoint. -The example then uses LinearRegressionWithSGD to build a simple linear model to predict label +The example then uses LinearRegressionWithSGD to build a simple linear model to predict label values. We compute the mean squared error at the end to evaluate [goodness of fit](http://en.wikipedia.org/wiki/Goodness_of_fit). @@ -706,8 +709,8 @@ a dependency. ###Streaming linear regression -When data arrive in a streaming fashion, it is useful to fit regression models online, -updating the parameters of the model as new data arrives. MLlib currently supports +When data arrive in a streaming fashion, it is useful to fit regression models online, +updating the parameters of the model as new data arrives. MLlib currently supports streaming linear regression using ordinary least squares. The fitting is similar to that performed offline, except fitting occurs on each batch of data, so that the model continually updates to reflect the data from the stream. @@ -722,7 +725,7 @@ online to the first stream, and make predictions on the second stream.
-First, we import the necessary classes for parsing our input data and creating the model. +First, we import the necessary classes for parsing our input data and creating the model. {% highlight scala %} @@ -734,7 +737,7 @@ import org.apache.spark.mllib.regression.StreamingLinearRegressionWithSGD Then we make input streams for training and testing data. We assume a StreamingContext `ssc` has already been created, see [Spark Streaming Programming Guide](streaming-programming-guide.html#initializing) -for more info. For this example, we use labeled points in training and testing streams, +for more info. For this example, we use labeled points in training and testing streams, but in practice you will likely want to use unlabeled vectors for test data. {% highlight scala %} @@ -754,7 +757,7 @@ val model = new StreamingLinearRegressionWithSGD() {% endhighlight %} -Now we register the streams for training and testing and start the job. +Now we register the streams for training and testing and start the job. Printing predictions alongside true labels lets us easily see the result. {% highlight scala %} @@ -764,14 +767,14 @@ model.predictOnValues(testData.map(lp => (lp.label, lp.features))).print() ssc.start() ssc.awaitTermination() - + {% endhighlight %} We can now save text files with data to the training or testing folders. -Each line should be a data point formatted as `(y,[x1,x2,x3])` where `y` is the label -and `x1,x2,x3` are the features. Anytime a text file is placed in `/training/data/dir` -the model will update. Anytime a text file is placed in `/testing/data/dir` you will see predictions. -As you feed more data to the training directory, the predictions +Each line should be a data point formatted as `(y,[x1,x2,x3])` where `y` is the label +and `x1,x2,x3` are the features. Anytime a text file is placed in `/training/data/dir` +the model will update. Anytime a text file is placed in `/testing/data/dir` you will see predictions. +As you feed more data to the training directory, the predictions will get better!
From ec9b621647b893abae3afef219bceab382b99564 Mon Sep 17 00:00:00 2001 From: Steve Loughran Date: Wed, 15 Jul 2015 12:15:35 -0700 Subject: [PATCH 17/27] SPARK-9070 JavaDataFrameSuite teardown NPEs if setup failed fix teardown to skip table delete if hive context is null Author: Steve Loughran Closes #7425 from steveloughran/stevel/patches/SPARK-9070-JavaDataFrameSuite-NPE and squashes the following commits: 1982d38 [Steve Loughran] SPARK-9070 JavaDataFrameSuite teardown NPEs if setup failed --- .../test/org/apache/spark/sql/hive/JavaDataFrameSuite.java | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/sql/hive/src/test/java/test/org/apache/spark/sql/hive/JavaDataFrameSuite.java b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/JavaDataFrameSuite.java index c4828c4717643..741a3cd31c603 100644 --- a/sql/hive/src/test/java/test/org/apache/spark/sql/hive/JavaDataFrameSuite.java +++ b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/JavaDataFrameSuite.java @@ -61,7 +61,9 @@ public void setUp() throws IOException { @After public void tearDown() throws IOException { // Clean up tables. - hc.sql("DROP TABLE IF EXISTS window_table"); + if (hc != null) { + hc.sql("DROP TABLE IF EXISTS window_table"); + } } @Test From 536533cad83a26f8fa7c60042904a31057ab56c2 Mon Sep 17 00:00:00 2001 From: Feynman Liang Date: Wed, 15 Jul 2015 13:32:25 -0700 Subject: [PATCH 18/27] [SPARK-9005] [MLLIB] Fix RegressionMetrics computation of explainedVariance Fixes implementation of `explainedVariance` and `r2` to be consistent with their definitions as described in [SPARK-9005](https://issues.apache.org/jira/browse/SPARK-9005). Author: Feynman Liang Closes #7361 from feynmanliang/SPARK-9005-RegressionMetrics-bugs and squashes the following commits: f1112fc [Feynman Liang] Add explainedVariance formula 1a3d098 [Feynman Liang] SROwen code review comments 08a0e1b [Feynman Liang] Fix pyspark tests db8605a [Feynman Liang] Style fix bde9761 [Feynman Liang] Fix RegressionMetrics tests, relax assumption predictor is unbiased c235de0 [Feynman Liang] Fix RegressionMetrics tests 4c4e56f [Feynman Liang] Fix RegressionMetrics computation of explainedVariance and r2 --- .../mllib/evaluation/RegressionMetrics.scala | 27 +++++--- .../evaluation/RegressionMetricsSuite.scala | 69 +++++++++++++++++-- python/pyspark/mllib/evaluation.py | 2 +- 3 files changed, 83 insertions(+), 15 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala index e577bf87f885e..408847afa800d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala @@ -53,14 +53,22 @@ class RegressionMetrics(predictionAndObservations: RDD[(Double, Double)]) extend ) summary } + private lazy val SSerr = math.pow(summary.normL2(1), 2) + private lazy val SStot = summary.variance(0) * (summary.count - 1) + private lazy val SSreg = { + val yMean = summary.mean(0) + predictionAndObservations.map { + case (prediction, _) => math.pow(prediction - yMean, 2) + }.sum() + } /** - * Returns the explained variance regression score. - * explainedVariance = 1 - variance(y - \hat{y}) / variance(y) - * Reference: [[http://en.wikipedia.org/wiki/Explained_variation]] + * Returns the variance explained by regression. + * explainedVariance = \sum_i (\hat{y_i} - \bar{y})^2 / n + * @see [[https://en.wikipedia.org/wiki/Fraction_of_variance_unexplained]] */ def explainedVariance: Double = { - 1 - summary.variance(1) / summary.variance(0) + SSreg / summary.count } /** @@ -76,8 +84,7 @@ class RegressionMetrics(predictionAndObservations: RDD[(Double, Double)]) extend * expected value of the squared error loss or quadratic loss. */ def meanSquaredError: Double = { - val rmse = summary.normL2(1) / math.sqrt(summary.count) - rmse * rmse + SSerr / summary.count } /** @@ -85,14 +92,14 @@ class RegressionMetrics(predictionAndObservations: RDD[(Double, Double)]) extend * the mean squared error. */ def rootMeanSquaredError: Double = { - summary.normL2(1) / math.sqrt(summary.count) + math.sqrt(this.meanSquaredError) } /** - * Returns R^2^, the coefficient of determination. - * Reference: [[http://en.wikipedia.org/wiki/Coefficient_of_determination]] + * Returns R^2^, the unadjusted coefficient of determination. + * @see [[http://en.wikipedia.org/wiki/Coefficient_of_determination]] */ def r2: Double = { - 1 - math.pow(summary.normL2(1), 2) / (summary.variance(0) * (summary.count - 1)) + 1 - SSerr / SStot } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RegressionMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RegressionMetricsSuite.scala index 9de2bdb6d7246..4b7f1be58f99b 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RegressionMetricsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RegressionMetricsSuite.scala @@ -23,24 +23,85 @@ import org.apache.spark.mllib.util.TestingUtils._ class RegressionMetricsSuite extends SparkFunSuite with MLlibTestSparkContext { - test("regression metrics") { + test("regression metrics for unbiased (includes intercept term) predictor") { + /* Verify results in R: + preds = c(2.25, -0.25, 1.75, 7.75) + obs = c(3.0, -0.5, 2.0, 7.0) + + SStot = sum((obs - mean(obs))^2) + SSreg = sum((preds - mean(obs))^2) + SSerr = sum((obs - preds)^2) + + explainedVariance = SSreg / length(obs) + explainedVariance + > [1] 8.796875 + meanAbsoluteError = mean(abs(preds - obs)) + meanAbsoluteError + > [1] 0.5 + meanSquaredError = mean((preds - obs)^2) + meanSquaredError + > [1] 0.3125 + rmse = sqrt(meanSquaredError) + rmse + > [1] 0.559017 + r2 = 1 - SSerr / SStot + r2 + > [1] 0.9571734 + */ + val predictionAndObservations = sc.parallelize( + Seq((2.25, 3.0), (-0.25, -0.5), (1.75, 2.0), (7.75, 7.0)), 2) + val metrics = new RegressionMetrics(predictionAndObservations) + assert(metrics.explainedVariance ~== 8.79687 absTol 1E-5, + "explained variance regression score mismatch") + assert(metrics.meanAbsoluteError ~== 0.5 absTol 1E-5, "mean absolute error mismatch") + assert(metrics.meanSquaredError ~== 0.3125 absTol 1E-5, "mean squared error mismatch") + assert(metrics.rootMeanSquaredError ~== 0.55901 absTol 1E-5, + "root mean squared error mismatch") + assert(metrics.r2 ~== 0.95717 absTol 1E-5, "r2 score mismatch") + } + + test("regression metrics for biased (no intercept term) predictor") { + /* Verify results in R: + preds = c(2.5, 0.0, 2.0, 8.0) + obs = c(3.0, -0.5, 2.0, 7.0) + + SStot = sum((obs - mean(obs))^2) + SSreg = sum((preds - mean(obs))^2) + SSerr = sum((obs - preds)^2) + + explainedVariance = SSreg / length(obs) + explainedVariance + > [1] 8.859375 + meanAbsoluteError = mean(abs(preds - obs)) + meanAbsoluteError + > [1] 0.5 + meanSquaredError = mean((preds - obs)^2) + meanSquaredError + > [1] 0.375 + rmse = sqrt(meanSquaredError) + rmse + > [1] 0.6123724 + r2 = 1 - SSerr / SStot + r2 + > [1] 0.9486081 + */ val predictionAndObservations = sc.parallelize( Seq((2.5, 3.0), (0.0, -0.5), (2.0, 2.0), (8.0, 7.0)), 2) val metrics = new RegressionMetrics(predictionAndObservations) - assert(metrics.explainedVariance ~== 0.95717 absTol 1E-5, + assert(metrics.explainedVariance ~== 8.85937 absTol 1E-5, "explained variance regression score mismatch") assert(metrics.meanAbsoluteError ~== 0.5 absTol 1E-5, "mean absolute error mismatch") assert(metrics.meanSquaredError ~== 0.375 absTol 1E-5, "mean squared error mismatch") assert(metrics.rootMeanSquaredError ~== 0.61237 absTol 1E-5, "root mean squared error mismatch") - assert(metrics.r2 ~== 0.94861 absTol 1E-5, "r2 score mismatch") + assert(metrics.r2 ~== 0.94860 absTol 1E-5, "r2 score mismatch") } test("regression metrics with complete fitting") { val predictionAndObservations = sc.parallelize( Seq((3.0, 3.0), (0.0, 0.0), (2.0, 2.0), (8.0, 8.0)), 2) val metrics = new RegressionMetrics(predictionAndObservations) - assert(metrics.explainedVariance ~== 1.0 absTol 1E-5, + assert(metrics.explainedVariance ~== 8.6875 absTol 1E-5, "explained variance regression score mismatch") assert(metrics.meanAbsoluteError ~== 0.0 absTol 1E-5, "mean absolute error mismatch") assert(metrics.meanSquaredError ~== 0.0 absTol 1E-5, "mean squared error mismatch") diff --git a/python/pyspark/mllib/evaluation.py b/python/pyspark/mllib/evaluation.py index f21403707e12a..4398ca86f2ec2 100644 --- a/python/pyspark/mllib/evaluation.py +++ b/python/pyspark/mllib/evaluation.py @@ -82,7 +82,7 @@ class RegressionMetrics(JavaModelWrapper): ... (2.5, 3.0), (0.0, -0.5), (2.0, 2.0), (8.0, 7.0)]) >>> metrics = RegressionMetrics(predictionAndObservations) >>> metrics.explainedVariance - 0.95... + 8.859... >>> metrics.meanAbsoluteError 0.5... >>> metrics.meanSquaredError From b9a922e260bec1b211437f020be37fab46a85db0 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Wed, 15 Jul 2015 14:02:23 -0700 Subject: [PATCH 19/27] [SPARK-6602][Core]Replace Akka Serialization with Spark Serializer Replace Akka Serialization with Spark Serializer and add unit tests. Author: zsxwing Closes #7159 from zsxwing/remove-akka-serialization and squashes the following commits: fc0fca3 [zsxwing] Merge branch 'master' into remove-akka-serialization cf81a58 [zsxwing] Fix the code style 73251c6 [zsxwing] Add test scope 9ef4af9 [zsxwing] Add AkkaRpcEndpointRef.hashCode 433115c [zsxwing] Remove final be3edb0 [zsxwing] Support deserializing RpcEndpointRef ecec410 [zsxwing] Replace Akka Serialization with Spark Serializer --- core/pom.xml | 5 + .../master/FileSystemPersistenceEngine.scala | 35 ++--- .../apache/spark/deploy/master/Master.scala | 18 +-- .../deploy/master/PersistenceEngine.scala | 8 +- .../deploy/master/RecoveryModeFactory.scala | 9 +- .../master/ZooKeeperPersistenceEngine.scala | 16 +-- .../scala/org/apache/spark/rpc/RpcEnv.scala | 6 + .../apache/spark/rpc/akka/AkkaRpcEnv.scala | 14 +- .../master/CustomRecoveryModeFactory.scala | 31 ++--- .../spark/deploy/master/MasterSuite.scala | 2 +- .../master/PersistenceEngineSuite.scala | 126 ++++++++++++++++++ pom.xml | 6 + 12 files changed, 214 insertions(+), 62 deletions(-) create mode 100644 core/src/test/scala/org/apache/spark/deploy/master/PersistenceEngineSuite.scala diff --git a/core/pom.xml b/core/pom.xml index 558cc3fb9f2f3..73f7a75cab9d3 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -372,6 +372,11 @@ junit-interface test + + org.apache.curator + curator-test + test + net.razorvine pyrolite diff --git a/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala b/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala index f459ed5b3a1a1..aa379d4cd61e7 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala @@ -21,9 +21,8 @@ import java.io._ import scala.reflect.ClassTag -import akka.serialization.Serialization - import org.apache.spark.Logging +import org.apache.spark.serializer.{DeserializationStream, SerializationStream, Serializer} import org.apache.spark.util.Utils @@ -32,11 +31,11 @@ import org.apache.spark.util.Utils * Files are deleted when applications and workers are removed. * * @param dir Directory to store files. Created if non-existent (but not recursively). - * @param serialization Used to serialize our objects. + * @param serializer Used to serialize our objects. */ private[master] class FileSystemPersistenceEngine( val dir: String, - val serialization: Serialization) + val serializer: Serializer) extends PersistenceEngine with Logging { new File(dir).mkdir() @@ -57,27 +56,31 @@ private[master] class FileSystemPersistenceEngine( private def serializeIntoFile(file: File, value: AnyRef) { val created = file.createNewFile() if (!created) { throw new IllegalStateException("Could not create file: " + file) } - val serializer = serialization.findSerializerFor(value) - val serialized = serializer.toBinary(value) - val out = new FileOutputStream(file) + val fileOut = new FileOutputStream(file) + var out: SerializationStream = null Utils.tryWithSafeFinally { - out.write(serialized) + out = serializer.newInstance().serializeStream(fileOut) + out.writeObject(value) } { - out.close() + fileOut.close() + if (out != null) { + out.close() + } } } private def deserializeFromFile[T](file: File)(implicit m: ClassTag[T]): T = { - val fileData = new Array[Byte](file.length().asInstanceOf[Int]) - val dis = new DataInputStream(new FileInputStream(file)) + val fileIn = new FileInputStream(file) + var in: DeserializationStream = null try { - dis.readFully(fileData) + in = serializer.newInstance().deserializeStream(fileIn) + in.readObject[T]() } finally { - dis.close() + fileIn.close() + if (in != null) { + in.close() + } } - val clazz = m.runtimeClass.asInstanceOf[Class[T]] - val serializer = serialization.serializerFor(clazz) - serializer.fromBinary(fileData).asInstanceOf[T] } } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index 245b047e7dfbd..4615febf17d24 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -27,11 +27,8 @@ import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} import scala.language.postfixOps import scala.util.Random -import akka.serialization.Serialization -import akka.serialization.SerializationExtension import org.apache.hadoop.fs.Path -import org.apache.spark.rpc.akka.AkkaRpcEnv import org.apache.spark.rpc._ import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkException} import org.apache.spark.deploy.{ApplicationDescription, DriverDescription, @@ -44,6 +41,7 @@ import org.apache.spark.deploy.master.ui.MasterWebUI import org.apache.spark.deploy.rest.StandaloneRestServer import org.apache.spark.metrics.MetricsSystem import org.apache.spark.scheduler.{EventLoggingListener, ReplayListenerBus} +import org.apache.spark.serializer.{JavaSerializer, Serializer} import org.apache.spark.ui.SparkUI import org.apache.spark.util.{ThreadUtils, SignalLogger, Utils} @@ -58,9 +56,6 @@ private[master] class Master( private val forwardMessageThread = ThreadUtils.newDaemonSingleThreadScheduledExecutor("master-forward-message-thread") - // TODO Remove it once we don't use akka.serialization.Serialization - private val actorSystem = rpcEnv.asInstanceOf[AkkaRpcEnv].actorSystem - private val hadoopConf = SparkHadoopUtil.get.newConfiguration(conf) private def createDateFormat = new SimpleDateFormat("yyyyMMddHHmmss") // For application IDs @@ -161,20 +156,21 @@ private[master] class Master( masterMetricsSystem.getServletHandlers.foreach(webUi.attachHandler) applicationMetricsSystem.getServletHandlers.foreach(webUi.attachHandler) + val serializer = new JavaSerializer(conf) val (persistenceEngine_, leaderElectionAgent_) = RECOVERY_MODE match { case "ZOOKEEPER" => logInfo("Persisting recovery state to ZooKeeper") val zkFactory = - new ZooKeeperRecoveryModeFactory(conf, SerializationExtension(actorSystem)) + new ZooKeeperRecoveryModeFactory(conf, serializer) (zkFactory.createPersistenceEngine(), zkFactory.createLeaderElectionAgent(this)) case "FILESYSTEM" => val fsFactory = - new FileSystemRecoveryModeFactory(conf, SerializationExtension(actorSystem)) + new FileSystemRecoveryModeFactory(conf, serializer) (fsFactory.createPersistenceEngine(), fsFactory.createLeaderElectionAgent(this)) case "CUSTOM" => val clazz = Utils.classForName(conf.get("spark.deploy.recoveryMode.factory")) - val factory = clazz.getConstructor(classOf[SparkConf], classOf[Serialization]) - .newInstance(conf, SerializationExtension(actorSystem)) + val factory = clazz.getConstructor(classOf[SparkConf], classOf[Serializer]) + .newInstance(conf, serializer) .asInstanceOf[StandaloneRecoveryModeFactory] (factory.createPersistenceEngine(), factory.createLeaderElectionAgent(this)) case _ => @@ -213,7 +209,7 @@ private[master] class Master( override def receive: PartialFunction[Any, Unit] = { case ElectedLeader => { - val (storedApps, storedDrivers, storedWorkers) = persistenceEngine.readPersistedData() + val (storedApps, storedDrivers, storedWorkers) = persistenceEngine.readPersistedData(rpcEnv) state = if (storedApps.isEmpty && storedDrivers.isEmpty && storedWorkers.isEmpty) { RecoveryState.ALIVE } else { diff --git a/core/src/main/scala/org/apache/spark/deploy/master/PersistenceEngine.scala b/core/src/main/scala/org/apache/spark/deploy/master/PersistenceEngine.scala index a03d460509e03..58a00bceee6af 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/PersistenceEngine.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/PersistenceEngine.scala @@ -18,6 +18,7 @@ package org.apache.spark.deploy.master import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.rpc.RpcEnv import scala.reflect.ClassTag @@ -80,8 +81,11 @@ abstract class PersistenceEngine { * Returns the persisted data sorted by their respective ids (which implies that they're * sorted by time of creation). */ - final def readPersistedData(): (Seq[ApplicationInfo], Seq[DriverInfo], Seq[WorkerInfo]) = { - (read[ApplicationInfo]("app_"), read[DriverInfo]("driver_"), read[WorkerInfo]("worker_")) + final def readPersistedData( + rpcEnv: RpcEnv): (Seq[ApplicationInfo], Seq[DriverInfo], Seq[WorkerInfo]) = { + rpcEnv.deserialize { () => + (read[ApplicationInfo]("app_"), read[DriverInfo]("driver_"), read[WorkerInfo]("worker_")) + } } def close() {} diff --git a/core/src/main/scala/org/apache/spark/deploy/master/RecoveryModeFactory.scala b/core/src/main/scala/org/apache/spark/deploy/master/RecoveryModeFactory.scala index 351db8fab2041..c4c3283fb73f7 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/RecoveryModeFactory.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/RecoveryModeFactory.scala @@ -17,10 +17,9 @@ package org.apache.spark.deploy.master -import akka.serialization.Serialization - import org.apache.spark.{Logging, SparkConf} import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.serializer.Serializer /** * ::DeveloperApi:: @@ -30,7 +29,7 @@ import org.apache.spark.annotation.DeveloperApi * */ @DeveloperApi -abstract class StandaloneRecoveryModeFactory(conf: SparkConf, serializer: Serialization) { +abstract class StandaloneRecoveryModeFactory(conf: SparkConf, serializer: Serializer) { /** * PersistenceEngine defines how the persistent data(Information about worker, driver etc..) @@ -49,7 +48,7 @@ abstract class StandaloneRecoveryModeFactory(conf: SparkConf, serializer: Serial * LeaderAgent in this case is a no-op. Since leader is forever leader as the actual * recovery is made by restoring from filesystem. */ -private[master] class FileSystemRecoveryModeFactory(conf: SparkConf, serializer: Serialization) +private[master] class FileSystemRecoveryModeFactory(conf: SparkConf, serializer: Serializer) extends StandaloneRecoveryModeFactory(conf, serializer) with Logging { val RECOVERY_DIR = conf.get("spark.deploy.recoveryDirectory", "") @@ -64,7 +63,7 @@ private[master] class FileSystemRecoveryModeFactory(conf: SparkConf, serializer: } } -private[master] class ZooKeeperRecoveryModeFactory(conf: SparkConf, serializer: Serialization) +private[master] class ZooKeeperRecoveryModeFactory(conf: SparkConf, serializer: Serializer) extends StandaloneRecoveryModeFactory(conf, serializer) { def createPersistenceEngine(): PersistenceEngine = { diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala index 328d95a7a0c68..563831cc6b8dd 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala @@ -17,7 +17,7 @@ package org.apache.spark.deploy.master -import akka.serialization.Serialization +import java.nio.ByteBuffer import scala.collection.JavaConversions._ import scala.reflect.ClassTag @@ -27,9 +27,10 @@ import org.apache.zookeeper.CreateMode import org.apache.spark.{Logging, SparkConf} import org.apache.spark.deploy.SparkCuratorUtil +import org.apache.spark.serializer.Serializer -private[master] class ZooKeeperPersistenceEngine(conf: SparkConf, val serialization: Serialization) +private[master] class ZooKeeperPersistenceEngine(conf: SparkConf, val serializer: Serializer) extends PersistenceEngine with Logging { @@ -57,17 +58,16 @@ private[master] class ZooKeeperPersistenceEngine(conf: SparkConf, val serializat } private def serializeIntoFile(path: String, value: AnyRef) { - val serializer = serialization.findSerializerFor(value) - val serialized = serializer.toBinary(value) - zk.create().withMode(CreateMode.PERSISTENT).forPath(path, serialized) + val serialized = serializer.newInstance().serialize(value) + val bytes = new Array[Byte](serialized.remaining()) + serialized.get(bytes) + zk.create().withMode(CreateMode.PERSISTENT).forPath(path, bytes) } private def deserializeFromFile[T](filename: String)(implicit m: ClassTag[T]): Option[T] = { val fileData = zk.getData().forPath(WORKING_DIR + "/" + filename) - val clazz = m.runtimeClass.asInstanceOf[Class[T]] - val serializer = serialization.serializerFor(clazz) try { - Some(serializer.fromBinary(fileData).asInstanceOf[T]) + Some(serializer.newInstance().deserialize[T](ByteBuffer.wrap(fileData))) } catch { case e: Exception => { logWarning("Exception while reading persisted file, deleting", e) diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala index c9fcc7a36cc04..29debe8081308 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala @@ -139,6 +139,12 @@ private[spark] abstract class RpcEnv(conf: SparkConf) { * creating it manually because different [[RpcEnv]] may have different formats. */ def uriOf(systemName: String, address: RpcAddress, endpointName: String): String + + /** + * [[RpcEndpointRef]] cannot be deserialized without [[RpcEnv]]. So when deserializing any object + * that contains [[RpcEndpointRef]]s, the deserialization codes should be wrapped by this method. + */ + def deserialize[T](deserializationAction: () => T): T } diff --git a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala index f2d87f68341af..fc17542abf81d 100644 --- a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala @@ -28,7 +28,7 @@ import akka.actor.{ActorSystem, ExtendedActorSystem, Actor, ActorRef, Props, Add import akka.event.Logging.Error import akka.pattern.{ask => akkaAsk} import akka.remote.{AssociationEvent, AssociatedEvent, DisassociatedEvent, AssociationErrorEvent} -import com.google.common.util.concurrent.MoreExecutors +import akka.serialization.JavaSerializer import org.apache.spark.{SparkException, Logging, SparkConf} import org.apache.spark.rpc._ @@ -239,6 +239,12 @@ private[spark] class AkkaRpcEnv private[akka] ( } override def toString: String = s"${getClass.getSimpleName}($actorSystem)" + + override def deserialize[T](deserializationAction: () => T): T = { + JavaSerializer.currentSystem.withValue(actorSystem.asInstanceOf[ExtendedActorSystem]) { + deserializationAction() + } + } } private[spark] class AkkaRpcEnvFactory extends RpcEnvFactory { @@ -315,6 +321,12 @@ private[akka] class AkkaRpcEndpointRef( override def toString: String = s"${getClass.getSimpleName}($actorRef)" + final override def equals(that: Any): Boolean = that match { + case other: AkkaRpcEndpointRef => actorRef == other.actorRef + case _ => false + } + + final override def hashCode(): Int = if (actorRef == null) 0 else actorRef.hashCode() } /** diff --git a/core/src/test/scala/org/apache/spark/deploy/master/CustomRecoveryModeFactory.scala b/core/src/test/scala/org/apache/spark/deploy/master/CustomRecoveryModeFactory.scala index f4e56632e426a..8c96b0e71dfdd 100644 --- a/core/src/test/scala/org/apache/spark/deploy/master/CustomRecoveryModeFactory.scala +++ b/core/src/test/scala/org/apache/spark/deploy/master/CustomRecoveryModeFactory.scala @@ -19,18 +19,19 @@ // when they are outside of org.apache.spark. package other.supplier +import java.nio.ByteBuffer + import scala.collection.mutable import scala.reflect.ClassTag -import akka.serialization.Serialization - import org.apache.spark.SparkConf import org.apache.spark.deploy.master._ +import org.apache.spark.serializer.Serializer class CustomRecoveryModeFactory( conf: SparkConf, - serialization: Serialization -) extends StandaloneRecoveryModeFactory(conf, serialization) { + serializer: Serializer +) extends StandaloneRecoveryModeFactory(conf, serializer) { CustomRecoveryModeFactory.instantiationAttempts += 1 @@ -40,7 +41,7 @@ class CustomRecoveryModeFactory( * */ override def createPersistenceEngine(): PersistenceEngine = - new CustomPersistenceEngine(serialization) + new CustomPersistenceEngine(serializer) /** * Create an instance of LeaderAgent that decides who gets elected as master. @@ -53,7 +54,7 @@ object CustomRecoveryModeFactory { @volatile var instantiationAttempts = 0 } -class CustomPersistenceEngine(serialization: Serialization) extends PersistenceEngine { +class CustomPersistenceEngine(serializer: Serializer) extends PersistenceEngine { val data = mutable.HashMap[String, Array[Byte]]() CustomPersistenceEngine.lastInstance = Some(this) @@ -64,10 +65,10 @@ class CustomPersistenceEngine(serialization: Serialization) extends PersistenceE */ override def persist(name: String, obj: Object): Unit = { CustomPersistenceEngine.persistAttempts += 1 - serialization.serialize(obj) match { - case util.Success(bytes) => data += name -> bytes - case util.Failure(cause) => throw new RuntimeException(cause) - } + val serialized = serializer.newInstance().serialize(obj) + val bytes = new Array[Byte](serialized.remaining()) + serialized.get(bytes) + data += name -> bytes } /** @@ -84,15 +85,9 @@ class CustomPersistenceEngine(serialization: Serialization) extends PersistenceE */ override def read[T: ClassTag](prefix: String): Seq[T] = { CustomPersistenceEngine.readAttempts += 1 - val clazz = implicitly[ClassTag[T]].runtimeClass.asInstanceOf[Class[T]] val results = for ((name, bytes) <- data; if name.startsWith(prefix)) - yield serialization.deserialize(bytes, clazz) - - results.find(_.isFailure).foreach { - case util.Failure(cause) => throw new RuntimeException(cause) - } - - results.flatMap(_.toOption).toSeq + yield serializer.newInstance().deserialize[T](ByteBuffer.wrap(bytes)) + results.toSeq } } diff --git a/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala b/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala index 9cb6dd43bac47..a8fbaf1d9da0a 100644 --- a/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala @@ -105,7 +105,7 @@ class MasterSuite extends SparkFunSuite with Matchers with Eventually { persistenceEngine.addDriver(driverToPersist) persistenceEngine.addWorker(workerToPersist) - val (apps, drivers, workers) = persistenceEngine.readPersistedData() + val (apps, drivers, workers) = persistenceEngine.readPersistedData(rpcEnv) apps.map(_.id) should contain(appToPersist.id) drivers.map(_.id) should contain(driverToPersist.id) diff --git a/core/src/test/scala/org/apache/spark/deploy/master/PersistenceEngineSuite.scala b/core/src/test/scala/org/apache/spark/deploy/master/PersistenceEngineSuite.scala new file mode 100644 index 0000000000000..11e87bd1dd8eb --- /dev/null +++ b/core/src/test/scala/org/apache/spark/deploy/master/PersistenceEngineSuite.scala @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +package org.apache.spark.deploy.master + +import java.net.ServerSocket + +import org.apache.commons.lang3.RandomUtils +import org.apache.curator.test.TestingServer + +import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite} +import org.apache.spark.rpc.{RpcEndpoint, RpcEnv} +import org.apache.spark.serializer.{Serializer, JavaSerializer} +import org.apache.spark.util.Utils + +class PersistenceEngineSuite extends SparkFunSuite { + + test("FileSystemPersistenceEngine") { + val dir = Utils.createTempDir() + try { + val conf = new SparkConf() + testPersistenceEngine(conf, serializer => + new FileSystemPersistenceEngine(dir.getAbsolutePath, serializer) + ) + } finally { + Utils.deleteRecursively(dir) + } + } + + test("ZooKeeperPersistenceEngine") { + val conf = new SparkConf() + // TestingServer logs the port conflict exception rather than throwing an exception. + // So we have to find a free port by ourselves. This approach cannot guarantee always starting + // zkTestServer successfully because there is a time gap between finding a free port and + // starting zkTestServer. But the failure possibility should be very low. + val zkTestServer = new TestingServer(findFreePort(conf)) + try { + testPersistenceEngine(conf, serializer => { + conf.set("spark.deploy.zookeeper.url", zkTestServer.getConnectString) + new ZooKeeperPersistenceEngine(conf, serializer) + }) + } finally { + zkTestServer.stop() + } + } + + private def testPersistenceEngine( + conf: SparkConf, persistenceEngineCreator: Serializer => PersistenceEngine): Unit = { + val serializer = new JavaSerializer(conf) + val persistenceEngine = persistenceEngineCreator(serializer) + persistenceEngine.persist("test_1", "test_1_value") + assert(Seq("test_1_value") === persistenceEngine.read[String]("test_")) + persistenceEngine.persist("test_2", "test_2_value") + assert(Set("test_1_value", "test_2_value") === persistenceEngine.read[String]("test_").toSet) + persistenceEngine.unpersist("test_1") + assert(Seq("test_2_value") === persistenceEngine.read[String]("test_")) + persistenceEngine.unpersist("test_2") + assert(persistenceEngine.read[String]("test_").isEmpty) + + // Test deserializing objects that contain RpcEndpointRef + val rpcEnv = RpcEnv.create("test", "localhost", 12345, conf, new SecurityManager(conf)) + try { + // Create a real endpoint so that we can test RpcEndpointRef deserialization + val workerEndpoint = rpcEnv.setupEndpoint("worker", new RpcEndpoint { + override val rpcEnv: RpcEnv = rpcEnv + }) + + val workerToPersist = new WorkerInfo( + id = "test_worker", + host = "127.0.0.1", + port = 10000, + cores = 0, + memory = 0, + endpoint = workerEndpoint, + webUiPort = 0, + publicAddress = "" + ) + + persistenceEngine.addWorker(workerToPersist) + + val (storedApps, storedDrivers, storedWorkers) = persistenceEngine.readPersistedData(rpcEnv) + + assert(storedApps.isEmpty) + assert(storedDrivers.isEmpty) + + // Check deserializing WorkerInfo + assert(storedWorkers.size == 1) + val recoveryWorkerInfo = storedWorkers.head + assert(workerToPersist.id === recoveryWorkerInfo.id) + assert(workerToPersist.host === recoveryWorkerInfo.host) + assert(workerToPersist.port === recoveryWorkerInfo.port) + assert(workerToPersist.cores === recoveryWorkerInfo.cores) + assert(workerToPersist.memory === recoveryWorkerInfo.memory) + assert(workerToPersist.endpoint === recoveryWorkerInfo.endpoint) + assert(workerToPersist.webUiPort === recoveryWorkerInfo.webUiPort) + assert(workerToPersist.publicAddress === recoveryWorkerInfo.publicAddress) + } finally { + rpcEnv.shutdown() + rpcEnv.awaitTermination() + } + } + + private def findFreePort(conf: SparkConf): Int = { + val candidatePort = RandomUtils.nextInt(1024, 65536) + Utils.startServiceOnPort(candidatePort, (trialPort: Int) => { + val socket = new ServerSocket(trialPort) + socket.close() + (null, trialPort) + }, conf)._2 + } +} diff --git a/pom.xml b/pom.xml index 370c95dd03632..aa49e2ab7294b 100644 --- a/pom.xml +++ b/pom.xml @@ -748,6 +748,12 @@ curator-framework ${curator.version} + + org.apache.curator + curator-test + ${curator.version} + test + org.apache.hadoop hadoop-client From 674eb2a4c3ff595760f990daf369ba75d2547593 Mon Sep 17 00:00:00 2001 From: KaiXinXiaoLei Date: Wed, 15 Jul 2015 22:31:10 +0100 Subject: [PATCH 20/27] [SPARK-8974] Catch exceptions in allocation schedule task. I meet a problem. When I submit some tasks, the thread spark-dynamic-executor-allocation should seed the message about "requestTotalExecutors", and the new executor should start. But I meet a problem about this thread, like: 2015-07-14 19:02:17,461 | WARN | [spark-dynamic-executor-allocation] | Error sending message [message = RequestExecutors(1)] in 1 attempts java.util.concurrent.TimeoutException: Futures timed out after [120 seconds] at scala.concurrent.impl.Promise$DefaultPromise.ready(Promise.scala:219) at scala.concurrent.impl.Promise$DefaultPromise.result(Promise.scala:223) at scala.concurrent.Await$$anonfun$result$1.apply(package.scala:107) at scala.concurrent.BlockContext$DefaultBlockContext$.blockOn(BlockContext.scala:53) at scala.concurrent.Await$.result(package.scala:107) at org.apache.spark.rpc.RpcEndpointRef.askWithRetry(RpcEndpointRef.scala:102) at org.apache.spark.rpc.RpcEndpointRef.askWithRetry(RpcEndpointRef.scala:78) at org.apache.spark.scheduler.cluster.YarnSchedulerBackend.doRequestTotalExecutors(YarnSchedulerBackend.scala:57) at org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend.requestTotalExecutors(CoarseGrainedSchedulerBackend.scala:351) at org.apache.spark.SparkContext.requestTotalExecutors(SparkContext.scala:1382) at org.apache.spark.ExecutorAllocationManager.addExecutors(ExecutorAllocationManager.scala:343) at org.apache.spark.ExecutorAllocationManager.updateAndSyncNumExecutorsTarget(ExecutorAllocationManager.scala:295) at org.apache.spark.ExecutorAllocationManager.org$apache$spark$ExecutorAllocationManager$$schedule(ExecutorAllocationManager.scala:248) when after some minutes, I find a new ApplicationMaster start, and tasks submitted start to run. The tasks Completed. And after long time (eg, ten minutes), the number of executor does not reduce to zero. I use the default value of "spark.dynamicAllocation.minExecutors". Author: KaiXinXiaoLei Closes #7352 from KaiXinXiaoLei/dym and squashes the following commits: 3603631 [KaiXinXiaoLei] change logError to logWarning efc4f24 [KaiXinXiaoLei] change file --- .../org/apache/spark/ExecutorAllocationManager.scala | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala index 0c50b4002cf7b..648bcfe28cad2 100644 --- a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala +++ b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala @@ -20,6 +20,7 @@ package org.apache.spark import java.util.concurrent.TimeUnit import scala.collection.mutable +import scala.util.control.ControlThrowable import com.codahale.metrics.{Gauge, MetricRegistry} @@ -211,7 +212,16 @@ private[spark] class ExecutorAllocationManager( listenerBus.addListener(listener) val scheduleTask = new Runnable() { - override def run(): Unit = Utils.logUncaughtExceptions(schedule()) + override def run(): Unit = { + try { + schedule() + } catch { + case ct: ControlThrowable => + throw ct + case t: Throwable => + logWarning(s"Uncaught exception in thread ${Thread.currentThread().getName}", t) + } + } } executor.scheduleAtFixedRate(scheduleTask, 0, intervalMillis, TimeUnit.MILLISECONDS) } From affbe329ae0100bd50a3c3fb081b0f2b07efce33 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 15 Jul 2015 14:52:02 -0700 Subject: [PATCH 21/27] [SPARK-9071][SQL] MonotonicallyIncreasingID and SparkPartitionID should be marked as nondeterministic. I also took the chance to more explicitly define the semantics of deterministic. Author: Reynold Xin Closes #7428 from rxin/non-deterministic and squashes the following commits: a760827 [Reynold Xin] [SPARK-9071][SQL] MonotonicallyIncreasingID and SparkPartitionID should be marked as nondeterministic. --- .../spark/sql/catalyst/expressions/Expression.scala | 10 ++++++++-- .../expressions/MonotonicallyIncreasingID.scala | 4 +++- .../sql/execution/expressions/SparkPartitionID.scala | 4 +++- 3 files changed, 14 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 3f19ac2b592b5..7b37ae7335253 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -61,9 +61,15 @@ abstract class Expression extends TreeNode[Expression] { def foldable: Boolean = false /** - * Returns true when the current expression always return the same result for fixed input values. + * Returns true when the current expression always return the same result for fixed inputs from + * children. + * + * Note that this means that an expression should be considered as non-deterministic if: + * - if it relies on some mutable internal state, or + * - if it relies on some implicit input that is not part of the children expression list. + * + * An example would be `SparkPartitionID` that relies on the partition id returned by TaskContext. */ - // TODO: Need to define explicit input values vs implicit input values. def deterministic: Boolean = true def nullable: Boolean diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/MonotonicallyIncreasingID.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/MonotonicallyIncreasingID.scala index 69a37750d7525..fec403fe2d348 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/MonotonicallyIncreasingID.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/MonotonicallyIncreasingID.scala @@ -41,7 +41,9 @@ private[sql] case class MonotonicallyIncreasingID() extends LeafExpression { */ @transient private[this] var count: Long = 0L - @transient private lazy val partitionMask = TaskContext.getPartitionId.toLong << 33 + @transient private lazy val partitionMask = TaskContext.getPartitionId().toLong << 33 + + override def deterministic: Boolean = false override def nullable: Boolean = false diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/SparkPartitionID.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/SparkPartitionID.scala index 5f1b514f2cff2..7c790c549a5d8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/SparkPartitionID.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/SparkPartitionID.scala @@ -29,11 +29,13 @@ import org.apache.spark.sql.types.{IntegerType, DataType} */ private[sql] case object SparkPartitionID extends LeafExpression { + override def deterministic: Boolean = false + override def nullable: Boolean = false override def dataType: DataType = IntegerType - @transient private lazy val partitionId = TaskContext.getPartitionId + @transient private lazy val partitionId = TaskContext.getPartitionId() override def eval(input: InternalRow): Int = partitionId From b0645195d0da57065885e078e08bd6c42f4f19b0 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 15 Jul 2015 17:50:11 -0700 Subject: [PATCH 22/27] [SPARK-9086][SQL] Remove BinaryNode from TreeNode. These traits are not super useful, and yet cause problems with toString in expressions due to the orders they are mixed in. Author: Reynold Xin Closes #7433 from rxin/remove-binary-node and squashes the following commits: 1881f78 [Reynold Xin] [SPARK-9086][SQL] Remove BinaryNode from TreeNode. --- .../sql/catalyst/expressions/Expression.scala | 17 ++++++++++++++--- .../catalyst/plans/logical/LogicalPlan.scala | 7 ++++++- .../spark/sql/catalyst/trees/TreeNode.scala | 9 --------- .../apache/spark/sql/execution/SparkPlan.scala | 7 ++++++- 4 files changed, 26 insertions(+), 14 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 7b37ae7335253..87667316aca67 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -187,8 +187,10 @@ abstract class Expression extends TreeNode[Expression] { /** * A leaf expression, i.e. one without any child expressions. */ -abstract class LeafExpression extends Expression with trees.LeafNode[Expression] { +abstract class LeafExpression extends Expression { self: Product => + + def children: Seq[Expression] = Nil } @@ -196,9 +198,13 @@ abstract class LeafExpression extends Expression with trees.LeafNode[Expression] * An expression with one input and one output. The output is by default evaluated to null * if the input is evaluated to null. */ -abstract class UnaryExpression extends Expression with trees.UnaryNode[Expression] { +abstract class UnaryExpression extends Expression { self: Product => + def child: Expression + + override def children: Seq[Expression] = child :: Nil + override def foldable: Boolean = child.foldable override def nullable: Boolean = child.nullable @@ -271,9 +277,14 @@ abstract class UnaryExpression extends Expression with trees.UnaryNode[Expressio * An expression with two inputs and one output. The output is by default evaluated to null * if any input is evaluated to null. */ -abstract class BinaryExpression extends Expression with trees.BinaryNode[Expression] { +abstract class BinaryExpression extends Expression { self: Product => + def left: Expression + def right: Expression + + override def children: Seq[Expression] = Seq(left, right) + override def foldable: Boolean = left.foldable && right.foldable override def nullable: Boolean = left.nullable || right.nullable diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index e911b907e8536..d7077a0ec907a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -291,6 +291,11 @@ abstract class UnaryNode extends LogicalPlan with trees.UnaryNode[LogicalPlan] { /** * A logical plan node with a left and right child. */ -abstract class BinaryNode extends LogicalPlan with trees.BinaryNode[LogicalPlan] { +abstract class BinaryNode extends LogicalPlan { self: Product => + + def left: LogicalPlan + def right: LogicalPlan + + override def children: Seq[LogicalPlan] = Seq(left, right) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index 09f6c6b0ec423..16844b2f4b680 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -453,15 +453,6 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] { } } -/** - * A [[TreeNode]] that has two children, [[left]] and [[right]]. - */ -trait BinaryNode[BaseType <: TreeNode[BaseType]] { - def left: BaseType - def right: BaseType - - def children: Seq[BaseType] = Seq(left, right) -} /** * A [[TreeNode]] with no children. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index 4d7d8626a0ecc..9dc7879fa4a1a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -247,6 +247,11 @@ private[sql] trait UnaryNode extends SparkPlan with trees.UnaryNode[SparkPlan] { override def outputPartitioning: Partitioning = child.outputPartitioning } -private[sql] trait BinaryNode extends SparkPlan with trees.BinaryNode[SparkPlan] { +private[sql] trait BinaryNode extends SparkPlan { self: Product => + + def left: SparkPlan + def right: SparkPlan + + override def children: Seq[SparkPlan] = Seq(left, right) } From 6960a7938c61cc07f181ca85e0d8152ceeb453d9 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Wed, 15 Jul 2015 20:33:06 -0700 Subject: [PATCH 23/27] [SPARK-8774] [ML] Add R model formula with basic support as a transformer This implements minimal R formula support as a feature transformer. Both numeric and string labels are supported, but features must be numeric for now. cc mengxr Author: Eric Liang Closes #7381 from ericl/spark-8774-1 and squashes the following commits: d1959d2 [Eric Liang] clarify comment 2db68aa [Eric Liang] second round of comments dc3c943 [Eric Liang] address comments 5765ec6 [Eric Liang] fix style checks 1f361b0 [Eric Liang] doc fb0826b [Eric Liang] [SPARK-8774] Add R model formula with basic support as a transformer --- .../apache/spark/ml/feature/RFormula.scala | 151 ++++++++++++++++++ .../spark/ml/feature/VectorAssembler.scala | 2 +- .../ml/feature/RFormulaParserSuite.scala | 34 ++++ .../spark/ml/feature/RFormulaSuite.scala | 93 +++++++++++ 4 files changed, 279 insertions(+), 1 deletion(-) create mode 100644 mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala create mode 100644 mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaParserSuite.scala create mode 100644 mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala new file mode 100644 index 0000000000000..d9a36bda386b3 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala @@ -0,0 +1,151 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import scala.util.parsing.combinator.RegexParsers + +import org.apache.spark.annotation.Experimental +import org.apache.spark.ml.Transformer +import org.apache.spark.ml.param.{Param, ParamMap} +import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasLabelCol} +import org.apache.spark.ml.util.Identifiable +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types._ + +/** + * :: Experimental :: + * Implements the transforms required for fitting a dataset against an R model formula. Currently + * we support a limited subset of the R operators, including '~' and '+'. Also see the R formula + * docs here: http://stat.ethz.ch/R-manual/R-patched/library/stats/html/formula.html + */ +@Experimental +class RFormula(override val uid: String) + extends Transformer with HasFeaturesCol with HasLabelCol { + + def this() = this(Identifiable.randomUID("rFormula")) + + /** + * R formula parameter. The formula is provided in string form. + * @group setParam + */ + val formula: Param[String] = new Param(this, "formula", "R model formula") + + private var parsedFormula: Option[ParsedRFormula] = None + + /** + * Sets the formula to use for this transformer. Must be called before use. + * @group setParam + * @param value an R formula in string form (e.g. "y ~ x + z") + */ + def setFormula(value: String): this.type = { + parsedFormula = Some(RFormulaParser.parse(value)) + set(formula, value) + this + } + + /** @group getParam */ + def getFormula: String = $(formula) + + /** @group getParam */ + def setFeaturesCol(value: String): this.type = set(featuresCol, value) + + /** @group getParam */ + def setLabelCol(value: String): this.type = set(labelCol, value) + + override def transformSchema(schema: StructType): StructType = { + checkCanTransform(schema) + val withFeatures = transformFeatures.transformSchema(schema) + if (hasLabelCol(schema)) { + withFeatures + } else { + val nullable = schema(parsedFormula.get.label).dataType match { + case _: NumericType | BooleanType => false + case _ => true + } + StructType(withFeatures.fields :+ StructField($(labelCol), DoubleType, nullable)) + } + } + + override def transform(dataset: DataFrame): DataFrame = { + checkCanTransform(dataset.schema) + transformLabel(transformFeatures.transform(dataset)) + } + + override def copy(extra: ParamMap): RFormula = defaultCopy(extra) + + override def toString: String = s"RFormula(${get(formula)})" + + private def transformLabel(dataset: DataFrame): DataFrame = { + if (hasLabelCol(dataset.schema)) { + dataset + } else { + val labelName = parsedFormula.get.label + dataset.schema(labelName).dataType match { + case _: NumericType | BooleanType => + dataset.withColumn($(labelCol), dataset(labelName).cast(DoubleType)) + // TODO(ekl) add support for string-type labels + case other => + throw new IllegalArgumentException("Unsupported type for label: " + other) + } + } + } + + private def transformFeatures: Transformer = { + // TODO(ekl) add support for non-numeric features and feature interactions + new VectorAssembler(uid) + .setInputCols(parsedFormula.get.terms.toArray) + .setOutputCol($(featuresCol)) + } + + private def checkCanTransform(schema: StructType) { + require(parsedFormula.isDefined, "Must call setFormula() first.") + val columnNames = schema.map(_.name) + require(!columnNames.contains($(featuresCol)), "Features column already exists.") + require( + !columnNames.contains($(labelCol)) || schema($(labelCol)).dataType == DoubleType, + "Label column already exists and is not of type DoubleType.") + } + + private def hasLabelCol(schema: StructType): Boolean = { + schema.map(_.name).contains($(labelCol)) + } +} + +/** + * Represents a parsed R formula. + */ +private[ml] case class ParsedRFormula(label: String, terms: Seq[String]) + +/** + * Limited implementation of R formula parsing. Currently supports: '~', '+'. + */ +private[ml] object RFormulaParser extends RegexParsers { + def term: Parser[String] = "([a-zA-Z]|\\.[a-zA-Z_])[a-zA-Z0-9._]*".r + + def expr: Parser[List[String]] = term ~ rep("+" ~> term) ^^ { case a ~ list => a :: list } + + def formula: Parser[ParsedRFormula] = + (term ~ "~" ~ expr) ^^ { case r ~ "~" ~ t => ParsedRFormula(r, t) } + + def parse(value: String): ParsedRFormula = parseAll(formula, value) match { + case Success(result, _) => result + case failure: NoSuccess => throw new IllegalArgumentException( + "Could not parse formula: " + value) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala index 9f83c2ee16178..086917fa680f8 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala @@ -116,7 +116,7 @@ class VectorAssembler(override val uid: String) if (schema.fieldNames.contains(outputColName)) { throw new IllegalArgumentException(s"Output column $outputColName already exists.") } - StructType(schema.fields :+ new StructField(outputColName, new VectorUDT, false)) + StructType(schema.fields :+ new StructField(outputColName, new VectorUDT, true)) } override def copy(extra: ParamMap): VectorAssembler = defaultCopy(extra) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaParserSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaParserSuite.scala new file mode 100644 index 0000000000000..c8d065f37a605 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaParserSuite.scala @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.apache.spark.SparkFunSuite + +class RFormulaParserSuite extends SparkFunSuite { + private def checkParse(formula: String, label: String, terms: Seq[String]) { + val parsed = RFormulaParser.parse(formula) + assert(parsed.label == label) + assert(parsed.terms == terms) + } + + test("parse simple formulas") { + checkParse("y ~ x", "y", Seq("x")) + checkParse("y ~ ._foo ", "y", Seq("._foo")) + checkParse("resp ~ A_VAR + B + c123", "resp", Seq("A_VAR", "B", "c123")) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala new file mode 100644 index 0000000000000..fa8611b243a9f --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.util.MLlibTestSparkContext + +class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext { + test("params") { + ParamsSuite.checkParams(new RFormula()) + } + + test("transform numeric data") { + val formula = new RFormula().setFormula("id ~ v1 + v2") + val original = sqlContext.createDataFrame( + Seq((0, 1.0, 3.0), (2, 2.0, 5.0))).toDF("id", "v1", "v2") + val result = formula.transform(original) + val resultSchema = formula.transformSchema(original.schema) + val expected = sqlContext.createDataFrame( + Seq( + (0, 1.0, 3.0, Vectors.dense(Array(1.0, 3.0)), 0.0), + (2, 2.0, 5.0, Vectors.dense(Array(2.0, 5.0)), 2.0)) + ).toDF("id", "v1", "v2", "features", "label") + // TODO(ekl) make schema comparisons ignore metadata, to avoid .toString + assert(result.schema.toString == resultSchema.toString) + assert(resultSchema == expected.schema) + assert(result.collect().toSeq == expected.collect().toSeq) + } + + test("features column already exists") { + val formula = new RFormula().setFormula("y ~ x").setFeaturesCol("x") + val original = sqlContext.createDataFrame(Seq((0, 1.0), (2, 2.0))).toDF("x", "y") + intercept[IllegalArgumentException] { + formula.transformSchema(original.schema) + } + intercept[IllegalArgumentException] { + formula.transform(original) + } + } + + test("label column already exists") { + val formula = new RFormula().setFormula("y ~ x").setLabelCol("y") + val original = sqlContext.createDataFrame(Seq((0, 1.0), (2, 2.0))).toDF("x", "y") + val resultSchema = formula.transformSchema(original.schema) + assert(resultSchema.length == 3) + assert(resultSchema.toString == formula.transform(original).schema.toString) + } + + test("label column already exists but is not double type") { + val formula = new RFormula().setFormula("y ~ x").setLabelCol("y") + val original = sqlContext.createDataFrame(Seq((0, 1), (2, 2))).toDF("x", "y") + intercept[IllegalArgumentException] { + formula.transformSchema(original.schema) + } + intercept[IllegalArgumentException] { + formula.transform(original) + } + } + +// TODO(ekl) enable after we implement string label support +// test("transform string label") { +// val formula = new RFormula().setFormula("name ~ id") +// val original = sqlContext.createDataFrame( +// Seq((1, "foo"), (2, "bar"), (3, "bar"))).toDF("id", "name") +// val result = formula.transform(original) +// val resultSchema = formula.transformSchema(original.schema) +// val expected = sqlContext.createDataFrame( +// Seq( +// (1, "foo", Vectors.dense(Array(1.0)), 1.0), +// (2, "bar", Vectors.dense(Array(2.0)), 0.0), +// (3, "bar", Vectors.dense(Array(3.0)), 0.0)) +// ).toDF("id", "name", "features", "label") +// assert(result.schema.toString == resultSchema.toString) +// assert(result.collect().toSeq == expected.collect().toSeq) +// } +} From 73d92b00b9a6f5dfc2f8116447d17b381cd74f80 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Wed, 15 Jul 2015 21:02:42 -0700 Subject: [PATCH 24/27] [SPARK-9018] [MLLIB] add stopwatches Add stopwatches for easy instrumentation of MLlib algorithms. This is based on the `TimeTracker` used in decision trees. The distributed version uses Spark accumulator. jkbradley Author: Xiangrui Meng Closes #7415 from mengxr/SPARK-9018 and squashes the following commits: 40b4347 [Xiangrui Meng] == -> === c477745 [Xiangrui Meng] address Joseph's comments f981a49 [Xiangrui Meng] add stopwatches --- .../apache/spark/ml/util/stopwatches.scala | 151 ++++++++++++++++++ .../apache/spark/ml/util/StopwatchSuite.scala | 109 +++++++++++++ 2 files changed, 260 insertions(+) create mode 100644 mllib/src/main/scala/org/apache/spark/ml/util/stopwatches.scala create mode 100644 mllib/src/test/scala/org/apache/spark/ml/util/StopwatchSuite.scala diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/stopwatches.scala b/mllib/src/main/scala/org/apache/spark/ml/util/stopwatches.scala new file mode 100644 index 0000000000000..5fdf878a3df72 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/util/stopwatches.scala @@ -0,0 +1,151 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.util + +import scala.collection.mutable + +import org.apache.spark.{Accumulator, SparkContext} + +/** + * Abstract class for stopwatches. + */ +private[spark] abstract class Stopwatch extends Serializable { + + @transient private var running: Boolean = false + private var startTime: Long = _ + + /** + * Name of the stopwatch. + */ + val name: String + + /** + * Starts the stopwatch. + * Throws an exception if the stopwatch is already running. + */ + def start(): Unit = { + assume(!running, "start() called but the stopwatch is already running.") + running = true + startTime = now + } + + /** + * Stops the stopwatch and returns the duration of the last session in milliseconds. + * Throws an exception if the stopwatch is not running. + */ + def stop(): Long = { + assume(running, "stop() called but the stopwatch is not running.") + val duration = now - startTime + add(duration) + running = false + duration + } + + /** + * Checks whether the stopwatch is running. + */ + def isRunning: Boolean = running + + /** + * Returns total elapsed time in milliseconds, not counting the current session if the stopwatch + * is running. + */ + def elapsed(): Long + + /** + * Gets the current time in milliseconds. + */ + protected def now: Long = System.currentTimeMillis() + + /** + * Adds input duration to total elapsed time. + */ + protected def add(duration: Long): Unit +} + +/** + * A local [[Stopwatch]]. + */ +private[spark] class LocalStopwatch(override val name: String) extends Stopwatch { + + private var elapsedTime: Long = 0L + + override def elapsed(): Long = elapsedTime + + override protected def add(duration: Long): Unit = { + elapsedTime += duration + } +} + +/** + * A distributed [[Stopwatch]] using Spark accumulator. + * @param sc SparkContext + */ +private[spark] class DistributedStopwatch( + sc: SparkContext, + override val name: String) extends Stopwatch { + + private val elapsedTime: Accumulator[Long] = sc.accumulator(0L, s"DistributedStopwatch($name)") + + override def elapsed(): Long = elapsedTime.value + + override protected def add(duration: Long): Unit = { + elapsedTime += duration + } +} + +/** + * A multiple stopwatch that contains local and distributed stopwatches. + * @param sc SparkContext + */ +private[spark] class MultiStopwatch(@transient private val sc: SparkContext) extends Serializable { + + private val stopwatches: mutable.Map[String, Stopwatch] = mutable.Map.empty + + /** + * Adds a local stopwatch. + * @param name stopwatch name + */ + def addLocal(name: String): this.type = { + require(!stopwatches.contains(name), s"Stopwatch with name $name already exists.") + stopwatches(name) = new LocalStopwatch(name) + this + } + + /** + * Adds a distributed stopwatch. + * @param name stopwatch name + */ + def addDistributed(name: String): this.type = { + require(!stopwatches.contains(name), s"Stopwatch with name $name already exists.") + stopwatches(name) = new DistributedStopwatch(sc, name) + this + } + + /** + * Gets a stopwatch. + * @param name stopwatch name + */ + def apply(name: String): Stopwatch = stopwatches(name) + + override def toString: String = { + stopwatches.values.toArray.sortBy(_.name) + .map(c => s" ${c.name}: ${c.elapsed()}ms") + .mkString("{\n", ",\n", "\n}") + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/StopwatchSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/util/StopwatchSuite.scala new file mode 100644 index 0000000000000..8df6617fe0228 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/util/StopwatchSuite.scala @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.util + +import org.apache.spark.SparkFunSuite +import org.apache.spark.mllib.util.MLlibTestSparkContext + +class StopwatchSuite extends SparkFunSuite with MLlibTestSparkContext { + + private def testStopwatchOnDriver(sw: Stopwatch): Unit = { + assert(sw.name === "sw") + assert(sw.elapsed() === 0L) + assert(!sw.isRunning) + intercept[AssertionError] { + sw.stop() + } + sw.start() + Thread.sleep(50) + val duration = sw.stop() + assert(duration >= 50 && duration < 100) // using a loose upper bound + val elapsed = sw.elapsed() + assert(elapsed === duration) + sw.start() + Thread.sleep(50) + val duration2 = sw.stop() + assert(duration2 >= 50 && duration2 < 100) + val elapsed2 = sw.elapsed() + assert(elapsed2 === duration + duration2) + sw.start() + assert(sw.isRunning) + intercept[AssertionError] { + sw.start() + } + } + + test("LocalStopwatch") { + val sw = new LocalStopwatch("sw") + testStopwatchOnDriver(sw) + } + + test("DistributedStopwatch on driver") { + val sw = new DistributedStopwatch(sc, "sw") + testStopwatchOnDriver(sw) + } + + test("DistributedStopwatch on executors") { + val sw = new DistributedStopwatch(sc, "sw") + val rdd = sc.parallelize(0 until 4, 4) + rdd.foreach { i => + sw.start() + Thread.sleep(50) + sw.stop() + } + assert(!sw.isRunning) + val elapsed = sw.elapsed() + assert(elapsed >= 200 && elapsed < 400) // using a loose upper bound + } + + test("MultiStopwatch") { + val sw = new MultiStopwatch(sc) + .addLocal("local") + .addDistributed("spark") + assert(sw("local").name === "local") + assert(sw("spark").name === "spark") + intercept[NoSuchElementException] { + sw("some") + } + assert(sw.toString === "{\n local: 0ms,\n spark: 0ms\n}") + sw("local").start() + sw("spark").start() + Thread.sleep(50) + sw("local").stop() + Thread.sleep(50) + sw("spark").stop() + val localElapsed = sw("local").elapsed() + val sparkElapsed = sw("spark").elapsed() + assert(localElapsed >= 50 && localElapsed < 100) + assert(sparkElapsed >= 100 && sparkElapsed < 200) + assert(sw.toString === + s"{\n local: ${localElapsed}ms,\n spark: ${sparkElapsed}ms\n}") + val rdd = sc.parallelize(0 until 4, 4) + rdd.foreach { i => + sw("local").start() + sw("spark").start() + Thread.sleep(50) + sw("spark").stop() + sw("local").stop() + } + val localElapsed2 = sw("local").elapsed() + assert(localElapsed2 === localElapsed) + val sparkElapsed2 = sw("spark").elapsed() + assert(sparkElapsed2 >= 300 && sparkElapsed2 < 600) + } +} From 9c64a75bfc5e2566d1b4cd0d9b4585a818086ca6 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Wed, 15 Jul 2015 21:08:30 -0700 Subject: [PATCH 25/27] [SPARK-9060] [SQL] Revert SPARK-8359, SPARK-8800, and SPARK-8677 JIRA: https://issues.apache.org/jira/browse/SPARK-9060 This PR reverts: * https://github.com/apache/spark/commit/31bd30687bc29c0e457c37308d489ae2b6e5b72a (SPARK-8359) * https://github.com/apache/spark/commit/24fda7381171738cbbbacb5965393b660763e562 (SPARK-8677) * https://github.com/apache/spark/commit/4b5cfc988f23988c2334882a255d494fc93d252e (SPARK-8800) Author: Yin Huai Closes #7426 from yhuai/SPARK-9060 and squashes the following commits: 651264d [Yin Huai] Revert "[SPARK-8359] [SQL] Fix incorrect decimal precision after multiplication" cfda7e4 [Yin Huai] Revert "[SPARK-8677] [SQL] Fix non-terminating decimal expansion for decimal divide operation" 2de9afe [Yin Huai] Revert "[SPARK-8800] [SQL] Fix inaccurate precision/scale of Decimal division operation" --- .../org/apache/spark/sql/types/Decimal.scala | 21 ++----------------- .../sql/types/decimal/DecimalSuite.scala | 18 ---------------- 2 files changed, 2 insertions(+), 37 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala index f5bd068d60dc4..a85af9e04aedb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.types -import java.math.{MathContext, RoundingMode} - import org.apache.spark.annotation.DeveloperApi /** @@ -138,14 +136,6 @@ final class Decimal extends Ordered[Decimal] with Serializable { } def toBigDecimal: BigDecimal = { - if (decimalVal.ne(null)) { - decimalVal(MathContext.UNLIMITED) - } else { - BigDecimal(longVal, _scale)(MathContext.UNLIMITED) - } - } - - def toLimitedBigDecimal: BigDecimal = { if (decimalVal.ne(null)) { decimalVal } else { @@ -273,15 +263,8 @@ final class Decimal extends Ordered[Decimal] with Serializable { def * (that: Decimal): Decimal = Decimal(toBigDecimal * that.toBigDecimal) - def / (that: Decimal): Decimal = { - if (that.isZero) { - null - } else { - // To avoid non-terminating decimal expansion problem, we get scala's BigDecimal with limited - // precision and scala. - Decimal(toLimitedBigDecimal / that.toLimitedBigDecimal) - } - } + def / (that: Decimal): Decimal = + if (that.isZero) null else Decimal(toBigDecimal / that.toBigDecimal) def % (that: Decimal): Decimal = if (that.isZero) null else Decimal(toBigDecimal % that.toBigDecimal) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/decimal/DecimalSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/decimal/DecimalSuite.scala index f0c849d1a1564..1d297beb3868d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/decimal/DecimalSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/decimal/DecimalSuite.scala @@ -171,22 +171,4 @@ class DecimalSuite extends SparkFunSuite with PrivateMethodTester { assert(new Decimal().set(100L, 10, 0).toUnscaledLong === 100L) assert(Decimal(Long.MaxValue, 100, 0).toUnscaledLong === Long.MaxValue) } - - test("accurate precision after multiplication") { - val decimal = (Decimal(Long.MaxValue, 38, 0) * Decimal(Long.MaxValue, 38, 0)).toJavaBigDecimal - assert(decimal.unscaledValue.toString === "85070591730234615847396907784232501249") - } - - test("fix non-terminating decimal expansion problem") { - val decimal = Decimal(1.0, 10, 3) / Decimal(3.0, 10, 3) - // The difference between decimal should not be more than 0.001. - assert(decimal.toDouble - 0.333 < 0.001) - } - - test("fix loss of precision/scale when doing division operation") { - val a = Decimal(2) / Decimal(3) - assert(a.toDouble < 1.0 && a.toDouble > 0.6) - val b = Decimal(1) / Decimal(8) - assert(b.toDouble === 0.125) - } } From 42dea3acf90ec506a0b79720b55ae1d753cc7544 Mon Sep 17 00:00:00 2001 From: Cheng Hao Date: Wed, 15 Jul 2015 21:47:21 -0700 Subject: [PATCH 26/27] [SPARK-8245][SQL] FormatNumber/Length Support for Expression - `BinaryType` for `Length` - `FormatNumber` Author: Cheng Hao Closes #7034 from chenghao-intel/expression and squashes the following commits: e534b87 [Cheng Hao] python api style issue 601bbf5 [Cheng Hao] add python API support 3ebe288 [Cheng Hao] update as feedback 52274f7 [Cheng Hao] add support for udf_format_number and length for binary --- python/pyspark/sql/functions.py | 25 ++++- .../catalyst/analysis/FunctionRegistry.scala | 5 +- .../expressions/stringOperations.scala | 94 +++++++++++++++++-- .../expressions/StringFunctionsSuite.scala | 53 +++++++++-- .../org/apache/spark/sql/functions.scala | 32 ++++++- .../spark/sql/DataFrameFunctionsSuite.scala | 93 +++++++++++++++--- 6 files changed, 261 insertions(+), 41 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index dca39fa833435..e0816b3e654bc 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -39,6 +39,8 @@ 'coalesce', 'countDistinct', 'explode', + 'format_number', + 'length', 'log2', 'md5', 'monotonicallyIncreasingId', @@ -47,7 +49,6 @@ 'sha1', 'sha2', 'sparkPartitionId', - 'strlen', 'struct', 'udf', 'when'] @@ -506,14 +507,28 @@ def sparkPartitionId(): @ignore_unicode_prefix @since(1.5) -def strlen(col): - """Calculates the length of a string expression. +def length(col): + """Calculates the length of a string or binary expression. - >>> sqlContext.createDataFrame([('ABC',)], ['a']).select(strlen('a').alias('length')).collect() + >>> sqlContext.createDataFrame([('ABC',)], ['a']).select(length('a').alias('length')).collect() [Row(length=3)] """ sc = SparkContext._active_spark_context - return Column(sc._jvm.functions.strlen(_to_java_column(col))) + return Column(sc._jvm.functions.length(_to_java_column(col))) + + +@ignore_unicode_prefix +@since(1.5) +def format_number(col, d): + """Formats the number X to a format like '#,###,###.##', rounded to d decimal places, + and returns the result as a string. + :param col: the column name of the numeric value to be formatted + :param d: the N decimal places + >>> sqlContext.createDataFrame([(5,)], ['a']).select(format_number('a', 4).alias('v')).collect() + [Row(v=u'5.0000')] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.format_number(_to_java_column(col), d)) @ignore_unicode_prefix diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index d2678ce860701..e0beafe710079 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -152,11 +152,12 @@ object FunctionRegistry { expression[Base64]("base64"), expression[Encode]("encode"), expression[Decode]("decode"), - expression[StringInstr]("instr"), + expression[FormatNumber]("format_number"), expression[Lower]("lcase"), expression[Lower]("lower"), - expression[StringLength]("length"), + expression[Length]("length"), expression[Levenshtein]("levenshtein"), + expression[StringInstr]("instr"), expression[StringLocate]("locate"), expression[StringLPad]("lpad"), expression[StringTrimLeft]("ltrim"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index 03b55ce5fe7cc..c64afe7b3f19a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -17,11 +17,10 @@ package org.apache.spark.sql.catalyst.expressions +import java.text.DecimalFormat import java.util.Locale import java.util.regex.Pattern -import org.apache.commons.lang3.StringUtils - import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.UnresolvedException import org.apache.spark.sql.catalyst.expressions.codegen._ @@ -553,17 +552,22 @@ case class Substring(str: Expression, pos: Expression, len: Expression) } /** - * A function that return the length of the given string expression. + * A function that return the length of the given string or binary expression. */ -case class StringLength(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { +case class Length(child: Expression) extends UnaryExpression with ExpectsInputTypes { override def dataType: DataType = IntegerType - override def inputTypes: Seq[DataType] = Seq(StringType) + override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(StringType, BinaryType)) - protected override def nullSafeEval(string: Any): Any = - string.asInstanceOf[UTF8String].numChars + protected override def nullSafeEval(value: Any): Any = child.dataType match { + case StringType => value.asInstanceOf[UTF8String].numChars + case BinaryType => value.asInstanceOf[Array[Byte]].length + } override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - defineCodeGen(ctx, ev, c => s"($c).numChars()") + child.dataType match { + case StringType => defineCodeGen(ctx, ev, c => s"($c).numChars()") + case BinaryType => defineCodeGen(ctx, ev, c => s"($c).length") + } } override def prettyName: String = "length" @@ -668,3 +672,77 @@ case class Encode(value: Expression, charset: Expression) } } +/** + * Formats the number X to a format like '#,###,###.##', rounded to D decimal places, + * and returns the result as a string. If D is 0, the result has no decimal point or + * fractional part. + */ +case class FormatNumber(x: Expression, d: Expression) + extends BinaryExpression with ExpectsInputTypes { + + override def left: Expression = x + override def right: Expression = d + override def dataType: DataType = StringType + override def inputTypes: Seq[AbstractDataType] = Seq(NumericType, IntegerType) + + // Associated with the pattern, for the last d value, and we will update the + // pattern (DecimalFormat) once the new coming d value differ with the last one. + @transient + private var lastDValue: Int = -100 + + // A cached DecimalFormat, for performance concern, we will change it + // only if the d value changed. + @transient + private val pattern: StringBuffer = new StringBuffer() + + @transient + private val numberFormat: DecimalFormat = new DecimalFormat("") + + override def eval(input: InternalRow): Any = { + val xObject = x.eval(input) + if (xObject == null) { + return null + } + + val dObject = d.eval(input) + + if (dObject == null || dObject.asInstanceOf[Int] < 0) { + return null + } + val dValue = dObject.asInstanceOf[Int] + + if (dValue != lastDValue) { + // construct a new DecimalFormat only if a new dValue + pattern.delete(0, pattern.length()) + pattern.append("#,###,###,###,###,###,##0") + + // decimal place + if (dValue > 0) { + pattern.append(".") + + var i = 0 + while (i < dValue) { + i += 1 + pattern.append("0") + } + } + val dFormat = new DecimalFormat(pattern.toString()) + lastDValue = dValue; + numberFormat.applyPattern(dFormat.toPattern()) + } + + x.dataType match { + case ByteType => UTF8String.fromString(numberFormat.format(xObject.asInstanceOf[Byte])) + case ShortType => UTF8String.fromString(numberFormat.format(xObject.asInstanceOf[Short])) + case FloatType => UTF8String.fromString(numberFormat.format(xObject.asInstanceOf[Float])) + case IntegerType => UTF8String.fromString(numberFormat.format(xObject.asInstanceOf[Int])) + case LongType => UTF8String.fromString(numberFormat.format(xObject.asInstanceOf[Long])) + case DoubleType => UTF8String.fromString(numberFormat.format(xObject.asInstanceOf[Double])) + case _: DecimalType => + UTF8String.fromString(numberFormat.format(xObject.asInstanceOf[Decimal].toJavaBigDecimal)) + } + } + + override def prettyName: String = "format_number" +} + diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala index b19f4ee37a109..5d7763bedf6bd 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.types.{BinaryType, IntegerType, StringType} +import org.apache.spark.sql.types._ class StringFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { @@ -216,15 +216,6 @@ class StringFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { } } - test("length for string") { - val a = 'a.string.at(0) - checkEvaluation(StringLength(Literal("abc")), 3, create_row("abdef")) - checkEvaluation(StringLength(a), 5, create_row("abdef")) - checkEvaluation(StringLength(a), 0, create_row("")) - checkEvaluation(StringLength(a), null, create_row(null)) - checkEvaluation(StringLength(Literal.create(null, StringType)), null, create_row("abdef")) - } - test("ascii for string") { val a = 'a.string.at(0) checkEvaluation(Ascii(Literal("efg")), 101, create_row("abdef")) @@ -426,4 +417,46 @@ class StringFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation( StringSplit(s1, s2), Seq("aa", "bb", "cc"), row1) } + + test("length for string / binary") { + val a = 'a.string.at(0) + val b = 'b.binary.at(0) + val bytes = Array[Byte](1, 2, 3, 1, 2) + val string = "abdef" + + // scalastyle:off + // non ascii characters are not allowed in the source code, so we disable the scalastyle. + checkEvaluation(Length(Literal("a花花c")), 4, create_row(string)) + // scalastyle:on + checkEvaluation(Length(Literal(bytes)), 5, create_row(Array[Byte]())) + + checkEvaluation(Length(a), 5, create_row(string)) + checkEvaluation(Length(b), 5, create_row(bytes)) + + checkEvaluation(Length(a), 0, create_row("")) + checkEvaluation(Length(b), 0, create_row(Array[Byte]())) + + checkEvaluation(Length(a), null, create_row(null)) + checkEvaluation(Length(b), null, create_row(null)) + + checkEvaluation(Length(Literal.create(null, StringType)), null, create_row(string)) + checkEvaluation(Length(Literal.create(null, BinaryType)), null, create_row(bytes)) + } + + test("number format") { + checkEvaluation(FormatNumber(Literal(4.asInstanceOf[Byte]), Literal(3)), "4.000") + checkEvaluation(FormatNumber(Literal(4.asInstanceOf[Short]), Literal(3)), "4.000") + checkEvaluation(FormatNumber(Literal(4.0f), Literal(3)), "4.000") + checkEvaluation(FormatNumber(Literal(4), Literal(3)), "4.000") + checkEvaluation(FormatNumber(Literal(12831273.23481d), Literal(3)), "12,831,273.235") + checkEvaluation(FormatNumber(Literal(12831273.83421d), Literal(0)), "12,831,274") + checkEvaluation(FormatNumber(Literal(123123324123L), Literal(3)), "123,123,324,123.000") + checkEvaluation(FormatNumber(Literal(123123324123L), Literal(-1)), null) + checkEvaluation( + FormatNumber( + Literal(Decimal(123123324123L) * Decimal(123123.21234d)), Literal(4)), + "15,159,339,180,002,773.2778") + checkEvaluation(FormatNumber(Literal.create(null, IntegerType), Literal(3)), null) + checkEvaluation(FormatNumber(Literal.create(null, NullType), Literal(3)), null) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index c7deaca8437a1..d6da284a4c788 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1685,20 +1685,44 @@ object functions { ////////////////////////////////////////////////////////////////////////////////////////////// /** - * Computes the length of a given string value. + * Computes the length of a given string / binary value. * * @group string_funcs * @since 1.5.0 */ - def strlen(e: Column): Column = StringLength(e.expr) + def length(e: Column): Column = Length(e.expr) /** - * Computes the length of a given string column. + * Computes the length of a given string / binary column. * * @group string_funcs * @since 1.5.0 */ - def strlen(columnName: String): Column = strlen(Column(columnName)) + def length(columnName: String): Column = length(Column(columnName)) + + /** + * Formats the number X to a format like '#,###,###.##', rounded to d decimal places, + * and returns the result as a string. + * If d is 0, the result has no decimal point or fractional part. + * If d < 0, the result will be null. + * + * @group string_funcs + * @since 1.5.0 + */ + def format_number(x: Column, d: Int): Column = FormatNumber(x.expr, lit(d).expr) + + /** + * Formats the number X to a format like '#,###,###.##', rounded to d decimal places, + * and returns the result as a string. + * If d is 0, the result has no decimal point or fractional part. + * If d < 0, the result will be null. + * + * @group string_funcs + * @since 1.5.0 + */ + def format_number(columnXName: String, d: Int): Column = { + format_number(Column(columnXName), d) + } /** * Computes the Levenshtein distance of the two given strings. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 70bd78737f69c..6dccdd857b453 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -208,17 +208,6 @@ class DataFrameFunctionsSuite extends QueryTest { Row(2743272264L, 2180413220L)) } - test("string length function") { - val df = Seq(("abc", "")).toDF("a", "b") - checkAnswer( - df.select(strlen($"a"), strlen("b")), - Row(3, 0)) - - checkAnswer( - df.selectExpr("length(a)", "length(b)"), - Row(3, 0)) - } - test("Levenshtein distance") { val df = Seq(("kitten", "sitting"), ("frog", "fog")).toDF("l", "r") checkAnswer(df.select(levenshtein("l", "r")), Seq(Row(3), Row(1))) @@ -433,11 +422,91 @@ class DataFrameFunctionsSuite extends QueryTest { val doubleData = Seq((7.2, 4.1)).toDF("a", "b") checkAnswer( doubleData.select(pmod('a, 'b)), - Seq(Row(3.1000000000000005)) // same as hive + Seq(Row(3.1000000000000005)) // same as hive ) checkAnswer( doubleData.select(pmod(lit(2), lit(Int.MaxValue))), Seq(Row(2)) ) } + + test("string / binary length function") { + val df = Seq(("123", Array[Byte](1, 2, 3, 4), 123)).toDF("a", "b", "c") + checkAnswer( + df.select(length($"a"), length("a"), length($"b"), length("b")), + Row(3, 3, 4, 4)) + + checkAnswer( + df.selectExpr("length(a)", "length(b)"), + Row(3, 4)) + + intercept[AnalysisException] { + checkAnswer( + df.selectExpr("length(c)"), // int type of the argument is unacceptable + Row("5.0000")) + } + } + + test("number format function") { + val tuple = + ("aa", 1.asInstanceOf[Byte], 2.asInstanceOf[Short], + 3.13223f, 4, 5L, 6.48173d, Decimal(7.128381)) + val df = + Seq(tuple) + .toDF( + "a", // string "aa" + "b", // byte 1 + "c", // short 2 + "d", // float 3.13223f + "e", // integer 4 + "f", // long 5L + "g", // double 6.48173d + "h") // decimal 7.128381 + + checkAnswer( + df.select( + format_number($"f", 4), + format_number("f", 4)), + Row("5.0000", "5.0000")) + + checkAnswer( + df.selectExpr("format_number(b, e)"), // convert the 1st argument to integer + Row("1.0000")) + + checkAnswer( + df.selectExpr("format_number(c, e)"), // convert the 1st argument to integer + Row("2.0000")) + + checkAnswer( + df.selectExpr("format_number(d, e)"), // convert the 1st argument to double + Row("3.1322")) + + checkAnswer( + df.selectExpr("format_number(e, e)"), // not convert anything + Row("4.0000")) + + checkAnswer( + df.selectExpr("format_number(f, e)"), // not convert anything + Row("5.0000")) + + checkAnswer( + df.selectExpr("format_number(g, e)"), // not convert anything + Row("6.4817")) + + checkAnswer( + df.selectExpr("format_number(h, e)"), // not convert anything + Row("7.1284")) + + intercept[AnalysisException] { + checkAnswer( + df.selectExpr("format_number(a, e)"), // string type of the 1st argument is unacceptable + Row("5.0000")) + } + + intercept[AnalysisException] { + checkAnswer( + df.selectExpr("format_number(e, g)"), // decimal type of the 2nd argument is unacceptable + Row("5.0000")) + } + } } From ba33096846dc8061e97a7bf8f3b46f899d530159 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 15 Jul 2015 22:27:39 -0700 Subject: [PATCH 27/27] [SPARK-9068][SQL] refactor the implicit type cast code based on https://github.com/apache/spark/pull/7348 Author: Wenchen Fan Closes #7420 from cloud-fan/type-check and squashes the following commits: 7633fa9 [Wenchen Fan] revert fe169b0 [Wenchen Fan] improve test 03b70da [Wenchen Fan] enhance implicit type cast --- .../catalyst/analysis/HiveTypeCoercion.scala | 33 +++----- .../sql/catalyst/expressions/Expression.scala | 20 +++-- .../sql/catalyst/expressions/arithmetic.scala | 2 - .../sql/catalyst/expressions/bitwise.scala | 8 +- .../catalyst/expressions/conditionals.scala | 4 +- .../spark/sql/types/AbstractDataType.scala | 45 +++-------- .../apache/spark/sql/types/ArrayType.scala | 2 +- .../org/apache/spark/sql/types/DataType.scala | 2 +- .../apache/spark/sql/types/DecimalType.scala | 2 +- .../org/apache/spark/sql/types/MapType.scala | 2 +- .../apache/spark/sql/types/StructType.scala | 2 +- .../ExpressionTypeCheckingSuite.scala | 75 +++++++++---------- .../analysis/HiveTypeCoercionSuite.scala | 10 +-- 13 files changed, 81 insertions(+), 126 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 25087915b5c35..50db7d21f01ca 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -675,10 +675,10 @@ object HiveTypeCoercion { case b @ BinaryOperator(left, right) if left.dataType != right.dataType => findTightestCommonTypeOfTwo(left.dataType, right.dataType).map { commonType => if (b.inputType.acceptsType(commonType)) { - // If the expression accepts the tighest common type, cast to that. + // If the expression accepts the tightest common type, cast to that. val newLeft = if (left.dataType == commonType) left else Cast(left, commonType) val newRight = if (right.dataType == commonType) right else Cast(right, commonType) - b.makeCopy(Array(newLeft, newRight)) + b.withNewChildren(Seq(newLeft, newRight)) } else { // Otherwise, don't do anything with the expression. b @@ -697,7 +697,7 @@ object HiveTypeCoercion { // general implicit casting. val children: Seq[Expression] = e.children.zip(e.inputTypes).map { case (in, expected) => if (in.dataType == NullType && !expected.acceptsType(NullType)) { - Cast(in, expected.defaultConcreteType) + Literal.create(null, expected.defaultConcreteType) } else { in } @@ -719,27 +719,22 @@ object HiveTypeCoercion { @Nullable val ret: Expression = (inType, expectedType) match { // If the expected type is already a parent of the input type, no need to cast. - case _ if expectedType.isSameType(inType) => e + case _ if expectedType.acceptsType(inType) => e // Cast null type (usually from null literals) into target types case (NullType, target) => Cast(e, target.defaultConcreteType) - // If the function accepts any numeric type (i.e. the ADT `NumericType`) and the input is - // already a number, leave it as is. - case (_: NumericType, NumericType) => e - // If the function accepts any numeric type and the input is a string, we follow the hive // convention and cast that input into a double case (StringType, NumericType) => Cast(e, NumericType.defaultConcreteType) - // Implicit cast among numeric types + // Implicit cast among numeric types. When we reach here, input type is not acceptable. + // If input is a numeric type but not decimal, and we expect a decimal type, // cast the input to unlimited precision decimal. - case (_: NumericType, DecimalType) if !inType.isInstanceOf[DecimalType] => - Cast(e, DecimalType.Unlimited) + case (_: NumericType, DecimalType) => Cast(e, DecimalType.Unlimited) // For any other numeric types, implicitly cast to each other, e.g. long -> int, int -> long - case (_: NumericType, target: NumericType) if e.dataType != target => Cast(e, target) - case (_: NumericType, target: NumericType) => e + case (_: NumericType, target: NumericType) => Cast(e, target) // Implicit cast between date time types case (DateType, TimestampType) => Cast(e, TimestampType) @@ -753,15 +748,9 @@ object HiveTypeCoercion { case (StringType, BinaryType) => Cast(e, BinaryType) case (any, StringType) if any != StringType => Cast(e, StringType) - // Type collection. - // First see if we can find our input type in the type collection. If we can, then just - // use the current expression; otherwise, find the first one we can implicitly cast. - case (_, TypeCollection(types)) => - if (types.exists(_.isSameType(inType))) { - e - } else { - types.flatMap(implicitCast(e, _)).headOption.orNull - } + // When we reach here, input type is not acceptable for any types in this type collection, + // try to find the first one we can implicitly cast. + case (_, TypeCollection(types)) => types.flatMap(implicitCast(e, _)).headOption.orNull // Else, just return the same input expression case _ => null diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 87667316aca67..a655cc8e48ae1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -386,17 +386,15 @@ abstract class BinaryOperator extends BinaryExpression with ExpectsInputTypes { override def inputTypes: Seq[AbstractDataType] = Seq(inputType, inputType) override def checkInputDataTypes(): TypeCheckResult = { - // First call the checker for ExpectsInputTypes, and then check whether left and right have - // the same type. - super.checkInputDataTypes() match { - case TypeCheckResult.TypeCheckSuccess => - if (left.dataType != right.dataType) { - TypeCheckResult.TypeCheckFailure(s"differing types in '$prettyString' " + - s"(${left.dataType.simpleString} and ${right.dataType.simpleString}).") - } else { - TypeCheckResult.TypeCheckSuccess - } - case TypeCheckResult.TypeCheckFailure(msg) => TypeCheckResult.TypeCheckFailure(msg) + // First check whether left and right have the same type, then check if the type is acceptable. + if (left.dataType != right.dataType) { + TypeCheckResult.TypeCheckFailure(s"differing types in '$prettyString' " + + s"(${left.dataType.simpleString} and ${right.dataType.simpleString}).") + } else if (!inputType.acceptsType(left.dataType)) { + TypeCheckResult.TypeCheckFailure(s"'$prettyString' accepts ${inputType.simpleString} type," + + s" not ${left.dataType.simpleString}") + } else { + TypeCheckResult.TypeCheckSuccess } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 394ef556e04a2..382cbe3b84a07 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -320,7 +320,6 @@ case class MaxOf(left: Expression, right: Expression) extends BinaryArithmetic { } override def symbol: String = "max" - override def prettyName: String = symbol } case class MinOf(left: Expression, right: Expression) extends BinaryArithmetic { @@ -375,7 +374,6 @@ case class MinOf(left: Expression, right: Expression) extends BinaryArithmetic { } override def symbol: String = "min" - override def prettyName: String = symbol } case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwise.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwise.scala index af1abbcd2239b..a1e48c4210877 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwise.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwise.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.types._ */ case class BitwiseAnd(left: Expression, right: Expression) extends BinaryArithmetic { - override def inputType: AbstractDataType = TypeCollection.Bitwise + override def inputType: AbstractDataType = IntegralType override def symbol: String = "&" @@ -53,7 +53,7 @@ case class BitwiseAnd(left: Expression, right: Expression) extends BinaryArithme */ case class BitwiseOr(left: Expression, right: Expression) extends BinaryArithmetic { - override def inputType: AbstractDataType = TypeCollection.Bitwise + override def inputType: AbstractDataType = IntegralType override def symbol: String = "|" @@ -78,7 +78,7 @@ case class BitwiseOr(left: Expression, right: Expression) extends BinaryArithmet */ case class BitwiseXor(left: Expression, right: Expression) extends BinaryArithmetic { - override def inputType: AbstractDataType = TypeCollection.Bitwise + override def inputType: AbstractDataType = IntegralType override def symbol: String = "^" @@ -101,7 +101,7 @@ case class BitwiseXor(left: Expression, right: Expression) extends BinaryArithme */ case class BitwiseNot(child: Expression) extends UnaryExpression with ExpectsInputTypes { - override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection.Bitwise) + override def inputTypes: Seq[AbstractDataType] = Seq(IntegralType) override def dataType: DataType = child.dataType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala index c7f039ede26b3..9162b73fe56eb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala @@ -35,8 +35,8 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi TypeCheckResult.TypeCheckFailure( s"type of predicate expression in If should be boolean, not ${predicate.dataType}") } else if (trueValue.dataType != falseValue.dataType) { - TypeCheckResult.TypeCheckFailure( - s"differing types in If (${trueValue.dataType} and ${falseValue.dataType}).") + TypeCheckResult.TypeCheckFailure(s"differing types in '$prettyString' " + + s"(${trueValue.dataType.simpleString} and ${falseValue.dataType.simpleString}).") } else { TypeCheckResult.TypeCheckSuccess } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala index f5715f7a829ff..076d7b5a5118d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala @@ -34,32 +34,18 @@ private[sql] abstract class AbstractDataType { private[sql] def defaultConcreteType: DataType /** - * Returns true if this data type is the same type as `other`. This is different that equality - * as equality will also consider data type parametrization, such as decimal precision. + * Returns true if `other` is an acceptable input type for a function that expects this, + * possibly abstract DataType. * * {{{ * // this should return true - * DecimalType.isSameType(DecimalType(10, 2)) - * - * // this should return false - * NumericType.isSameType(DecimalType(10, 2)) - * }}} - */ - private[sql] def isSameType(other: DataType): Boolean - - /** - * Returns true if `other` is an acceptable input type for a function that expectes this, - * possibly abstract, DataType. - * - * {{{ - * // this should return true - * DecimalType.isSameType(DecimalType(10, 2)) + * DecimalType.acceptsType(DecimalType(10, 2)) * * // this should return true as well * NumericType.acceptsType(DecimalType(10, 2)) * }}} */ - private[sql] def acceptsType(other: DataType): Boolean = isSameType(other) + private[sql] def acceptsType(other: DataType): Boolean /** Readable string representation for the type. */ private[sql] def simpleString: String @@ -83,10 +69,8 @@ private[sql] class TypeCollection(private val types: Seq[AbstractDataType]) override private[sql] def defaultConcreteType: DataType = types.head.defaultConcreteType - override private[sql] def isSameType(other: DataType): Boolean = false - override private[sql] def acceptsType(other: DataType): Boolean = - types.exists(_.isSameType(other)) + types.exists(_.acceptsType(other)) override private[sql] def simpleString: String = { types.map(_.simpleString).mkString("(", " or ", ")") @@ -107,13 +91,6 @@ private[sql] object TypeCollection { TimestampType, DateType, StringType, BinaryType) - /** - * Types that can be used in bitwise operations. - */ - val Bitwise = TypeCollection( - BooleanType, - ByteType, ShortType, IntegerType, LongType) - def apply(types: AbstractDataType*): TypeCollection = new TypeCollection(types) def unapply(typ: AbstractDataType): Option[Seq[AbstractDataType]] = typ match { @@ -134,8 +111,6 @@ protected[sql] object AnyDataType extends AbstractDataType { override private[sql] def simpleString: String = "any" - override private[sql] def isSameType(other: DataType): Boolean = false - override private[sql] def acceptsType(other: DataType): Boolean = true } @@ -183,13 +158,11 @@ private[sql] object NumericType extends AbstractDataType { override private[sql] def simpleString: String = "numeric" - override private[sql] def isSameType(other: DataType): Boolean = false - override private[sql] def acceptsType(other: DataType): Boolean = other.isInstanceOf[NumericType] } -private[sql] object IntegralType { +private[sql] object IntegralType extends AbstractDataType { /** * Enables matching against IntegralType for expressions: * {{{ @@ -198,6 +171,12 @@ private[sql] object IntegralType { * }}} */ def unapply(e: Expression): Boolean = e.dataType.isInstanceOf[IntegralType] + + override private[sql] def defaultConcreteType: DataType = IntegerType + + override private[sql] def simpleString: String = "integral" + + override private[sql] def acceptsType(other: DataType): Boolean = other.isInstanceOf[IntegralType] } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala index 76ca7a84c1d1a..5094058164b2f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala @@ -28,7 +28,7 @@ object ArrayType extends AbstractDataType { override private[sql] def defaultConcreteType: DataType = ArrayType(NullType, containsNull = true) - override private[sql] def isSameType(other: DataType): Boolean = { + override private[sql] def acceptsType(other: DataType): Boolean = { other.isInstanceOf[ArrayType] } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala index da83a7f0ba379..2d133eea19fe0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala @@ -79,7 +79,7 @@ abstract class DataType extends AbstractDataType { override private[sql] def defaultConcreteType: DataType = this - override private[sql] def isSameType(other: DataType): Boolean = this == other + override private[sql] def acceptsType(other: DataType): Boolean = this == other } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala index a1cafeab1704d..377c75f6e85a5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala @@ -86,7 +86,7 @@ object DecimalType extends AbstractDataType { override private[sql] def defaultConcreteType: DataType = Unlimited - override private[sql] def isSameType(other: DataType): Boolean = { + override private[sql] def acceptsType(other: DataType): Boolean = { other.isInstanceOf[DecimalType] } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala index ddead10bc2171..ac34b642827ca 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala @@ -71,7 +71,7 @@ object MapType extends AbstractDataType { override private[sql] def defaultConcreteType: DataType = apply(NullType, NullType) - override private[sql] def isSameType(other: DataType): Boolean = { + override private[sql] def acceptsType(other: DataType): Boolean = { other.isInstanceOf[MapType] } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala index b8097403ec3cc..2ef97a427c37e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala @@ -307,7 +307,7 @@ object StructType extends AbstractDataType { override private[sql] def defaultConcreteType: DataType = new StructType - override private[sql] def isSameType(other: DataType): Boolean = { + override private[sql] def acceptsType(other: DataType): Boolean = { other.isInstanceOf[StructType] } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala index a4ce1825cab28..ed0d20e7de80e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.LocalRelation -import org.apache.spark.sql.types.StringType +import org.apache.spark.sql.types.{TypeCollection, StringType} class ExpressionTypeCheckingSuite extends SparkFunSuite { @@ -49,23 +49,16 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { def assertErrorForDifferingTypes(expr: Expression): Unit = { assertError(expr, - s"differing types in '${expr.prettyString}' (int and boolean)") - } - - def assertErrorWithImplicitCast(expr: Expression, errorMessage: String): Unit = { - val e = intercept[AnalysisException] { - assertSuccess(expr) - } - assert(e.getMessage.contains(errorMessage)) + s"differing types in '${expr.prettyString}'") } test("check types for unary arithmetic") { assertError(UnaryMinus('stringField), "expected to be of type numeric") assertError(Abs('stringField), "expected to be of type numeric") - assertError(BitwiseNot('stringField), "type (boolean or tinyint or smallint or int or bigint)") + assertError(BitwiseNot('stringField), "expected to be of type integral") } - ignore("check types for binary arithmetic") { + test("check types for binary arithmetic") { // We will cast String to Double for binary arithmetic assertSuccess(Add('intField, 'stringField)) assertSuccess(Subtract('intField, 'stringField)) @@ -85,21 +78,23 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { assertErrorForDifferingTypes(MaxOf('intField, 'booleanField)) assertErrorForDifferingTypes(MinOf('intField, 'booleanField)) - assertError(Add('booleanField, 'booleanField), "operator + accepts numeric type") - assertError(Subtract('booleanField, 'booleanField), "operator - accepts numeric type") - assertError(Multiply('booleanField, 'booleanField), "operator * accepts numeric type") - assertError(Divide('booleanField, 'booleanField), "operator / accepts numeric type") - assertError(Remainder('booleanField, 'booleanField), "operator % accepts numeric type") + assertError(Add('booleanField, 'booleanField), "accepts numeric type") + assertError(Subtract('booleanField, 'booleanField), "accepts numeric type") + assertError(Multiply('booleanField, 'booleanField), "accepts numeric type") + assertError(Divide('booleanField, 'booleanField), "accepts numeric type") + assertError(Remainder('booleanField, 'booleanField), "accepts numeric type") - assertError(BitwiseAnd('booleanField, 'booleanField), "operator & accepts integral type") - assertError(BitwiseOr('booleanField, 'booleanField), "operator | accepts integral type") - assertError(BitwiseXor('booleanField, 'booleanField), "operator ^ accepts integral type") + assertError(BitwiseAnd('booleanField, 'booleanField), "accepts integral type") + assertError(BitwiseOr('booleanField, 'booleanField), "accepts integral type") + assertError(BitwiseXor('booleanField, 'booleanField), "accepts integral type") - assertError(MaxOf('complexField, 'complexField), "function maxOf accepts non-complex type") - assertError(MinOf('complexField, 'complexField), "function minOf accepts non-complex type") + assertError(MaxOf('complexField, 'complexField), + s"accepts ${TypeCollection.Ordered.simpleString} type") + assertError(MinOf('complexField, 'complexField), + s"accepts ${TypeCollection.Ordered.simpleString} type") } - ignore("check types for predicates") { + test("check types for predicates") { // We will cast String to Double for binary comparison assertSuccess(EqualTo('intField, 'stringField)) assertSuccess(EqualNullSafe('intField, 'stringField)) @@ -112,25 +107,23 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { assertSuccess(EqualTo('intField, 'booleanField)) assertSuccess(EqualNullSafe('intField, 'booleanField)) - assertError(EqualTo('intField, 'complexField), "differing types") - assertError(EqualNullSafe('intField, 'complexField), "differing types") - + assertErrorForDifferingTypes(EqualTo('intField, 'complexField)) + assertErrorForDifferingTypes(EqualNullSafe('intField, 'complexField)) assertErrorForDifferingTypes(LessThan('intField, 'booleanField)) assertErrorForDifferingTypes(LessThanOrEqual('intField, 'booleanField)) assertErrorForDifferingTypes(GreaterThan('intField, 'booleanField)) assertErrorForDifferingTypes(GreaterThanOrEqual('intField, 'booleanField)) - assertError( - LessThan('complexField, 'complexField), "operator < accepts non-complex type") - assertError( - LessThanOrEqual('complexField, 'complexField), "operator <= accepts non-complex type") - assertError( - GreaterThan('complexField, 'complexField), "operator > accepts non-complex type") - assertError( - GreaterThanOrEqual('complexField, 'complexField), "operator >= accepts non-complex type") + assertError(LessThan('complexField, 'complexField), + s"accepts ${TypeCollection.Ordered.simpleString} type") + assertError(LessThanOrEqual('complexField, 'complexField), + s"accepts ${TypeCollection.Ordered.simpleString} type") + assertError(GreaterThan('complexField, 'complexField), + s"accepts ${TypeCollection.Ordered.simpleString} type") + assertError(GreaterThanOrEqual('complexField, 'complexField), + s"accepts ${TypeCollection.Ordered.simpleString} type") - assertError( - If('intField, 'stringField, 'stringField), + assertError(If('intField, 'stringField, 'stringField), "type of predicate expression in If should be boolean") assertErrorForDifferingTypes(If('booleanField, 'intField, 'booleanField)) @@ -180,12 +173,12 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { } test("check types for ROUND") { - assertErrorWithImplicitCast(Round(Literal(null), 'booleanField), - "data type mismatch: argument 2 is expected to be of type int") - assertErrorWithImplicitCast(Round(Literal(null), 'complexField), - "data type mismatch: argument 2 is expected to be of type int") assertSuccess(Round(Literal(null), Literal(null))) - assertError(Round('booleanField, 'intField), - "data type mismatch: argument 1 is expected to be of type numeric") + assertSuccess(Round('intField, Literal(1))) + + assertError(Round('intField, 'intField), "Only foldable Expression is allowed") + assertError(Round('intField, 'booleanField), "expected to be of type int") + assertError(Round('intField, 'complexField), "expected to be of type int") + assertError(Round('booleanField, 'intField), "expected to be of type numeric") } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala index 8e9b20a3ebe42..d0fd033b981c8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala @@ -203,7 +203,7 @@ class HiveTypeCoercionSuite extends PlanTest { ruleTest(HiveTypeCoercion.ImplicitTypeCasts, NumericTypeUnaryExpression(Literal.create(null, NullType)), - NumericTypeUnaryExpression(Cast(Literal.create(null, NullType), DoubleType))) + NumericTypeUnaryExpression(Literal.create(null, DoubleType))) } test("cast NullType for binary operators") { @@ -215,9 +215,7 @@ class HiveTypeCoercionSuite extends PlanTest { ruleTest(HiveTypeCoercion.ImplicitTypeCasts, NumericTypeBinaryOperator(Literal.create(null, NullType), Literal.create(null, NullType)), - NumericTypeBinaryOperator( - Cast(Literal.create(null, NullType), DoubleType), - Cast(Literal.create(null, NullType), DoubleType))) + NumericTypeBinaryOperator(Literal.create(null, DoubleType), Literal.create(null, DoubleType))) } test("coalesce casts") { @@ -345,14 +343,14 @@ object HiveTypeCoercionSuite { } case class AnyTypeBinaryOperator(left: Expression, right: Expression) - extends BinaryOperator with ExpectsInputTypes { + extends BinaryOperator { override def dataType: DataType = NullType override def inputType: AbstractDataType = AnyDataType override def symbol: String = "anytype" } case class NumericTypeBinaryOperator(left: Expression, right: Expression) - extends BinaryOperator with ExpectsInputTypes { + extends BinaryOperator { override def dataType: DataType = NullType override def inputType: AbstractDataType = NumericType override def symbol: String = "numerictype"