Skip to content

Commit

Permalink
improve coverage and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Davies Liu committed Jun 6, 2015
1 parent bad6828 commit f42c732
Show file tree
Hide file tree
Showing 20 changed files with 440 additions and 237 deletions.
Expand Up @@ -45,8 +45,8 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean)

override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
s"""
final boolean ${ev.nullTerm} = i.isNullAt($ordinal);
final ${ctx.primitiveType(dataType)} ${ev.primitiveTerm} = ${ev.nullTerm} ?
boolean ${ev.isNull} = i.isNullAt($ordinal);
${ctx.primitiveType(dataType)} ${ev.primitive} = ${ev.isNull} ?
${ctx.defaultValue(dataType)} : (${ctx.getColumn(dataType, ordinal)});
"""
}
Expand Down
Expand Up @@ -60,9 +60,9 @@ abstract class Expression extends TreeNode[Expression] {
* @return [[GeneratedExpressionCode]]
*/
def gen(ctx: CodeGenContext): GeneratedExpressionCode = {
val nullTerm = ctx.freshName("nullTerm")
val primitiveTerm = ctx.freshName("primitiveTerm")
val ve = GeneratedExpressionCode("", nullTerm, primitiveTerm)
val isNull = ctx.freshName("isNull")
val primitive = ctx.freshName("primitive")
val ve = GeneratedExpressionCode("", isNull, primitive)
ve.code = genCode(ctx, ve)
ve
}
Expand All @@ -82,11 +82,11 @@ abstract class Expression extends TreeNode[Expression] {
val objectTerm = ctx.freshName("obj")
s"""
/* expression: ${this} */
final Object ${objectTerm} = expressions[${ctx.references.size - 1}].eval(i);
final boolean ${ev.nullTerm} = ${objectTerm} == null;
${ctx.primitiveType(e.dataType)} ${ev.primitiveTerm} = ${ctx.defaultValue(e.dataType)};
if (!${ev.nullTerm}) {
${ev.primitiveTerm} = (${ctx.boxedType(e.dataType)})${objectTerm};
Object ${objectTerm} = expressions[${ctx.references.size - 1}].eval(i);
boolean ${ev.isNull} = ${objectTerm} == null;
${ctx.primitiveType(e.dataType)} ${ev.primitive} = ${ctx.defaultValue(e.dataType)};
if (!${ev.isNull}) {
${ev.primitive} = (${ctx.boxedType(e.dataType)})${objectTerm};
}
"""
}
Expand Down Expand Up @@ -175,18 +175,18 @@ abstract class BinaryExpression extends Expression with trees.BinaryNode[Express

val eval1 = left.gen(ctx)
val eval2 = right.gen(ctx)
val resultCode = f(eval1.primitiveTerm, eval2.primitiveTerm)
val resultCode = f(eval1.primitive, eval2.primitive)

s"""
${eval1.code}
boolean ${ev.nullTerm} = ${eval1.nullTerm};
${ctx.primitiveType(dataType)} ${ev.primitiveTerm} = ${ctx.defaultValue(dataType)};
if (!${ev.nullTerm}) {
boolean ${ev.isNull} = ${eval1.isNull};
${ctx.primitiveType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
if (!${ev.isNull}) {
${eval2.code}
if(!${eval2.nullTerm}) {
${ev.primitiveTerm} = $resultCode;
if(!${eval2.isNull}) {
${ev.primitive} = $resultCode;
} else {
${ev.nullTerm} = true;
${ev.isNull} = true;
}
}
"""
Expand Down Expand Up @@ -216,12 +216,12 @@ abstract class UnaryExpression extends Expression with trees.UnaryNode[Expressio
ev: GeneratedExpressionCode,
f: Term => Code): Code = {
val eval = child.gen(ctx)
// reuse the previous nullTerm
ev.nullTerm = eval.nullTerm
// reuse the previous isNull
ev.isNull = eval.isNull
eval.code + s"""
${ctx.primitiveType(dataType)} ${ev.primitiveTerm} = ${ctx.defaultValue(dataType)};
if (!${ev.nullTerm}) {
${ev.primitiveTerm} = ${f(eval.primitiveTerm)};
${ctx.primitiveType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
if (!${ev.isNull}) {
${ev.primitive} = ${f(eval.primitive)};
}
"""
}
Expand Down
Expand Up @@ -50,6 +50,11 @@ case class UnaryMinus(child: Expression) extends UnaryArithmetic {

private lazy val numeric = TypeUtils.getNumeric(dataType)

override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = dataType match {
case dt: DecimalType => defineCodeGen(ctx, ev, c => s"c.unary_$$minus()")
case dt: NumericType => defineCodeGen(ctx, ev, c => s"-($c)")
}

protected override def evalInternal(evalE: Any) = numeric.negate(evalE)
}

Expand All @@ -68,6 +73,21 @@ case class Sqrt(child: Expression) extends UnaryArithmetic {
if (value < 0) null
else math.sqrt(value)
}

override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
val eval = child.gen(ctx)
eval.code + s"""
boolean ${ev.isNull} = ${eval.isNull};
${ctx.primitiveType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
if (!${ev.isNull}) {
if (${eval.primitive} < 0.0) {
${ev.isNull} = true;
} else {
${ev.primitive} = java.lang.Math.sqrt(${eval.primitive});
}
}
"""
}
}

