Skip to content

Commit

Permalink
Hook generated aggregation in to the planner.
Browse files Browse the repository at this point in the history
  • Loading branch information
marmbrus committed Jul 9, 2014
1 parent e742640 commit fc522d5
Show file tree
Hide file tree
Showing 6 changed files with 110 additions and 52 deletions.
Expand Up @@ -104,6 +104,62 @@ object PhysicalOperation extends PredicateHelper {
}
}

object PartialAggregation {
type ReturnType =
(Seq[Attribute], Seq[NamedExpression], Seq[Expression], Seq[NamedExpression], LogicalPlan)

def unapply(plan: LogicalPlan): Option[ReturnType] = plan match {
case logical.Aggregate(groupingExpressions, aggregateExpressions, child) =>
// Collect all aggregate expressions.
val allAggregates =
aggregateExpressions.flatMap(_ collect { case a: AggregateExpression => a})
// Collect all aggregate expressions that can be computed partially.
val partialAggregates =
aggregateExpressions.flatMap(_ collect { case p: PartialAggregate => p})

// Only do partial aggregation if supported by all aggregate expressions.
if (allAggregates.size == partialAggregates.size) {
// Create a map of expressions to their partial evaluations for all aggregate expressions.
val partialEvaluations: Map[Long, SplitEvaluation] =
partialAggregates.map(a => (a.id, a.asPartial)).toMap

// We need to pass all grouping expressions though so the grouping can happen a second
// time. However some of them might be unnamed so we alias them allowing them to be
// referenced in the second aggregation.
val namedGroupingExpressions: Map[Expression, NamedExpression] = groupingExpressions.map {
case n: NamedExpression => (n, n)
case other => (other, Alias(other, "PartialGroup")())
}.toMap

// Replace aggregations with a new expression that computes the result from the already
// computed partial evaluations and grouping values.
val rewrittenAggregateExpressions = aggregateExpressions.map(_.transformUp {
case e: Expression if partialEvaluations.contains(e.id) =>
partialEvaluations(e.id).finalEvaluation
case e: Expression if namedGroupingExpressions.contains(e) =>
namedGroupingExpressions(e).toAttribute
}).asInstanceOf[Seq[NamedExpression]]

val partialComputation =
(namedGroupingExpressions.values ++
partialEvaluations.values.flatMap(_.partialEvaluations)).toSeq

val namedGroupingAttributes = namedGroupingExpressions.values.map(_.toAttribute).toSeq

Some(
(namedGroupingAttributes,
rewrittenAggregateExpressions,
groupingExpressions,
partialComputation,
child))
} else {
None
}
case _ => None
}
}


/**
* A pattern that finds joins with equality conditions that can be evaluated using equi-join.
*/
Expand Down
Empty file.
Expand Up @@ -239,7 +239,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
val strategies: Seq[Strategy] =
CommandStrategy(self) ::
TakeOrdered ::
PartialAggregation ::
HashAggregation ::
LeftSemiJoin ::
HashJoin ::
InMemoryScans ::
Expand Down
Expand Up @@ -95,58 +95,57 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
}
}

