Skip to content

Commit

Permalink
Fix Constant Folding Bugs & Add More Unittests
Browse files Browse the repository at this point in the history
  • Loading branch information
chenghao-intel committed Apr 30, 2014
1 parent b28e03a commit 27ea3d7
Show file tree
Hide file tree
Showing 5 changed files with 182 additions and 52 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -114,37 +114,37 @@ package object dsl {
def attr = analysis.UnresolvedAttribute(s)

/** Creates a new AttributeReference of type boolean */
def boolean = AttributeReference(s, BooleanType, nullable = false)()
def boolean = AttributeReference(s, BooleanType, nullable = true)()

/** Creates a new AttributeReference of type byte */
def byte = AttributeReference(s, ByteType, nullable = false)()
def byte = AttributeReference(s, ByteType, nullable = true)()

/** Creates a new AttributeReference of type short */
def short = AttributeReference(s, ShortType, nullable = false)()
def short = AttributeReference(s, ShortType, nullable = true)()

/** Creates a new AttributeReference of type int */
def int = AttributeReference(s, IntegerType, nullable = false)()
def int = AttributeReference(s, IntegerType, nullable = true)()

/** Creates a new AttributeReference of type long */
def long = AttributeReference(s, LongType, nullable = false)()
def long = AttributeReference(s, LongType, nullable = true)()

/** Creates a new AttributeReference of type float */
def float = AttributeReference(s, FloatType, nullable = false)()
def float = AttributeReference(s, FloatType, nullable = true)()

/** Creates a new AttributeReference of type double */
def double = AttributeReference(s, DoubleType, nullable = false)()
def double = AttributeReference(s, DoubleType, nullable = true)()

/** Creates a new AttributeReference of type string */
def string = AttributeReference(s, StringType, nullable = false)()
def string = AttributeReference(s, StringType, nullable = true)()

/** Creates a new AttributeReference of type decimal */
def decimal = AttributeReference(s, DecimalType, nullable = false)()
def decimal = AttributeReference(s, DecimalType, nullable = true)()

/** Creates a new AttributeReference of type timestamp */
def timestamp = AttributeReference(s, TimestampType, nullable = false)()
def timestamp = AttributeReference(s, TimestampType, nullable = true)()

/** Creates a new AttributeReference of type binary */
def binary = AttributeReference(s, BinaryType, nullable = false)()
def binary = AttributeReference(s, BinaryType, nullable = true)()
}

implicit class DslAttribute(a: AttributeReference) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -222,11 +222,6 @@ abstract class Expression extends TreeNode[Expression] {
}
}

/**
* Root class for rewritten 2 operands UDF expression. By default, we assume it produces Null if
* either one of its operands is null. Exceptional case requires to update the optimization rule
* at [[optimizer.ConstantFolding ConstantFolding]]
*/
abstract class BinaryExpression extends Expression with trees.BinaryNode[Expression] {
self: Product =>

Expand All @@ -243,11 +238,6 @@ abstract class LeafExpression extends Expression with trees.LeafNode[Expression]
self: Product =>
}

