Skip to content

Commit

Permalink
Use semanticEquals to replace grouping expressions in the output of t…
Browse files Browse the repository at this point in the history
…he aggregate operator.
  • Loading branch information
yhuai committed Jul 22, 2015
1 parent 3b43b24 commit 35b0520
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan}
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.types.{StructType, MapType, ArrayType}

/**
* Utility functions used by the query planner to convert our plan to new aggregation code path.
*/
object Utils {
// Right now, we do not support complex types in the grouping key schema.
private def supportsGroupingKeySchema(aggregate: Aggregate): Boolean = {
Expand Down Expand Up @@ -214,11 +217,15 @@ object Utils {
expr => aggregateFunctionMap(expr.aggregateFunction, expr.isDistinct)
}
val rewrittenResultExpressions = resultExpressions.map { expr =>
expr.transform {
expr.transformDown {
case agg: AggregateExpression2 =>
aggregateFunctionMap(agg.aggregateFunction, agg.isDistinct).toAttribute
case expression if groupExpressionMap.contains(expression) =>
groupExpressionMap(expression).toAttribute
case expression =>
// We do not rely on the equality check at here since attributes may
// different cosmetically. Instead, we use semanticEquals.
groupExpressionMap.collectFirst {
case (expr, ne) if expr semanticEquals expression => ne.toAttribute
}.getOrElse(expression)
}.asInstanceOf[NamedExpression]
}
val finalAggregate = Aggregate2Sort(
Expand Down Expand Up @@ -334,8 +341,12 @@ object Utils {
expr.transform {
case agg: AggregateExpression2 =>
aggregateFunctionMap(agg.aggregateFunction, agg.isDistinct).toAttribute
case expression if groupExpressionMap.contains(expression) =>
groupExpressionMap(expression).toAttribute
case expression =>
// We do not rely on the equality check at here since attributes may
// different cosmetically. Instead, we use semanticEquals.
groupExpressionMap.collectFirst {
case (expr, ne) if expr semanticEquals expression => ne.toAttribute
}.getOrElse(expression)
}.asInstanceOf[NamedExpression]
}
val finalAndCompleteAggregate = FinalAndCompleteAggregate2Sort(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,40 @@ class AggregationQuerySuite extends QueryTest with SQLTestUtils with BeforeAndAf
Row(null, null) :: Nil)
}

test("case in-sensitive resolution") {
checkAnswer(
sqlContext.sql(
"""
|SELECT avg(value), kEY - 100
|FROM agg1
|GROUP BY Key - 100
""".stripMargin),
Row(20.0, -99) :: Row(-0.5, -98) :: Row(null, -97) :: Row(10.0, null) :: Nil)

checkAnswer(
sqlContext.sql(
"""
|SELECT sum(distinct value1), kEY - 100, count(distinct value1)
|FROM agg2
|GROUP BY Key - 100
""".stripMargin),
Row(40, -99, 2) :: Row(0, -98, 2) :: Row(null, -97, 0) :: Row(30, null, 3) :: Nil)

checkAnswer(
sqlContext.sql(
"""
|SELECT valUe * key - 100
|FROM agg1
|GROUP BY vAlue * keY - 100
""".stripMargin),
Row(-90) ::
Row(-80) ::
Row(-70) ::
Row(-100) ::
Row(-102) ::
Row(null) :: Nil)
}

test("test average no key in output") {
checkAnswer(
sqlContext.sql(
Expand Down

0 comments on commit 35b0520

Please sign in to comment.