diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala index f27241512245c..1cb27710e0480 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala @@ -25,6 +25,9 @@ import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan} import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.types.{StructType, MapType, ArrayType} +/** + * Utility functions used by the query planner to convert our plan to new aggregation code path. + */ object Utils { // Right now, we do not support complex types in the grouping key schema. private def supportsGroupingKeySchema(aggregate: Aggregate): Boolean = { @@ -214,11 +217,15 @@ object Utils { expr => aggregateFunctionMap(expr.aggregateFunction, expr.isDistinct) } val rewrittenResultExpressions = resultExpressions.map { expr => - expr.transform { + expr.transformDown { case agg: AggregateExpression2 => aggregateFunctionMap(agg.aggregateFunction, agg.isDistinct).toAttribute - case expression if groupExpressionMap.contains(expression) => - groupExpressionMap(expression).toAttribute + case expression => + // We do not rely on the equality check at here since attributes may + // different cosmetically. Instead, we use semanticEquals. + groupExpressionMap.collectFirst { + case (expr, ne) if expr semanticEquals expression => ne.toAttribute + }.getOrElse(expression) }.asInstanceOf[NamedExpression] } val finalAggregate = Aggregate2Sort( @@ -334,8 +341,12 @@ object Utils { expr.transform { case agg: AggregateExpression2 => aggregateFunctionMap(agg.aggregateFunction, agg.isDistinct).toAttribute - case expression if groupExpressionMap.contains(expression) => - groupExpressionMap(expression).toAttribute + case expression => + // We do not rely on the equality check at here since attributes may + // different cosmetically. Instead, we use semanticEquals. + groupExpressionMap.collectFirst { + case (expr, ne) if expr semanticEquals expression => ne.toAttribute + }.getOrElse(expression) }.asInstanceOf[NamedExpression] } val finalAndCompleteAggregate = FinalAndCompleteAggregate2Sort( diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala index bfd8805474caf..0375eb79add95 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala @@ -187,6 +187,40 @@ class AggregationQuerySuite extends QueryTest with SQLTestUtils with BeforeAndAf Row(null, null) :: Nil) } + test("case in-sensitive resolution") { + checkAnswer( + sqlContext.sql( + """ + |SELECT avg(value), kEY - 100 + |FROM agg1 + |GROUP BY Key - 100 + """.stripMargin), + Row(20.0, -99) :: Row(-0.5, -98) :: Row(null, -97) :: Row(10.0, null) :: Nil) + + checkAnswer( + sqlContext.sql( + """ + |SELECT sum(distinct value1), kEY - 100, count(distinct value1) + |FROM agg2 + |GROUP BY Key - 100 + """.stripMargin), + Row(40, -99, 2) :: Row(0, -98, 2) :: Row(null, -97, 0) :: Row(30, null, 3) :: Nil) + + checkAnswer( + sqlContext.sql( + """ + |SELECT valUe * key - 100 + |FROM agg1 + |GROUP BY vAlue * keY - 100 + """.stripMargin), + Row(-90) :: + Row(-80) :: + Row(-70) :: + Row(-100) :: + Row(-102) :: + Row(null) :: Nil) + } + test("test average no key in output") { checkAnswer( sqlContext.sql(