Skip to content

Commit

Permalink
Back out Projection changes.
Browse files Browse the repository at this point in the history
  • Loading branch information
JoshRosen committed Jul 21, 2015
1 parent c5419b3 commit adc8239
Showing 1 changed file with 47 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,18 @@

package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateUnsafeProjection, GenerateMutableProjection}
import org.apache.spark.sql.types.{StructType, DataType}

/**
* A [[Projection]] that is calculated by calling the `eval` of each of the specified expressions.
* @param expressions a sequence of expressions that determine the value of each column of the
* output row.
*/
class InterpretedProjection(expressions: Seq[Expression], mutableRow: Boolean = false)
extends Projection {

def this(
expressions: Seq[Expression],
inputSchema: Seq[Attribute],
mutableRow: Boolean = false) = {
this(expressions.map(BindReferences.bindReference(_, inputSchema)), mutableRow)
}
class InterpretedProjection(expressions: Seq[Expression]) extends Projection {
def this(expressions: Seq[Expression], inputSchema: Seq[Attribute]) =
this(expressions.map(BindReferences.bindReference(_, inputSchema)))

// null check is required for when Kryo invokes the no-arg constructor.
protected val exprArray = if (expressions != null) expressions.toArray else null
Expand All @@ -42,7 +40,7 @@ class InterpretedProjection(expressions: Seq[Expression], mutableRow: Boolean =
outputArray(i) = exprArray(i).eval(input)
i += 1
}
if (mutableRow) new GenericMutableRow(outputArray) else new GenericInternalRow(outputArray)
new GenericInternalRow(outputArray)
}

override def toString: String = s"Row => [${exprArray.mkString(",")}]"
Expand Down Expand Up @@ -77,6 +75,39 @@ case class InterpretedMutableProjection(expressions: Seq[Expression]) extends Mu
}
}

/**
* A projection that returns UnsafeRow.
*/
abstract class UnsafeProjection extends Projection {
override def apply(row: InternalRow): UnsafeRow
}

object UnsafeProjection {
def create(schema: StructType): UnsafeProjection = create(schema.fields.map(_.dataType))

def create(fields: Seq[DataType]): UnsafeProjection = {
val exprs = fields.zipWithIndex.map(x => new BoundReference(x._2, x._1, true))
GenerateUnsafeProjection.generate(exprs)
}
}

/**
* A projection that could turn UnsafeRow into GenericInternalRow
*/
case class FromUnsafeProjection(fields: Seq[DataType]) extends Projection {

private[this] val expressions = fields.zipWithIndex.map { case (dt, idx) =>
new BoundReference(idx, dt, true)
}

@transient private[this] lazy val generatedProj =
GenerateMutableProjection.generate(expressions)()

override def apply(input: InternalRow): InternalRow = {
generatedProj(input)
}
}

/**
* A mutable wrapper that makes two rows appear as a single concatenated row. Designed to
* be instantiated once per thread and reused.
Expand Down Expand Up @@ -114,7 +145,7 @@ class JoinedRow extends InternalRow {

override def length: Int = row1.length + row2.length

override def apply(i: Int): Any =
override def get(i: Int): Any =
if (i < row1.length) row1(i) else row2(i - row1.length)

override def isNullAt(i: Int): Boolean =
Expand Down Expand Up @@ -208,7 +239,7 @@ class JoinedRow2 extends InternalRow {

override def length: Int = row1.length + row2.length

override def apply(i: Int): Any =
override def get(i: Int): Any =
if (i < row1.length) row1(i) else row2(i - row1.length)

override def isNullAt(i: Int): Boolean =
Expand Down Expand Up @@ -296,7 +327,7 @@ class JoinedRow3 extends InternalRow {

override def length: Int = row1.length + row2.length

override def apply(i: Int): Any =
override def get(i: Int): Any =
if (i < row1.length) row1(i) else row2(i - row1.length)

override def isNullAt(i: Int): Boolean =
Expand Down Expand Up @@ -384,7 +415,7 @@ class JoinedRow4 extends InternalRow {

override def length: Int = row1.length + row2.length

override def apply(i: Int): Any =
override def get(i: Int): Any =
if (i < row1.length) row1(i) else row2(i - row1.length)

override def isNullAt(i: Int): Boolean =
Expand Down Expand Up @@ -472,7 +503,7 @@ class JoinedRow5 extends InternalRow {

override def length: Int = row1.length + row2.length

override def apply(i: Int): Any =
override def get(i: Int): Any =
if (i < row1.length) row1(i) else row2(i - row1.length)

override def isNullAt(i: Int): Boolean =
Expand Down Expand Up @@ -560,7 +591,7 @@ class JoinedRow6 extends InternalRow {

override def length: Int = row1.length + row2.length

override def apply(i: Int): Any =
override def get(i: Int): Any =
if (i < row1.length) row1(i) else row2(i - row1.length)

override def isNullAt(i: Int): Boolean =
Expand Down

0 comments on commit adc8239

Please sign in to comment.