diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala index d4ef04c2294a2..c04bd6cd85187 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala @@ -266,11 +266,12 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser { } } | ident ~ ("(" ~> repsep(expression, ",")) <~ ")" ^^ - { case udfName ~ exprs => UnresolvedFunction(udfName, exprs) } + { case udfName ~ exprs => UnresolvedFunction(udfName, exprs, isDistinct = false) } | ident ~ ("(" ~ DISTINCT ~> repsep(expression, ",")) <~ ")" ^^ { case udfName ~ exprs => lexical.normalizeKeyword(udfName) match { case "sum" => SumDistinct(exprs.head) case "count" => CountDistinct(exprs) + case name => UnresolvedFunction(name, exprs, isDistinct = true) case _ => throw new AnalysisException(s"function $udfName does not support DISTINCT") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 1ae1f5a3c7976..4e1bd0e3116b8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.expressions.aggregate2.{Complete, AggregateExpression2, AggregateFunction2} +import org.apache.spark.sql.catalyst.expressions.aggregate2.{DistinctAggregateExpression1, Complete, AggregateExpression2, AggregateFunction2} import org.apache.spark.sql.catalyst.{SimpleCatalystConf, CatalystConf} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ @@ -278,7 +278,7 @@ class Analyzer( Project( projectList.flatMap { case s: Star => s.expand(child.output, resolver) - case UnresolvedAlias(f @ UnresolvedFunction(_, args)) if containsStar(args) => + case UnresolvedAlias(f @ UnresolvedFunction(_, args, _)) if containsStar(args) => val expandedArgs = args.flatMap { case s: Star => s.expand(child.output, resolver) case o => o :: Nil @@ -518,10 +518,12 @@ class Analyzer( def apply(plan: LogicalPlan): LogicalPlan = plan transform { case q: LogicalPlan => q transformExpressions { - case u @ UnresolvedFunction(name, children) => + case u @ UnresolvedFunction(name, children, isDistinct) => withPosition(u) { registry.lookupFunction(name, children) match { - case agg2: AggregateFunction2 => AggregateExpression2(agg2, Complete, false) + case agg2: AggregateFunction2 => AggregateExpression2(agg2, Complete, isDistinct) + case agg1: AggregateExpression1 if isDistinct => + DistinctAggregateExpression1(agg1) case other => other } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index 0daee1990a6e0..03da45b09f928 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -73,7 +73,10 @@ object UnresolvedAttribute { def quoted(name: String): UnresolvedAttribute = new UnresolvedAttribute(Seq(name)) } -case class UnresolvedFunction(name: String, children: Seq[Expression]) +case class UnresolvedFunction( + name: String, + children: Seq[Expression], + isDistinct: Boolean) extends Expression with Unevaluable { override def dataType: DataType = throw new UnresolvedException(this, "dataType") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate2/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate2/aggregates.scala index 104bd58500314..a75c8803f5196 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate2/aggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate2/aggregates.scala @@ -68,6 +68,16 @@ private[sql] case object NoOp extends Expression with Unevaluable { override def children: Seq[Expression] = Nil } +private[sql] case class DistinctAggregateExpression1( + aggregateExpression: AggregateExpression1) extends AggregateExpression { + override def children: Seq[Expression] = aggregateExpression :: Nil + override def dataType: DataType = aggregateExpression.dataType + override def foldable: Boolean = aggregateExpression.foldable + override def nullable: Boolean = aggregateExpression.nullable + + override def toString: String = s"DISTINCT ${aggregateExpression.toString}" +} + /** * A container for an [[AggregateFunction2]] with its [[AggregateMode]] and a field * (`isDistinct`) indicating if DISTINCT keyword is specified for this function. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 70197993730f6..474081cf05fbe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.{BroadcastHint, LogicalPlan} import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.columnar.{InMemoryColumnarTableScan, InMemoryRelation} -import org.apache.spark.sql.execution.aggregate2.Aggregate2Sort +import org.apache.spark.sql.execution.aggregate2.{FinalAndCompleteAggregate2Sort, Aggregate2Sort} import org.apache.spark.sql.execution.{DescribeCommand => RunnableDescribeCommand} import org.apache.spark.sql.parquet._ import org.apache.spark.sql.sources.{CreateTableUsing, CreateTempTableUsing, DescribeCommand => LogicalDescribeCommand, _} @@ -200,6 +200,181 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { * Used to plan the aggregate operator for expressions based on the AggregateFunction2 interface. */ object AggregateOperator2 extends Strategy { + private def planAggregateWithoutDistinct( + groupingExpressions: Seq[Expression], + aggregateExpressions: Seq[AggregateExpression2], + aggregateFunctionMap: Map[AggregateFunction2, Attribute], + resultExpressions: Seq[NamedExpression], + child: SparkPlan): Seq[SparkPlan] = { + // 1. Create an Aggregate Operator for partial aggregations. + val namedGroupingExpressions = groupingExpressions.map { + case ne: NamedExpression => ne -> ne + // If the expression is not a NamedExpressions, we add an alias. + // So, when we generate the result of the operator, the Aggregate Operator + // can directly get the Seq of attributes representing the grouping expressions. + case other => + val withAlias = Alias(other, other.toString)() + other -> withAlias + } + val groupExpressionMap = namedGroupingExpressions.toMap + val namedGroupingAttributes = namedGroupingExpressions.map(_._2.toAttribute) + val partialAggregateExpressions = aggregateExpressions.map { + case AggregateExpression2(aggregateFunction, mode, isDistinct) => + AggregateExpression2(aggregateFunction, Partial, isDistinct) + } + val partialAggregateAttributes = partialAggregateExpressions.flatMap { agg => + agg.aggregateFunction.bufferAttributes + } + val partialAggregate = + Aggregate2Sort( + None: Option[Seq[Expression]], + namedGroupingExpressions.map(_._2), + partialAggregateExpressions, + partialAggregateAttributes, + namedGroupingAttributes ++ partialAggregateAttributes, + child) + + // 2. Create an Aggregate Operator for final aggregations. + val finalAggregateExpressions = aggregateExpressions.map { + case AggregateExpression2(aggregateFunction, mode, isDistinct) => + AggregateExpression2(aggregateFunction, Final, isDistinct) + } + val finalAggregateAttributes = + finalAggregateExpressions.map { + expr => aggregateFunctionMap(expr.aggregateFunction) + } + val rewrittenResultExpressions = resultExpressions.map { expr => + expr.transform { + case agg: AggregateExpression2 => + aggregateFunctionMap(agg.aggregateFunction).toAttribute + case expression if groupExpressionMap.contains(expression) => + groupExpressionMap(expression).toAttribute + }.asInstanceOf[NamedExpression] + } + val finalAggregate = Aggregate2Sort( + Some(namedGroupingAttributes), + namedGroupingAttributes, + finalAggregateExpressions, + finalAggregateAttributes, + rewrittenResultExpressions, + partialAggregate) + + finalAggregate :: Nil + } + + private def planAggregateWithOneDistinct( + groupingExpressions: Seq[Expression], + functionsWithDistinct: Seq[AggregateExpression2], + functionsWithoutDistinct: Seq[AggregateExpression2], + aggregateFunctionMap: Map[AggregateFunction2, Attribute], + resultExpressions: Seq[NamedExpression], + child: SparkPlan): Seq[SparkPlan] = { + + // 1. Create an Aggregate Operator for partial aggregations. + // The grouping expressions are original groupingExpressions and + // distinct columns. For example, for avg(distinct value) ... group by key + // the grouping expressions of this Aggregate Operator will be [key, value]. + val namedGroupingExpressions = groupingExpressions.map { + case ne: NamedExpression => ne -> ne + // If the expression is not a NamedExpressions, we add an alias. + // So, when we generate the result of the operator, the Aggregate Operator + // can directly get the Seq of attributes representing the grouping expressions. + case other => + val withAlias = Alias(other, other.toString)() + other -> withAlias + } + val groupExpressionMap = namedGroupingExpressions.toMap + val namedGroupingAttributes = namedGroupingExpressions.map(_._2.toAttribute) + + // It is safe to call head at here since functionsWithDistinct has at least one + // AggregateExpression2. + val distinctColumnExpressions = + functionsWithDistinct.head.aggregateFunction.children + val namedDistinctColumnExpressions = distinctColumnExpressions.map { + case ne: NamedExpression => ne -> ne + case other => + val withAlias = Alias(other, other.toString)() + other -> withAlias + } + val distinctColumnExpressionMap = namedDistinctColumnExpressions.toMap + val distinctColumnAttributes = namedDistinctColumnExpressions.map(_._2.toAttribute) + + val partialAggregateExpressions = functionsWithoutDistinct.map { + case AggregateExpression2(aggregateFunction, mode, _) => + AggregateExpression2(aggregateFunction, Partial, false) + } + val partialAggregateAttributes = partialAggregateExpressions.flatMap { agg => + agg.aggregateFunction.bufferAttributes + } + println("namedDistinctColumnExpressions " + namedDistinctColumnExpressions) + val partialAggregate = + Aggregate2Sort( + None: Option[Seq[Expression]], + (namedGroupingExpressions ++ namedDistinctColumnExpressions).map(_._2), + partialAggregateExpressions, + partialAggregateAttributes, + namedGroupingAttributes ++ distinctColumnAttributes ++ partialAggregateAttributes, + child) + + // 2. Create an Aggregate Operator for partial merge aggregations. + val partialMergeAggregateExpressions = functionsWithoutDistinct.map { + case AggregateExpression2(aggregateFunction, mode, _) => + AggregateExpression2(aggregateFunction, PartialMerge, false) + } + val partialMergeAggregateAttributes = + partialMergeAggregateExpressions.map { + expr => aggregateFunctionMap(expr.aggregateFunction) + } + val partialMergeAggregate = + Aggregate2Sort( + Some(namedGroupingAttributes), + namedGroupingAttributes ++ distinctColumnAttributes, + partialMergeAggregateExpressions, + partialMergeAggregateAttributes, + namedGroupingAttributes ++ distinctColumnAttributes ++ partialMergeAggregateAttributes, + partialAggregate) + + // 3. Create an Aggregate Operator for partial merge aggregations. + val finalAggregateExpressions = functionsWithoutDistinct.map { + Need to replace the children to distinctColumnAttributes + case AggregateExpression2(aggregateFunction, mode, _) => + AggregateExpression2(aggregateFunction, Final, false) + } + val finalAggregateAttributes = + finalAggregateExpressions.map { + expr => aggregateFunctionMap(expr.aggregateFunction) + } + val completeAggregateExpressions = functionsWithDistinct.map { + case AggregateExpression2(aggregateFunction, mode, _) => + AggregateExpression2(aggregateFunction, Complete, false) + } + val completeAggregateAttributes = + completeAggregateExpressions.map { + expr => aggregateFunctionMap(expr.aggregateFunction) + } + + val rewrittenResultExpressions = resultExpressions.map { expr => + expr.transform { + case agg: AggregateExpression2 => + aggregateFunctionMap(agg.aggregateFunction).toAttribute + case expression if groupExpressionMap.contains(expression) => + groupExpressionMap(expression).toAttribute + case expression if distinctColumnExpressionMap.contains(expression) => + distinctColumnExpressionMap(expression).toAttribute + }.asInstanceOf[NamedExpression] + } + val finalAndCompleteAggregate = FinalAndCompleteAggregate2Sort( + namedGroupingAttributes, + finalAggregateExpressions, + finalAggregateAttributes, + completeAggregateExpressions, + completeAggregateAttributes, + rewrittenResultExpressions, + partialMergeAggregate) + + finalAndCompleteAggregate :: Nil + } + def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case logical.Aggregate(groupingExpressions, resultExpressions, child) if sqlContext.conf.useSqlAggregate2 => @@ -216,58 +391,33 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { aggregateFunction -> Alias(aggregateFunction, aggregateFunction.toString)().toAttribute }.toMap - // 2. Create an Aggregate Operator for partial aggregations. - val namedGroupingExpressions = groupingExpressions.map { - case ne: NamedExpression => ne -> ne - // If the expression is not a NamedExpressions, we add an alias. - // So, when we generate the result of the operator, the Aggregate Operator - // can directly get the Seq of attributes representing the grouping expressions. - case other => - val withAlias = Alias(other, other.toString)() - other -> withAlias - } - val groupExpressionMap = namedGroupingExpressions.toMap - val namedGroupingAttributes = namedGroupingExpressions.map(_._2.toAttribute) - val partialAggregateExpressions = aggregateExpressions.map { - case AggregateExpression2(aggregateFunction, mode, isDistinct) => - AggregateExpression2(aggregateFunction, Partial, isDistinct) + val (functionsWithDistinct, functionsWithoutDistinct) = + aggregateExpressions.partition(_.isDistinct) + println("functionsWithDistinct " + functionsWithDistinct) + if (functionsWithDistinct.map(_.aggregateFunction.children).distinct.length > 1) { + // This is a sanity check. We should not reach here since we check the same thing in + // CheckAggregateFunction. + sys.error("Having more than one distinct column sets is not allowed.") } - val partialAggregateAttributes = partialAggregateExpressions.flatMap { agg => - agg.aggregateFunction.bufferAttributes - } - val partialAggregate = - Aggregate2Sort( - namedGroupingExpressions.map(_._2), - partialAggregateExpressions, - partialAggregateAttributes, - namedGroupingAttributes ++ partialAggregateAttributes, - planLater(child)) - - // 3. Create an Aggregate Operator for final aggregations. - val finalAggregateExpressions = aggregateExpressions.map { - case AggregateExpression2(aggregateFunction, mode, isDistinct) => - AggregateExpression2(aggregateFunction, Final, isDistinct) - } - val finalAggregateAttributes = - finalAggregateExpressions.map { - expr => aggregateFunctionMap(expr.aggregateFunction) + val aggregate = + if (functionsWithDistinct.isEmpty) { + planAggregateWithoutDistinct( + groupingExpressions, + aggregateExpressions, + aggregateFunctionMap, + resultExpressions, + planLater(child)) + } else { + planAggregateWithOneDistinct( + groupingExpressions, + functionsWithDistinct, + functionsWithoutDistinct, + aggregateFunctionMap, + resultExpressions, + planLater(child)) } - val rewrittenResultExpressions = resultExpressions.map { expr => - expr.transform { - case agg: AggregateExpression2 => - aggregateFunctionMap(agg.aggregateFunction).toAttribute - case expression if groupExpressionMap.contains(expression) => - groupExpressionMap(expression).toAttribute - }.asInstanceOf[NamedExpression] - } - val finalAggregate = Aggregate2Sort( - namedGroupingAttributes, - finalAggregateExpressions, - finalAggregateAttributes, - rewrittenResultExpressions, - partialAggregate) - finalAggregate :: Nil + aggregate case _ => Nil } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate2/Aggregate2Sort.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate2/Aggregate2Sort.scala deleted file mode 100644 index bc102f2e9d01d..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate2/Aggregate2Sort.scala +++ /dev/null @@ -1,99 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.aggregate2 - -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.errors._ -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate2._ -import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, UnspecifiedDistribution} -import org.apache.spark.sql.execution.{SparkPlan, UnaryNode} - -case class Aggregate2Sort( - groupingExpressions: Seq[NamedExpression], - aggregateExpressions: Seq[AggregateExpression2], - aggregateAttributes: Seq[Attribute], - resultExpressions: Seq[NamedExpression], - child: SparkPlan) - extends UnaryNode { - - /** Indicates if this operator is for partial aggregations. */ - val partialAggregation: Boolean = { - aggregateExpressions.map(_.mode).distinct.toList match { - case Partial :: Nil => true - case Final :: Nil => false - case other => - sys.error( - s"Could not evaluate ${aggregateExpressions} because we do not support evaluate " + - s"modes $other in this operator.") - } - } - - override def references: AttributeSet = { - val referencesInResults = - AttributeSet(resultExpressions.flatMap(_.references)) -- AttributeSet(aggregateAttributes) - - AttributeSet( - groupingExpressions.flatMap(_.references) ++ - aggregateExpressions.flatMap(_.references) ++ - referencesInResults) - } - - override def requiredChildDistribution: List[Distribution] = { - if (partialAggregation) { - UnspecifiedDistribution :: Nil - } else { - if (groupingExpressions == Nil) { - AllTuples :: Nil - } else { - ClusteredDistribution(groupingExpressions) :: Nil - } - } - } - - override def requiredChildOrdering: Seq[Seq[SortOrder]] = - groupingExpressions.map(SortOrder(_, Ascending)) :: Nil - - override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute) - - protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") { - child.execute().mapPartitions { iter => - val aggregationIterator = - if (partialAggregation) { - new PartialSortAggregationIterator( - groupingExpressions, - aggregateExpressions, - newMutableProjection, - child.output, - iter) - } else { - new FinalSortAggregationIterator( - groupingExpressions, - aggregateExpressions, - aggregateAttributes, - resultExpressions, - newMutableProjection, - child.output, - iter) - } - - aggregationIterator - } - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate2/aggregateOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate2/aggregateOperators.scala new file mode 100644 index 0000000000000..89bb325a8edc0 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate2/aggregateOperators.scala @@ -0,0 +1,158 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.aggregate2 + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.errors._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate2._ +import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, UnspecifiedDistribution} +import org.apache.spark.sql.execution.{SparkPlan, UnaryNode} + +case class Aggregate2Sort( + requiredChildDistributionExpressions: Option[Seq[Expression]], + groupingExpressions: Seq[NamedExpression], + aggregateExpressions: Seq[AggregateExpression2], + aggregateAttributes: Seq[Attribute], + resultExpressions: Seq[NamedExpression], + child: SparkPlan) + extends UnaryNode { + + /** Indicates if this operator is for partial aggregations. */ + + + override def references: AttributeSet = { + val referencesInResults = + AttributeSet(resultExpressions.flatMap(_.references)) -- AttributeSet(aggregateAttributes) + + AttributeSet( + groupingExpressions.flatMap(_.references) ++ + aggregateExpressions.flatMap(_.references) ++ + referencesInResults) + } + + override def requiredChildDistribution: List[Distribution] = { + requiredChildDistributionExpressions match { + case Some(exprs) if exprs.length == 0 => AllTuples :: Nil + case Some(exprs) if exprs.length > 0 => ClusteredDistribution(exprs) :: Nil + case None => UnspecifiedDistribution :: Nil + } + } + + override def requiredChildOrdering: Seq[Seq[SortOrder]] = + groupingExpressions.map(SortOrder(_, Ascending)) :: Nil + + override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute) + + protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") { + child.execute().mapPartitions { iter => + if (aggregateExpressions.length == 0) { + new GroupingIterator( + groupingExpressions, + newMutableProjection, + child.output, + iter) + } else { + val partialAggregation: Boolean = { + aggregateExpressions.map(_.mode).distinct.toList match { + case Partial :: Nil => true + case Final :: Nil => false + TODO: HANDLE PARTIAL MERGE + case other => + sys.error( + s"Could not evaluate ${aggregateExpressions} because we do not support evaluate " + + s"modes $other in this operator.") + } + } + val aggregationIterator = + if (partialAggregation) { + new PartialSortAggregationIterator( + groupingExpressions, + aggregateExpressions, + newMutableProjection, + child.output, + iter) + } else { + new FinalSortAggregationIterator( + groupingExpressions, + aggregateExpressions, + aggregateAttributes, + resultExpressions, + newMutableProjection, + child.output, + iter) + } + aggregationIterator + } + } + } +} + +case class FinalAndCompleteAggregate2Sort( + groupingExpressions: Seq[NamedExpression], + finalAggregateExpressions: Seq[AggregateExpression2], + finalAggregateAttributes: Seq[Attribute], + completeAggregateExpressions: Seq[AggregateExpression2], + completeAggregateAttributes: Seq[Attribute], + resultExpressions: Seq[NamedExpression], + child: SparkPlan) + extends UnaryNode { + override def references: AttributeSet = { + val referencesInResults = + AttributeSet(resultExpressions.flatMap(_.references)) -- + AttributeSet(finalAggregateExpressions) -- + AttributeSet(completeAggregateExpressions) + + AttributeSet( + groupingExpressions.flatMap(_.references) ++ + finalAggregateExpressions.flatMap(_.references) ++ + completeAggregateExpressions.flatMap(_.references) ++ + referencesInResults) + } + + override def requiredChildDistribution: List[Distribution] = { + if (groupingExpressions.isEmpty) { + AllTuples :: Nil + } else { + ClusteredDistribution(groupingExpressions) :: Nil + } + } + + override def requiredChildOrdering: Seq[Seq[SortOrder]] = + groupingExpressions.map(SortOrder(_, Ascending)) :: Nil + + override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute) + + protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") { + child.execute().mapPartitions { iter => + + new FinalAndCompleteSortAggregationIterator( + groupingExpressions, + finalAggregateExpressions, + finalAggregateAttributes, + completeAggregateExpressions, + completeAggregateAttributes, + resultExpressions, + newMutableProjection, + child.output, + iter) + } + } + +} \ No newline at end of file diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate2/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate2/rules.scala index 4f041305ff88d..b93d1b516a882 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate2/rules.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate2/rules.scala @@ -19,17 +19,23 @@ package org.apache.spark.sql.execution.aggregate2 import org.apache.spark.sql.{SQLConf, AnalysisException, SQLContext} import org.apache.spark.sql.catalyst.expressions.{Average => Average1, AggregateExpression1} -import org.apache.spark.sql.catalyst.expressions.aggregate2.{Average => Average2, AggregateExpression2, Complete} -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.expressions.aggregate2.{Average => Average2, DistinctAggregateExpression1, AggregateExpression2, Complete} +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan} import org.apache.spark.sql.catalyst.rules.Rule case class ConvertAggregateFunction(context: SQLContext) extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { case p: LogicalPlan if !p.childrenResolved => p - case p if context.conf.useSqlAggregate2 => p.transformExpressionsUp { + case p: Aggregate if context.conf.useSqlAggregate2 => p.transformExpressionsDown { + case DistinctAggregateExpression1(Average1(child)) => + AggregateExpression2(Average2(child), Complete, true) case Average1(child) => AggregateExpression2(Average2(child), Complete, false) } + case p: Aggregate if !context.conf.useSqlAggregate2 => p.transformExpressionsDown { + // If aggregate2 is not enabled, just remove DistinctAggregateExpression1. + case DistinctAggregateExpression1(agg1) => agg1 + } } } @@ -37,15 +43,35 @@ case class CheckAggregateFunction(context: SQLContext) extends (LogicalPlan => U def failAnalysis(msg: String): Nothing = { throw new AnalysisException(msg) } def apply(plan: LogicalPlan): Unit = plan.foreachUp { - case p if context.conf.useSqlAggregate2 => p.transformExpressionsUp { - case agg: AggregateExpression1 => - failAnalysis( - s"${SQLConf.USE_SQL_AGGREGATE2.key} is enabled. Please disable it to use $agg.") + case p: Aggregate if context.conf.useSqlAggregate2 => { + p.transformExpressionsUp { + case agg: AggregateExpression1 => + failAnalysis( + s"${SQLConf.USE_SQL_AGGREGATE2.key} is enabled. Please disable it to use $agg.") + case DistinctAggregateExpression1(agg: AggregateExpression1) => + failAnalysis( + s"${SQLConf.USE_SQL_AGGREGATE2.key} is enabled. " + + s"Please disable it to use $agg with DISTINCT keyword.") + } + + val distinctColumnSets = p.aggregateExpressions.flatMap { expr => + expr.collect { + case AggregateExpression2(func, mode, isDistinct) if isDistinct => func.children + } + }.distinct + if (distinctColumnSets.length > 1) { + // TODO: Provide more information in the error message. + // There are more than one distinct column sets. For example, sum(distinct a) and + // sum(distinct b) will generate two distinct column sets, {a} and {b}. + failAnalysis(s"When ${SQLConf.USE_SQL_AGGREGATE2.key} is enabled, " + + s"only a single distinct column set is supported.") + } } - case p if !context.conf.useSqlAggregate2 => p.transformExpressionsUp { + case p: Aggregate if !context.conf.useSqlAggregate2 => p.transformExpressionsUp { case agg: AggregateExpression2 => failAnalysis( s"${SQLConf.USE_SQL_AGGREGATE2.key} is disabled. Please enable it to use $agg.") } + case _ => } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate2/sortBasedIterators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate2/sortBasedIterators.scala index 35ccc6981e301..d40ee344820a9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate2/sortBasedIterators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate2/sortBasedIterators.scala @@ -232,6 +232,27 @@ private[sql] abstract class SortAggregationIterator( initialize() } +class GroupingIterator( + groupingExpressions: Seq[NamedExpression], + newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection), + inputAttributes: Seq[Attribute], + inputIter: Iterator[InternalRow]) + extends SortAggregationIterator( + groupingExpressions, + Nil, + newMutableProjection, + inputAttributes, + inputIter) { + override protected def initialBufferOffset: Int = 0 + + override protected def processRow(row: InternalRow): Unit = { + // Since we only do grouping, there is nothing to do. + } + + override protected def generateOutput(): InternalRow = { + currentGroupingKey + } +} class PartialSortAggregationIterator( groupingExpressions: Seq[NamedExpression], @@ -371,3 +392,162 @@ class FinalSortAggregationIterator( resultProjection(joinedRow(currentGroupingKey, aggregateResult)) } } + +class FinalAndCompleteSortAggregationIterator( + groupingExpressions: Seq[NamedExpression], + finalAggregateExpressions: Seq[AggregateExpression2], + finalAggregateAttributes: Seq[Attribute], + completeAggregateExpressions: Seq[AggregateExpression2], + completeAggregateAttributes: Seq[Attribute], + resultExpressions: Seq[NamedExpression], + newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection), + inputAttributes: Seq[Attribute], + inputIter: Iterator[InternalRow]) + extends SortAggregationIterator( + groupingExpressions, + // This ordering is important. Because the format of its input is + // groupingExprs | distinctExprs | intermediate results for non-distinct aggs. + completeAggregateExpressions ++ finalAggregateExpressions, + newMutableProjection, + inputAttributes, + inputIter) { + + private val aggregateResult: MutableRow = + new GenericMutableRow(completeAggregateAttributes.length + finalAggregateAttributes.length) + + private val resultProjection = { + val inputSchema = + groupingExpressions.map(_.toAttribute) ++ + completeAggregateAttributes ++ + finalAggregateAttributes + newMutableProjection(resultExpressions, inputSchema)() + } + + private val offsetAttributes = + Seq.fill(initialBufferOffset)(AttributeReference("offset", NullType)()) + + private val completeAggregateFunctions: Array[AggregateFunction2] = { + val functions = new Array[AggregateFunction2](completeAggregateExpressions.length) + var i = 0 + while (i < completeAggregateExpressions.length) { + functions(i) = aggregateFunctions(i) + i += 1 + } + functions + } + + private val completeNonAlgebraicAggregateFunctions: Array[AggregateFunction2] = { + completeAggregateFunctions.collect { + case func: AggregateFunction2 if !func.isInstanceOf[AlgebraicAggregate] => func + }.toArray + } + + // This projection is used to update buffer values for all AlgebraicAggregates. + private val completeAlgebraicUpdateProjection = { + val bufferSchema = completeAggregateFunctions.flatMap { + case ae: AlgebraicAggregate => ae.bufferAttributes + case agg: AggregateFunction2 => agg.bufferAttributes + } + val updateExpressions = completeAggregateFunctions.flatMap { + case ae: AlgebraicAggregate => ae.updateExpressions + case agg: AggregateFunction2 => NoOp :: Nil + } + newMutableProjection(updateExpressions, bufferSchema ++ inputAttributes)().target(buffer) + } + + private val finalOffsetExpressions = { + val size = + initialBufferOffset + completeAggregateFunctions.length + Seq.fill(size)(NoOp) + } + + private val finalOffsetAttributes = { + val size = + initialBufferOffset + completeAggregateFunctions.length + Seq.fill(size)(AttributeReference("offset", NullType)()) + } + + private val finalAggregateFunctions: Array[AggregateFunction2] = { + val functions = new Array[AggregateFunction2](finalAggregateExpressions.length) + var i = completeAggregateExpressions.length + while (i < finalAggregateExpressions.length) { + functions(i) = aggregateFunctions(i) + i += 1 + } + functions + } + + private val finalNonAlgebraicAggregateFunctions: Array[AggregateFunction2] = { + finalAggregateFunctions.collect { + case func: AggregateFunction2 if !func.isInstanceOf[AlgebraicAggregate] => func + }.toArray + } + + private val finalAlgebraicMergeProjection = { + val bufferSchemata = + finalOffsetAttributes ++ finalAggregateFunctions.flatMap { + case ae: AlgebraicAggregate => ae.bufferAttributes + case agg: AggregateFunction2 => agg.bufferAttributes + } ++ finalOffsetAttributes ++ finalAggregateFunctions.flatMap { + case ae: AlgebraicAggregate => ae.cloneBufferAttributes + case agg: AggregateFunction2 => agg.cloneBufferAttributes + } + val mergeExpressions = finalOffsetExpressions ++ finalAggregateFunctions.flatMap { + case ae: AlgebraicAggregate => ae.mergeExpressions + case agg: AggregateFunction2 => NoOp :: Nil + } + + newMutableProjection(mergeExpressions, bufferSchemata)() + } + + // This projection is used to evaluate all AlgebraicAggregates. + private val algebraicEvalProjection = { + val bufferSchemata = + offsetAttributes ++ aggregateFunctions.flatMap { + case ae: AlgebraicAggregate => ae.bufferAttributes + case agg: AggregateFunction2 => agg.bufferAttributes + } ++ offsetAttributes ++ aggregateFunctions.flatMap { + case ae: AlgebraicAggregate => ae.cloneBufferAttributes + case agg: AggregateFunction2 => agg.cloneBufferAttributes + } + val evalExpressions = aggregateFunctions.map { + case ae: AlgebraicAggregate => ae.evaluateExpression + case agg: AggregateFunction2 => NoOp + } + + newMutableProjection(evalExpressions, bufferSchemata)() + } + + override protected def initialBufferOffset: Int = groupingExpressions.length + + override protected def processRow(row: InternalRow): Unit = { + println("inputRow" + row + " current buffer " + buffer) + val input = joinedRow(buffer, row) + completeAlgebraicUpdateProjection(input) + var i = 0 + while (i < completeNonAlgebraicAggregateFunctions.length) { + completeNonAlgebraicAggregateFunctions(i).update(buffer, row) + i += 1 + } + + finalAlgebraicMergeProjection.target(buffer)(input) + i = 0 + while (i < finalNonAlgebraicAggregateFunctions.length) { + finalNonAlgebraicAggregateFunctions(i).merge(buffer, row) + i += 1 + } + } + + override protected def generateOutput(): InternalRow = { + println("buffer " + buffer) + algebraicEvalProjection.target(aggregateResult)(buffer) + var i = 0 + while (i < nonAlgebraicAggregateFunctions.length) { + aggregateResult.update( + nonAlgebraicAggregateFunctionPositions(i), + nonAlgebraicAggregateFunctions(i).eval(buffer)) + i += 1 + } + resultProjection(joinedRow(currentGroupingKey, aggregateResult)) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index b5140dca0487f..4738d6072c512 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -2758,7 +2758,7 @@ object functions { * @since 1.5.0 */ def callUDF(udfName: String, cols: Column*): Column = { - UnresolvedFunction(udfName, cols.map(_.expr)) + UnresolvedFunction(udfName, cols.map(_.expr), isDistinct = false) } /** @@ -2787,7 +2787,7 @@ object functions { exprs(i) = cols(i).expr i += 1 } - UnresolvedFunction(udfName, exprs) + UnresolvedFunction(udfName, exprs, isDistinct = false) } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index 7fc517b646b20..0370a2b0115b2 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -1464,9 +1464,12 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C /* UDFs - Must be last otherwise will preempt built in functions */ case Token("TOK_FUNCTION", Token(name, Nil) :: args) => - UnresolvedFunction(name, args.map(nodeToExpr)) + UnresolvedFunction(name, args.map(nodeToExpr), isDistinct = false) + // Aggregate function with DISTINCT keyword. + case Token("TOK_FUNCTIONDI", Token(name, Nil) :: args) => + UnresolvedFunction(name, args.map(nodeToExpr), isDistinct = true) case Token("TOK_FUNCTIONSTAR", Token(name, Nil) :: args) => - UnresolvedFunction(name, UnresolvedStar(None) :: Nil) + UnresolvedFunction(name, UnresolvedStar(None) :: Nil, isDistinct = false) /* Literals */ case Token("TOK_NULL", Nil) => Literal.create(null, NullType) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/Aggregate2Suite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/Aggregate2Suite.scala index 59b10a2562cf2..7f3bd1450f39b 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/Aggregate2Suite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/Aggregate2Suite.scala @@ -18,21 +18,22 @@ package org.apache.spark.sql.hive.execution import org.apache.spark.sql.hive.test.TestHive +import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.{SQLConf, AnalysisException, QueryTest, Row} import org.scalatest.BeforeAndAfterAll import test.org.apache.spark.sql.hive.aggregate2.MyDoubleSum -class Aggregate2Suite extends QueryTest with BeforeAndAfterAll { +class Aggregate2Suite extends QueryTest with SQLTestUtils with BeforeAndAfterAll { - protected lazy val ctx = TestHive - import ctx.implicits._ + override val sqlContext = TestHive + import sqlContext.implicits._ var originalUseAggregate2: Boolean = _ override def beforeAll(): Unit = { - originalUseAggregate2 = ctx.conf.useSqlAggregate2 - ctx.sql("set spark.sql.useAggregate2=true") - val data = Seq[(Integer, Integer)]( + originalUseAggregate2 = sqlContext.conf.useSqlAggregate2 + sqlContext.sql("set spark.sql.useAggregate2=true") + val data1 = Seq[(Integer, Integer)]( (1, 10), (null, -60), (1, 20), @@ -46,20 +47,52 @@ class Aggregate2Suite extends QueryTest with BeforeAndAfterAll { (3, null), (null, null), (3, null)).toDF("key", "value") + data1.write.saveAsTable("agg1") - data.write.saveAsTable("agg2") + val data2 = Seq[(Integer, Integer)]( + (1, 10), + (null, -60), + (1, 30), + (1, 30), + (2, 1), + (null, -10), + (2, -1), + (2, 1), + (2, null), + (null, 100), + (3, null), + (null, null), + (3, null)).toDF("key", "value") + data2.write.saveAsTable("agg2") // Register a UDAF val javaUDAF = new MyDoubleSum - ctx.udaf.register("mydoublesum", javaUDAF) + sqlContext.udaf.register("mydoublesum", javaUDAF) + } + + override def afterAll(): Unit = { + sqlContext.sql("DROP TABLE IF EXISTS agg1") + sqlContext.sql("DROP TABLE IF EXISTS agg2") + sqlContext.sql(s"set spark.sql.useAggregate2=$originalUseAggregate2") + } + + test("only do grouping") { + checkAnswer( + sqlContext.sql( + """ + |SELECT key + |FROM agg1 + |GROUP BY key + """.stripMargin), + Row(1) :: Row(2) :: Row(3) :: Row(null) :: Nil) } test("test average2 no key in output") { checkAnswer( - ctx.sql( + sqlContext.sql( """ |SELECT avg(value) - |FROM agg2 + |FROM agg1 |GROUP BY key """.stripMargin), Row(-0.5) :: Row(20.0) :: Row(null) :: Row(10.0) :: Nil) @@ -67,41 +100,41 @@ class Aggregate2Suite extends QueryTest with BeforeAndAfterAll { test("test average2") { checkAnswer( - ctx.sql( + sqlContext.sql( """ |SELECT key, avg(value) - |FROM agg2 + |FROM agg1 |GROUP BY key """.stripMargin), Row(1, 20.0) :: Row(2, -0.5) :: Row(3, null) :: Row(null, 10.0) :: Nil) checkAnswer( - ctx.sql( + sqlContext.sql( """ |SELECT avg(value), key - |FROM agg2 + |FROM agg1 |GROUP BY key """.stripMargin), Row(20.0, 1) :: Row(-0.5, 2) :: Row(null, 3) :: Row(10.0, null) :: Nil) checkAnswer( - ctx.sql( + sqlContext.sql( """ |SELECT avg(value) + 1.5, key + 10 - |FROM agg2 + |FROM agg1 |GROUP BY key + 10 """.stripMargin), Row(21.5, 11) :: Row(1.0, 12) :: Row(null, 13) :: Row(11.5, null) :: Nil) checkAnswer( - ctx.sql( + sqlContext.sql( """ - |SELECT avg(value) FROM agg2 + |SELECT avg(value) FROM agg1 """.stripMargin), Row(11.125) :: Nil) checkAnswer( - ctx.sql( + sqlContext.sql( """ |SELECT avg(null) """.stripMargin), @@ -110,7 +143,7 @@ class Aggregate2Suite extends QueryTest with BeforeAndAfterAll { test("udaf") { checkAnswer( - ctx.sql( + sqlContext.sql( """ |SELECT | key, @@ -118,7 +151,7 @@ class Aggregate2Suite extends QueryTest with BeforeAndAfterAll { | avg(value - key), | mydoublesum(cast(value as double) - 1.5 * key), | avg(value) - |FROM agg2 + |FROM agg1 |GROUP BY key """.stripMargin), Row(1, 64.5, 19.0, 55.5, 20.0) :: @@ -128,26 +161,24 @@ class Aggregate2Suite extends QueryTest with BeforeAndAfterAll { } test("non-AlgebraicAggregate aggreguate function") { - - checkAnswer( - ctx.sql( + sqlContext.sql( """ |SELECT mydoublesum(cast(value as double)), key - |FROM agg2 + |FROM agg1 |GROUP BY key """.stripMargin), Row(60.0, 1) :: Row(-1.0, 2) :: Row(null, 3) :: Row(30.0, null) :: Nil) checkAnswer( - ctx.sql( + sqlContext.sql( """ - |SELECT mydoublesum(cast(value as double)) FROM agg2 + |SELECT mydoublesum(cast(value as double)) FROM agg1 """.stripMargin), Row(89.0) :: Nil) checkAnswer( - ctx.sql( + sqlContext.sql( """ |SELECT mydoublesum(null) """.stripMargin), @@ -156,10 +187,10 @@ class Aggregate2Suite extends QueryTest with BeforeAndAfterAll { test("non-AlgebraicAggregate and AlgebraicAggregate aggreguate function") { checkAnswer( - ctx.sql( + sqlContext.sql( """ |SELECT mydoublesum(cast(value as double)), key, avg(value) - |FROM agg2 + |FROM agg1 |GROUP BY key """.stripMargin), Row(60.0, 1, 20.0) :: @@ -168,7 +199,7 @@ class Aggregate2Suite extends QueryTest with BeforeAndAfterAll { Row(30.0, null, 10.0) :: Nil) checkAnswer( - ctx.sql( + sqlContext.sql( """ |SELECT | mydoublesum(cast(value as double) + 1.5 * key), @@ -176,7 +207,7 @@ class Aggregate2Suite extends QueryTest with BeforeAndAfterAll { | key, | mydoublesum(cast(value as double) - 1.5 * key), | avg(value) - |FROM agg2 + |FROM agg1 |GROUP BY key """.stripMargin), Row(64.5, 19.0, 1, 55.5, 20.0) :: @@ -187,15 +218,15 @@ class Aggregate2Suite extends QueryTest with BeforeAndAfterAll { test("Cannot use AggregateExpression1 and AggregateExpressions2 together") { Seq(true, false).foreach { useAggregate2 => - ctx.sql(s"set spark.sql.useAggregate2=$useAggregate2") + sqlContext.sql(s"set spark.sql.useAggregate2=$useAggregate2") val errorMessage = intercept[AnalysisException] { - ctx.sql( + sqlContext.sql( """ |SELECT | key, | sum(cast(value as double) + 1.5 * key), | mydoublesum(value) - |FROM agg2 + |FROM agg1 |GROUP BY key """.stripMargin).collect() }.getMessage @@ -205,10 +236,20 @@ class Aggregate2Suite extends QueryTest with BeforeAndAfterAll { assert(errorMessage.contains(expectedErrorMessage)) } - ctx.sql(s"set spark.sql.useAggregate2=true") + sqlContext.sql(s"set spark.sql.useAggregate2=true") } - override def afterAll(): Unit = { - ctx.sql(s"set spark.sql.useAggregate2=$originalUseAggregate2") + test("single distinct column sets") { + sqlContext.sql( + """ + |SELECT avg(distinct value) FROM agg2 + """.stripMargin).explain(true) + + sqlContext.sql( + """ + |SELECT avg(distinct value) FROM agg2 + """.stripMargin).collect.foreach(println) + + // TODO: add both distinct agg non-distinct agg in the same query. } }