Skip to content

Commit

Permalink
[SPARK-7289] handle project -> limit -> sort efficiently
Browse files Browse the repository at this point in the history
make the `TakeOrdered` strategy and operator more general, such that it can optionally handle a projection when necessary

Author: Wenchen Fan <cloud0fan@outlook.com>

Closes apache#6780 from cloud-fan/limit and squashes the following commits:

34aa07b [Wenchen Fan] revert
07d5456 [Wenchen Fan] clean closure
20821ec [Wenchen Fan] fix
3676a82 [Wenchen Fan] address comments
b558549 [Wenchen Fan] address comments
214842b [Wenchen Fan] fix style
2d8be83 [Wenchen Fan] add LimitPushDown
948f740 [Wenchen Fan] fix existing
  • Loading branch information
cloud-fan authored and marmbrus committed Jun 24, 2015
1 parent b84d4b4 commit f04b567
Show file tree
Hide file tree
Showing 8 changed files with 62 additions and 40 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -39,19 +39,22 @@ object DefaultOptimizer extends Optimizer {
Batch("Distinct", FixedPoint(100),
ReplaceDistinctWithAggregate) ::
Batch("Operator Optimizations", FixedPoint(100),
UnionPushdown,
CombineFilters,
// Operator push down
UnionPushDown,
PushPredicateThroughJoin,
PushPredicateThroughProject,
PushPredicateThroughGenerate,
ColumnPruning,
// Operator combine
ProjectCollapsing,
CombineFilters,
CombineLimits,
// Constant folding
NullPropagation,
OptimizeIn,
ConstantFolding,
LikeSimplification,
BooleanSimplification,
PushPredicateThroughJoin,
RemovePositive,
SimplifyFilters,
SimplifyCasts,
Expand All @@ -63,25 +66,25 @@ object DefaultOptimizer extends Optimizer {
}

/**
* Pushes operations to either side of a Union.
*/
object UnionPushdown extends Rule[LogicalPlan] {
* Pushes operations to either side of a Union.
*/
object UnionPushDown extends Rule[LogicalPlan] {

/**
* Maps Attributes from the left side to the corresponding Attribute on the right side.
*/
def buildRewrites(union: Union): AttributeMap[Attribute] = {
* Maps Attributes from the left side to the corresponding Attribute on the right side.
*/
private def buildRewrites(union: Union): AttributeMap[Attribute] = {
assert(union.left.output.size == union.right.output.size)

AttributeMap(union.left.output.zip(union.right.output))
}

/**
* Rewrites an expression so that it can be pushed to the right side of a Union operator.
* This method relies on the fact that the output attributes of a union are always equal
* to the left child's output.
*/
def pushToRight[A <: Expression](e: A, rewrites: AttributeMap[Attribute]): A = {
* Rewrites an expression so that it can be pushed to the right side of a Union operator.
* This method relies on the fact that the output attributes of a union are always equal
* to the left child's output.
*/
private def pushToRight[A <: Expression](e: A, rewrites: AttributeMap[Attribute]) = {
val result = e transform {
case a: Attribute => rewrites(a)
}
Expand All @@ -108,7 +111,6 @@ object UnionPushdown extends Rule[LogicalPlan] {
}
}


/**
* Attempts to eliminate the reading of unneeded columns from the query plan using the following
* transformations:
Expand All @@ -117,7 +119,6 @@ object UnionPushdown extends Rule[LogicalPlan] {
* - Aggregate
* - Project <- Join
* - LeftSemiJoin
* - Performing alias substitution.
*/
object ColumnPruning extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
Expand Down Expand Up @@ -159,10 +160,11 @@ object ColumnPruning extends Rule[LogicalPlan] {

Join(left, prunedChild(right, allReferences), LeftSemi, condition)

// Push down project through limit, so that we may have chance to push it further.
case Project(projectList, Limit(exp, child)) =>
Limit(exp, Project(projectList, child))

// push down project if possible when the child is sort
// Push down project if possible when the child is sort
case p @ Project(projectList, s @ Sort(_, _, grandChild))
if s.references.subsetOf(p.outputSet) =>
s.copy(child = Project(projectList, grandChild))
Expand All @@ -181,8 +183,8 @@ object ColumnPruning extends Rule[LogicalPlan] {
}

/**
* Combines two adjacent [[Project]] operators into one, merging the
* expressions into one single expression.
* Combines two adjacent [[Project]] operators into one and perform alias substitution,
* merging the expressions into one single expression.
*/
object ProjectCollapsing extends Rule[LogicalPlan] {

Expand Down Expand Up @@ -222,10 +224,10 @@ object ProjectCollapsing extends Rule[LogicalPlan] {
object LikeSimplification extends Rule[LogicalPlan] {
// if guards below protect from escapes on trailing %.
// Cases like "something\%" are not optimized, but this does not affect correctness.
val startsWith = "([^_%]+)%".r
val endsWith = "%([^_%]+)".r
val contains = "%([^_%]+)%".r
val equalTo = "([^_%]*)".r
private val startsWith = "([^_%]+)%".r
private val endsWith = "%([^_%]+)".r
private val contains = "%([^_%]+)%".r
private val equalTo = "([^_%]*)".r

def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
case Like(l, Literal(utf, StringType)) =>
Expand Down Expand Up @@ -497,7 +499,7 @@ object PushPredicateThroughProject extends Rule[LogicalPlan] {
grandChild))
}

def replaceAlias(condition: Expression, sourceAliases: Map[Attribute, Expression]): Expression = {
private def replaceAlias(condition: Expression, sourceAliases: Map[Attribute, Expression]) = {
condition transform {
case a: AttributeReference => sourceAliases.getOrElse(a, a)
}
Expand Down Expand Up @@ -682,7 +684,7 @@ object DecimalAggregates extends Rule[LogicalPlan] {
import Decimal.MAX_LONG_DIGITS

/** Maximum number of decimal digits representable precisely in a Double */
val MAX_DOUBLE_DIGITS = 15
private val MAX_DOUBLE_DIGITS = 15

def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
case Sum(e @ DecimalType.Expression(prec, scale)) if prec + 10 <= MAX_LONG_DIGITS =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,13 @@ import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.dsl.expressions._

class UnionPushdownSuite extends PlanTest {
class UnionPushDownSuite extends PlanTest {
object Optimize extends RuleExecutor[LogicalPlan] {
val batches =
Batch("Subqueries", Once,
EliminateSubQueries) ::
Batch("Union Pushdown", Once,
UnionPushdown) :: Nil
UnionPushDown) :: Nil
}

val testRelation = LocalRelation('a.int, 'b.int, 'c.int)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -858,7 +858,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
experimental.extraStrategies ++ (
DataSourceStrategy ::
DDLStrategy ::
TakeOrdered ::
TakeOrderedAndProject ::
HashAggregation ::
LeftSemiJoin ::
HashJoin ::
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,6 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
log.debug(
s"Creating MutableProj: $expressions, inputSchema: $inputSchema, codegen:$codegenEnabled")
if(codegenEnabled && expressions.forall(_.isThreadSafe)) {

GenerateMutableProjection.generate(expressions, inputSchema)
} else {
() => new InterpretedMutableProjection(expressions, inputSchema)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -213,10 +213,14 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
protected lazy val singleRowRdd =
sparkContext.parallelize(Seq(new GenericRow(Array[Any]()): InternalRow), 1)

object TakeOrdered extends Strategy {
object TakeOrderedAndProject extends Strategy {
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
case logical.Limit(IntegerLiteral(limit), logical.Sort(order, true, child)) =>
execution.TakeOrdered(limit, order, planLater(child)) :: Nil
execution.TakeOrderedAndProject(limit, order, None, planLater(child)) :: Nil
case logical.Limit(
IntegerLiteral(limit),
logical.Project(projectList, logical.Sort(order, true, child))) =>
execution.TakeOrderedAndProject(limit, order, Some(projectList), planLater(child)) :: Nil
case _ => Nil
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ case class Project(projectList: Seq[NamedExpression], child: SparkPlan) extends
@transient lazy val buildProjection = newMutableProjection(projectList, child.output)

protected override def doExecute(): RDD[InternalRow] = child.execute().mapPartitions { iter =>
val resuableProjection = buildProjection()
iter.map(resuableProjection)
val reusableProjection = buildProjection()
iter.map(reusableProjection)
}

override def outputOrdering: Seq[SortOrder] = child.outputOrdering
Expand Down Expand Up @@ -147,21 +147,32 @@ case class Limit(limit: Int, child: SparkPlan)

/**
* :: DeveloperApi ::
* Take the first limit elements as defined by the sortOrder. This is logically equivalent to
* having a [[Limit]] operator after a [[Sort]] operator. This could have been named TopK, but
* Spark's top operator does the opposite in ordering so we name it TakeOrdered to avoid confusion.
* Take the first limit elements as defined by the sortOrder, and do projection if needed.
* This is logically equivalent to having a [[Limit]] operator after a [[Sort]] operator,
* or having a [[Project]] operator between them.
* This could have been named TopK, but Spark's top operator does the opposite in ordering
* so we name it TakeOrdered to avoid confusion.
*/
@DeveloperApi
case class TakeOrdered(limit: Int, sortOrder: Seq[SortOrder], child: SparkPlan) extends UnaryNode {
case class TakeOrderedAndProject(
limit: Int,
sortOrder: Seq[SortOrder],
projectList: Option[Seq[NamedExpression]],
child: SparkPlan) extends UnaryNode {

override def output: Seq[Attribute] = child.output

override def outputPartitioning: Partitioning = SinglePartition

private val ord: RowOrdering = new RowOrdering(sortOrder, child.output)

private def collectData(): Array[InternalRow] =
child.execute().map(_.copy()).takeOrdered(limit)(ord)
// TODO: remove @transient after figure out how to clean closure at InsertIntoHiveTable.
@transient private val projection = projectList.map(new InterpretedProjection(_, child.output))

private def collectData(): Array[InternalRow] = {
val data = child.execute().map(_.copy()).takeOrdered(limit)(ord)
projection.map(data.map(_)).getOrElse(data)
}

override def executeCollect(): Array[Row] = {
val converter = CatalystTypeConverters.createToScalaConverter(schema)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,4 +141,10 @@ class PlannerSuite extends SparkFunSuite {

setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, origThreshold)
}

test("efficient limit -> project -> sort") {
val query = testData.sort('key).select('value).limit(2).logicalPlan
val planned = planner.TakeOrderedAndProject(query)
assert(planned.head.isInstanceOf[execution.TakeOrderedAndProject])
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -442,7 +442,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
HiveCommandStrategy(self),
HiveDDLStrategy,
DDLStrategy,
TakeOrdered,
TakeOrderedAndProject,
ParquetOperations,
InMemoryScans,
ParquetConversion, // Must be before HiveTableScans
Expand Down

0 comments on commit f04b567

Please sign in to comment.