Skip to content

Commit

Permalink
WIP extract codegen utils
Browse files Browse the repository at this point in the history
  • Loading branch information
rednaxelafx committed Feb 21, 2018
1 parent 3fd0ccb commit 249fb93
Show file tree
Hide file tree
Showing 26 changed files with 292 additions and 261 deletions.
Expand Up @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.errors.attachTree
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodegenUtils, ExprCode}
import org.apache.spark.sql.types._

/**
Expand Down Expand Up @@ -66,13 +66,14 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean)
ev.copy(code = oev.code)
} else {
assert(ctx.INPUT_ROW != null, "INPUT_ROW and currentVars cannot both be null.")
val javaType = ctx.javaType(dataType)
val javaType = CodegenUtils.javaType(dataType)
val value = ctx.getValue(ctx.INPUT_ROW, dataType, ordinal.toString)
if (nullable) {
val defaultValueLiteral = CodegenUtils.defaultValue(dataType)
ev.copy(code =
s"""
|boolean ${ev.isNull} = ${ctx.INPUT_ROW}.isNullAt($ordinal);
|$javaType ${ev.value} = ${ev.isNull} ? ${ctx.defaultValue(dataType)} : ($value);
|$javaType ${ev.value} = ${ev.isNull} ? $defaultValueLiteral : ($value);
""".stripMargin)
} else {
ev.copy(code = s"$javaType ${ev.value} = $value;", isNull = "false")
Expand Down
Expand Up @@ -23,6 +23,7 @@ import org.apache.spark.SparkException
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion}
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenUtils._
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
Expand Down Expand Up @@ -669,7 +670,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
result: String, resultIsNull: String, resultType: DataType, cast: CastFunction): String = {
s"""
boolean $resultIsNull = $inputIsNull;
${ctx.javaType(resultType)} $result = ${ctx.defaultValue(resultType)};
${javaType(resultType)} $result = ${defaultValue(resultType)};
if (!$inputIsNull) {
${cast(input, result, resultIsNull)}
}
Expand All @@ -685,7 +686,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
val funcName = ctx.freshName("elementToString")
val elementToStringFunc = ctx.addNewFunction(funcName,
s"""
|private UTF8String $funcName(${ctx.javaType(et)} element) {
|private UTF8String $funcName(${javaType(et)} element) {
| UTF8String elementStr = null;
| ${elementToStringCode("element", "elementStr", null /* resultIsNull won't be used */)}
| return elementStr;
Expand Down Expand Up @@ -723,7 +724,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
val dataToStringCode = castToStringCode(dataType, ctx)
ctx.addNewFunction(funcName,
s"""
|private UTF8String $funcName(${ctx.javaType(dataType)} data) {
|private UTF8String $funcName(${javaType(dataType)} data) {
| UTF8String dataStr = null;
| ${dataToStringCode("data", "dataStr", null /* resultIsNull won't be used */)}
| return dataStr;
Expand Down Expand Up @@ -773,7 +774,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
| ${if (i != 0) s"""$buffer.append(" ");""" else ""}
|
| // Append $i field into the string buffer
| ${ctx.javaType(ft)} $field = ${ctx.getValue(row, ft, s"$i")};
| ${javaType(ft)} $field = ${ctx.getValue(row, ft, s"$i")};
| UTF8String $fieldStr = null;
| ${fieldToStringCode(field, fieldStr, null /* resultIsNull won't be used */)}
| $buffer.append($fieldStr);
Expand Down Expand Up @@ -1202,7 +1203,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
$values[$j] = null;
} else {
boolean $fromElementNull = false;
${ctx.javaType(fromType)} $fromElementPrim =
${javaType(fromType)} $fromElementPrim =
${ctx.getValue(c, fromType, j)};
${castCode(ctx, fromElementPrim,
fromElementNull, toElementPrim, toElementNull, toType, elementCast)}
Expand Down Expand Up @@ -1259,7 +1260,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
val fromFieldNull = ctx.freshName("ffn")
val toFieldPrim = ctx.freshName("tfp")
val toFieldNull = ctx.freshName("tfn")
val fromType = ctx.javaType(from.fields(i).dataType)
val fromType = javaType(from.fields(i).dataType)
s"""
boolean $fromFieldNull = $tmpInput.isNullAt($i);
if ($fromFieldNull) {
Expand Down
Expand Up @@ -22,6 +22,7 @@ import java.util.Locale
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenUtils._
import org.apache.spark.sql.catalyst.trees.TreeNode
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils
Expand Down Expand Up @@ -119,15 +120,15 @@ abstract class Expression extends TreeNode[Expression] {
// TODO: support whole stage codegen too
if (eval.code.trim.length > 1024 && ctx.INPUT_ROW != null && ctx.currentVars == null) {
val setIsNull = if (eval.isNull != "false" && eval.isNull != "true") {
val globalIsNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, "globalIsNull")
val globalIsNull = ctx.addMutableState(CodegenConstants.JAVA_BOOLEAN, "globalIsNull")
val localIsNull = eval.isNull
eval.isNull = globalIsNull
s"$globalIsNull = $localIsNull;"
} else {
""
}

val javaType = ctx.javaType(dataType)
val javaType = CodegenUtils.javaType(dataType)
val newValue = ctx.freshName("value")

val funcName = ctx.freshName(nodeName)
Expand Down Expand Up @@ -411,14 +412,14 @@ abstract class UnaryExpression extends Expression {
ev.copy(code = s"""
${childGen.code}
boolean ${ev.isNull} = ${childGen.isNull};
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
${javaType(dataType)} ${ev.value} = ${defaultValue(dataType)};
$nullSafeEval
""")
} else {
ev.copy(code = s"""
boolean ${ev.isNull} = false;
${childGen.code}
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
${javaType(dataType)} ${ev.value} = ${defaultValue(dataType)};
$resultCode""", isNull = "false")
}
}
Expand Down Expand Up @@ -510,15 +511,15 @@ abstract class BinaryExpression extends Expression {

ev.copy(code = s"""
boolean ${ev.isNull} = true;
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
${javaType(dataType)} ${ev.value} = ${defaultValue(dataType)};
$nullSafeEval
""")
} else {
ev.copy(code = s"""
boolean ${ev.isNull} = false;
${leftGen.code}
${rightGen.code}
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
${javaType(dataType)} ${ev.value} = ${defaultValue(dataType)};
$resultCode""", isNull = "false")
}
}
Expand Down Expand Up @@ -654,15 +655,15 @@ abstract class TernaryExpression extends Expression {

ev.copy(code = s"""
boolean ${ev.isNull} = true;
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
${javaType(dataType)} ${ev.value} = ${defaultValue(dataType)};
$nullSafeEval""")
} else {
ev.copy(code = s"""
boolean ${ev.isNull} = false;
${leftGen.code}
${midGen.code}
${rightGen.code}
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
${javaType(dataType)} ${ev.value} = ${defaultValue(dataType)};
$resultCode""", isNull = "false")
}
}
Expand Down
Expand Up @@ -18,7 +18,7 @@
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenConstants, CodegenContext, CodegenUtils, ExprCode}
import org.apache.spark.sql.types.{DataType, LongType}

