Skip to content

Commit

Permalink
[SPARK-13732][SPARK-13797][SQL] Remove projectList from Window and El…
Browse files Browse the repository at this point in the history
…iminate useless Window

#### What changes were proposed in this pull request?

`projectList` is useless. Its value is always the same as the child.output. Remove it from the class `Window`. Removal can simplify the codes in Analyzer and Optimizer.

This PR is based on the discussion started by cloud-fan in a separate PR:
apache#5604 (comment)

This PR also eliminates useless `Window`.

cloud-fan yhuai

#### How was this patch tested?

Existing test cases cover it.

Author: gatorsmile <gatorsmile@gmail.com>
Author: xiaoli <lixiao1983@gmail.com>
Author: Xiao Li <xiaoli@Xiaos-MacBook-Pro.local>

Closes apache#11565 from gatorsmile/removeProjListWindow.
  • Loading branch information
gatorsmile authored and roygao94 committed Mar 22, 2016
1 parent 600ffae commit 1e24581
Show file tree
Hide file tree
Showing 7 changed files with 94 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -421,7 +421,7 @@ class Analyzer(
val newOutput = oldVersion.generatorOutput.map(_.newInstance())
(oldVersion, oldVersion.copy(generatorOutput = newOutput))

case oldVersion @ Window(_, windowExpressions, _, _, child)
case oldVersion @ Window(windowExpressions, _, _, child)
if AttributeSet(windowExpressions.map(_.toAttribute)).intersect(conflictingAttributes)
.nonEmpty =>
(oldVersion, oldVersion.copy(windowExpressions = newAliases(windowExpressions)))
Expand Down Expand Up @@ -658,10 +658,6 @@ class Analyzer(
case p: Project =>
val missing = missingAttrs -- p.child.outputSet
Project(p.projectList ++ missingAttrs, addMissingAttr(p.child, missing))
case w: Window =>
val missing = missingAttrs -- w.child.outputSet
w.copy(projectList = w.projectList ++ missingAttrs,
child = addMissingAttr(w.child, missing))
case a: Aggregate =>
// all the missing attributes should be grouping expressions
// TODO: push down AggregateExpression
Expand Down Expand Up @@ -1166,7 +1162,6 @@ class Analyzer(
// Set currentChild to the newly created Window operator.
currentChild =
Window(
currentChild.output,
windowExpressions,
partitionSpec,
orderSpec,
Expand Down Expand Up @@ -1436,10 +1431,10 @@ object CleanupAliases extends Rule[LogicalPlan] {
val cleanedAggs = aggs.map(trimNonTopLevelAliases(_).asInstanceOf[NamedExpression])
Aggregate(grouping.map(trimAliases), cleanedAggs, child)

case w @ Window(projectList, windowExprs, partitionSpec, orderSpec, child) =>
case w @ Window(windowExprs, partitionSpec, orderSpec, child) =>
val cleanedWindowExprs =
windowExprs.map(e => trimNonTopLevelAliases(e).asInstanceOf[NamedExpression])
Window(projectList, cleanedWindowExprs, partitionSpec.map(trimAliases),
Window(cleanedWindowExprs, partitionSpec.map(trimAliases),
orderSpec.map(trimAliases(_).asInstanceOf[SortOrder]), child)

// Operators that operate on objects should only have expressions from encoders, which should
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,12 @@ package object dsl {
Aggregate(groupingExprs, aliasedExprs, logicalPlan)
}

def window(
windowExpressions: Seq[NamedExpression],
partitionSpec: Seq[Expression],
orderSpec: Seq[SortOrder]): LogicalPlan =
Window(windowExpressions, partitionSpec, orderSpec, logicalPlan)

def subquery(alias: Symbol): LogicalPlan = SubqueryAlias(alias.name, logicalPlan)

def except(otherPlan: LogicalPlan): LogicalPlan = Except(logicalPlan, otherPlan)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -315,21 +315,17 @@ object SetOperationPushDown extends Rule[LogicalPlan] with PredicateHelper {
* - LeftSemiJoin
*/
object ColumnPruning extends Rule[LogicalPlan] {
def sameOutput(output1: Seq[Attribute], output2: Seq[Attribute]): Boolean =
private def sameOutput(output1: Seq[Attribute], output2: Seq[Attribute]): Boolean =
output1.size == output2.size &&
output1.zip(output2).forall(pair => pair._1.semanticEquals(pair._2))

def apply(plan: LogicalPlan): LogicalPlan = plan transform {
// Prunes the unused columns from project list of Project/Aggregate/Window/Expand
// Prunes the unused columns from project list of Project/Aggregate/Expand
case p @ Project(_, p2: Project) if (p2.outputSet -- p.references).nonEmpty =>
p.copy(child = p2.copy(projectList = p2.projectList.filter(p.references.contains)))
case p @ Project(_, a: Aggregate) if (a.outputSet -- p.references).nonEmpty =>
p.copy(
child = a.copy(aggregateExpressions = a.aggregateExpressions.filter(p.references.contains)))
case p @ Project(_, w: Window) if (w.outputSet -- p.references).nonEmpty =>
p.copy(child = w.copy(
projectList = w.projectList.filter(p.references.contains),
windowExpressions = w.windowExpressions.filter(p.references.contains)))
case a @ Project(_, e @ Expand(_, _, grandChild)) if (e.outputSet -- a.references).nonEmpty =>
val newOutput = e.output.filter(a.references.contains(_))
val newProjects = e.projections.map { proj =>
Expand All @@ -343,11 +339,9 @@ object ColumnPruning extends Rule[LogicalPlan] {
case mp @ MapPartitions(_, _, _, child) if (child.outputSet -- mp.references).nonEmpty =>
mp.copy(child = prunedChild(child, mp.references))

// Prunes the unused columns from child of Aggregate/Window/Expand/Generate
// Prunes the unused columns from child of Aggregate/Expand/Generate
case a @ Aggregate(_, _, child) if (child.outputSet -- a.references).nonEmpty =>
a.copy(child = prunedChild(child, a.references))
case w @ Window(_, _, _, _, child) if (child.outputSet -- w.references).nonEmpty =>
w.copy(child = prunedChild(child, w.references))
case e @ Expand(_, _, child) if (child.outputSet -- e.references).nonEmpty =>
e.copy(child = prunedChild(child, e.references))
case g: Generate if !g.join && (g.child.outputSet -- g.references).nonEmpty =>
Expand Down Expand Up @@ -381,6 +375,14 @@ object ColumnPruning extends Rule[LogicalPlan] {
p
}

// Prune unnecessary window expressions
case p @ Project(_, w: Window) if (w.windowOutputSet -- p.references).nonEmpty =>
p.copy(child = w.copy(
windowExpressions = w.windowExpressions.filter(p.references.contains)))

// Eliminate no-op Window
case w: Window if w.windowExpressions.isEmpty => w.child

// Eliminate no-op Projects
case p @ Project(projectList, child) if sameOutput(child.output, p.output) => child

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -434,14 +434,15 @@ case class Aggregate(
}

case class Window(
projectList: Seq[Attribute],
windowExpressions: Seq[NamedExpression],
partitionSpec: Seq[Expression],
orderSpec: Seq[SortOrder],
child: LogicalPlan) extends UnaryNode {

override def output: Seq[Attribute] =
projectList ++ windowExpressions.map(_.toAttribute)
child.output ++ windowExpressions.map(_.toAttribute)

def windowOutputSet: AttributeSet = AttributeSet(windowExpressions.map(_.toAttribute))
}

private[sql] object Expand {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ import org.apache.spark.sql.catalyst.analysis
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions.{Ascending, Explode, Literal, SortOrder}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Complete, Count}
import org.apache.spark.sql.catalyst.plans.{Inner, PlanTest}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.RuleExecutor
Expand All @@ -33,7 +34,8 @@ class ColumnPruningSuite extends PlanTest {

object Optimize extends RuleExecutor[LogicalPlan] {
val batches = Batch("Column pruning", FixedPoint(100),
ColumnPruning) :: Nil
ColumnPruning,
CollapseProject) :: Nil
}

test("Column pruning for Generate when Generate.join = false") {
Expand Down Expand Up @@ -258,6 +260,68 @@ class ColumnPruningSuite extends PlanTest {
comparePlans(optimized1, analysis.EliminateSubqueryAliases(correctAnswer1))
}

test("Column pruning on Window with useless aggregate functions") {
val input = LocalRelation('a.int, 'b.string, 'c.double, 'd.int)

val originalQuery =
input.groupBy('a, 'c, 'd)('a, 'c, 'd,
WindowExpression(
AggregateExpression(Count('b), Complete, isDistinct = false),
WindowSpecDefinition( 'a :: Nil,
SortOrder('b, Ascending) :: Nil,
UnspecifiedFrame)).as('window)).select('a, 'c)

val correctAnswer = input.select('a, 'c, 'd).groupBy('a, 'c, 'd)('a, 'c).analyze

val optimized = Optimize.execute(originalQuery.analyze)

comparePlans(optimized, correctAnswer)
}

test("Column pruning on Window with selected agg expressions") {
val input = LocalRelation('a.int, 'b.string, 'c.double, 'd.int)

val originalQuery =
input.select('a, 'b, 'c, 'd,
WindowExpression(
AggregateExpression(Count('b), Complete, isDistinct = false),
WindowSpecDefinition( 'a :: Nil,
SortOrder('b, Ascending) :: Nil,
UnspecifiedFrame)).as('window)).where('window > 1).select('a, 'c)

val correctAnswer =
input.select('a, 'b, 'c)
.window(WindowExpression(
AggregateExpression(Count('b), Complete, isDistinct = false),
WindowSpecDefinition( 'a :: Nil,
SortOrder('b, Ascending) :: Nil,
UnspecifiedFrame)).as('window) :: Nil,
'a :: Nil, 'b.asc :: Nil)
.select('a, 'c, 'window).where('window > 1).select('a, 'c).analyze

val optimized = Optimize.execute(originalQuery.analyze)

comparePlans(optimized, correctAnswer)
}

test("Column pruning on Window in select") {
val input = LocalRelation('a.int, 'b.string, 'c.double, 'd.int)

val originalQuery =
input.select('a, 'b, 'c, 'd,
WindowExpression(
AggregateExpression(Count('b), Complete, isDistinct = false),
WindowSpecDefinition( 'a :: Nil,
SortOrder('b, Ascending) :: Nil,
UnspecifiedFrame)).as('window)).select('a, 'c)

val correctAnswer = input.select('a, 'c).analyze

val optimized = Optimize.execute(originalQuery.analyze)

comparePlans(optimized, correctAnswer)
}

test("Column pruning on Union") {
val input1 = LocalRelation('a.int, 'b.string, 'c.double)
val input2 = LocalRelation('c.int, 'd.string, 'e.double)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -344,9 +344,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
execution.Filter(condition, planLater(child)) :: Nil
case e @ logical.Expand(_, _, child) =>
execution.Expand(e.projections, e.output, planLater(child)) :: Nil
case logical.Window(projectList, windowExprs, partitionSpec, orderSpec, child) =>
execution.Window(
projectList, windowExprs, partitionSpec, orderSpec, planLater(child)) :: Nil
case logical.Window(windowExprs, partitionSpec, orderSpec, child) =>
execution.Window(windowExprs, partitionSpec, orderSpec, planLater(child)) :: Nil
case logical.Sample(lb, ub, withReplacement, seed, child) =>
execution.Sample(lb, ub, withReplacement, seed, planLater(child)) :: Nil
case logical.LocalRelation(output, data) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,14 +81,14 @@ import org.apache.spark.util.collection.unsafe.sort.{UnsafeExternalSorter, Unsaf
* of specialized classes: [[RowBoundOrdering]] & [[RangeBoundOrdering]].
*/
case class Window(
projectList: Seq[Attribute],
windowExpression: Seq[NamedExpression],
partitionSpec: Seq[Expression],
orderSpec: Seq[SortOrder],
child: SparkPlan)
extends UnaryNode {

override def output: Seq[Attribute] = projectList ++ windowExpression.map(_.toAttribute)
override def output: Seq[Attribute] =
child.output ++ windowExpression.map(_.toAttribute)

override def requiredChildDistribution: Seq[Distribution] = {
if (partitionSpec.isEmpty) {
Expand Down Expand Up @@ -275,7 +275,7 @@ case class Window(
val unboundToRefMap = expressions.zip(references).toMap
val patchedWindowExpression = windowExpression.map(_.transform(unboundToRefMap))
UnsafeProjection.create(
projectList ++ patchedWindowExpression,
child.output ++ patchedWindowExpression,
child.output)
}

Expand Down

0 comments on commit 1e24581

Please sign in to comment.