Skip to content

Commit

Permalink
codegen for createArray createStruct & createNamedStruct
Browse files Browse the repository at this point in the history
  • Loading branch information
yjshen committed Jul 20, 2015
1 parent 79ec072 commit 39cefb8
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,14 @@ package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext}
import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.types._

/**
* Returns an Array containing the evaluation of all children expressions.
*/
case class CreateArray(children: Seq[Expression]) extends Expression with CodegenFallback {
case class CreateArray(children: Seq[Expression]) extends Expression {

override def foldable: Boolean = children.forall(_.foldable)

Expand All @@ -45,14 +45,31 @@ case class CreateArray(children: Seq[Expression]) extends Expression with Codege
children.map(_.eval(input))
}

override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
val arraySeqClass = "scala.collection.mutable.ArraySeq"
s"""
boolean ${ev.isNull} = false;
$arraySeqClass<Object> ${ev.primitive} = new $arraySeqClass<Object>(${children.size});
""" +
children.zipWithIndex.map { case (e, i) =>
val eval = e.gen(ctx)
eval.code + s"""
if (${eval.isNull}) {
${ev.primitive}.update($i, null);
} else {
${ev.primitive}.update($i, ${eval.primitive});
}
"""
}.mkString("\n")
}

override def prettyName: String = "array"
}

/**
* Returns a Row containing the evaluation of all children expressions.
* TODO: [[CreateStruct]] does not support codegen.
*/
case class CreateStruct(children: Seq[Expression]) extends Expression with CodegenFallback {
case class CreateStruct(children: Seq[Expression]) extends Expression {

override def foldable: Boolean = children.forall(_.foldable)

Expand All @@ -76,6 +93,24 @@ case class CreateStruct(children: Seq[Expression]) extends Expression with Codeg
InternalRow(children.map(_.eval(input)): _*)
}

override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
val rowClass = "org.apache.spark.sql.catalyst.expressions.GenericMutableRow"
s"""
boolean ${ev.isNull} = false;
final $rowClass ${ev.primitive} = new $rowClass(${children.size});
""" +
children.zipWithIndex.map { case (e, i) =>
val eval = e.gen(ctx)
eval.code + s"""
if (${eval.isNull}) {
${ev.primitive}.update($i, null);
} else {
${ev.primitive}.update($i, ${eval.primitive});
}
"""
}.mkString("\n")
}

override def prettyName: String = "struct"
}

Expand All @@ -84,7 +119,7 @@ case class CreateStruct(children: Seq[Expression]) extends Expression with Codeg
*
* @param children Seq(name1, val1, name2, val2, ...)
*/
case class CreateNamedStruct(children: Seq[Expression]) extends Expression with CodegenFallback {
case class CreateNamedStruct(children: Seq[Expression]) extends Expression {

private lazy val (nameExprs, valExprs) =
children.grouped(2).map { case Seq(name, value) => (name, value) }.toList.unzip
Expand Down Expand Up @@ -122,5 +157,23 @@ case class CreateNamedStruct(children: Seq[Expression]) extends Expression with
InternalRow(valExprs.map(_.eval(input)): _*)
}

override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
val rowClass = "org.apache.spark.sql.catalyst.expressions.GenericMutableRow"
s"""
boolean ${ev.isNull} = false;
final $rowClass ${ev.primitive} = new $rowClass(${valExprs.size});
""" +
valExprs.zipWithIndex.map { case (e, i) =>
val eval = e.gen(ctx)
eval.code + s"""
if (${eval.isNull}) {
${ev.primitive}.update($i, null);
} else {
${ev.primitive}.update($i, ${eval.primitive});
}
"""
}.mkString("\n")
}

override def prettyName: String = "named_struct"
}
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,15 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(getArrayStructFields(nullArrayStruct, "a"), null)
}

test("CreateArray") {
val intSeq = Seq(5, 10, 15, 20, 25)
val longSeq = intSeq.map(_.toLong)
val strSeq = intSeq.map(_.toString)
checkEvaluation(CreateArray(intSeq.map(Literal(_))), intSeq, EmptyRow)
checkEvaluation(CreateArray(longSeq.map(Literal(_))), longSeq, EmptyRow)
checkEvaluation(CreateArray(strSeq.map(Literal(_))), strSeq, EmptyRow)
}

test("CreateStruct") {
val row = create_row(1, 2, 3)
val c1 = 'a.int.at(0)
Expand Down

0 comments on commit 39cefb8

Please sign in to comment.