From f42c732febd5fa720ab72dfa0fc442c847012b8c Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Sat, 6 Jun 2015 00:23:31 -0700 Subject: [PATCH] improve coverage and tests --- .../catalyst/expressions/BoundAttribute.scala | 4 +- .../sql/catalyst/expressions/Expression.scala | 40 ++--- .../sql/catalyst/expressions/arithmetic.scala | 96 +++++++----- .../expressions/codegen/CodeGenerator.scala | 14 +- .../codegen/GenerateMutableProjection.scala | 4 +- .../codegen/GenerateOrdering.scala | 16 +- .../codegen/GeneratePredicate.scala | 2 +- .../codegen/GenerateProjection.scala | 8 +- .../expressions/decimalFunctions.scala | 12 +- .../sql/catalyst/expressions/literals.scala | 44 ++++-- .../expressions/mathfuncs/binary.scala | 24 ++- .../expressions/mathfuncs/unary.scala | 30 +++- .../expressions/namedExpressions.scala | 6 +- .../catalyst/expressions/nullFunctions.scala | 26 ++-- .../sql/catalyst/expressions/predicates.scala | 139 ++++++++++++++---- .../spark/sql/catalyst/expressions/sets.scala | 20 +-- .../expressions/stringOperations.scala | 18 +++ .../ExpressionEvaluationSuite.scala | 86 ++++++++++- .../GeneratedEvaluationSuite.scala | 27 +--- .../GeneratedMutableEvaluationSuite.scala | 61 -------- 20 files changed, 440 insertions(+), 237 deletions(-) delete mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedMutableEvaluationSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index 478ee997a96a2..00fd7294a8966 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -45,8 +45,8 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { s""" - final boolean ${ev.nullTerm} = i.isNullAt($ordinal); - final ${ctx.primitiveType(dataType)} ${ev.primitiveTerm} = ${ev.nullTerm} ? + boolean ${ev.isNull} = i.isNullAt($ordinal); + ${ctx.primitiveType(dataType)} ${ev.primitive} = ${ev.isNull} ? ${ctx.defaultValue(dataType)} : (${ctx.getColumn(dataType, ordinal)}); """ } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 6866b1182e0da..87f864a7c0d9c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -60,9 +60,9 @@ abstract class Expression extends TreeNode[Expression] { * @return [[GeneratedExpressionCode]] */ def gen(ctx: CodeGenContext): GeneratedExpressionCode = { - val nullTerm = ctx.freshName("nullTerm") - val primitiveTerm = ctx.freshName("primitiveTerm") - val ve = GeneratedExpressionCode("", nullTerm, primitiveTerm) + val isNull = ctx.freshName("isNull") + val primitive = ctx.freshName("primitive") + val ve = GeneratedExpressionCode("", isNull, primitive) ve.code = genCode(ctx, ve) ve } @@ -82,11 +82,11 @@ abstract class Expression extends TreeNode[Expression] { val objectTerm = ctx.freshName("obj") s""" /* expression: ${this} */ - final Object ${objectTerm} = expressions[${ctx.references.size - 1}].eval(i); - final boolean ${ev.nullTerm} = ${objectTerm} == null; - ${ctx.primitiveType(e.dataType)} ${ev.primitiveTerm} = ${ctx.defaultValue(e.dataType)}; - if (!${ev.nullTerm}) { - ${ev.primitiveTerm} = (${ctx.boxedType(e.dataType)})${objectTerm}; + Object ${objectTerm} = expressions[${ctx.references.size - 1}].eval(i); + boolean ${ev.isNull} = ${objectTerm} == null; + ${ctx.primitiveType(e.dataType)} ${ev.primitive} = ${ctx.defaultValue(e.dataType)}; + if (!${ev.isNull}) { + ${ev.primitive} = (${ctx.boxedType(e.dataType)})${objectTerm}; } """ } @@ -175,18 +175,18 @@ abstract class BinaryExpression extends Expression with trees.BinaryNode[Express val eval1 = left.gen(ctx) val eval2 = right.gen(ctx) - val resultCode = f(eval1.primitiveTerm, eval2.primitiveTerm) + val resultCode = f(eval1.primitive, eval2.primitive) s""" ${eval1.code} - boolean ${ev.nullTerm} = ${eval1.nullTerm}; - ${ctx.primitiveType(dataType)} ${ev.primitiveTerm} = ${ctx.defaultValue(dataType)}; - if (!${ev.nullTerm}) { + boolean ${ev.isNull} = ${eval1.isNull}; + ${ctx.primitiveType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + if (!${ev.isNull}) { ${eval2.code} - if(!${eval2.nullTerm}) { - ${ev.primitiveTerm} = $resultCode; + if(!${eval2.isNull}) { + ${ev.primitive} = $resultCode; } else { - ${ev.nullTerm} = true; + ${ev.isNull} = true; } } """ @@ -216,12 +216,12 @@ abstract class UnaryExpression extends Expression with trees.UnaryNode[Expressio ev: GeneratedExpressionCode, f: Term => Code): Code = { val eval = child.gen(ctx) - // reuse the previous nullTerm - ev.nullTerm = eval.nullTerm + // reuse the previous isNull + ev.isNull = eval.isNull eval.code + s""" - ${ctx.primitiveType(dataType)} ${ev.primitiveTerm} = ${ctx.defaultValue(dataType)}; - if (!${ev.nullTerm}) { - ${ev.primitiveTerm} = ${f(eval.primitiveTerm)}; + ${ctx.primitiveType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + if (!${ev.isNull}) { + ${ev.primitive} = ${f(eval.primitive)}; } """ } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 0923ab6f59564..c161a514fcd4e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -50,6 +50,11 @@ case class UnaryMinus(child: Expression) extends UnaryArithmetic { private lazy val numeric = TypeUtils.getNumeric(dataType) + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = dataType match { + case dt: DecimalType => defineCodeGen(ctx, ev, c => s"c.unary_$$minus()") + case dt: NumericType => defineCodeGen(ctx, ev, c => s"-($c)") + } + protected override def evalInternal(evalE: Any) = numeric.negate(evalE) } @@ -68,6 +73,21 @@ case class Sqrt(child: Expression) extends UnaryArithmetic { if (value < 0) null else math.sqrt(value) } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + val eval = child.gen(ctx) + eval.code + s""" + boolean ${ev.isNull} = ${eval.isNull}; + ${ctx.primitiveType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + if (!${ev.isNull}) { + if (${eval.primitive} < 0.0) { + ${ev.isNull} = true; + } else { + ${ev.primitive} = java.lang.Math.sqrt(${eval.primitive}); + } + } + """ + } } /** @@ -216,9 +236,9 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic val eval1 = left.gen(ctx) val eval2 = right.gen(ctx) val test = if (left.dataType.isInstanceOf[DecimalType]) { - s"${eval2.primitiveTerm}.isZero()" + s"${eval2.primitive}.isZero()" } else { - s"${eval2.primitiveTerm} == 0" + s"${eval2.primitive} == 0" } val method = if (left.dataType.isInstanceOf[DecimalType]) { s".$decimalMethod" @@ -227,12 +247,12 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic } eval1.code + eval2.code + s""" - boolean ${ev.nullTerm} = false; - ${ctx.primitiveType(left.dataType)} ${ev.primitiveTerm} = ${ctx.defaultValue(left.dataType)}; - if (${eval1.nullTerm} || ${eval2.nullTerm} || $test) { - ${ev.nullTerm} = true; + boolean ${ev.isNull} = false; + ${ctx.primitiveType(left.dataType)} ${ev.primitive} = ${ctx.defaultValue(left.dataType)}; + if (${eval1.isNull} || ${eval2.isNull} || $test) { + ${ev.isNull} = true; } else { - ${ev.primitiveTerm} = ${eval1.primitiveTerm}$method(${eval2.primitiveTerm}); + ${ev.primitive} = ${eval1.primitive}$method(${eval2.primitive}); } """ } @@ -276,9 +296,9 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet val eval1 = left.gen(ctx) val eval2 = right.gen(ctx) val test = if (left.dataType.isInstanceOf[DecimalType]) { - s"${eval2.primitiveTerm}.isZero()" + s"${eval2.primitive}.isZero()" } else { - s"${eval2.primitiveTerm} == 0" + s"${eval2.primitive} == 0" } val method = if (left.dataType.isInstanceOf[DecimalType]) { s".$decimalMethod" @@ -287,12 +307,12 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet } eval1.code + eval2.code + s""" - boolean ${ev.nullTerm} = false; - ${ctx.primitiveType(left.dataType)} ${ev.primitiveTerm} = ${ctx.defaultValue(left.dataType)}; - if (${eval1.nullTerm} || ${eval2.nullTerm} || $test) { - ${ev.nullTerm} = true; + boolean ${ev.isNull} = false; + ${ctx.primitiveType(left.dataType)} ${ev.primitive} = ${ctx.defaultValue(left.dataType)}; + if (${eval1.isNull} || ${eval2.isNull} || $test) { + ${ev.isNull} = true; } else { - ${ev.primitiveTerm} = ${eval1.primitiveTerm}$method(${eval2.primitiveTerm}); + ${ev.primitive} = ${eval1.primitive}$method(${eval2.primitive}); } """ } @@ -387,6 +407,10 @@ case class BitwiseNot(child: Expression) extends UnaryArithmetic { ((evalE: Long) => ~evalE).asInstanceOf[(Any) => Any] } + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + defineCodeGen(ctx, ev, c => s"(${ctx.primitiveType(dataType)})~($c)") + } + protected override def evalInternal(evalE: Any) = not(evalE) } @@ -419,21 +443,21 @@ case class MaxOf(left: Expression, right: Expression) extends BinaryArithmetic { val eval1 = left.gen(ctx) val eval2 = right.gen(ctx) eval1.code + eval2.code + s""" - boolean ${ev.nullTerm} = false; - ${ctx.primitiveType(left.dataType)} ${ev.primitiveTerm} = + boolean ${ev.isNull} = false; + ${ctx.primitiveType(left.dataType)} ${ev.primitive} = ${ctx.defaultValue(left.dataType)}; - if (${eval1.nullTerm}) { - ${ev.nullTerm} = ${eval2.nullTerm}; - ${ev.primitiveTerm} = ${eval2.primitiveTerm}; - } else if (${eval2.nullTerm}) { - ${ev.nullTerm} = ${eval1.nullTerm}; - ${ev.primitiveTerm} = ${eval1.primitiveTerm}; + if (${eval1.isNull}) { + ${ev.isNull} = ${eval2.isNull}; + ${ev.primitive} = ${eval2.primitive}; + } else if (${eval2.isNull}) { + ${ev.isNull} = ${eval1.isNull}; + ${ev.primitive} = ${eval1.primitive}; } else { - if (${eval1.primitiveTerm} > ${eval2.primitiveTerm}) { - ${ev.primitiveTerm} = ${eval1.primitiveTerm}; + if (${eval1.primitive} > ${eval2.primitive}) { + ${ev.primitive} = ${eval1.primitive}; } else { - ${ev.primitiveTerm} = ${eval2.primitiveTerm}; + ${ev.primitive} = ${eval2.primitive}; } } """ @@ -475,21 +499,21 @@ case class MinOf(left: Expression, right: Expression) extends BinaryArithmetic { val eval2 = right.gen(ctx) eval1.code + eval2.code + s""" - boolean ${ev.nullTerm} = false; - ${ctx.primitiveType(left.dataType)} ${ev.primitiveTerm} = + boolean ${ev.isNull} = false; + ${ctx.primitiveType(left.dataType)} ${ev.primitive} = ${ctx.defaultValue(left.dataType)}; - if (${eval1.nullTerm}) { - ${ev.nullTerm} = ${eval2.nullTerm}; - ${ev.primitiveTerm} = ${eval2.primitiveTerm}; - } else if (${eval2.nullTerm}) { - ${ev.nullTerm} = ${eval1.nullTerm}; - ${ev.primitiveTerm} = ${eval1.primitiveTerm}; + if (${eval1.isNull}) { + ${ev.isNull} = ${eval2.isNull}; + ${ev.primitive} = ${eval2.primitive}; + } else if (${eval2.isNull}) { + ${ev.isNull} = ${eval1.isNull}; + ${ev.primitive} = ${eval1.primitive}; } else { - if (${eval1.primitiveTerm} < ${eval2.primitiveTerm}) { - ${ev.primitiveTerm} = ${eval1.primitiveTerm}; + if (${eval1.primitive} < ${eval2.primitive}) { + ${ev.primitive} = ${eval1.primitive}; } else { - ${ev.primitiveTerm} = ${eval2.primitiveTerm}; + ${ev.primitive} = ${eval2.primitive}; } } """ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index f6a2a2be1c89f..94b1b4808d759 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -35,12 +35,12 @@ class LongHashSet extends org.apache.spark.util.collection.OpenHashSet[Long] * Java source for evaluating an [[Expression]] given a [[Row]] of input. * * @param code The sequence of statements required to evaluate the expression. - * @param nullTerm A term that holds a boolean value representing whether the expression evaluated + * @param isNull A term that holds a boolean value representing whether the expression evaluated * to null. - * @param primitiveTerm A term for a possible primitive value of the result of the evaluation. Not - * valid if `nullTerm` is set to `true`. + * @param primitive A term for a possible primitive value of the result of the evaluation. Not + * valid if `isNull` is set to `true`. */ -case class GeneratedExpressionCode(var code: Code, var nullTerm: Term, var primitiveTerm: Term) +case class GeneratedExpressionCode(var code: Code, var isNull: Term, var primitive: Term) /** * A context for codegen, which is used to bookkeeping the expressions those are not supported @@ -149,9 +149,9 @@ class CodeGenContext { def defaultValue(dt: DataType): Term = dt match { case BooleanType => "false" case FloatType => "-1.0f" - case ShortType => "-1" - case LongType => "-1" - case ByteType => "-1" + case ShortType => "(short)-1" + case LongType => "-1L" + case ByteType => "(byte)-1" case DoubleType => "-1.0" case IntegerType => "-1" case DateType => "-1" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index 4b641701008c3..e5ee2accd8a84 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -40,10 +40,10 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu val evaluationCode = e.gen(ctx) evaluationCode.code + s""" - if(${evaluationCode.nullTerm}) + if(${evaluationCode.isNull}) mutableRow.setNullAt($i); else - mutableRow.${ctx.setColumn(e.dataType, i, evaluationCode.primitiveTerm)}; + mutableRow.${ctx.setColumn(e.dataType, i, evaluationCode.primitive)}; """ }.mkString("\n") val code = s""" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala index d3c219fddc53c..36e155d164a40 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala @@ -59,8 +59,8 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[Row]] wit case BinaryType => s""" { - byte[] x = ${if (asc) evalA.primitiveTerm else evalB.primitiveTerm}; - byte[] y = ${if (!asc) evalB.primitiveTerm else evalA.primitiveTerm}; + byte[] x = ${if (asc) evalA.primitive else evalB.primitive}; + byte[] y = ${if (!asc) evalB.primitive else evalA.primitive}; int j = 0; while (j < x.length && j < y.length) { if (x[j] != y[j]) return x[j] - y[j]; @@ -73,8 +73,8 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[Row]] wit }""" case _: NumericType => s""" - if (${evalA.primitiveTerm} != ${evalB.primitiveTerm}) { - if (${evalA.primitiveTerm} > ${evalB.primitiveTerm}) { + if (${evalA.primitive} != ${evalB.primitive}) { + if (${evalA.primitive} > ${evalB.primitive}) { return ${if (asc) "1" else "-1"}; } else { return ${if (asc) "-1" else "1"}; @@ -82,7 +82,7 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[Row]] wit }""" case _ => s""" - int comp = ${evalA.primitiveTerm}.compare(${evalB.primitiveTerm}); + int comp = ${evalA.primitive}.compare(${evalB.primitive}); if (comp != 0) { return ${if (asc) "comp" else "-comp"}; }""" @@ -93,11 +93,11 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[Row]] wit ${evalA.code} i = $b; ${evalB.code} - if (${evalA.nullTerm} && ${evalB.nullTerm}) { + if (${evalA.isNull} && ${evalB.isNull}) { // Nothing - } else if (${evalA.nullTerm}) { + } else if (${evalA.isNull}) { return ${if (order.direction == Ascending) "-1" else "1"}; - } else if (${evalB.nullTerm}) { + } else if (${evalB.isNull}) { return ${if (order.direction == Ascending) "1" else "-1"}; } else { $compare diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala index dd4474de05df9..4a547b5ce9543 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala @@ -55,7 +55,7 @@ object GeneratePredicate extends CodeGenerator[Expression, (Row) => Boolean] { @Override public boolean eval(Row i) { ${eval.code} - return !${eval.nullTerm} && ${eval.primitiveTerm}; + return !${eval.isNull} && ${eval.primitive}; } }""" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala index 00c856dc02ba1..f621c894833c9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala @@ -55,9 +55,9 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { { // column$i ${eval.code} - nullBits[$i] = ${eval.nullTerm}; - if(!${eval.nullTerm}) { - c$i = ${eval.primitiveTerm}; + nullBits[$i] = ${eval.isNull}; + if (!${eval.isNull}) { + c$i = ${eval.primitive}; } } """ @@ -122,7 +122,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { case LongType => s"$col ^ ($col >>> 32)" case FloatType => s"Float.floatToIntBits($col)" case DoubleType => - s"Double.doubleToLongBits($col) ^ (Double.doubleToLongBits($col) >>> 32)" + s"(int)(Double.doubleToLongBits($col) ^ (Double.doubleToLongBits($col) >>> 32))" case _ => s"$col.hashCode()" } s"isNullAt($i) ? 0 : ($nonNull)" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala index 21f8c812c9ce5..ddfadf314f838 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala @@ -62,13 +62,13 @@ case class MakeDecimal(child: Expression, precision: Int, scale: Int) extends Un override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { val eval = child.gen(ctx) eval.code + s""" - boolean ${ev.nullTerm} = ${eval.nullTerm}; - ${ctx.decimalType} ${ev.primitiveTerm} = null; + boolean ${ev.isNull} = ${eval.isNull}; + ${ctx.decimalType} ${ev.primitive} = null; - if (!${ev.nullTerm}) { - ${ev.primitiveTerm} = (new ${ctx.decimalType}()).setOrNull( - ${eval.primitiveTerm}, $precision, $scale); - ${ev.nullTerm} = ${ev.primitiveTerm} == null; + if (!${ev.isNull}) { + ${ev.primitive} = (new ${ctx.decimalType}()).setOrNull( + ${eval.primitive}, $precision, $scale); + ${ev.isNull} = ${ev.primitive} == null; } """ } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index e121d39e1d9b4..bce96bd3c1309 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -79,27 +79,53 @@ case class Literal protected (value: Any, dataType: DataType) extends LeafExpres override def toString: String = if (value != null) value.toString else "null" + override def equals(other: Any): Boolean = other match { + case o: Literal => + dataType.equals(o.dataType) && + (value == null && null == o.value || value != null && value.equals(o.value)) + case _ => false + } + override def eval(input: Row): Any = value override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { - // change the nullTerm and primitiveTerm to consts, to inline them + // change the isNull and primitive to consts, to inline them if (value == null) { - ev.nullTerm = "true" - ev.primitiveTerm = ctx.defaultValue(dataType) + ev.isNull = "true" + ev.primitive = ctx.defaultValue(dataType) "" } else { dataType match { case BooleanType => - ev.nullTerm = "false" - ev.primitiveTerm = value.toString + ev.isNull = "false" + ev.primitive = value.toString "" case FloatType => // This must go before NumericType - ev.nullTerm = "false" - ev.primitiveTerm = s"${value}f" + val v = value.asInstanceOf[Float] + if (v.isNaN || v.isInfinite) { + super.genCode(ctx, ev) + } else { + ev.isNull = "false" + ev.primitive = s"${value}f" + "" + } + case DoubleType => // This must go before NumericType + val v = value.asInstanceOf[Double] + if (v.isNaN || v.isInfinite) { + super.genCode(ctx, ev) + } else { + ev.isNull = "false" + ev.primitive = s"${value}" + "" + } + + case ByteType | ShortType => // This must go before NumericType + ev.isNull = "false" + ev.primitive = s"(${ctx.primitiveType(dataType)})$value" "" case dt: NumericType if !dt.isInstanceOf[DecimalType] => - ev.nullTerm = "false" - ev.primitiveTerm = value.toString + ev.isNull = "false" + ev.primitive = value.toString "" // eval() version may be faster for non-primitive types case other => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/binary.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/binary.scala index db853a2b97fad..88211acd7713c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/binary.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/binary.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.expressions.mathfuncs +import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, BinaryExpression, Expression, Row} import org.apache.spark.sql.types._ @@ -49,6 +50,10 @@ abstract class BinaryMathExpression(f: (Double, Double) => Double, name: String) } } } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.${name.toLowerCase}($c1, $c2)") + } } case class Atan2(left: Expression, right: Expression) @@ -70,9 +75,26 @@ case class Atan2(left: Expression, right: Expression) } } } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.atan2($c1 + 0.0, $c2 + 0.0)") + s""" + if (Double.valueOf(${ev.primitive}).isNaN()) { + ${ev.isNull} = true; + } + """ + } } case class Hypot(left: Expression, right: Expression) extends BinaryMathExpression(math.hypot, "HYPOT") -case class Pow(left: Expression, right: Expression) extends BinaryMathExpression(math.pow, "POWER") +case class Pow(left: Expression, right: Expression) + extends BinaryMathExpression(math.pow, "POWER") { + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.pow($c1, $c2)") + s""" + if (Double.valueOf(${ev.primitive}).isNaN()) { + ${ev.isNull} = true; + } + """ + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/unary.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/unary.scala index 41b422346a02d..ad49c376e981e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/unary.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/unary.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.expressions.mathfuncs +import org.apache.spark.sql.catalyst.expressions.codegen.{Code, CodeGenContext, GeneratedExpressionCode} import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, Row, UnaryExpression} import org.apache.spark.sql.types._ @@ -44,6 +45,23 @@ abstract class UnaryMathExpression(f: Double => Double, name: String) if (result.isNaN) null else result } } + + // name of function in java.lang.Math + def funcName: String = name.toLowerCase + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + val eval = child.gen(ctx) + eval.code + s""" + boolean ${ev.isNull} = ${eval.isNull}; + ${ctx.primitiveType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + if (!${ev.isNull}) { + ${ev.primitive} = java.lang.Math.${funcName}(${eval.primitive}); + if (Double.valueOf(${ev.primitive}).isNaN()) { + ${ev.isNull} = true; + } + } + """ + } } case class Acos(child: Expression) extends UnaryMathExpression(math.acos, "ACOS") @@ -72,7 +90,9 @@ case class Log10(child: Expression) extends UnaryMathExpression(math.log10, "LOG case class Log1p(child: Expression) extends UnaryMathExpression(math.log1p, "LOG1P") -case class Rint(child: Expression) extends UnaryMathExpression(math.rint, "ROUND") +case class Rint(child: Expression) extends UnaryMathExpression(math.rint, "ROUND") { + override def funcName: String = "rint" +} case class Signum(child: Expression) extends UnaryMathExpression(math.signum, "SIGNUM") @@ -84,6 +104,10 @@ case class Tan(child: Expression) extends UnaryMathExpression(math.tan, "TAN") case class Tanh(child: Expression) extends UnaryMathExpression(math.tanh, "TANH") -case class ToDegrees(child: Expression) extends UnaryMathExpression(math.toDegrees, "DEGREES") +case class ToDegrees(child: Expression) extends UnaryMathExpression(math.toDegrees, "DEGREES") { + override def funcName: String = "toDegrees" +} -case class ToRadians(child: Expression) extends UnaryMathExpression(math.toRadians, "RADIANS") +case class ToRadians(child: Expression) extends UnaryMathExpression(math.toRadians, "RADIANS") { + override def funcName: String = "toRadians" +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index 00565ec651a59..2e4b9ba678433 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -17,10 +17,10 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.catalyst.trees import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.errors.TreeNodeException -import org.apache.spark.sql.catalyst.trees.LeafNode +import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode} +import org.apache.spark.sql.catalyst.trees import org.apache.spark.sql.types._ object NamedExpression { @@ -116,6 +116,8 @@ case class Alias(child: Expression, name: String)( override def eval(input: Row): Any = child.eval(input) + override def gen(ctx: CodeGenContext): GeneratedExpressionCode = child.gen(ctx) + override def dataType: DataType = child.dataType override def nullable: Boolean = child.nullable override def metadata: Metadata = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala index e3c3489d11aea..ea216b1d0d9f4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala @@ -55,17 +55,17 @@ case class Coalesce(children: Seq[Expression]) extends Expression { override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { s""" - boolean ${ev.nullTerm} = true; - ${ctx.primitiveType(dataType)} ${ev.primitiveTerm} = ${ctx.defaultValue(dataType)}; + boolean ${ev.isNull} = true; + ${ctx.primitiveType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; """ + children.map { e => val eval = e.gen(ctx) s""" - if (${ev.nullTerm}) { + if (${ev.isNull}) { ${eval.code} - if (!${eval.nullTerm}) { - ${ev.nullTerm} = false; - ${ev.primitiveTerm} = ${eval.primitiveTerm}; + if (!${eval.isNull}) { + ${ev.isNull} = false; + ${ev.primitive} = ${eval.primitive}; } } """ @@ -83,8 +83,8 @@ case class IsNull(child: Expression) extends Predicate with trees.UnaryNode[Expr override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { val eval = child.gen(ctx) - ev.nullTerm = "false" - ev.primitiveTerm = eval.nullTerm + ev.isNull = "false" + ev.primitive = eval.isNull eval.code } @@ -102,8 +102,8 @@ case class IsNotNull(child: Expression) extends Predicate with trees.UnaryNode[E override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { val eval = child.gen(ctx) - ev.nullTerm = "false" - ev.primitiveTerm = s"(!(${eval.nullTerm}))" + ev.isNull = "false" + ev.primitive = s"(!(${eval.isNull}))" eval.code } } @@ -137,7 +137,7 @@ case class AtLeastNNonNulls(n: Int, children: Seq[Expression]) extends Predicate s""" if ($nonnull < $n) { ${eval.code} - if (!${eval.nullTerm}) { + if (!${eval.isNull}) { $nonnull += 1; } } @@ -146,8 +146,8 @@ case class AtLeastNNonNulls(n: Int, children: Seq[Expression]) extends Predicate s""" int $nonnull = 0; $code - boolean ${ev.nullTerm} = false; - boolean ${ev.primitiveTerm} = $nonnull >= $n; + boolean ${ev.isNull} = false; + boolean ${ev.primitive} = $nonnull >= $n; """ } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index d69324acf0e5a..75af8d71dbd31 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -154,17 +154,17 @@ case class And(left: Expression, right: Expression) // The result should be `false`, if any of them is `false` whenever the other is null or not. s""" ${eval1.code} - boolean ${ev.nullTerm} = false; - boolean ${ev.primitiveTerm} = false; + boolean ${ev.isNull} = false; + boolean ${ev.primitive} = false; - if (!${eval1.nullTerm} && !${eval1.primitiveTerm}) { + if (!${eval1.isNull} && !${eval1.primitive}) { } else { ${eval2.code} - if (!${eval2.nullTerm} && !${eval2.primitiveTerm}) { - } else if (!${eval1.nullTerm} && !${eval2.nullTerm}) { - ${ev.primitiveTerm} = true; + if (!${eval2.isNull} && !${eval2.primitive}) { + } else if (!${eval1.isNull} && !${eval2.isNull}) { + ${ev.primitive} = true; } else { - ${ev.nullTerm} = true; + ${ev.isNull} = true; } } """ @@ -203,17 +203,17 @@ case class Or(left: Expression, right: Expression) // The result should be `true`, if any of them is `true` whenever the other is null or not. s""" ${eval1.code} - boolean ${ev.nullTerm} = false; - boolean ${ev.primitiveTerm} = true; + boolean ${ev.isNull} = false; + boolean ${ev.primitive} = true; - if (!${eval1.nullTerm} && ${eval1.primitiveTerm}) { + if (!${eval1.isNull} && ${eval1.primitive}) { } else { ${eval2.code} - if (!${eval2.nullTerm} && ${eval2.primitiveTerm}) { - } else if (!${eval1.nullTerm} && !${eval2.nullTerm}) { - ${ev.primitiveTerm} = false; + if (!${eval2.isNull} && ${eval2.primitive}) { + } else if (!${eval1.isNull} && !${eval2.isNull}) { + ${ev.primitive} = false; } else { - ${ev.nullTerm} = true; + ${ev.isNull} = true; } } """ @@ -308,11 +308,11 @@ case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComp override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { val eval1 = left.gen(ctx) val eval2 = right.gen(ctx) - val equalCode = ctx.equalFunc(left.dataType)(eval1.primitiveTerm, eval2.primitiveTerm) - ev.nullTerm = "false" + val equalCode = ctx.equalFunc(left.dataType)(eval1.primitive, eval2.primitive) + ev.isNull = "false" eval1.code + eval2.code + s""" - final boolean ${ev.primitiveTerm} = (${eval1.nullTerm} && ${eval2.nullTerm}) || - (!${eval1.nullTerm} && $equalCode); + boolean ${ev.primitive} = (${eval1.isNull} && ${eval2.isNull}) || + (!${eval1.isNull} && $equalCode); """ } } @@ -388,6 +388,7 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi falseValue.eval(input) } } + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { val condEval = predicate.gen(ctx) val trueEval = trueValue.gen(ctx) @@ -395,16 +396,16 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi s""" ${condEval.code} - boolean ${ev.nullTerm} = false; - ${ctx.primitiveType(dataType)} ${ev.primitiveTerm} = ${ctx.defaultValue(dataType)}; - if (!${condEval.nullTerm} && ${condEval.primitiveTerm}) { + boolean ${ev.isNull} = false; + ${ctx.primitiveType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + if (!${condEval.isNull} && ${condEval.primitive}) { ${trueEval.code} - ${ev.nullTerm} = ${trueEval.nullTerm}; - ${ev.primitiveTerm} = ${trueEval.primitiveTerm}; + ${ev.isNull} = ${trueEval.isNull}; + ${ev.primitive} = ${trueEval.primitive}; } else { ${falseEval.code} - ${ev.nullTerm} = ${falseEval.nullTerm}; - ${ev.primitiveTerm} = ${falseEval.primitiveTerm}; + ${ev.isNull} = ${falseEval.isNull}; + ${ev.primitive} = ${falseEval.primitive}; } """ } @@ -493,6 +494,48 @@ case class CaseWhen(branches: Seq[Expression]) extends CaseWhenLike { return res } + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + val len = branchesArr.length + val got = ctx.freshName("got") + + val cases = (0 until len/2).map { i => + val cond = branchesArr(i * 2).gen(ctx) + val res = branchesArr(i * 2 + 1).gen(ctx) + s""" + if (!$got) { + ${cond.code} + if (!${cond.isNull} && ${cond.primitive}) { + $got = true; + ${res.code} + ${ev.isNull} = ${res.isNull}; + ${ev.primitive} = ${res.primitive}; + } + } + """ + }.mkString("\n") + + val other = if (len % 2 == 1) { + val res = branchesArr(len - 1).gen(ctx) + s""" + if (!$got) { + ${res.code} + ${ev.isNull} = ${res.isNull}; + ${ev.primitive} = ${res.primitive}; + } + """ + } else { + "" + } + + s""" + boolean $got = false; + boolean ${ev.isNull} = true; + ${ctx.primitiveType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + $cases + $other + """ + } + override def toString: String = { "CASE" + branches.sliding(2, 2).map { case Seq(cond, value) => s" WHEN $cond THEN $value" @@ -544,6 +587,52 @@ case class CaseKeyWhen(key: Expression, branches: Seq[Expression]) extends CaseW return res } + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + val keyEval = key.gen(ctx) + val len = branchesArr.length + val got = ctx.freshName("got") + + val cases = (0 until len/2).map { i => + val cond = branchesArr(i * 2).gen(ctx) + val res = branchesArr(i * 2 + 1).gen(ctx) + s""" + if (!$got) { + ${cond.code} + if (${keyEval.isNull} && ${cond.isNull} || + !${keyEval.isNull} && !${cond.isNull} + && ${ctx.equalFunc(key.dataType)(keyEval.primitive, cond.primitive)}) { + $got = true; + ${res.code} + ${ev.isNull} = ${res.isNull}; + ${ev.primitive} = ${res.primitive}; + } + } + """ + }.mkString("\n") + + val other = if (len % 2 == 1) { + val res = branchesArr(len - 1).gen(ctx) + s""" + if (!$got) { + ${res.code} + ${ev.isNull} = ${res.isNull}; + ${ev.primitive} = ${res.primitive}; + } + """ + } else { + "" + } + + s""" + boolean $got = false; + boolean ${ev.isNull} = true; + ${ctx.primitiveType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + ${keyEval.code} + $cases + $other + """ + } + private def equalNullSafe(l: Any, r: Any) = { if (l == null && r == null) { true diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala index 40107c5985481..1038e7a653358 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala @@ -64,9 +64,9 @@ case class NewSet(elementType: DataType) extends LeafExpression { override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { elementType match { case IntegerType | LongType => - ev.nullTerm = "false" + ev.isNull = "false" s""" - ${ctx.primitiveType(dataType)} ${ev.primitiveTerm} = new ${ctx.primitiveType(dataType)}(); + ${ctx.primitiveType(dataType)} ${ev.primitive} = new ${ctx.primitiveType(dataType)}(); """ case _ => super.genCode(ctx, ev) } @@ -111,11 +111,11 @@ case class AddItemToSet(item: Expression, set: Expression) extends Expression { val setEval = set.gen(ctx) val htype = ctx.primitiveType(dataType) - ev.nullTerm = "false" - ev.primitiveTerm = setEval.primitiveTerm + ev.isNull = "false" + ev.primitive = setEval.primitive itemEval.code + setEval.code + s""" - if (!${itemEval.nullTerm} && !${setEval.nullTerm}) { - (($htype)${setEval.primitiveTerm}).add(${itemEval.primitiveTerm}); + if (!${itemEval.isNull} && !${setEval.isNull}) { + (($htype)${setEval.primitive}).add(${itemEval.primitive}); } """ case _ => super.genCode(ctx, ev) @@ -162,11 +162,11 @@ case class CombineSets(left: Expression, right: Expression) extends BinaryExpres val rightEval = right.gen(ctx) val htype = ctx.primitiveType(dataType) - ev.nullTerm = leftEval.nullTerm - ev.primitiveTerm = leftEval.primitiveTerm + ev.isNull = leftEval.isNull + ev.primitive = leftEval.primitive leftEval.code + rightEval.code + s""" - if (!${leftEval.nullTerm} && !${rightEval.nullTerm}) { - ${leftEval.primitiveTerm}.union((${htype})${rightEval.primitiveTerm}); + if (!${leftEval.isNull} && !${rightEval.isNull}) { + ${leftEval.primitive}.union((${htype})${rightEval.primitive}); } """ case _ => super.genCode(ctx, ev) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index c4ef9c30907f1..78adb509b470b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import java.util.regex.Pattern import org.apache.spark.sql.catalyst.analysis.UnresolvedException +import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.types._ trait StringRegexExpression extends ExpectsInputTypes { @@ -137,6 +138,10 @@ case class Upper(child: Expression) extends UnaryExpression with CaseConversionE override def convert(v: UTF8String): UTF8String = v.toUpperCase() override def toString: String = s"Upper($child)" + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + defineCodeGen(ctx, ev, c => s"($c).toUpperCase()") + } } /** @@ -147,6 +152,10 @@ case class Lower(child: Expression) extends UnaryExpression with CaseConversionE override def convert(v: UTF8String): UTF8String = v.toLowerCase() override def toString: String = s"Lower($child)" + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + defineCodeGen(ctx, ev, c => s"($c).toLowerCase()") + } } /** A base trait for functions that compare two strings, returning a boolean. */ @@ -181,6 +190,9 @@ trait StringComparison extends ExpectsInputTypes { case class Contains(left: Expression, right: Expression) extends BinaryExpression with Predicate with StringComparison { override def compare(l: UTF8String, r: UTF8String): Boolean = l.contains(r) + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + defineCodeGen(ctx, ev, (c1, c2) => s"($c1).contains($c2)") + } } /** @@ -189,6 +201,9 @@ case class Contains(left: Expression, right: Expression) case class StartsWith(left: Expression, right: Expression) extends BinaryExpression with Predicate with StringComparison { override def compare(l: UTF8String, r: UTF8String): Boolean = l.startsWith(r) + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + defineCodeGen(ctx, ev, (c1, c2) => s"($c1).startsWith($c2)") + } } /** @@ -197,6 +212,9 @@ case class StartsWith(left: Expression, right: Expression) case class EndsWith(left: Expression, right: Expression) extends BinaryExpression with Predicate with StringComparison { override def compare(l: UTF8String, r: UTF8String): Boolean = l.endsWith(r) + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + defineCodeGen(ctx, ev, (c1, c2) => s"($c1).endsWith($c2)") + } } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala index 5df528770ca6e..bc29f80dede19 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala @@ -28,6 +28,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.CatalystTypeConverters import org.apache.spark.sql.catalyst.analysis.UnresolvedExtractValue import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateProjection, GenerateMutableProjection} import org.apache.spark.sql.catalyst.expressions.mathfuncs._ import org.apache.spark.sql.catalyst.util.DateUtils import org.apache.spark.sql.types._ @@ -35,11 +36,20 @@ import org.apache.spark.sql.types._ class ExpressionEvaluationBaseSuite extends SparkFunSuite { + def checkEvaluation(expression: Expression, expected: Any, inputRow: Row = EmptyRow): Unit = { + checkEvaluationWithoutCodegen(expression, expected, inputRow) + checkEvaluationWithGeneratedMutableProjection(expression, expected, inputRow) + checkEvaluationWithGeneratedProjection(expression, expected, inputRow) + } + def evaluate(expression: Expression, inputRow: Row = EmptyRow): Any = { expression.eval(inputRow) } - def checkEvaluation(expression: Expression, expected: Any, inputRow: Row = EmptyRow): Unit = { + def checkEvaluationWithoutCodegen( + expression: Expression, + expected: Any, + inputRow: Row = EmptyRow): Unit = { val actual = try evaluate(expression, inputRow) catch { case e: Exception => fail(s"Exception evaluating $expression", e) } @@ -49,6 +59,68 @@ class ExpressionEvaluationBaseSuite extends SparkFunSuite { } } + def checkEvaluationWithGeneratedMutableProjection( + expression: Expression, + expected: Any, + inputRow: Row = EmptyRow): Unit = { + + val plan = try { + GenerateMutableProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil)() + } catch { + case e: Throwable => + val ctx = GenerateProjection.newCodeGenContext() + val evaluated = expression.gen(ctx) + fail( + s""" + |Code generation of $expression failed: + |${evaluated.code} + |$e + """.stripMargin) + } + + val actual = plan(inputRow).apply(0) + if (actual != expected) { + val input = if (inputRow == EmptyRow) "" else s", input: $inputRow" + fail(s"Incorrect Evaluation: $expression, actual: $actual, expected: $expected$input") + } + } + + def checkEvaluationWithGeneratedProjection( + expression: Expression, + expected: Any, + inputRow: Row = EmptyRow): Unit = { + val ctx = GenerateProjection.newCodeGenContext() + lazy val evaluated = expression.gen(ctx) + + val plan = try { + GenerateProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil) + } catch { + case e: Throwable => + fail( + s""" + |Code generation of $expression failed: + |${evaluated.code} + |$e + """.stripMargin) + } + + val actual = plan(inputRow) + val expectedRow = new GenericRow(Array[Any](CatalystTypeConverters.convertToCatalyst(expected))) + if (actual.hashCode() != expectedRow.hashCode()) { + fail( + s""" + |Mismatched hashCodes for values: $actual, $expectedRow + |Hash Codes: ${actual.hashCode()} != ${expectedRow.hashCode()} + |Expressions: ${expression} + |Code: ${evaluated} + """.stripMargin) + } + if (actual != expectedRow) { + val input = if (inputRow == EmptyRow) "" else s", input: $inputRow" + fail(s"Incorrect Evaluation: $expression, actual: $actual, expected: $expected$input") + } + } + def checkDoubleEvaluation( expression: Expression, expected: Spread[Double], @@ -69,8 +141,16 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { test("literals") { checkEvaluation(Literal(1), 1) checkEvaluation(Literal(true), true) + checkEvaluation(Literal(false), false) checkEvaluation(Literal(0L), 0L) + List(0.0, -0.0, Double.NegativeInfinity, Double.PositiveInfinity).foreach { + d => { + checkEvaluation(Literal(d), d) + checkEvaluation(Literal(d.toFloat), d.toFloat) + } + } checkEvaluation(Literal("test"), "test") + checkEvaluation(Literal.create(null, StringType), null) checkEvaluation(Literal(1) + Literal(1), 2) } @@ -1367,6 +1447,10 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { // TODO: Make the tests work with codegen. class ExpressionEvaluationWithoutCodeGenSuite extends ExpressionEvaluationBaseSuite { + override def checkEvaluation(expression: Expression, expected: Any, inputRow: Row = EmptyRow) = { + checkEvaluationWithoutCodegen(expression, expected, inputRow) + } + test("CreateStruct") { val row = Row(1, 2, 3) val c1 = 'a.int.at(0).as("a") diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedEvaluationSuite.scala index b577de1d5aab9..371a73181dad7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedEvaluationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedEvaluationSuite.scala @@ -21,34 +21,9 @@ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ /** - * Overrides our expression evaluation tests to use code generation for evaluation. + * Additional tests for code generation. */ class GeneratedEvaluationSuite extends ExpressionEvaluationSuite { - override def checkEvaluation( - expression: Expression, - expected: Any, - inputRow: Row = EmptyRow): Unit = { - val plan = try { - GenerateMutableProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil)() - } catch { - case e: Throwable => - val ctx = GenerateProjection.newCodeGenContext() - val evaluated = expression.gen(ctx) - fail( - s""" - |Code generation of $expression failed: - |${evaluated.code} - |$e - """.stripMargin) - } - - val actual = plan(inputRow).apply(0) - if (actual != expected) { - val input = if (inputRow == EmptyRow) "" else s", input: $inputRow" - fail(s"Incorrect Evaluation: $expression, actual: $actual, expected: $expected$input") - } - } - test("multithreaded eval") { import scala.concurrent._ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedMutableEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedMutableEvaluationSuite.scala deleted file mode 100644 index 9da72521ec3ec..0000000000000 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedMutableEvaluationSuite.scala +++ /dev/null @@ -1,61 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.expressions - -import org.apache.spark.sql.catalyst.CatalystTypeConverters -import org.apache.spark.sql.catalyst.expressions.codegen._ - -/** - * Overrides our expression evaluation tests to use generated code on mutable rows. - */ -class GeneratedMutableEvaluationSuite extends ExpressionEvaluationSuite { - override def checkEvaluation( - expression: Expression, - expected: Any, - inputRow: Row = EmptyRow): Unit = { - val ctx = GenerateProjection.newCodeGenContext() - lazy val evaluated = expression.gen(ctx) - - val plan = try { - GenerateProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil) - } catch { - case e: Throwable => - fail( - s""" - |Code generation of $expression failed: - |${evaluated.code} - |$e - """.stripMargin) - } - - val actual = plan(inputRow) - val expectedRow = new GenericRow(Array[Any](CatalystTypeConverters.convertToCatalyst(expected))) - if (actual.hashCode() != expectedRow.hashCode()) { - fail( - s""" - |Mismatched hashCodes for values: $actual, $expectedRow - |Hash Codes: ${actual.hashCode()} != ${expectedRow.hashCode()} - |${evaluated.code} - """.stripMargin) - } - if (actual != expectedRow) { - val input = if (inputRow == EmptyRow) "" else s", input: $inputRow" - fail(s"Incorrect Evaluation: $expression, actual: $actual, expected: $expected$input") - } - } -}