Skip to content

Commit

Permalink
Merge pull request #5 from rxin/codegen
Browse files Browse the repository at this point in the history
Code gen code review.
  • Loading branch information
Davies Liu committed Jun 5, 2015
2 parents 2344bc0 + 48c454f commit b5d3617
Show file tree
Hide file tree
Showing 7 changed files with 99 additions and 59 deletions.
Expand Up @@ -435,37 +435,57 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
if (evaluated == null) null else cast(evaluated)
}

override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = this match {
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
// TODO(cg): Add support for more data types.
(child.dataType, dataType) match {

case Cast(child @ BinaryType(), StringType) =>
castOrNull (ctx, ev, c =>
s"new ${ctx.stringType}().set($c)")
case (BinaryType, StringType) =>
defineCodeGen (ctx, ev, c =>
s"new ${ctx.stringType}().set($c)")

case Cast(child @ DateType(), StringType) =>
castOrNull(ctx, ev, c =>
s"""new ${ctx.stringType}().set(
case (DateType, StringType) =>
defineCodeGen(ctx, ev, c =>
s"""new ${ctx.stringType}().set(
org.apache.spark.sql.catalyst.util.DateUtils.toString($c))""")

case Cast(child @ BooleanType(), dt: NumericType) if !dt.isInstanceOf[DecimalType] =>
castOrNull(ctx, ev, c => s"(${ctx.primitiveType(dt)})($c?1:0)")
case (BooleanType, dt: NumericType) if !dt.isInstanceOf[DecimalType] =>
defineCodeGen(ctx, ev, c => s"(${ctx.primitiveType(dt)})($c ? 1 : 0)")

case Cast(child @ DecimalType(), IntegerType) =>
castOrNull(ctx, ev, c => s"($c).toInt()")
case (_: NumericType, dt: NumericType) if !dt.isInstanceOf[DecimalType] =>
defineCodeGen(ctx, ev, c => s"(${ctx.primitiveType(dt)})($c)")

case Cast(child @ DecimalType(), dt: NumericType) if !dt.isInstanceOf[DecimalType] =>
castOrNull(ctx, ev, c => s"($c).to${ctx.boxedType(dt)}()")
case (_: DecimalType, ByteType) =>
defineCodeGen(ctx, ev, c => s"($c).toByte()")

case Cast(child @ NumericType(), dt: NumericType) if !dt.isInstanceOf[DecimalType] =>
castOrNull(ctx, ev, c => s"(${ctx.primitiveType(dt)})($c)")
case (_: DecimalType, ShortType) =>
defineCodeGen(ctx, ev, c => s"($c).toShort()")

// Special handling required for timestamps in hive test cases since the toString function
// does not match the expected output.
case Cast(e, StringType) if e.dataType != TimestampType =>
castOrNull(ctx, ev, c =>
s"new ${ctx.stringType}().set(String.valueOf($c))")
case (_: DecimalType, IntegerType) =>
defineCodeGen(ctx, ev, c => s"($c).toInt()")

case other =>
super.genCode(ctx, ev)
case (_: DecimalType, LongType) =>
defineCodeGen(ctx, ev, c => s"($c).toLong()")

case (_: DecimalType, FloatType) =>
defineCodeGen(ctx, ev, c => s"($c).toFloat()")

case (_: DecimalType, DoubleType) =>
defineCodeGen(ctx, ev, c => s"($c).toDouble()")

case (_: DecimalType, dt: NumericType) if !dt.isInstanceOf[DecimalType] =>
defineCodeGen(ctx, ev, c => s"($c).to${ctx.boxedType(dt)}()")

// Special handling required for timestamps in hive test cases since the toString function
// does not match the expected output.
case (TimestampType, StringType) =>
super.genCode(ctx, ev)

case (_, StringType) =>
defineCodeGen(ctx, ev, c => s"new ${ctx.stringType}().set(String.valueOf($c))")

case other =>
super.genCode(ctx, ev)
}
}
}

Expand Down
Expand Up @@ -69,7 +69,9 @@ abstract class Expression extends TreeNode[Expression] {
}

/**
* Returns Java source code for this expression.
* Returns Java source code that can be compiled to evaluate this expression.
* The default behavior is to call the eval method of the expression. Concrete expression
* implementations should override this to do actual code generation.
*
* @param ctx a [[CodeGenContext]]
* @param ev an [[GeneratedExpressionCode]] with unique terms.
Expand All @@ -82,10 +84,10 @@ abstract class Expression extends TreeNode[Expression] {
/* expression: ${this} */
Object ${ev.objectTerm} = expressions[${ctx.references.size - 1}].eval(i);
boolean ${ev.nullTerm} = ${ev.objectTerm} == null;
${ctx.primitiveType(e.dataType)} ${ev.primitiveTerm} =
${ctx.defaultValue(e.dataType)};
if (!${ev.nullTerm}) ${ev.primitiveTerm} =
(${ctx.boxedType(e.dataType)})${ev.objectTerm};
${ctx.primitiveType(e.dataType)} ${ev.primitiveTerm} = ${ctx.defaultValue(e.dataType)};
if (!${ev.nullTerm}) {
${ev.primitiveTerm} = (${ctx.boxedType(e.dataType)})${ev.objectTerm};
}
"""
}

Expand Down Expand Up @@ -155,17 +157,17 @@ abstract class BinaryExpression extends Expression with trees.BinaryNode[Express

override def toString: String = s"($left $symbol $right)"


/**
* Short hand for generating binary evaluation code, which depends on two sub-evaluations of
* the same type. If either of the sub-expressions is null, the result of this computation
* is assumed to be null.
*
* @param f a function from two primitive term names to a tree that evaluates them.
* @param f accepts two variable names and returns Java code to compute the output.
*/
def evaluate(ctx: CodeGenContext,
ev: GeneratedExpressionCode,
f: (String, String) => String): String = {
protected def defineCodeGen(
ctx: CodeGenContext,
ev: GeneratedExpressionCode,
f: (String, String) => String): String = {
// TODO: Right now some timestamp tests fail if we enforce this...
if (left.dataType != right.dataType) {
// log.warn(s"${left.dataType} != ${right.dataType}")
Expand Down Expand Up @@ -197,9 +199,22 @@ abstract class LeafExpression extends Expression with trees.LeafNode[Expression]

abstract class UnaryExpression extends Expression with trees.UnaryNode[Expression] {
self: Product =>
def castOrNull(ctx: CodeGenContext,
ev: GeneratedExpressionCode,
f: String => String): String = {

/**
* Called by unary expressions to generate a code block that returns null if its parent returns
* null, and if not not null, use `f` to generate the expression.
*
* As an example, the following does a boolean inversion (i.e. NOT).
* {{{
* defineCodeGen(ctx, ev, c => s"!($c)")
* }}}
*
* @param f function that accepts a variable name and returns Java code to compute the output.
*/
protected def defineCodeGen(
ctx: CodeGenContext,
ev: GeneratedExpressionCode,
f: String => String): String = {
val eval = child.gen(ctx)
eval.code + s"""
boolean ${ev.nullTerm} = ${eval.nullTerm};
Expand Down
Expand Up @@ -87,6 +87,7 @@ case class Abs(child: Expression) extends UnaryArithmetic {
abstract class BinaryArithmetic extends BinaryExpression {
self: Product =>

/** Name of the function for this expression on a [[Decimal]] type. */
def decimalMethod: String = ""

override def dataType: DataType = left.dataType
Expand Down Expand Up @@ -119,9 +120,9 @@ abstract class BinaryArithmetic extends BinaryExpression {

override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
if (left.dataType.isInstanceOf[DecimalType]) {
evaluate(ctx, ev, { case (eval1, eval2) => s"$eval1.$decimalMethod($eval2)" } )
defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.$decimalMethod($eval2)")
} else {
evaluate(ctx, ev, { case (eval1, eval2) => s"$eval1 $symbol $eval2" } )
defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1 $symbol $eval2")
}
}

Expand Down Expand Up @@ -205,6 +206,9 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic
}
}

/**
* Special case handling due to division by 0 => null.
*/
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
val eval1 = left.gen(ctx)
val eval2 = right.gen(ctx)
Expand All @@ -221,8 +225,7 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic
eval1.code + eval2.code +
s"""
boolean ${ev.nullTerm} = false;
${ctx.primitiveType(left.dataType)} ${ev.primitiveTerm} =
${ctx.defaultValue(left.dataType)};
${ctx.primitiveType(left.dataType)} ${ev.primitiveTerm} = ${ctx.defaultValue(left.dataType)};
if (${eval1.nullTerm} || ${eval2.nullTerm} || $test) {
${ev.nullTerm} = true;
} else {
Expand Down Expand Up @@ -263,6 +266,9 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet
}
}

/**
* Special case handling for x % 0 ==> null.
*/
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
val eval1 = left.gen(ctx)
val eval2 = right.gen(ctx)
Expand All @@ -279,8 +285,7 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet
eval1.code + eval2.code +
s"""
boolean ${ev.nullTerm} = false;
${ctx.primitiveType(left.dataType)} ${ev.primitiveTerm} =
${ctx.defaultValue(left.dataType)};
${ctx.primitiveType(left.dataType)} ${ev.primitiveTerm} = ${ctx.defaultValue(left.dataType)};
if (${eval1.nullTerm} || ${eval2.nullTerm} || $test) {
${ev.nullTerm} = true;
} else {
Expand Down Expand Up @@ -337,7 +342,7 @@ case class BitwiseOr(left: Expression, right: Expression) extends BinaryArithmet
}

/**
* A function that calculates bitwise xor(^) of two numbers.
* A function that calculates bitwise xor of two numbers.
*/
case class BitwiseXor(left: Expression, right: Expression) extends BinaryArithmetic {
override def symbol: String = "^"
Expand Down
Expand Up @@ -67,14 +67,13 @@ case class MakeDecimal(child: Expression, precision: Int, scale: Int) extends Un
val eval = child.gen(ctx)
eval.code + s"""
boolean ${ev.nullTerm} = ${eval.nullTerm};
org.apache.spark.sql.types.Decimal ${ev.primitiveTerm} =
${ctx.defaultValue(DecimalType())};
org.apache.spark.sql.types.Decimal ${ev.primitiveTerm} = ${ctx.defaultValue(DecimalType())};

if (!${ev.nullTerm}) {
${ev.primitiveTerm} = new org.apache.spark.sql.types.Decimal();
${ev.primitiveTerm} =
${ev.primitiveTerm}.setOrNull(${eval.primitiveTerm}, $precision, $scale);
${ev.nullTerm} = ${ev.primitiveTerm} == null;
${ev.primitiveTerm} = new org.apache.spark.sql.types.Decimal();
${ev.primitiveTerm} =
${ev.primitiveTerm}.setOrNull(${eval.primitiveTerm}, $precision, $scale);
${ev.nullTerm} = ${ev.primitiveTerm} == null;
}
"""
}
Expand Down
Expand Up @@ -88,6 +88,7 @@ case class Literal protected (value: Any, dataType: DataType) extends LeafExpres
${ctx.primitiveType(dataType)} ${ev.primitiveTerm} = ${ctx.defaultValue(dataType)};
"""
} else {
// TODO(cg): Add support for more data types.
dataType match {
case StringType =>
val v = value.asInstanceOf[UTF8String]
Expand All @@ -96,12 +97,12 @@ case class Literal protected (value: Any, dataType: DataType) extends LeafExpres
final boolean ${ev.nullTerm} = false;
${ctx.stringType} ${ev.primitiveTerm} = new ${ctx.stringType}().set(${arr});
"""
case FloatType =>
case FloatType => // This must go before NumericType
s"""
final boolean ${ev.nullTerm} = false;
float ${ev.primitiveTerm} = ${value}f;
"""
case dt: DecimalType =>
case dt: DecimalType => // This must go before NumericType
s"""
final boolean ${ev.nullTerm} = false;
${ctx.primitiveType(dt)} ${ev.primitiveTerm} =
Expand Down
Expand Up @@ -61,9 +61,9 @@ case class Coalesce(children: Seq[Expression]) extends Expression {
children.map { e =>
val eval = e.gen(ctx)
s"""
if(${ev.nullTerm}) {
if (${ev.nullTerm}) {
${eval.code}
if(!${eval.nullTerm}) {
if (!${eval.nullTerm}) {
${ev.nullTerm} = false;
${ev.primitiveTerm} = ${eval.primitiveTerm};
}
Expand Down Expand Up @@ -137,9 +137,9 @@ case class AtLeastNNonNulls(n: Int, children: Seq[Expression]) extends Predicate
val code = children.map { e =>
val eval = e.gen(ctx)
s"""
if($nonnull < $n) {
if ($nonnull < $n) {
${eval.code}
if(!${eval.nullTerm}) {
if (!${eval.nullTerm}) {
$nonnull += 1;
}
}
Expand Down
Expand Up @@ -85,7 +85,7 @@ case class Not(child: Expression) extends UnaryExpression with Predicate with Ex
}

override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
castOrNull(ctx, ev, c => s"!($c)")
defineCodeGen(ctx, ev, c => s"!($c)")
}
}

Expand Down Expand Up @@ -220,13 +220,13 @@ abstract class BinaryComparison extends BinaryExpression with Predicate {
self: Product =>
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
left.dataType match {
case dt: NumericType if ctx.isNativeType(dt) => evaluate (ctx, ev, {
case dt: NumericType if ctx.isNativeType(dt) => defineCodeGen (ctx, ev, {
(c1, c3) => s"$c1 $symbol $c3"
})
case TimestampType =>
// java.sql.Timestamp does not have compare()
super.genCode(ctx, ev)
case other => evaluate (ctx, ev, {
case other => defineCodeGen (ctx, ev, {
(c1, c2) => s"$c1.compare($c2) $symbol 0"
})
}
Expand Down Expand Up @@ -277,7 +277,7 @@ case class EqualTo(left: Expression, right: Expression) extends BinaryComparison
else java.util.Arrays.equals(l.asInstanceOf[Array[Byte]], r.asInstanceOf[Array[Byte]])
}
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
evaluate(ctx, ev, ctx.equalFunc(left.dataType))
defineCodeGen(ctx, ev, ctx.equalFunc(left.dataType))
}
}

Expand Down Expand Up @@ -392,7 +392,7 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi
boolean ${ev.nullTerm} = false;
${ctx.primitiveType(dataType)} ${ev.primitiveTerm} = ${ctx.defaultValue(dataType)};
${condEval.code}
if(!${condEval.nullTerm} && ${condEval.primitiveTerm}) {
if (!${condEval.nullTerm} && ${condEval.primitiveTerm}) {
${trueEval.code}
${ev.nullTerm} = ${trueEval.nullTerm};
${ev.primitiveTerm} = ${trueEval.primitiveTerm};
Expand Down

0 comments on commit b5d3617

Please sign in to comment.