object PartialAggregation extends Strategy {
object HashAggregation extends Strategy {
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
case logical.Aggregate(groupingExpressions, aggregateExpressions, child) =>
// Collect all aggregate expressions.
val allAggregates =
aggregateExpressions.flatMap(_ collect { case a: AggregateExpression => a })
// Collect all aggregate expressions that can be computed partially.
val partialAggregates =
aggregateExpressions.flatMap(_ collect { case p: PartialAggregate => p })
// Aggregations that can be performed in two phases, before and after the shuffle.

// Only do partial aggregation if supported by all aggregate expressions.
if (allAggregates.size == partialAggregates.size) {
// Create a map of expressions to their partial evaluations for all aggregate expressions.
val partialEvaluations: Map[Long, SplitEvaluation] =
partialAggregates.map(a => (a.id, a.asPartial)).toMap

// We need to pass all grouping expressions though so the grouping can happen a second
// time. However some of them might be unnamed so we alias them allowing them to be
// referenced in the second aggregation.
val namedGroupingExpressions: Map[Expression, NamedExpression] = groupingExpressions.map {
case n: NamedExpression => (n, n)
case other => (other, Alias(other, "PartialGroup")())
}.toMap

// Replace aggregations with a new expression that computes the result from the already
// computed partial evaluations and grouping values.
val rewrittenAggregateExpressions = aggregateExpressions.map(_.transformUp {
case e: Expression if partialEvaluations.contains(e.id) =>
partialEvaluations(e.id).finalEvaluation
case e: Expression if namedGroupingExpressions.contains(e) =>
namedGroupingExpressions(e).toAttribute
}).asInstanceOf[Seq[NamedExpression]]

val partialComputation =
(namedGroupingExpressions.values ++
partialEvaluations.values.flatMap(_.partialEvaluations)).toSeq

// Construct two phased aggregation.
execution.Aggregate(
// Where all aggregates can be codegened.
case PartialAggregation(
namedGroupingAttributes,
rewrittenAggregateExpressions,
groupingExpressions,
partialComputation,
child)
if canBeCodeGened(
allAggregates(partialComputation) ++
allAggregates(rewrittenAggregateExpressions))=>
execution.HashAggregate(
partial = false,
namedGroupingExpressions.values.map(_.toAttribute).toSeq,
namedGroupingAttributes,
rewrittenAggregateExpressions,
execution.Aggregate(
execution.HashAggregate(
partial = true,
groupingExpressions,
partialComputation,
planLater(child))(sqlContext))(sqlContext) :: Nil
} else {
Nil
}


// Where some aggregate can not be codegened
case PartialAggregation(
namedGroupingAttributes,
rewrittenAggregateExpressions,
groupingExpressions,
partialComputation,
child) =>
execution.Aggregate(
partial = false,
namedGroupingAttributes,
rewrittenAggregateExpressions,
execution.Aggregate(
partial = true,
groupingExpressions,
partialComputation,
planLater(child))(sqlContext))(sqlContext) :: Nil
case _ => Nil
}

def canBeCodeGened(aggs: Seq[AggregateExpression]) = !aggs.exists {
case _: Sum | _: Count => false
case _ => true
}

def allAggregates(exprs: Seq[Expression]) =
exprs.flatMap(_.collect { case a: AggregateExpression => a })
}

object BroadcastNestedLoopJoin extends Strategy {
Expand Down
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.sql.execution

import org.apache.spark.SparkContext
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.catalyst.errors._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical._
Expand All @@ -33,12 +34,14 @@ import org.apache.spark.sql.catalyst.types._
* @param child the input data source.
*/
case class HashAggregate(
partial: Boolean,
groupingExpressions: Seq[Expression],
aggregateExpressions: Seq[NamedExpression],
child: SparkPlan)(@transient sc: SparkContext)
partial: Boolean,
groupingExpressions: Seq[Expression],
aggregateExpressions: Seq[NamedExpression],
child: SparkPlan)(@transient sqlContext: SQLContext)
extends UnaryNode with NoBind {

private def sc = sqlContext.sparkContext

override def requiredChildDistribution =
if (partial) {
UnspecifiedDistribution :: Nil
Expand All @@ -50,7 +53,7 @@ case class HashAggregate(
}
}

override def otherCopyArgs = sc :: Nil
override def otherCopyArgs = sqlContext :: Nil

def output = aggregateExpressions.map(_.toAttribute)

Expand Down
Expand Up @@ -39,22 +39,22 @@ class PlannerSuite extends FunSuite {

test("count is partially aggregated") {
val query = testData.groupBy('value)(Count('key)).queryExecution.analyzed
val planned = PartialAggregation(query).head
val aggregations = planned.collect { case a: Aggregate => a }
val planned = HashAggregation(query).head
val aggregations = planned.collect { case n if n.nodeName contains "Aggregate" => n }

assert(aggregations.size === 2)
}

test("count distinct is not partially aggregated") {
val query = testData.groupBy('value)(CountDistinct('key :: Nil)).queryExecution.analyzed
val planned = PartialAggregation(query)
val planned = HashAggregation(query)
assert(planned.isEmpty)
}

test("mixed aggregates are not partially aggregated") {
val query =
testData.groupBy('value)(Count('value), CountDistinct('key :: Nil)).queryExecution.analyzed
val planned = PartialAggregation(query)
val planned = HashAggregation(query)
assert(planned.isEmpty)
}
}

0 comments on commit fc522d5

Please sign in to comment.