/**
Expand Down Expand Up @@ -216,9 +236,9 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic
val eval1 = left.gen(ctx)
val eval2 = right.gen(ctx)
val test = if (left.dataType.isInstanceOf[DecimalType]) {
s"${eval2.primitiveTerm}.isZero()"
s"${eval2.primitive}.isZero()"
} else {
s"${eval2.primitiveTerm} == 0"
s"${eval2.primitive} == 0"
}
val method = if (left.dataType.isInstanceOf[DecimalType]) {
s".$decimalMethod"
Expand All @@ -227,12 +247,12 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic
}
eval1.code + eval2.code +
s"""
boolean ${ev.nullTerm} = false;
${ctx.primitiveType(left.dataType)} ${ev.primitiveTerm} = ${ctx.defaultValue(left.dataType)};
if (${eval1.nullTerm} || ${eval2.nullTerm} || $test) {
${ev.nullTerm} = true;
boolean ${ev.isNull} = false;
${ctx.primitiveType(left.dataType)} ${ev.primitive} = ${ctx.defaultValue(left.dataType)};
if (${eval1.isNull} || ${eval2.isNull} || $test) {
${ev.isNull} = true;
} else {
${ev.primitiveTerm} = ${eval1.primitiveTerm}$method(${eval2.primitiveTerm});
${ev.primitive} = ${eval1.primitive}$method(${eval2.primitive});
}
"""
}
Expand Down Expand Up @@ -276,9 +296,9 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet
val eval1 = left.gen(ctx)
val eval2 = right.gen(ctx)
val test = if (left.dataType.isInstanceOf[DecimalType]) {
s"${eval2.primitiveTerm}.isZero()"
s"${eval2.primitive}.isZero()"
} else {
s"${eval2.primitiveTerm} == 0"
s"${eval2.primitive} == 0"
}
val method = if (left.dataType.isInstanceOf[DecimalType]) {
s".$decimalMethod"
Expand All @@ -287,12 +307,12 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet
}
eval1.code + eval2.code +
s"""
boolean ${ev.nullTerm} = false;
${ctx.primitiveType(left.dataType)} ${ev.primitiveTerm} = ${ctx.defaultValue(left.dataType)};
if (${eval1.nullTerm} || ${eval2.nullTerm} || $test) {
${ev.nullTerm} = true;
boolean ${ev.isNull} = false;
${ctx.primitiveType(left.dataType)} ${ev.primitive} = ${ctx.defaultValue(left.dataType)};
if (${eval1.isNull} || ${eval2.isNull} || $test) {
${ev.isNull} = true;
} else {
${ev.primitiveTerm} = ${eval1.primitiveTerm}$method(${eval2.primitiveTerm});
${ev.primitive} = ${eval1.primitive}$method(${eval2.primitive});
}
"""
}
Expand Down Expand Up @@ -387,6 +407,10 @@ case class BitwiseNot(child: Expression) extends UnaryArithmetic {
((evalE: Long) => ~evalE).asInstanceOf[(Any) => Any]
}

override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
defineCodeGen(ctx, ev, c => s"(${ctx.primitiveType(dataType)})~($c)")
}

protected override def evalInternal(evalE: Any) = not(evalE)
}

Expand Down Expand Up @@ -419,21 +443,21 @@ case class MaxOf(left: Expression, right: Expression) extends BinaryArithmetic {
val eval1 = left.gen(ctx)
val eval2 = right.gen(ctx)
eval1.code + eval2.code + s"""
boolean ${ev.nullTerm} = false;
${ctx.primitiveType(left.dataType)} ${ev.primitiveTerm} =
boolean ${ev.isNull} = false;
${ctx.primitiveType(left.dataType)} ${ev.primitive} =
${ctx.defaultValue(left.dataType)};

if (${eval1.nullTerm}) {
${ev.nullTerm} = ${eval2.nullTerm};
${ev.primitiveTerm} = ${eval2.primitiveTerm};
} else if (${eval2.nullTerm}) {
${ev.nullTerm} = ${eval1.nullTerm};
${ev.primitiveTerm} = ${eval1.primitiveTerm};
if (${eval1.isNull}) {
${ev.isNull} = ${eval2.isNull};
${ev.primitive} = ${eval2.primitive};
} else if (${eval2.isNull}) {
${ev.isNull} = ${eval1.isNull};
${ev.primitive} = ${eval1.primitive};
} else {
if (${eval1.primitiveTerm} > ${eval2.primitiveTerm}) {
${ev.primitiveTerm} = ${eval1.primitiveTerm};
if (${eval1.primitive} > ${eval2.primitive}) {
${ev.primitive} = ${eval1.primitive};
} else {
${ev.primitiveTerm} = ${eval2.primitiveTerm};
${ev.primitive} = ${eval2.primitive};
}
}
"""
Expand Down Expand Up @@ -475,21 +499,21 @@ case class MinOf(left: Expression, right: Expression) extends BinaryArithmetic {
val eval2 = right.gen(ctx)

eval1.code + eval2.code + s"""
boolean ${ev.nullTerm} = false;
${ctx.primitiveType(left.dataType)} ${ev.primitiveTerm} =
boolean ${ev.isNull} = false;
${ctx.primitiveType(left.dataType)} ${ev.primitive} =
${ctx.defaultValue(left.dataType)};

if (${eval1.nullTerm}) {
${ev.nullTerm} = ${eval2.nullTerm};
${ev.primitiveTerm} = ${eval2.primitiveTerm};
} else if (${eval2.nullTerm}) {
${ev.nullTerm} = ${eval1.nullTerm};
${ev.primitiveTerm} = ${eval1.primitiveTerm};
if (${eval1.isNull}) {
${ev.isNull} = ${eval2.isNull};
${ev.primitive} = ${eval2.primitive};
} else if (${eval2.isNull}) {
${ev.isNull} = ${eval1.isNull};
${ev.primitive} = ${eval1.primitive};
} else {
if (${eval1.primitiveTerm} < ${eval2.primitiveTerm}) {
${ev.primitiveTerm} = ${eval1.primitiveTerm};
if (${eval1.primitive} < ${eval2.primitive}) {
${ev.primitive} = ${eval1.primitive};
} else {
${ev.primitiveTerm} = ${eval2.primitiveTerm};
${ev.primitive} = ${eval2.primitive};
}
}
"""
Expand Down
Expand Up @@ -35,12 +35,12 @@ class LongHashSet extends org.apache.spark.util.collection.OpenHashSet[Long]
* Java source for evaluating an [[Expression]] given a [[Row]] of input.
*
* @param code The sequence of statements required to evaluate the expression.
* @param nullTerm A term that holds a boolean value representing whether the expression evaluated
* @param isNull A term that holds a boolean value representing whether the expression evaluated
* to null.
* @param primitiveTerm A term for a possible primitive value of the result of the evaluation. Not
* valid if `nullTerm` is set to `true`.
* @param primitive A term for a possible primitive value of the result of the evaluation. Not
* valid if `isNull` is set to `true`.
*/
case class GeneratedExpressionCode(var code: Code, var nullTerm: Term, var primitiveTerm: Term)
case class GeneratedExpressionCode(var code: Code, var isNull: Term, var primitive: Term)

