Skip to content

Commit

Permalink
enhance implicit type cast
Browse files Browse the repository at this point in the history
  • Loading branch information
cloud-fan committed Jul 15, 2015
1 parent adb33d3 commit 03b70da
Show file tree
Hide file tree
Showing 13 changed files with 70 additions and 109 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -669,7 +669,7 @@ object HiveTypeCoercion {
case b @ BinaryOperator(left, right) if left.dataType != right.dataType =>
findTightestCommonTypeOfTwo(left.dataType, right.dataType).map { commonType =>
if (b.inputType.acceptsType(commonType)) {
// If the expression accepts the tighest common type, cast to that.
// If the expression accepts the tightest common type, cast to that.
val newLeft = if (left.dataType == commonType) left else Cast(left, commonType)
val newRight = if (right.dataType == commonType) right else Cast(right, commonType)
b.makeCopy(Array(newLeft, newRight))
Expand Down Expand Up @@ -713,27 +713,22 @@ object HiveTypeCoercion {
@Nullable val ret: Expression = (inType, expectedType) match {

// If the expected type is already a parent of the input type, no need to cast.
case _ if expectedType.isSameType(inType) => e
case _ if expectedType.acceptsType(inType) => e

// Cast null type (usually from null literals) into target types
case (NullType, target) => Cast(e, target.defaultConcreteType)

// If the function accepts any numeric type (i.e. the ADT `NumericType`) and the input is
// already a number, leave it as is.
case (_: NumericType, NumericType) => e

// If the function accepts any numeric type and the input is a string, we follow the hive
// convention and cast that input into a double
case (StringType, NumericType) => Cast(e, NumericType.defaultConcreteType)

// Implicit cast among numeric types
// Implicit cast among numeric types. When we reach here, input type is not acceptable.

// If input is a numeric type but not decimal, and we expect a decimal type,
// cast the input to unlimited precision decimal.
case (_: NumericType, DecimalType) if !inType.isInstanceOf[DecimalType] =>
Cast(e, DecimalType.Unlimited)
case (_: NumericType, DecimalType) => Cast(e, DecimalType.Unlimited)
// For any other numeric types, implicitly cast to each other, e.g. long -> int, int -> long
case (_: NumericType, target: NumericType) if e.dataType != target => Cast(e, target)
case (_: NumericType, target: NumericType) => e
case (_: NumericType, target: NumericType) => Cast(e, target)

// Implicit cast between date time types
case (DateType, TimestampType) => Cast(e, TimestampType)
Expand All @@ -747,15 +742,9 @@ object HiveTypeCoercion {
case (StringType, BinaryType) => Cast(e, BinaryType)
case (any, StringType) if any != StringType => Cast(e, StringType)

// Type collection.
// First see if we can find our input type in the type collection. If we can, then just
// use the current expression; otherwise, find the first one we can implicitly cast.
case (_, TypeCollection(types)) =>
if (types.exists(_.isSameType(inType))) {
e
} else {
types.flatMap(implicitCast(e, _)).headOption.orNull
}
// When we reach here, input type is not acceptable for any types in this type collection,
// try to find the first one we can implicitly cast.
case (_, TypeCollection(types)) => types.flatMap(implicitCast(e, _)).headOption.orNull

// Else, just return the same input expression
case _ => null
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,7 @@ abstract class BinaryExpression extends Expression with trees.BinaryNode[Express
* 2. Two inputs are expected to the be same type. If the two inputs have different types,
* the analyzer will find the tightest common type and do the proper type casting.
*/
abstract class BinaryOperator extends BinaryExpression with ExpectsInputTypes {
abstract class BinaryOperator extends BinaryExpression {
self: Product =>

/**
Expand All @@ -366,20 +366,16 @@ abstract class BinaryOperator extends BinaryExpression with ExpectsInputTypes {

override def toString: String = s"($left $symbol $right)"

override def inputTypes: Seq[AbstractDataType] = Seq(inputType, inputType)

override def checkInputDataTypes(): TypeCheckResult = {
// First call the checker for ExpectsInputTypes, and then check whether left and right have
// the same type.
super.checkInputDataTypes() match {
case TypeCheckResult.TypeCheckSuccess =>
if (left.dataType != right.dataType) {
TypeCheckResult.TypeCheckFailure(s"differing types in '$prettyString' " +
s"(${left.dataType.simpleString} and ${right.dataType.simpleString}).")
} else {
TypeCheckResult.TypeCheckSuccess
}
case TypeCheckResult.TypeCheckFailure(msg) => TypeCheckResult.TypeCheckFailure(msg)
// First check whether left and right have the same type, then check if the type is acceptable.
if (left.dataType != right.dataType) {
TypeCheckResult.TypeCheckFailure(s"differing types in '$prettyString' " +
s"(${left.dataType.simpleString} and ${right.dataType.simpleString}).")
} else if (!inputType.acceptsType(left.dataType)) {
TypeCheckResult.TypeCheckFailure(s"'$prettyString' accepts ${inputType.simpleString} type," +
s" not ${left.dataType.simpleString}")
} else {
TypeCheckResult.TypeCheckSuccess
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,6 @@ case class MaxOf(left: Expression, right: Expression) extends BinaryArithmetic {
}

override def symbol: String = "max"
override def prettyName: String = symbol
}

case class MinOf(left: Expression, right: Expression) extends BinaryArithmetic {
Expand Down Expand Up @@ -375,5 +374,4 @@ case class MinOf(left: Expression, right: Expression) extends BinaryArithmetic {
}

override def symbol: String = "min"
override def prettyName: String = symbol
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import org.apache.spark.sql.types._
*/
case class BitwiseAnd(left: Expression, right: Expression) extends BinaryArithmetic {

override def inputType: AbstractDataType = TypeCollection.Bitwise
override def inputType: AbstractDataType = IntegralType

override def symbol: String = "&"

Expand All @@ -53,7 +53,7 @@ case class BitwiseAnd(left: Expression, right: Expression) extends BinaryArithme
*/
case class BitwiseOr(left: Expression, right: Expression) extends BinaryArithmetic {

override def inputType: AbstractDataType = TypeCollection.Bitwise
override def inputType: AbstractDataType = IntegralType

override def symbol: String = "|"

Expand All @@ -78,7 +78,7 @@ case class BitwiseOr(left: Expression, right: Expression) extends BinaryArithmet
*/
case class BitwiseXor(left: Expression, right: Expression) extends BinaryArithmetic {

override def inputType: AbstractDataType = TypeCollection.Bitwise
override def inputType: AbstractDataType = IntegralType

override def symbol: String = "^"

Expand All @@ -101,7 +101,7 @@ case class BitwiseXor(left: Expression, right: Expression) extends BinaryArithme
*/
case class BitwiseNot(child: Expression) extends UnaryExpression with ExpectsInputTypes {

override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection.Bitwise)
override def inputTypes: Seq[AbstractDataType] = Seq(IntegralType)

override def dataType: DataType = child.dataType

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi
TypeCheckResult.TypeCheckFailure(
s"type of predicate expression in If should be boolean, not ${predicate.dataType}")
} else if (trueValue.dataType != falseValue.dataType) {
TypeCheckResult.TypeCheckFailure(
s"differing types in If (${trueValue.dataType} and ${falseValue.dataType}).")
TypeCheckResult.TypeCheckFailure(s"differing types in '$prettyString' " +
s"(${trueValue.dataType.simpleString} and ${falseValue.dataType.simpleString}).")
} else {
TypeCheckResult.TypeCheckSuccess
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,32 +34,18 @@ private[sql] abstract class AbstractDataType {
private[sql] def defaultConcreteType: DataType

/**
* Returns true if this data type is the same type as `other`. This is different that equality
* as equality will also consider data type parametrization, such as decimal precision.
* Returns true if `other` is an acceptable input type for a function that expects this,
* possibly abstract DataType.
*
* {{{
* // this should return true
* DecimalType.isSameType(DecimalType(10, 2))
*
* // this should return false
* NumericType.isSameType(DecimalType(10, 2))
* }}}
*/
private[sql] def isSameType(other: DataType): Boolean

/**
* Returns true if `other` is an acceptable input type for a function that expectes this,
* possibly abstract, DataType.
*
* {{{
* // this should return true
* DecimalType.isSameType(DecimalType(10, 2))
* DecimalType.acceptsType(DecimalType(10, 2))
*
* // this should return true as well
* NumericType.acceptsType(DecimalType(10, 2))
* }}}
*/
private[sql] def acceptsType(other: DataType): Boolean = isSameType(other)
private[sql] def acceptsType(other: DataType): Boolean

/** Readable string representation for the type. */
private[sql] def simpleString: String
Expand All @@ -83,10 +69,8 @@ private[sql] class TypeCollection(private val types: Seq[AbstractDataType])

override private[sql] def defaultConcreteType: DataType = types.head.defaultConcreteType

override private[sql] def isSameType(other: DataType): Boolean = false

override private[sql] def acceptsType(other: DataType): Boolean =
types.exists(_.isSameType(other))
types.exists(_.acceptsType(other))

override private[sql] def simpleString: String = {
types.map(_.simpleString).mkString("(", " or ", ")")
Expand All @@ -107,13 +91,6 @@ private[sql] object TypeCollection {
TimestampType, DateType,
StringType, BinaryType)

/**
* Types that can be used in bitwise operations.
*/
val Bitwise = TypeCollection(
BooleanType,
ByteType, ShortType, IntegerType, LongType)

def apply(types: AbstractDataType*): TypeCollection = new TypeCollection(types)

def unapply(typ: AbstractDataType): Option[Seq[AbstractDataType]] = typ match {
Expand All @@ -134,8 +111,6 @@ protected[sql] object AnyDataType extends AbstractDataType {

override private[sql] def simpleString: String = "any"

override private[sql] def isSameType(other: DataType): Boolean = false

override private[sql] def acceptsType(other: DataType): Boolean = true
}

Expand Down Expand Up @@ -183,13 +158,11 @@ private[sql] object NumericType extends AbstractDataType {

override private[sql] def simpleString: String = "numeric"

override private[sql] def isSameType(other: DataType): Boolean = false

override private[sql] def acceptsType(other: DataType): Boolean = other.isInstanceOf[NumericType]
}


private[sql] object IntegralType {
private[sql] object IntegralType extends AbstractDataType {
/**
* Enables matching against IntegralType for expressions:
* {{{
Expand All @@ -198,6 +171,12 @@ private[sql] object IntegralType {
* }}}
*/
def unapply(e: Expression): Boolean = e.dataType.isInstanceOf[IntegralType]

override private[sql] def defaultConcreteType: DataType = IntegerType

override private[sql] def simpleString: String = "integral"

override private[sql] def acceptsType(other: DataType): Boolean = other.isInstanceOf[IntegralType]
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ object ArrayType extends AbstractDataType {

override private[sql] def defaultConcreteType: DataType = ArrayType(NullType, containsNull = true)

override private[sql] def isSameType(other: DataType): Boolean = {
override private[sql] def acceptsType(other: DataType): Boolean = {
other.isInstanceOf[ArrayType]
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ abstract class DataType extends AbstractDataType {

override private[sql] def defaultConcreteType: DataType = this

override private[sql] def isSameType(other: DataType): Boolean = this == other
override private[sql] def acceptsType(other: DataType): Boolean = this == other
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ object DecimalType extends AbstractDataType {

override private[sql] def defaultConcreteType: DataType = Unlimited

override private[sql] def isSameType(other: DataType): Boolean = {
override private[sql] def acceptsType(other: DataType): Boolean = {
other.isInstanceOf[DecimalType]
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ object MapType extends AbstractDataType {

override private[sql] def defaultConcreteType: DataType = apply(NullType, NullType)

override private[sql] def isSameType(other: DataType): Boolean = {
override private[sql] def acceptsType(other: DataType): Boolean = {
other.isInstanceOf[MapType]
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ object StructType extends AbstractDataType {

override private[sql] def defaultConcreteType: DataType = new StructType

override private[sql] def isSameType(other: DataType): Boolean = {
override private[sql] def acceptsType(other: DataType): Boolean = {
other.isInstanceOf[StructType]
}

Expand Down
Loading

0 comments on commit 03b70da

Please sign in to comment.