/**
Expand Down Expand Up @@ -65,14 +65,14 @@ case class MonotonicallyIncreasingID() extends LeafExpression with Nondeterminis
}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val countTerm = ctx.addMutableState(ctx.JAVA_LONG, "count")
val countTerm = ctx.addMutableState(CodegenConstants.JAVA_LONG, "count")
val partitionMaskTerm = "partitionMask"
ctx.addImmutableStateIfNotExists(ctx.JAVA_LONG, partitionMaskTerm)
ctx.addImmutableStateIfNotExists(CodegenConstants.JAVA_LONG, partitionMaskTerm)
ctx.addPartitionInitializationStatement(s"$countTerm = 0L;")
ctx.addPartitionInitializationStatement(s"$partitionMaskTerm = ((long) partitionIndex) << 33;")

ev.copy(code = s"""
final ${ctx.javaType(dataType)} ${ev.value} = $partitionMaskTerm + $countTerm;
final ${CodegenUtils.javaType(dataType)} ${ev.value} = $partitionMaskTerm + $countTerm;
$countTerm++;""", isNull = "false")
}

Expand Down
Expand Up @@ -1018,11 +1018,12 @@ case class ScalaUDF(
val udf = ctx.addReferenceObj("udf", function, s"scala.Function${children.length}")
val getFuncResult = s"$udf.apply(${funcArgs.mkString(", ")})"
val resultConverter = s"$convertersTerm[${children.length}]"
val boxedType = CodegenUtils.boxedType(dataType)
val callFunc =
s"""
|${ctx.boxedType(dataType)} $resultTerm = null;
|$boxedType $resultTerm = null;
|try {
| $resultTerm = (${ctx.boxedType(dataType)})$resultConverter.apply($getFuncResult);
| $resultTerm = ($boxedType)$resultConverter.apply($getFuncResult);
|} catch (Exception e) {
| throw new org.apache.spark.SparkException($errorMsgTerm, e);
|}
Expand All @@ -1035,7 +1036,7 @@ case class ScalaUDF(
|$callFunc
|
|boolean ${ev.isNull} = $resultTerm == null;
|${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
|${CodegenUtils.javaType(dataType)} ${ev.value} = ${CodegenUtils.defaultValue(dataType)};
|if (!${ev.isNull}) {
| ${ev.value} = $resultTerm;
|}
Expand Down
Expand Up @@ -18,7 +18,7 @@
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenConstants, CodegenContext, CodegenUtils, ExprCode}
import org.apache.spark.sql.types.{DataType, IntegerType}

/**
Expand All @@ -44,8 +44,9 @@ case class SparkPartitionID() extends LeafExpression with Nondeterministic {

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val idTerm = "partitionId"
ctx.addImmutableStateIfNotExists(ctx.JAVA_INT, idTerm)
ctx.addImmutableStateIfNotExists(CodegenConstants.JAVA_INT, idTerm)
ctx.addPartitionInitializationStatement(s"$idTerm = partitionIndex;")
ev.copy(code = s"final ${ctx.javaType(dataType)} ${ev.value} = $idTerm;", isNull = "false")
val javaType = CodegenUtils.javaType(dataType)
ev.copy(code = s"final $javaType ${ev.value} = $idTerm;", isNull = "false")
}
}
Expand Up @@ -22,7 +22,7 @@ import org.apache.commons.lang3.StringUtils
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.TypeCheckFailure
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodegenUtils, ExprCode}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.CalendarInterval

Expand Down Expand Up @@ -165,7 +165,7 @@ case class PreciseTimestampConversion(
val eval = child.genCode(ctx)
ev.copy(code = eval.code +
s"""boolean ${ev.isNull} = ${eval.isNull};
|${ctx.javaType(dataType)} ${ev.value} = ${eval.value};
|${CodegenUtils.javaType(dataType)} ${ev.value} = ${eval.value};
""".stripMargin)
}
override def nullSafeEval(input: Any): Any = input
Expand Down

0 comments on commit 249fb93

Please sign in to comment.