diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 6ee479939d25c..d111578530506 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -98,13 +98,19 @@ case class And(left: Expression, right: Expression) extends BinaryPredicate { override def eval(input: Row): Any = { val l = left.eval(input) - val r = right.eval(input) - if (l == false || r == false) { - false - } else if (l == null || r == null ) { - null + if (l == false) { + false } else { - true + val r = right.eval(input) + if (r == false) { + false + } else { + if (l != null && r != null) { + true + } else { + null + } + } } } } @@ -114,13 +120,19 @@ case class Or(left: Expression, right: Expression) extends BinaryPredicate { override def eval(input: Row): Any = { val l = left.eval(input) - val r = right.eval(input) - if (l == true || r == true) { + if (l == true) { true - } else if (l == null || r == null) { - null } else { - false + val r = right.eval(input) + if (r == true) { + true + } else { + if (l != null && r != null) { + false + } else { + null + } + } } } } @@ -133,8 +145,12 @@ case class Equals(left: Expression, right: Expression) extends BinaryComparison def symbol = "=" override def eval(input: Row): Any = { val l = left.eval(input) - val r = right.eval(input) - if (l == null || r == null) null else l == r + if (l == null) { + null + } else { + val r = right.eval(input) + if (r == null) null else l == r + } } } @@ -162,7 +178,7 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi extends Expression { def children = predicate :: trueValue :: falseValue :: Nil - def nullable = trueValue.nullable || falseValue.nullable + override def nullable = trueValue.nullable || falseValue.nullable def references = children.flatMap(_.references).toSet override lazy val resolved = childrenResolved && trueValue.dataType == falseValue.dataType def dataType = { @@ -175,8 +191,9 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi } type EvaluatedType = Any + override def eval(input: Row): Any = { - if (predicate.eval(input).asInstanceOf[Boolean]) { + if (true == predicate.eval(input)) { trueValue.eval(input) } else { falseValue.eval(input) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala index d50e2c65b7b36..572902042337f 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala @@ -248,17 +248,31 @@ private[hive] case class HiveGenericUdf(name: String, children: Seq[Expression]) isUDFDeterministic && children.foldLeft(true)((prev, n) => prev && n.foldable) } + protected lazy val deferedObjects = Array.fill[DeferredObject](children.length)({ + new DeferredObjectAdapter + }) + + // Adapter from Catalyst ExpressionResult to Hive DeferredObject + class DeferredObjectAdapter extends DeferredObject { + private var func: () => Any = _ + def set(func: () => Any) { + this.func = func + } + override def prepare(i: Int) = {} + override def get(): AnyRef = wrap(func()) + } + val dataType: DataType = inspectorToDataType(returnInspector) override def eval(input: Row): Any = { returnInspector // Make sure initialized. - val args = children.map { v => - new DeferredObject { - override def prepare(i: Int) = {} - override def get(): AnyRef = wrap(v.eval(input)) - } - }.toArray - unwrap(function.evaluate(args)) + var i = 0 + while (i < children.length) { + val idx = i + deferedObjects(i).asInstanceOf[DeferredObjectAdapter].set(() => {children(idx).eval(input)}) + i += 1 + } + unwrap(function.evaluate(deferedObjects)) } }