Skip to content

Commit

Permalink
Refactor putting SQLContext into SparkPlan. Fix ordering, other test …
Browse files Browse the repository at this point in the history
…cases.
  • Loading branch information
marmbrus committed Jul 22, 2014
1 parent be2cd6b commit d2ad5c5
Show file tree
Hide file tree
Showing 14 changed files with 104 additions and 95 deletions.
Expand Up @@ -17,13 +17,15 @@

package org.apache.spark.sql.catalyst.expressions.codegen

import com.typesafe.scalalogging.slf4j.Logging
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.types.{StringType, NumericType}

/**
* Generates bytecode for an [[Ordering]] of [[Row Rows]] for a given set of
* [[Expression Expressions]].
*/
object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[Row]] {
object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[Row]] with Logging {
import scala.reflect.runtime.{universe => ru}
import scala.reflect.runtime.universe._

Expand All @@ -40,6 +42,22 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[Row]] {
val evalA = expressionEvaluator(order.child)
val evalB = expressionEvaluator(order.child)

val compare = order.child.dataType match {
case _: NumericType =>
q"""
val comp = ${evalA.primitiveTerm} - ${evalB.primitiveTerm}
if(comp != 0) {
return ${if (order.direction == Ascending) q"comp.toInt" else q"-comp.toInt"}
}
"""
case StringType =>
if (order.direction == Ascending) {
q"""return ${evalA.primitiveTerm}.compare(${evalB.primitiveTerm})"""
} else {
q"""return ${evalB.primitiveTerm}.compare(${evalA.primitiveTerm})"""
}
}

q"""
i = $a
..${evalA.code}
Expand All @@ -52,9 +70,7 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[Row]] {
} else if (${evalB.nullTerm}) {
return ${if (order.direction == Ascending) q"1" else q"-1"}
} else {
i = a
val comp = ${evalA.primitiveTerm} - ${evalB.primitiveTerm}
if(comp != 0) return comp.toInt
$compare
}
"""
}
Expand All @@ -76,6 +92,7 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[Row]] {
}
new $orderingName()
"""
logger.debug(s"Generated Ordering: $code")
toolBox.eval(code).asInstanceOf[Ordering[Row]]
}
}
18 changes: 5 additions & 13 deletions sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
Expand Up @@ -304,18 +304,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
@transient
protected[sql] val prepareForExecution = new RuleExecutor[SparkPlan] {
val batches =
Batch("Add exchange", Once, AddExchange(self)) ::
Batch("CodeGen", Once, TurnOnCodeGen) :: Nil
}

protected object TurnOnCodeGen extends Rule[SparkPlan] {
def apply(plan: SparkPlan): SparkPlan = {
if (self.codegenEnabled) {
plan.foreach(p => println(p.simpleString))
plan.foreach(_._codegenEnabled = true)
}
plan
}
Batch("Add exchange", Once, AddExchange(self)) :: Nil
}

/**
Expand All @@ -330,7 +319,10 @@ class SQLContext(@transient val sparkContext: SparkContext)
lazy val analyzed = analyzer(logical)
lazy val optimizedPlan = optimizer(analyzed)
// TODO: Don't just pick the first one...
lazy val sparkPlan = planner(optimizedPlan).next()
lazy val sparkPlan = {
SparkPlan.currentContext.set(self)
planner(optimizedPlan).next()
}
// executedPlan should not be used to initialize any SparkPlan. It should be
// only used for execution.
lazy val executedPlan: SparkPlan = prepareForExecution(sparkPlan)
Expand Down
Expand Up @@ -42,7 +42,7 @@ case class Aggregate(
partial: Boolean,
groupingExpressions: Seq[Expression],
aggregateExpressions: Seq[NamedExpression],
child: SparkPlan)(@transient sqlContext: SQLContext)
child: SparkPlan)
extends UnaryNode {

override def requiredChildDistribution =
Expand All @@ -56,8 +56,6 @@ case class Aggregate(
}
}

override def otherCopyArgs = sqlContext :: Nil

// HACK: Generators don't correctly preserve their output through serializations so we grab
// out child's output attributes statically here.
private[this] val childOutput = child.output
Expand Down
Expand Up @@ -51,9 +51,11 @@ case class Generate(
if (join) child.output ++ generatorOutput else generatorOutput

/** Codegenned rows are not serializable... */
override def codegenEnabled = false
override val codegenEnabled = false

override def execute() = {
val boundGenerator = BindReferences.bindReference(generator, child.output)

if (join) {
child.execute().mapPartitions { iter =>
val nullValues = Seq.fill(generator.output.size)(Literal(null))
Expand All @@ -66,7 +68,7 @@ case class Generate(
val joinedRow = new JoinedRow

iter.flatMap {row =>
val outputRows = generator.eval(row)
val outputRows = boundGenerator.eval(row)
if (outer && outputRows.isEmpty) {
outerProjection(row) :: Nil
} else {
Expand All @@ -75,7 +77,7 @@ case class Generate(
}
}
} else {
child.execute().mapPartitions(iter => iter.flatMap(row => generator.eval(row)))
child.execute().mapPartitions(iter => iter.flatMap(row => boundGenerator.eval(row)))
}
}
}
Expand Up @@ -46,11 +46,9 @@ case class GeneratedAggregate(
partial: Boolean,
groupingExpressions: Seq[Expression],
aggregateExpressions: Seq[NamedExpression],
child: SparkPlan)(@transient sqlContext: SQLContext)
child: SparkPlan)
extends UnaryNode {

println(s"new $codegenEnabled")

override def requiredChildDistribution =
if (partial) {
UnspecifiedDistribution :: Nil
Expand All @@ -62,12 +60,9 @@ case class GeneratedAggregate(
}
}

override def otherCopyArgs = sqlContext :: Nil

override def output = aggregateExpressions.map(_.toAttribute)

override def execute() = {
println(s"codegen: $codegenEnabled")
val aggregatesToCompute = aggregateExpressions.flatMap { a =>
a.collect { case agg: AggregateExpression => agg}
}
Expand Down Expand Up @@ -160,7 +155,6 @@ case class GeneratedAggregate(
// TODO: Codegening anything other than the updateProjection is probably over kill.
val buffer = newAggregationBuffer(EmptyRow).asInstanceOf[MutableRow]
var currentRow: Row = null
println(codegenEnabled)

while (iter.hasNext) {
currentRow = iter.next()
Expand All @@ -172,7 +166,6 @@ case class GeneratedAggregate(
} else {
val buffers = new java.util.HashMap[Row, MutableRow]()

println(codegenEnabled)
var currentRow: Row = null
while (iter.hasNext) {
currentRow = iter.next()
Expand Down
Expand Up @@ -18,8 +18,9 @@
package org.apache.spark.sql.execution

import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{SQLContext, Logging, Row}
import org.apache.spark.sql.{SQLContext, Row}
import org.apache.spark.sql.catalyst.trees
import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
import org.apache.spark.sql.catalyst.expressions._
Expand All @@ -28,17 +29,35 @@ import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.plans.logical.BaseRelation
import org.apache.spark.sql.catalyst.plans.physical._


object SparkPlan {
protected[sql] val currentContext = new ThreadLocal[SQLContext]()
}

/**
* :: DeveloperApi ::
*/
@DeveloperApi
abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging {
abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializable {
self: Product =>

def codegenEnabled = _codegenEnabled
/**
* A handle to the SQL Context that was used to create this plan. Since many operators need
* access to the sqlContext for RDD operations or configuration this field is automatically
* populated by the query planning infrastructure.
*/
@transient
protected val sqlContext = SparkPlan.currentContext.get()

/** Will be set to true during planning if code generation should be used for this operator. */
private[sql] var _codegenEnabled = false
protected def sparkContext = sqlContext.sparkContext

def logger = log

val codegenEnabled: Boolean = if(sqlContext != null) {
sqlContext.codegenEnabled
} else {
false
}

// TODO: Move to `DistributedPlan`
/** Specifies how data is partitioned across different nodes in the cluster. */
Expand All @@ -57,16 +76,22 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging {
*/
def executeCollect(): Array[Row] = execute().map(_.copy()).collect()

def newProjection(expressions: Seq[Expression], inputSchema: Seq[Attribute]): Projection =
protected def newProjection(
expressions: Seq[Expression], inputSchema: Seq[Attribute]): Projection = {
log.debug(
s"Creating Projection: $expressions, inputSchema: $inputSchema, codegen:$codegenEnabled")
if (codegenEnabled) {
GenerateProjection(expressions, inputSchema)
} else {
new InterpretedProjection(expressions, inputSchema)
}
}

def newMutableProjection(
protected def newMutableProjection(
expressions: Seq[Expression],
inputSchema: Seq[Attribute]): () => MutableProjection = {
log.debug(
s"Creating MutableProj: $expressions, inputSchema: $inputSchema, codegen:$codegenEnabled")
if(codegenEnabled) {
GenerateMutableProjection(expressions, inputSchema)
} else {
Expand All @@ -75,15 +100,16 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging {
}


def newPredicate(expression: Expression, inputSchema: Seq[Attribute]): (Row) => Boolean = {
protected def newPredicate(
expression: Expression, inputSchema: Seq[Attribute]): (Row) => Boolean = {
if (codegenEnabled) {
GeneratePredicate(expression, inputSchema)
} else {
InterpretedPredicate(expression, inputSchema)
}
}

def newOrdering(order: Seq[SortOrder], inputSchema: Seq[Attribute]): Ordering[Row] = {
protected def newOrdering(order: Seq[SortOrder], inputSchema: Seq[Attribute]): Ordering[Row] = {
if (codegenEnabled) {
GenerateOrdering(order, inputSchema)
} else {
Expand Down
Expand Up @@ -39,7 +39,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
// no predicate can be evaluated by matching hash keys
case logical.Join(left, right, LeftSemi, condition) =>
execution.LeftSemiJoinBNL(
planLater(left), planLater(right), condition)(sqlContext) :: Nil
planLater(left), planLater(right), condition) :: Nil
case _ => Nil
}
}
Expand All @@ -58,7 +58,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
condition: Option[Expression],
side: BuildSide) = {
val broadcastHashJoin = execution.BroadcastHashJoin(
leftKeys, rightKeys, side, planLater(left), planLater(right))(sqlContext)
leftKeys, rightKeys, side, planLater(left), planLater(right))
condition.map(Filter(_, broadcastHashJoin)).getOrElse(broadcastHashJoin) :: Nil
}

Expand Down Expand Up @@ -118,7 +118,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
partial = true,
groupingExpressions,
partialComputation,
planLater(child))(sqlContext))(sqlContext) :: Nil
planLater(child))) :: Nil

// Cases where some aggregate can not be codegened
case PartialAggregation(
Expand All @@ -135,7 +135,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
partial = true,
groupingExpressions,
partialComputation,
planLater(child))(sqlContext))(sqlContext) :: Nil
planLater(child))) :: Nil

case _ => Nil
}
Expand All @@ -153,7 +153,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
case logical.Join(left, right, joinType, condition) =>
execution.BroadcastNestedLoopJoin(
planLater(left), planLater(right), joinType, condition)(sqlContext) :: Nil
planLater(left), planLater(right), joinType, condition) :: Nil
case _ => Nil
}
}
Expand All @@ -175,7 +175,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
object TakeOrdered extends Strategy {
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
case logical.Limit(IntegerLiteral(limit), logical.Sort(order, child)) =>
execution.TakeOrdered(limit, order, planLater(child))(sqlContext) :: Nil
execution.TakeOrdered(limit, order, planLater(child)) :: Nil
case _ => Nil
}
}
Expand All @@ -187,9 +187,9 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
val relation =
ParquetRelation.create(path, child, sparkContext.hadoopConfiguration)
// Note: overwrite=false because otherwise the metadata we just created will be deleted
InsertIntoParquetTable(relation, planLater(child), overwrite=false)(sqlContext) :: Nil
InsertIntoParquetTable(relation, planLater(child), overwrite=false) :: Nil
case logical.InsertIntoTable(table: ParquetRelation, partition, child, overwrite) =>
InsertIntoParquetTable(table, planLater(child), overwrite)(sqlContext) :: Nil
InsertIntoParquetTable(table, planLater(child), overwrite) :: Nil
case PhysicalOperation(projectList, filters: Seq[Expression], relation: ParquetRelation) =>
val prunePushedDownFilters =
if (sparkContext.conf.getBoolean(ParquetFilters.PARQUET_FILTER_PUSHDOWN_ENABLED, true)) {
Expand Down Expand Up @@ -218,7 +218,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
projectList,
filters,
prunePushedDownFilters,
ParquetTableScan(_, relation, filters)(sqlContext)) :: Nil
ParquetTableScan(_, relation, filters)) :: Nil

case _ => Nil
}
Expand All @@ -243,7 +243,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
case logical.Distinct(child) =>
execution.Aggregate(
partial = false, child.output, child.output, planLater(child))(sqlContext) :: Nil
partial = false, child.output, child.output, planLater(child)) :: Nil
case logical.Sort(sortExprs, child) =>
// This sort is a global sort. Its requiredDistribution will be an OrderedDistribution.
execution.Sort(sortExprs, global = true, planLater(child)):: Nil
Expand All @@ -256,17 +256,17 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
case logical.Filter(condition, child) =>
execution.Filter(condition, planLater(child)) :: Nil
case logical.Aggregate(group, agg, child) =>
execution.Aggregate(partial = false, group, agg, planLater(child))(sqlContext) :: Nil
execution.Aggregate(partial = false, group, agg, planLater(child)) :: Nil
case logical.Sample(fraction, withReplacement, seed, child) =>
execution.Sample(fraction, withReplacement, seed, planLater(child)) :: Nil
case logical.LocalRelation(output, data) =>
ExistingRdd(
output,
ExistingRdd.productToRowRdd(sparkContext.parallelize(data, numPartitions))) :: Nil
case logical.Limit(IntegerLiteral(limit), child) =>
execution.Limit(limit, planLater(child))(sqlContext) :: Nil
execution.Limit(limit, planLater(child)) :: Nil
case Unions(unionChildren) =>
execution.Union(unionChildren.map(planLater))(sqlContext) :: Nil
execution.Union(unionChildren.map(planLater)) :: Nil
case logical.Except(left,right) =>
execution.Except(planLater(left),planLater(right)) :: Nil
case logical.Intersect(left, right) =>
Expand Down

0 comments on commit d2ad5c5

Please sign in to comment.