From ba33096846dc8061e97a7bf8f3b46f899d530159 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 15 Jul 2015 22:27:39 -0700 Subject: [PATCH] [SPARK-9068][SQL] refactor the implicit type cast code based on https://github.com/apache/spark/pull/7348 Author: Wenchen Fan Closes #7420 from cloud-fan/type-check and squashes the following commits: 7633fa9 [Wenchen Fan] revert fe169b0 [Wenchen Fan] improve test 03b70da [Wenchen Fan] enhance implicit type cast --- .../catalyst/analysis/HiveTypeCoercion.scala | 33 +++----- .../sql/catalyst/expressions/Expression.scala | 20 +++-- .../sql/catalyst/expressions/arithmetic.scala | 2 - .../sql/catalyst/expressions/bitwise.scala | 8 +- .../catalyst/expressions/conditionals.scala | 4 +- .../spark/sql/types/AbstractDataType.scala | 45 +++-------- .../apache/spark/sql/types/ArrayType.scala | 2 +- .../org/apache/spark/sql/types/DataType.scala | 2 +- .../apache/spark/sql/types/DecimalType.scala | 2 +- .../org/apache/spark/sql/types/MapType.scala | 2 +- .../apache/spark/sql/types/StructType.scala | 2 +- .../ExpressionTypeCheckingSuite.scala | 75 +++++++++---------- .../analysis/HiveTypeCoercionSuite.scala | 10 +-- 13 files changed, 81 insertions(+), 126 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 25087915b5c35..50db7d21f01ca 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -675,10 +675,10 @@ object HiveTypeCoercion { case b @ BinaryOperator(left, right) if left.dataType != right.dataType => findTightestCommonTypeOfTwo(left.dataType, right.dataType).map { commonType => if (b.inputType.acceptsType(commonType)) { - // If the expression accepts the tighest common type, cast to that. + // If the expression accepts the tightest common type, cast to that. val newLeft = if (left.dataType == commonType) left else Cast(left, commonType) val newRight = if (right.dataType == commonType) right else Cast(right, commonType) - b.makeCopy(Array(newLeft, newRight)) + b.withNewChildren(Seq(newLeft, newRight)) } else { // Otherwise, don't do anything with the expression. b @@ -697,7 +697,7 @@ object HiveTypeCoercion { // general implicit casting. val children: Seq[Expression] = e.children.zip(e.inputTypes).map { case (in, expected) => if (in.dataType == NullType && !expected.acceptsType(NullType)) { - Cast(in, expected.defaultConcreteType) + Literal.create(null, expected.defaultConcreteType) } else { in } @@ -719,27 +719,22 @@ object HiveTypeCoercion { @Nullable val ret: Expression = (inType, expectedType) match { // If the expected type is already a parent of the input type, no need to cast. - case _ if expectedType.isSameType(inType) => e + case _ if expectedType.acceptsType(inType) => e // Cast null type (usually from null literals) into target types case (NullType, target) => Cast(e, target.defaultConcreteType) - // If the function accepts any numeric type (i.e. the ADT `NumericType`) and the input is - // already a number, leave it as is. - case (_: NumericType, NumericType) => e - // If the function accepts any numeric type and the input is a string, we follow the hive // convention and cast that input into a double case (StringType, NumericType) => Cast(e, NumericType.defaultConcreteType) - // Implicit cast among numeric types + // Implicit cast among numeric types. When we reach here, input type is not acceptable. + // If input is a numeric type but not decimal, and we expect a decimal type, // cast the input to unlimited precision decimal. - case (_: NumericType, DecimalType) if !inType.isInstanceOf[DecimalType] => - Cast(e, DecimalType.Unlimited) + case (_: NumericType, DecimalType) => Cast(e, DecimalType.Unlimited) // For any other numeric types, implicitly cast to each other, e.g. long -> int, int -> long - case (_: NumericType, target: NumericType) if e.dataType != target => Cast(e, target) - case (_: NumericType, target: NumericType) => e + case (_: NumericType, target: NumericType) => Cast(e, target) // Implicit cast between date time types case (DateType, TimestampType) => Cast(e, TimestampType) @@ -753,15 +748,9 @@ object HiveTypeCoercion { case (StringType, BinaryType) => Cast(e, BinaryType) case (any, StringType) if any != StringType => Cast(e, StringType) - // Type collection. - // First see if we can find our input type in the type collection. If we can, then just - // use the current expression; otherwise, find the first one we can implicitly cast. - case (_, TypeCollection(types)) => - if (types.exists(_.isSameType(inType))) { - e - } else { - types.flatMap(implicitCast(e, _)).headOption.orNull - } + // When we reach here, input type is not acceptable for any types in this type collection, + // try to find the first one we can implicitly cast. + case (_, TypeCollection(types)) => types.flatMap(implicitCast(e, _)).headOption.orNull // Else, just return the same input expression case _ => null diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 87667316aca67..a655cc8e48ae1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -386,17 +386,15 @@ abstract class BinaryOperator extends BinaryExpression with ExpectsInputTypes { override def inputTypes: Seq[AbstractDataType] = Seq(inputType, inputType) override def checkInputDataTypes(): TypeCheckResult = { - // First call the checker for ExpectsInputTypes, and then check whether left and right have - // the same type. - super.checkInputDataTypes() match { - case TypeCheckResult.TypeCheckSuccess => - if (left.dataType != right.dataType) { - TypeCheckResult.TypeCheckFailure(s"differing types in '$prettyString' " + - s"(${left.dataType.simpleString} and ${right.dataType.simpleString}).") - } else { - TypeCheckResult.TypeCheckSuccess - } - case TypeCheckResult.TypeCheckFailure(msg) => TypeCheckResult.TypeCheckFailure(msg) + // First check whether left and right have the same type, then check if the type is acceptable. + if (left.dataType != right.dataType) { + TypeCheckResult.TypeCheckFailure(s"differing types in '$prettyString' " + + s"(${left.dataType.simpleString} and ${right.dataType.simpleString}).") + } else if (!inputType.acceptsType(left.dataType)) { + TypeCheckResult.TypeCheckFailure(s"'$prettyString' accepts ${inputType.simpleString} type," + + s" not ${left.dataType.simpleString}") + } else { + TypeCheckResult.TypeCheckSuccess } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 394ef556e04a2..382cbe3b84a07 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -320,7 +320,6 @@ case class MaxOf(left: Expression, right: Expression) extends BinaryArithmetic { } override def symbol: String = "max" - override def prettyName: String = symbol } case class MinOf(left: Expression, right: Expression) extends BinaryArithmetic { @@ -375,7 +374,6 @@ case class MinOf(left: Expression, right: Expression) extends BinaryArithmetic { } override def symbol: String = "min" - override def prettyName: String = symbol } case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwise.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwise.scala index af1abbcd2239b..a1e48c4210877 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwise.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwise.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.types._ */ case class BitwiseAnd(left: Expression, right: Expression) extends BinaryArithmetic { - override def inputType: AbstractDataType = TypeCollection.Bitwise + override def inputType: AbstractDataType = IntegralType override def symbol: String = "&" @@ -53,7 +53,7 @@ case class BitwiseAnd(left: Expression, right: Expression) extends BinaryArithme */ case class BitwiseOr(left: Expression, right: Expression) extends BinaryArithmetic { - override def inputType: AbstractDataType = TypeCollection.Bitwise + override def inputType: AbstractDataType = IntegralType override def symbol: String = "|" @@ -78,7 +78,7 @@ case class BitwiseOr(left: Expression, right: Expression) extends BinaryArithmet */ case class BitwiseXor(left: Expression, right: Expression) extends BinaryArithmetic { - override def inputType: AbstractDataType = TypeCollection.Bitwise + override def inputType: AbstractDataType = IntegralType override def symbol: String = "^" @@ -101,7 +101,7 @@ case class BitwiseXor(left: Expression, right: Expression) extends BinaryArithme */ case class BitwiseNot(child: Expression) extends UnaryExpression with ExpectsInputTypes { - override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection.Bitwise) + override def inputTypes: Seq[AbstractDataType] = Seq(IntegralType) override def dataType: DataType = child.dataType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala index c7f039ede26b3..9162b73fe56eb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala @@ -35,8 +35,8 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi TypeCheckResult.TypeCheckFailure( s"type of predicate expression in If should be boolean, not ${predicate.dataType}") } else if (trueValue.dataType != falseValue.dataType) { - TypeCheckResult.TypeCheckFailure( - s"differing types in If (${trueValue.dataType} and ${falseValue.dataType}).") + TypeCheckResult.TypeCheckFailure(s"differing types in '$prettyString' " + + s"(${trueValue.dataType.simpleString} and ${falseValue.dataType.simpleString}).") } else { TypeCheckResult.TypeCheckSuccess } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala index f5715f7a829ff..076d7b5a5118d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala @@ -34,32 +34,18 @@ private[sql] abstract class AbstractDataType { private[sql] def defaultConcreteType: DataType /** - * Returns true if this data type is the same type as `other`. This is different that equality - * as equality will also consider data type parametrization, such as decimal precision. + * Returns true if `other` is an acceptable input type for a function that expects this, + * possibly abstract DataType. * * {{{ * // this should return true - * DecimalType.isSameType(DecimalType(10, 2)) - * - * // this should return false - * NumericType.isSameType(DecimalType(10, 2)) - * }}} - */ - private[sql] def isSameType(other: DataType): Boolean - - /** - * Returns true if `other` is an acceptable input type for a function that expectes this, - * possibly abstract, DataType. - * - * {{{ - * // this should return true - * DecimalType.isSameType(DecimalType(10, 2)) + * DecimalType.acceptsType(DecimalType(10, 2)) * * // this should return true as well * NumericType.acceptsType(DecimalType(10, 2)) * }}} */ - private[sql] def acceptsType(other: DataType): Boolean = isSameType(other) + private[sql] def acceptsType(other: DataType): Boolean /** Readable string representation for the type. */ private[sql] def simpleString: String @@ -83,10 +69,8 @@ private[sql] class TypeCollection(private val types: Seq[AbstractDataType]) override private[sql] def defaultConcreteType: DataType = types.head.defaultConcreteType - override private[sql] def isSameType(other: DataType): Boolean = false - override private[sql] def acceptsType(other: DataType): Boolean = - types.exists(_.isSameType(other)) + types.exists(_.acceptsType(other)) override private[sql] def simpleString: String = { types.map(_.simpleString).mkString("(", " or ", ")") @@ -107,13 +91,6 @@ private[sql] object TypeCollection { TimestampType, DateType, StringType, BinaryType) - /** - * Types that can be used in bitwise operations. - */ - val Bitwise = TypeCollection( - BooleanType, - ByteType, ShortType, IntegerType, LongType) - def apply(types: AbstractDataType*): TypeCollection = new TypeCollection(types) def unapply(typ: AbstractDataType): Option[Seq[AbstractDataType]] = typ match { @@ -134,8 +111,6 @@ protected[sql] object AnyDataType extends AbstractDataType { override private[sql] def simpleString: String = "any" - override private[sql] def isSameType(other: DataType): Boolean = false - override private[sql] def acceptsType(other: DataType): Boolean = true } @@ -183,13 +158,11 @@ private[sql] object NumericType extends AbstractDataType { override private[sql] def simpleString: String = "numeric" - override private[sql] def isSameType(other: DataType): Boolean = false - override private[sql] def acceptsType(other: DataType): Boolean = other.isInstanceOf[NumericType] } -private[sql] object IntegralType { +private[sql] object IntegralType extends AbstractDataType { /** * Enables matching against IntegralType for expressions: * {{{ @@ -198,6 +171,12 @@ private[sql] object IntegralType { * }}} */ def unapply(e: Expression): Boolean = e.dataType.isInstanceOf[IntegralType] + + override private[sql] def defaultConcreteType: DataType = IntegerType + + override private[sql] def simpleString: String = "integral" + + override private[sql] def acceptsType(other: DataType): Boolean = other.isInstanceOf[IntegralType] } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala index 76ca7a84c1d1a..5094058164b2f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala @@ -28,7 +28,7 @@ object ArrayType extends AbstractDataType { override private[sql] def defaultConcreteType: DataType = ArrayType(NullType, containsNull = true) - override private[sql] def isSameType(other: DataType): Boolean = { + override private[sql] def acceptsType(other: DataType): Boolean = { other.isInstanceOf[ArrayType] } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala index da83a7f0ba379..2d133eea19fe0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala @@ -79,7 +79,7 @@ abstract class DataType extends AbstractDataType { override private[sql] def defaultConcreteType: DataType = this - override private[sql] def isSameType(other: DataType): Boolean = this == other + override private[sql] def acceptsType(other: DataType): Boolean = this == other } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala index a1cafeab1704d..377c75f6e85a5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala @@ -86,7 +86,7 @@ object DecimalType extends AbstractDataType { override private[sql] def defaultConcreteType: DataType = Unlimited - override private[sql] def isSameType(other: DataType): Boolean = { + override private[sql] def acceptsType(other: DataType): Boolean = { other.isInstanceOf[DecimalType] } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala index ddead10bc2171..ac34b642827ca 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala @@ -71,7 +71,7 @@ object MapType extends AbstractDataType { override private[sql] def defaultConcreteType: DataType = apply(NullType, NullType) - override private[sql] def isSameType(other: DataType): Boolean = { + override private[sql] def acceptsType(other: DataType): Boolean = { other.isInstanceOf[MapType] } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala index b8097403ec3cc..2ef97a427c37e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala @@ -307,7 +307,7 @@ object StructType extends AbstractDataType { override private[sql] def defaultConcreteType: DataType = new StructType - override private[sql] def isSameType(other: DataType): Boolean = { + override private[sql] def acceptsType(other: DataType): Boolean = { other.isInstanceOf[StructType] } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala index a4ce1825cab28..ed0d20e7de80e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.LocalRelation -import org.apache.spark.sql.types.StringType +import org.apache.spark.sql.types.{TypeCollection, StringType} class ExpressionTypeCheckingSuite extends SparkFunSuite { @@ -49,23 +49,16 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { def assertErrorForDifferingTypes(expr: Expression): Unit = { assertError(expr, - s"differing types in '${expr.prettyString}' (int and boolean)") - } - - def assertErrorWithImplicitCast(expr: Expression, errorMessage: String): Unit = { - val e = intercept[AnalysisException] { - assertSuccess(expr) - } - assert(e.getMessage.contains(errorMessage)) + s"differing types in '${expr.prettyString}'") } test("check types for unary arithmetic") { assertError(UnaryMinus('stringField), "expected to be of type numeric") assertError(Abs('stringField), "expected to be of type numeric") - assertError(BitwiseNot('stringField), "type (boolean or tinyint or smallint or int or bigint)") + assertError(BitwiseNot('stringField), "expected to be of type integral") } - ignore("check types for binary arithmetic") { + test("check types for binary arithmetic") { // We will cast String to Double for binary arithmetic assertSuccess(Add('intField, 'stringField)) assertSuccess(Subtract('intField, 'stringField)) @@ -85,21 +78,23 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { assertErrorForDifferingTypes(MaxOf('intField, 'booleanField)) assertErrorForDifferingTypes(MinOf('intField, 'booleanField)) - assertError(Add('booleanField, 'booleanField), "operator + accepts numeric type") - assertError(Subtract('booleanField, 'booleanField), "operator - accepts numeric type") - assertError(Multiply('booleanField, 'booleanField), "operator * accepts numeric type") - assertError(Divide('booleanField, 'booleanField), "operator / accepts numeric type") - assertError(Remainder('booleanField, 'booleanField), "operator % accepts numeric type") + assertError(Add('booleanField, 'booleanField), "accepts numeric type") + assertError(Subtract('booleanField, 'booleanField), "accepts numeric type") + assertError(Multiply('booleanField, 'booleanField), "accepts numeric type") + assertError(Divide('booleanField, 'booleanField), "accepts numeric type") + assertError(Remainder('booleanField, 'booleanField), "accepts numeric type") - assertError(BitwiseAnd('booleanField, 'booleanField), "operator & accepts integral type") - assertError(BitwiseOr('booleanField, 'booleanField), "operator | accepts integral type") - assertError(BitwiseXor('booleanField, 'booleanField), "operator ^ accepts integral type") + assertError(BitwiseAnd('booleanField, 'booleanField), "accepts integral type") + assertError(BitwiseOr('booleanField, 'booleanField), "accepts integral type") + assertError(BitwiseXor('booleanField, 'booleanField), "accepts integral type") - assertError(MaxOf('complexField, 'complexField), "function maxOf accepts non-complex type") - assertError(MinOf('complexField, 'complexField), "function minOf accepts non-complex type") + assertError(MaxOf('complexField, 'complexField), + s"accepts ${TypeCollection.Ordered.simpleString} type") + assertError(MinOf('complexField, 'complexField), + s"accepts ${TypeCollection.Ordered.simpleString} type") } - ignore("check types for predicates") { + test("check types for predicates") { // We will cast String to Double for binary comparison assertSuccess(EqualTo('intField, 'stringField)) assertSuccess(EqualNullSafe('intField, 'stringField)) @@ -112,25 +107,23 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { assertSuccess(EqualTo('intField, 'booleanField)) assertSuccess(EqualNullSafe('intField, 'booleanField)) - assertError(EqualTo('intField, 'complexField), "differing types") - assertError(EqualNullSafe('intField, 'complexField), "differing types") - + assertErrorForDifferingTypes(EqualTo('intField, 'complexField)) + assertErrorForDifferingTypes(EqualNullSafe('intField, 'complexField)) assertErrorForDifferingTypes(LessThan('intField, 'booleanField)) assertErrorForDifferingTypes(LessThanOrEqual('intField, 'booleanField)) assertErrorForDifferingTypes(GreaterThan('intField, 'booleanField)) assertErrorForDifferingTypes(GreaterThanOrEqual('intField, 'booleanField)) - assertError( - LessThan('complexField, 'complexField), "operator < accepts non-complex type") - assertError( - LessThanOrEqual('complexField, 'complexField), "operator <= accepts non-complex type") - assertError( - GreaterThan('complexField, 'complexField), "operator > accepts non-complex type") - assertError( - GreaterThanOrEqual('complexField, 'complexField), "operator >= accepts non-complex type") + assertError(LessThan('complexField, 'complexField), + s"accepts ${TypeCollection.Ordered.simpleString} type") + assertError(LessThanOrEqual('complexField, 'complexField), + s"accepts ${TypeCollection.Ordered.simpleString} type") + assertError(GreaterThan('complexField, 'complexField), + s"accepts ${TypeCollection.Ordered.simpleString} type") + assertError(GreaterThanOrEqual('complexField, 'complexField), + s"accepts ${TypeCollection.Ordered.simpleString} type") - assertError( - If('intField, 'stringField, 'stringField), + assertError(If('intField, 'stringField, 'stringField), "type of predicate expression in If should be boolean") assertErrorForDifferingTypes(If('booleanField, 'intField, 'booleanField)) @@ -180,12 +173,12 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { } test("check types for ROUND") { - assertErrorWithImplicitCast(Round(Literal(null), 'booleanField), - "data type mismatch: argument 2 is expected to be of type int") - assertErrorWithImplicitCast(Round(Literal(null), 'complexField), - "data type mismatch: argument 2 is expected to be of type int") assertSuccess(Round(Literal(null), Literal(null))) - assertError(Round('booleanField, 'intField), - "data type mismatch: argument 1 is expected to be of type numeric") + assertSuccess(Round('intField, Literal(1))) + + assertError(Round('intField, 'intField), "Only foldable Expression is allowed") + assertError(Round('intField, 'booleanField), "expected to be of type int") + assertError(Round('intField, 'complexField), "expected to be of type int") + assertError(Round('booleanField, 'intField), "expected to be of type numeric") } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala index 8e9b20a3ebe42..d0fd033b981c8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala @@ -203,7 +203,7 @@ class HiveTypeCoercionSuite extends PlanTest { ruleTest(HiveTypeCoercion.ImplicitTypeCasts, NumericTypeUnaryExpression(Literal.create(null, NullType)), - NumericTypeUnaryExpression(Cast(Literal.create(null, NullType), DoubleType))) + NumericTypeUnaryExpression(Literal.create(null, DoubleType))) } test("cast NullType for binary operators") { @@ -215,9 +215,7 @@ class HiveTypeCoercionSuite extends PlanTest { ruleTest(HiveTypeCoercion.ImplicitTypeCasts, NumericTypeBinaryOperator(Literal.create(null, NullType), Literal.create(null, NullType)), - NumericTypeBinaryOperator( - Cast(Literal.create(null, NullType), DoubleType), - Cast(Literal.create(null, NullType), DoubleType))) + NumericTypeBinaryOperator(Literal.create(null, DoubleType), Literal.create(null, DoubleType))) } test("coalesce casts") { @@ -345,14 +343,14 @@ object HiveTypeCoercionSuite { } case class AnyTypeBinaryOperator(left: Expression, right: Expression) - extends BinaryOperator with ExpectsInputTypes { + extends BinaryOperator { override def dataType: DataType = NullType override def inputType: AbstractDataType = AnyDataType override def symbol: String = "anytype" } case class NumericTypeBinaryOperator(left: Expression, right: Expression) - extends BinaryOperator with ExpectsInputTypes { + extends BinaryOperator { override def dataType: DataType = NullType override def inputType: AbstractDataType = NumericType override def symbol: String = "numerictype"