/**
* Root class for rewritten single operand UDF expression. By default, we assume it produces Null
* if its operand is null. Exceptional case requires to update the optimization rule
* at [[optimizer.ConstantFolding ConstantFolding]]
*/
abstract class UnaryExpression extends Expression with trees.UnaryNode[Expression] {
self: Product =>

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,23 +41,33 @@ case class GetItem(child: Expression, ordinal: Expression) extends Expression {
override def toString = s"$child[$ordinal]"

override def eval(input: Row): Any = {
if (child.dataType.isInstanceOf[ArrayType]) {
val baseValue = child.eval(input).asInstanceOf[Seq[_]]
val o = ordinal.eval(input).asInstanceOf[Int]
if (baseValue == null) {
null
} else if (o >= baseValue.size || o < 0) {
null
} else {
baseValue(o)
}
val value = child.eval(input)
if(value == null) {
null
} else {
val baseValue = child.eval(input).asInstanceOf[Map[Any, _]]
val key = ordinal.eval(input)
if (baseValue == null) {
if(key == null) {
null
} else {
baseValue.get(key).orNull
if (child.dataType.isInstanceOf[ArrayType]) {
val baseValue = value.asInstanceOf[Seq[_]]
val o = key.asInstanceOf[Int]
if (baseValue == null) {
null
} else if (o >= baseValue.size || o < 0) {
null
} else {
baseValue(o)
}
} else {
val baseValue = value.asInstanceOf[Map[Any, _]]
val key = ordinal.eval(input)
if (baseValue == null) {
null
} else {
baseValue.get(key).orNull
}
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,11 +96,11 @@ object NullPropagation extends Rule[LogicalPlan] {
case q: LogicalPlan => q transformExpressionsUp {
// Skip redundant folding of literals.
case l: Literal => l
case e @ Count(Literal(null, _)) => Literal(null, e.dataType)
case e @ Count(Literal(null, _)) => Literal(0, e.dataType)
case e @ Sum(Literal(null, _)) => Literal(null, e.dataType)
case e @ Average(Literal(null, _)) => Literal(null, e.dataType)
case e @ IsNull(c @ Rand) => Literal(false, BooleanType)
case e @ IsNotNull(c @ Rand) => Literal(true, BooleanType)
case e @ IsNull(c) if c.nullable == false => Literal(false, BooleanType)
case e @ IsNotNull(c) if c.nullable == false => Literal(true, BooleanType)
case e @ GetItem(Literal(null, _), _) => Literal(null, e.dataType)
case e @ GetItem(_, Literal(null, _)) => Literal(null, e.dataType)
case e @ GetField(Literal(null, _), _) => Literal(null, e.dataType)
Expand All @@ -122,13 +122,32 @@ object NullPropagation extends Rule[LogicalPlan] {
case Literal(candidate, _) if(candidate == v) => true
case _ => false
})) => Literal(true, BooleanType)
// Put exceptional cases(Unary & Binary Expression if it doesn't produce null with constant
// null operand) before here.
case e: UnaryExpression => e.child match {
case e: UnaryMinus => e.child match {
case Literal(null, _) => Literal(null, e.dataType)
case _ => e
}
case e: BinaryExpression => e.children match {
case e: Cast => e.child match {
case Literal(null, _) => Literal(null, e.dataType)
case _ => e
}
case e: Not => e.child match {
case Literal(null, _) => Literal(null, e.dataType)
case _ => e
}
case e: And => e // leave it for BooleanSimplification
case e: Or => e // leave it for BooleanSimplification
// Put exceptional cases above
case e: BinaryArithmetic => e.children match {
case Literal(null, _) :: right :: Nil => Literal(null, e.dataType)
case left :: Literal(null, _) :: Nil => Literal(null, e.dataType)
case _ => e
}
case e: BinaryPredicate => e.children match {
case Literal(null, _) :: right :: Nil => Literal(null, e.dataType)
case left :: Literal(null, _) :: Nil => Literal(null, e.dataType)
case _ => e
}
case e: StringRegexExpression => e.children match {
case Literal(null, _) :: right :: Nil => Literal(null, e.dataType)
case left :: Literal(null, _) :: Nil => Literal(null, e.dataType)
case _ => e
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ class ExpressionEvaluationSuite extends FunSuite {

test("LIKE literal Regular Expression") {
checkEvaluation(Literal(null, StringType).like("a"), null)
checkEvaluation(Literal("a", StringType).like(Literal(null, StringType)), null)
checkEvaluation(Literal(null, StringType).like(Literal(null, StringType)), null)
checkEvaluation("abdef" like "abdef", true)
checkEvaluation("a_%b" like "a\\__b", true)
Expand Down Expand Up @@ -157,9 +158,14 @@ class ExpressionEvaluationSuite extends FunSuite {
checkEvaluation("abc" like regEx, true, new GenericRow(Array[Any]("a%")))
checkEvaluation("abc" like regEx, false, new GenericRow(Array[Any]("b%")))
checkEvaluation("abc" like regEx, false, new GenericRow(Array[Any]("bc%")))

checkEvaluation(Literal(null, StringType) like regEx, null, new GenericRow(Array[Any]("bc%")))
}

test("RLIKE literal Regular Expression") {
checkEvaluation(Literal(null, StringType) rlike "abdef", null)
checkEvaluation("abdef" rlike Literal(null, StringType), null)
checkEvaluation(Literal(null, StringType) rlike Literal(null, StringType), null)
checkEvaluation("abdef" rlike "abdef", true)
checkEvaluation("abbbbc" rlike "a.*c", true)

Expand Down Expand Up @@ -244,17 +250,19 @@ class ExpressionEvaluationSuite extends FunSuite {

intercept[Exception] {evaluate(Literal(1) cast BinaryType, null)}

assert(("abcdef" cast StringType).nullable === false)
assert(("abcdef" cast BinaryType).nullable === false)
assert(("abcdef" cast BooleanType).nullable === false)
assert(("abcdef" cast TimestampType).nullable === true)
assert(("abcdef" cast LongType).nullable === true)
assert(("abcdef" cast IntegerType).nullable === true)
assert(("abcdef" cast ShortType).nullable === true)
assert(("abcdef" cast ByteType).nullable === true)
assert(("abcdef" cast DecimalType).nullable === true)
assert(("abcdef" cast DoubleType).nullable === true)
assert(("abcdef" cast FloatType).nullable === true)
checkEvaluation(("abcdef" cast StringType).nullable, false)
checkEvaluation(("abcdef" cast BinaryType).nullable,false)
checkEvaluation(("abcdef" cast BooleanType).nullable, false)
checkEvaluation(("abcdef" cast TimestampType).nullable, true)
checkEvaluation(("abcdef" cast LongType).nullable, true)
checkEvaluation(("abcdef" cast IntegerType).nullable, true)
checkEvaluation(("abcdef" cast ShortType).nullable, true)
checkEvaluation(("abcdef" cast ByteType).nullable, true)
checkEvaluation(("abcdef" cast DecimalType).nullable, true)
checkEvaluation(("abcdef" cast DoubleType).nullable, true)
checkEvaluation(("abcdef" cast FloatType).nullable, true)

checkEvaluation(Cast(Literal(null, IntegerType), ShortType), null)
}

test("timestamp") {
Expand Down Expand Up @@ -285,5 +293,108 @@ class ExpressionEvaluationSuite extends FunSuite {
// A test for higher precision than millis
checkEvaluation(Cast(Cast(0.00000001, TimestampType), DoubleType), 0.00000001)
}

test("null checking") {
val row = new GenericRow(Array[Any]("^Ba*n", null, true, null))
val c1 = 'a.string.at(0)
val c2 = 'a.string.at(1)
val c3 = 'a.boolean.at(2)
val c4 = 'a.boolean.at(3)

checkEvaluation(IsNull(c1), false, row)
checkEvaluation(IsNotNull(c1), true, row)

checkEvaluation(IsNull(c2), true, row)
checkEvaluation(IsNotNull(c2), false, row)

checkEvaluation(IsNull(Literal(1, ShortType)), false)
checkEvaluation(IsNotNull(Literal(1, ShortType)), true)

checkEvaluation(IsNull(Literal(null, ShortType)), true)
checkEvaluation(IsNotNull(Literal(null, ShortType)), false)

checkEvaluation(Coalesce(c1 :: c2 :: Nil), "^Ba*n", row)
checkEvaluation(Coalesce(Literal(null, StringType) :: Nil), null, row)
checkEvaluation(Coalesce(Literal(null, StringType) :: c1 :: c2 :: Nil), "^Ba*n", row)

checkEvaluation(If(c3, Literal("a", StringType), Literal("b", StringType)), "a", row)
checkEvaluation(If(c3, c1, c2), "^Ba*n", row)
checkEvaluation(If(c4, c2, c1), "^Ba*n", row)
checkEvaluation(If(Literal(null, BooleanType), c2, c1), "^Ba*n", row)
checkEvaluation(If(Literal(true, BooleanType), c1, c2), "^Ba*n", row)
checkEvaluation(If(Literal(false, BooleanType), c2, c1), "^Ba*n", row)
checkEvaluation(If(Literal(false, BooleanType),
Literal("a", StringType), Literal("b", StringType)), "b", row)

checkEvaluation(In(c1, c1 :: c2 :: Nil), true, row)
checkEvaluation(In(Literal("^Ba*n", StringType),
Literal("^Ba*n", StringType) :: Nil), true, row)
checkEvaluation(In(Literal("^Ba*n", StringType),
Literal("^Ba*n", StringType) :: c2 :: Nil), true, row)
}

test("complex type") {
val row = new GenericRow(Array[Any](
"^Ba*n", // 0
null.asInstanceOf[String], // 1
new GenericRow(Array[Any]("aa", "bb")), // 2
Map("aa"->"bb"), // 3
Seq("aa", "bb") // 4
))

val typeS = StructType(
StructField("a", StringType, true) :: StructField("b", StringType, true) :: Nil
)
val typeMap = MapType(StringType, StringType)
val typeArray = ArrayType(StringType)

checkEvaluation(GetItem(BoundReference(3, AttributeReference("c", typeMap)()),
Literal("aa")), "bb", row)
checkEvaluation(GetItem(Literal(null, typeMap), Literal("aa")), null, row)
checkEvaluation(GetItem(Literal(null, typeMap), Literal(null, StringType)), null, row)
checkEvaluation(GetItem(BoundReference(3, AttributeReference("c", typeMap)()),
Literal(null, StringType)), null, row)

checkEvaluation(GetItem(BoundReference(4, AttributeReference("c", typeArray)()),
Literal(1)), "bb", row)
checkEvaluation(GetItem(Literal(null, typeArray), Literal(1)), null, row)
checkEvaluation(GetItem(Literal(null, typeArray), Literal(null, IntegerType)), null, row)
checkEvaluation(GetItem(BoundReference(4, AttributeReference("c", typeArray)()),
Literal(null, IntegerType)), null, row)

checkEvaluation(GetField(BoundReference(2, AttributeReference("c", typeS)()), "a"), "aa", row)
checkEvaluation(GetField(Literal(null, typeS), "a"), null, row)
}

test("arithmetic") {
val row = new GenericRow(Array[Any](1, 2, 3, null))
val c1 = 'a.int.at(0)
val c2 = 'a.int.at(1)
val c3 = 'a.int.at(2)
val c4 = 'a.int.at(3)

checkEvaluation(UnaryMinus(c1), -1, row)
checkEvaluation(UnaryMinus(Literal(100, IntegerType)), -100)

checkEvaluation(Add(c1, c4), null, row)
checkEvaluation(Add(c1, c2), 3, row)
checkEvaluation(Add(c1, Literal(null, IntegerType)), null, row)
checkEvaluation(Add(Literal(null, IntegerType), c2), null, row)
checkEvaluation(Add(Literal(null, IntegerType), Literal(null, IntegerType)), null, row)
}

test("BinaryComparison") {
val row = new GenericRow(Array[Any](1, 2, 3, null))
val c1 = 'a.int.at(0)
val c2 = 'a.int.at(1)
val c3 = 'a.int.at(2)
val c4 = 'a.int.at(3)

checkEvaluation(LessThan(c1, c4), null, row)
checkEvaluation(LessThan(c1, c2), true, row)
checkEvaluation(LessThan(c1, Literal(null, IntegerType)), null, row)
checkEvaluation(LessThan(Literal(null, IntegerType), c2), null, row)
checkEvaluation(LessThan(Literal(null, IntegerType), Literal(null, IntegerType)), null, row)
}
}

0 comments on commit 27ea3d7

Please sign in to comment.