From 072209fdc4777a078f5c85c8f2e0296210118ec4 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Mon, 13 Jul 2015 21:25:45 -0700 Subject: [PATCH] Bug fix: Handle expressions in grouping columns that are not attribute references. --- .../spark/sql/execution/SparkStrategies.scala | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index d0ee8b1118e3b..226eead5c67e8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -220,13 +220,16 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { // 2. Create an Aggregate Operator for partial aggregations. val namedGroupingExpressions = groupingExpressions.map { - case ne: NamedExpression => ne + case ne: NamedExpression => ne -> ne // If the expression is not a NamedExpressions, we add an alias. // So, when we generate the result of the operator, the Aggregate Operator // can directly get the Seq of attributes representing the grouping expressions. - case other => Alias(other, other.toString)() + case other => + val withAlias = Alias(other, other.toString)() + other -> withAlias } - val namedGroupingAttributes = namedGroupingExpressions.map(_.toAttribute) + val groupExpressionMap = namedGroupingExpressions.toMap + val namedGroupingAttributes = namedGroupingExpressions.map(_._2.toAttribute) val partialAggregateExpressions = aggregateExpressions.map { case AggregateExpression2(aggregateFunction, mode, isDistinct) => AggregateExpression2(aggregateFunction, Partial, isDistinct) @@ -237,7 +240,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { val partialAggregate = Aggregate2Sort( true, - namedGroupingExpressions, + namedGroupingExpressions.map(_._2), partialAggregateExpressions, partialAggregateAttributes, namedGroupingAttributes ++ partialAggregateAttributes, @@ -256,6 +259,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { expr.transform { case agg: AggregateExpression2 => aggregateFunctionMap(agg.aggregateFunction).toAttribute + case expression if groupExpressionMap.contains(expression) => + groupExpressionMap(expression).toAttribute }.asInstanceOf[NamedExpression] } val finalAggregate = Aggregate2Sort(