/**
* A context for codegen, which is used to bookkeeping the expressions those are not supported
Expand Down Expand Up @@ -149,9 +149,9 @@ class CodeGenContext {
def defaultValue(dt: DataType): Term = dt match {
case BooleanType => "false"
case FloatType => "-1.0f"
case ShortType => "-1"
case LongType => "-1"
case ByteType => "-1"
case ShortType => "(short)-1"
case LongType => "-1L"
case ByteType => "(byte)-1"
case DoubleType => "-1.0"
case IntegerType => "-1"
case DateType => "-1"
Expand Down
Expand Up @@ -40,10 +40,10 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu
val evaluationCode = e.gen(ctx)
evaluationCode.code +
s"""
if(${evaluationCode.nullTerm})
if(${evaluationCode.isNull})
mutableRow.setNullAt($i);
else
mutableRow.${ctx.setColumn(e.dataType, i, evaluationCode.primitiveTerm)};
mutableRow.${ctx.setColumn(e.dataType, i, evaluationCode.primitive)};
"""
}.mkString("\n")
val code = s"""
Expand Down
Expand Up @@ -59,8 +59,8 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[Row]] wit
case BinaryType =>
s"""
{
byte[] x = ${if (asc) evalA.primitiveTerm else evalB.primitiveTerm};
byte[] y = ${if (!asc) evalB.primitiveTerm else evalA.primitiveTerm};
byte[] x = ${if (asc) evalA.primitive else evalB.primitive};
byte[] y = ${if (!asc) evalB.primitive else evalA.primitive};
int j = 0;
while (j < x.length && j < y.length) {
if (x[j] != y[j]) return x[j] - y[j];
Expand All @@ -73,16 +73,16 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[Row]] wit
}"""
case _: NumericType =>
s"""
if (${evalA.primitiveTerm} != ${evalB.primitiveTerm}) {
if (${evalA.primitiveTerm} > ${evalB.primitiveTerm}) {
if (${evalA.primitive} != ${evalB.primitive}) {
if (${evalA.primitive} > ${evalB.primitive}) {
return ${if (asc) "1" else "-1"};
} else {
return ${if (asc) "-1" else "1"};
}
}"""
case _ =>
s"""
int comp = ${evalA.primitiveTerm}.compare(${evalB.primitiveTerm});
int comp = ${evalA.primitive}.compare(${evalB.primitive});
if (comp != 0) {
return ${if (asc) "comp" else "-comp"};
}"""
Expand All @@ -93,11 +93,11 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[Row]] wit
${evalA.code}
i = $b;
${evalB.code}
if (${evalA.nullTerm} && ${evalB.nullTerm}) {
if (${evalA.isNull} && ${evalB.isNull}) {
// Nothing
} else if (${evalA.nullTerm}) {
} else if (${evalA.isNull}) {
return ${if (order.direction == Ascending) "-1" else "1"};
} else if (${evalB.nullTerm}) {
} else if (${evalB.isNull}) {
return ${if (order.direction == Ascending) "1" else "-1"};
} else {
$compare
Expand Down
Expand Up @@ -55,7 +55,7 @@ object GeneratePredicate extends CodeGenerator[Expression, (Row) => Boolean] {
@Override
public boolean eval(Row i) {
${eval.code}
return !${eval.nullTerm} && ${eval.primitiveTerm};
return !${eval.isNull} && ${eval.primitive};
}
}"""

Expand Down
Expand Up @@ -55,9 +55,9 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] {
{
// column$i
${eval.code}
nullBits[$i] = ${eval.nullTerm};
if(!${eval.nullTerm}) {
c$i = ${eval.primitiveTerm};
nullBits[$i] = ${eval.isNull};
if (!${eval.isNull}) {
c$i = ${eval.primitive};
}
}
"""
Expand Down Expand Up @@ -122,7 +122,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] {
case LongType => s"$col ^ ($col >>> 32)"
case FloatType => s"Float.floatToIntBits($col)"
case DoubleType =>
s"Double.doubleToLongBits($col) ^ (Double.doubleToLongBits($col) >>> 32)"
s"(int)(Double.doubleToLongBits($col) ^ (Double.doubleToLongBits($col) >>> 32))"
case _ => s"$col.hashCode()"
}
s"isNullAt($i) ? 0 : ($nonNull)"
Expand Down
Expand Up @@ -62,13 +62,13 @@ case class MakeDecimal(child: Expression, precision: Int, scale: Int) extends Un
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
val eval = child.gen(ctx)
eval.code + s"""
boolean ${ev.nullTerm} = ${eval.nullTerm};
${ctx.decimalType} ${ev.primitiveTerm} = null;
boolean ${ev.isNull} = ${eval.isNull};
${ctx.decimalType} ${ev.primitive} = null;

if (!${ev.nullTerm}) {
${ev.primitiveTerm} = (new ${ctx.decimalType}()).setOrNull(
${eval.primitiveTerm}, $precision, $scale);
${ev.nullTerm} = ${ev.primitiveTerm} == null;
if (!${ev.isNull}) {
${ev.primitive} = (new ${ctx.decimalType}()).setOrNull(
${eval.primitive}, $precision, $scale);
${ev.isNull} = ${ev.primitive} == null;
}
"""
}
Expand Down

0 comments on commit f42c732

Please sign in to comment.