Skip to content

Commit

Permalink
Gracefully fallback to old aggregation code path.
Browse files Browse the repository at this point in the history
  • Loading branch information
yhuai committed Jul 21, 2015
1 parent 8a8ac4a commit e0afca3
Show file tree
Hide file tree
Showing 13 changed files with 264 additions and 66 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -250,18 +250,19 @@ case class Sum(child: Expression) extends AlgebraicAggregate {
override def dataType: DataType = resultType

// Expected input data type.
override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(NumericType, NullType))
override def inputTypes: Seq[AbstractDataType] =
Seq(TypeCollection(LongType, DoubleType, DecimalType, NullType))

private val resultType = child.dataType match {
case DecimalType.Fixed(precision, scale) =>
DecimalType(precision + 4, scale + 4)
case DecimalType.Unlimited => DecimalType.Unlimited
case _ => DoubleType
case _ => child.dataType
}

private val sumDataType = child.dataType match {
case _ @ DecimalType() => DecimalType.Unlimited
case _ => DoubleType
case _ => child.dataType
}

private val currentSum = AttributeReference("currentSum", sumDataType)()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,6 @@ private[sql] case object NoOp extends Expression with Unevaluable {
override def children: Seq[Expression] = Nil
}



/**
* A container for an [[AggregateFunction2]] with its [[AggregateMode]] and a field
* (`isDistinct`) indicating if DISTINCT keyword is specified for this function.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -865,7 +865,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
DDLStrategy ::
TakeOrderedAndProject ::
HashAggregation ::
AggregateOperator2 ::
Aggregation ::
LeftSemiJoin ::
HashJoin ::
InMemoryScans ::
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -184,8 +184,12 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
case _ => Nil
}

def canBeConvertedToNewAggregation(plan: LogicalPlan): Boolean =
aggregate.Utils.tryConvert(plan, sqlContext.conf.useSqlAggregate2).isDefined
def canBeConvertedToNewAggregation(plan: LogicalPlan): Boolean = {
aggregate.Utils.tryConvert(
plan,
sqlContext.conf.useSqlAggregate2,
sqlContext.conf.codegenEnabled).isDefined
}

def canBeCodeGened(aggs: Seq[AggregateExpression1]): Boolean = !aggs.exists {
case _: CombineSum | _: Sum | _: Count | _: Max | _: Min | _: CombineSetsAndCount => false
Expand All @@ -202,50 +206,62 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
/**
* Used to plan the aggregate operator for expressions based on the AggregateFunction2 interface.
*/
object AggregateOperator2 extends Strategy {
object Aggregation extends Strategy {
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
case aggregate.NewAggregation(groupingExpressions, resultExpressions, child)
if sqlContext.conf.useSqlAggregate2 =>
// Extracts all distinct aggregate expressions from the resultExpressions.
val aggregateExpressions = resultExpressions.flatMap { expr =>
expr.collect {
case agg: AggregateExpression2 => agg
}
}.toSet.toSeq
// For those distinct aggregate expressions, we create a map from the aggregate function
// to the corresponding attribute of the function.
val aggregateFunctionMap = aggregateExpressions.map { agg =>
val aggregateFunction = agg.aggregateFunction
(aggregateFunction, agg.isDistinct) ->
Alias(aggregateFunction, aggregateFunction.toString)().toAttribute
}.toMap

val (functionsWithDistinct, functionsWithoutDistinct) =
aggregateExpressions.partition(_.isDistinct)
if (functionsWithDistinct.map(_.aggregateFunction.children).distinct.length > 1) {
// This is a sanity check. We should not reach here since we check the same thing in
// CheckAggregateFunction.
sys.error("Having more than one distinct column sets is not allowed.")
case p: logical.Aggregate =>
val converted =
aggregate.Utils.tryConvert(
p,
sqlContext.conf.useSqlAggregate2,
sqlContext.conf.codegenEnabled)
converted match {
case None => Nil // Cannot convert to new aggregation code path.
case Some(logical.Aggregate(groupingExpressions, resultExpressions, child)) =>
// Extracts all distinct aggregate expressions from the resultExpressions.
val aggregateExpressions = resultExpressions.flatMap { expr =>
expr.collect {
case agg: AggregateExpression2 => agg
}
}.toSet.toSeq
// For those distinct aggregate expressions, we create a map from the
// aggregate function to the corresponding attribute of the function.
val aggregateFunctionMap = aggregateExpressions.map { agg =>
val aggregateFunction = agg.aggregateFunction
(aggregateFunction, agg.isDistinct) ->
Alias(aggregateFunction, aggregateFunction.toString)().toAttribute
}.toMap

val (functionsWithDistinct, functionsWithoutDistinct) =
aggregateExpressions.partition(_.isDistinct)
if (functionsWithDistinct.map(_.aggregateFunction.children).distinct.length > 1) {
// This is a sanity check. We should not reach here when we have multiple distinct
// column sets (aggregate.NewAggregation will not match).
sys.error(
"Multiple distinct column sets are not supported by the new aggregation" +
"code path.")
}

val aggregateOperator =
if (functionsWithDistinct.isEmpty) {
aggregate.Utils.planAggregateWithoutDistinct(
groupingExpressions,
aggregateExpressions,
aggregateFunctionMap,
resultExpressions,
planLater(child))
} else {
aggregate.Utils.planAggregateWithOneDistinct(
groupingExpressions,
functionsWithDistinct,
functionsWithoutDistinct,
aggregateFunctionMap,
resultExpressions,
planLater(child))
}

aggregateOperator
}
val aggregateOperator =
if (functionsWithDistinct.isEmpty) {
aggregate.Utils.planAggregateWithoutDistinct(
groupingExpressions,
aggregateExpressions,
aggregateFunctionMap,
resultExpressions,
planLater(child))
} else {
aggregate.Utils.planAggregateWithOneDistinct(
groupingExpressions,
functionsWithDistinct,
functionsWithoutDistinct,
aggregateFunctionMap,
resultExpressions,
planLater(child))
}

aggregateOperator
case _ => Nil
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ private[sql] abstract class SortAggregationIterator(
}
}

private def initialize(): Unit = {
protected def initialize(): Unit = {
if (inputIter.hasNext) {
initializeBuffer()
val currentRow = inputIter.next().copy()
Expand Down Expand Up @@ -474,6 +474,31 @@ class FinalSortAggregationIterator(

override protected def initialBufferOffset: Int = groupingExpressions.length

override def initialize(): Unit = {
if (inputIter.hasNext) {
initializeBuffer()
val currentRow = inputIter.next().copy()
// partitionGenerator is a mutable projection. Since we need to track nextGroupingKey,
// we are making a copy at here.
nextGroupingKey = groupGenerator(currentRow).copy()
firstRowInNextGroup = currentRow
} else {
if (groupingExpressions.isEmpty) {
// If there is no grouping expression, we need to generate a single row as the output.
initializeBuffer()
// Right now, the buffer only contains initial buffer values. Because
// merging two buffers with initial values will generate a row that
// still store initial values. We set the currentRow as the copy of the current buffer.
val currentRow = buffer.copy()
nextGroupingKey = groupGenerator(currentRow).copy()
firstRowInNextGroup = currentRow
} else {
// This iter is an empty one.
hasNewGroup = false
}
}
}

override protected def processRow(row: InternalRow): Unit = {
// Process all algebraic aggregate functions.
algebraicMergeProjection.target(buffer)(joinedRow(buffer, row))
Expand Down Expand Up @@ -659,6 +684,31 @@ class FinalAndCompleteSortAggregationIterator(
newMutableProjection(evalExpressions, bufferSchemata)()
}

override def initialize(): Unit = {
if (inputIter.hasNext) {
initializeBuffer()
val currentRow = inputIter.next().copy()
// partitionGenerator is a mutable projection. Since we need to track nextGroupingKey,
// we are making a copy at here.
nextGroupingKey = groupGenerator(currentRow).copy()
firstRowInNextGroup = currentRow
} else {
if (groupingExpressions.isEmpty) {
// If there is no grouping expression, we need to generate a single row as the output.
initializeBuffer()
// Right now, the buffer only contains initial buffer values. Because
// merging two buffers with initial values will generate a row that
// still store initial values. We set the currentRow as the copy of the current buffer.
val currentRow = buffer.copy()
nextGroupingKey = groupGenerator(currentRow).copy()
firstRowInNextGroup = currentRow
} else {
// This iter is an empty one.
hasNewGroup = false
}
}
}

override protected def processRow(row: InternalRow): Unit = {
val input = joinedRow(buffer, row)
// For all aggregate functions with mode Complete, update buffers.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,20 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution.aggregate

import org.apache.spark.sql.AnalysisException
Expand All @@ -6,10 +23,23 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan}
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.types.{StructType, MapType, ArrayType}

object Utils {
// Right now, we do not support complex types in the grouping key schema.
private def groupingKeySchemaIsSupported(aggregate: Aggregate): Boolean = {
val hasComplexTypes = aggregate.groupingExpressions.map(_.dataType).exists {
case array: ArrayType => true
case map: MapType => true
case struct: StructType => true
case _ => false
}

!hasComplexTypes
}

private def tryConvert(plan: LogicalPlan): Option[Aggregate] = plan match {
case p: Aggregate =>
case p: Aggregate if groupingKeySchemaIsSupported(p) =>
val converted = p.transformExpressionsDown {
case expressions.Average(child) =>
aggregate.AggregateExpression2(
Expand Down Expand Up @@ -76,17 +106,32 @@ object Utils {
}.isDefined
}

if (!hasAggregateExpression1) Some(converted) else None
// Check if there are multiple distinct columns.
val aggregateExpressions = converted.aggregateExpressions.flatMap { expr =>
expr.collect {
case agg: AggregateExpression2 => agg
}
}.toSet.toSeq
val functionsWithDistinct = aggregateExpressions.filter(_.isDistinct)
val hasMultipleDistinctColumnSets =
if (functionsWithDistinct.map(_.aggregateFunction.children).distinct.length > 1) {
true
} else {
false
}

if (!hasAggregateExpression1 && !hasMultipleDistinctColumnSets) Some(converted) else None

case other => None
}

def tryConvert(
plan: LogicalPlan,
useNewAggregation: Boolean): Option[Aggregate] = plan match {
case p: Aggregate =>
useNewAggregation: Boolean,
codeGenEnabled: Boolean): Option[Aggregate] = plan match {
case p: Aggregate if useNewAggregation && codeGenEnabled =>
val converted = tryConvert(p)
if (useNewAggregation && converted.isDefined) {
if (converted.isDefined) {
converted
} else {
// If the plan cannot be converted, we will do a final round check to if the original
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,11 @@ abstract class UserDefinedAggregateFunction extends Serializable {
/** Indicates if this function is deterministic. */
def deterministic: Boolean

/** Initializes the given aggregation buffer. */
/**
* Initializes the given aggregation buffer. Initial values set by this method should satisfy
* the condition that when merging two buffers with initial values, the new buffer should
* still store initial values.
*/
def initialize(buffer: MutableAggregationBuffer): Unit

/** Updates the given aggregation buffer `buffer` with new input data from `input`. */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import java.sql.Timestamp

import org.apache.spark.sql.catalyst.DefaultParserDialect
import org.apache.spark.sql.catalyst.errors.DialectException
import org.apache.spark.sql.execution.aggregate.Aggregate2Sort
import org.apache.spark.sql.execution.GeneratedAggregate
import org.apache.spark.sql.functions._
import org.apache.spark.sql.TestData._
Expand Down Expand Up @@ -204,6 +205,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils {
var hasGeneratedAgg = false
df.queryExecution.executedPlan.foreach {
case generatedAgg: GeneratedAggregate => hasGeneratedAgg = true
case newAggregate: Aggregate2Sort => hasGeneratedAgg = true
case _ =>
}
if (!hasGeneratedAgg) {
Expand Down Expand Up @@ -285,7 +287,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils {
// Aggregate with Code generation handling all null values
testCodeGen(
"SELECT sum('a'), avg('a'), count(null) FROM testData",
Row(0, null, 0) :: Nil)
Row(null, null, 0) :: Nil)
} finally {
sqlContext.dropTempTable("testData3x")
sqlContext.setConf(SQLConf.CODEGEN_ENABLED, originalValue)
Expand Down
Loading

0 comments on commit e0afca3

Please sign in to comment.