diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index 66cdfd91cd831..24b01ea55110e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -17,20 +17,18 @@ package org.apache.spark.sql.catalyst.expressions +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateUnsafeProjection, GenerateMutableProjection} +import org.apache.spark.sql.types.{StructType, DataType} + /** * A [[Projection]] that is calculated by calling the `eval` of each of the specified expressions. * @param expressions a sequence of expressions that determine the value of each column of the * output row. */ -class InterpretedProjection(expressions: Seq[Expression], mutableRow: Boolean = false) - extends Projection { - - def this( - expressions: Seq[Expression], - inputSchema: Seq[Attribute], - mutableRow: Boolean = false) = { - this(expressions.map(BindReferences.bindReference(_, inputSchema)), mutableRow) - } +class InterpretedProjection(expressions: Seq[Expression]) extends Projection { + def this(expressions: Seq[Expression], inputSchema: Seq[Attribute]) = + this(expressions.map(BindReferences.bindReference(_, inputSchema))) // null check is required for when Kryo invokes the no-arg constructor. protected val exprArray = if (expressions != null) expressions.toArray else null @@ -42,7 +40,7 @@ class InterpretedProjection(expressions: Seq[Expression], mutableRow: Boolean = outputArray(i) = exprArray(i).eval(input) i += 1 } - if (mutableRow) new GenericMutableRow(outputArray) else new GenericInternalRow(outputArray) + new GenericInternalRow(outputArray) } override def toString: String = s"Row => [${exprArray.mkString(",")}]" @@ -77,6 +75,39 @@ case class InterpretedMutableProjection(expressions: Seq[Expression]) extends Mu } } +/** + * A projection that returns UnsafeRow. + */ +abstract class UnsafeProjection extends Projection { + override def apply(row: InternalRow): UnsafeRow +} + +object UnsafeProjection { + def create(schema: StructType): UnsafeProjection = create(schema.fields.map(_.dataType)) + + def create(fields: Seq[DataType]): UnsafeProjection = { + val exprs = fields.zipWithIndex.map(x => new BoundReference(x._2, x._1, true)) + GenerateUnsafeProjection.generate(exprs) + } +} + +/** + * A projection that could turn UnsafeRow into GenericInternalRow + */ +case class FromUnsafeProjection(fields: Seq[DataType]) extends Projection { + + private[this] val expressions = fields.zipWithIndex.map { case (dt, idx) => + new BoundReference(idx, dt, true) + } + + @transient private[this] lazy val generatedProj = + GenerateMutableProjection.generate(expressions)() + + override def apply(input: InternalRow): InternalRow = { + generatedProj(input) + } +} + /** * A mutable wrapper that makes two rows appear as a single concatenated row. Designed to * be instantiated once per thread and reused. @@ -114,7 +145,7 @@ class JoinedRow extends InternalRow { override def length: Int = row1.length + row2.length - override def apply(i: Int): Any = + override def get(i: Int): Any = if (i < row1.length) row1(i) else row2(i - row1.length) override def isNullAt(i: Int): Boolean = @@ -208,7 +239,7 @@ class JoinedRow2 extends InternalRow { override def length: Int = row1.length + row2.length - override def apply(i: Int): Any = + override def get(i: Int): Any = if (i < row1.length) row1(i) else row2(i - row1.length) override def isNullAt(i: Int): Boolean = @@ -296,7 +327,7 @@ class JoinedRow3 extends InternalRow { override def length: Int = row1.length + row2.length - override def apply(i: Int): Any = + override def get(i: Int): Any = if (i < row1.length) row1(i) else row2(i - row1.length) override def isNullAt(i: Int): Boolean = @@ -384,7 +415,7 @@ class JoinedRow4 extends InternalRow { override def length: Int = row1.length + row2.length - override def apply(i: Int): Any = + override def get(i: Int): Any = if (i < row1.length) row1(i) else row2(i - row1.length) override def isNullAt(i: Int): Boolean = @@ -472,7 +503,7 @@ class JoinedRow5 extends InternalRow { override def length: Int = row1.length + row2.length - override def apply(i: Int): Any = + override def get(i: Int): Any = if (i < row1.length) row1(i) else row2(i - row1.length) override def isNullAt(i: Int): Boolean = @@ -560,7 +591,7 @@ class JoinedRow6 extends InternalRow { override def length: Int = row1.length + row2.length - override def apply(i: Int): Any = + override def get(i: Int): Any = if (i < row1.length) row1(i) else row2(i - row1.length) override def isNullAt(i: Int): Boolean =