diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 268d7f21e6f85..9ab0a20a5276e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -91,7 +91,6 @@ class Analyzer( ExtractWindowExpressions :: GlobalAggregates :: ResolveAggregateFunctions :: - DistinctAggregationRewriter(conf) :: HiveTypeCoercion.typeCoercionRules ++ extendedResolutionRules : _*), Batch("Nondeterministic", Once, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DistinctAggregationRewriter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DistinctAggregationRewriter.scala index 7518946a943f4..38c1641f73d9f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DistinctAggregationRewriter.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DistinctAggregationRewriter.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.catalyst.analysis -import org.apache.spark.sql.catalyst.CatalystConf import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateFunction, Complete} import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Expand, LogicalPlan} @@ -100,13 +99,10 @@ import org.apache.spark.sql.types.IntegerType * we could improve this in the current rule by applying more advanced expression cannocalization * techniques. */ -case class DistinctAggregationRewriter(conf: CatalystConf) extends Rule[LogicalPlan] { +object DistinctAggregationRewriter extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { - case p if !p.resolved => p - // We need to wait until this Aggregate operator is resolved. + def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { case a: Aggregate => rewrite(a) - case p => p } def rewrite(a: Aggregate): Aggregate = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index deea7238f564c..7455e68ee8f64 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.optimizer import scala.annotation.tailrec import scala.collection.immutable.HashSet -import org.apache.spark.sql.catalyst.analysis.{CleanupAliases, EliminateSubqueryAliases} +import org.apache.spark.sql.catalyst.analysis.{CleanupAliases, DistinctAggregationRewriter, EliminateSubqueryAliases} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral} @@ -42,7 +42,8 @@ abstract class Optimizer extends RuleExecutor[LogicalPlan] { // we do not eliminate subqueries or compute current time in the analyzer. Batch("Finish Analysis", Once, EliminateSubqueryAliases, - ComputeCurrentTime) :: + ComputeCurrentTime, + DistinctAggregationRewriter) :: ////////////////////////////////////////////////////////////////////////////////////////// // Optimizer rules start here ////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala index ed85856f017df..ccce9871e29fc 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala @@ -139,7 +139,6 @@ class LogicalPlanToSQLSuite extends SQLBuilderTest with SQLTestUtils { """.stripMargin) } - test("intersect") { checkHiveQl("SELECT * FROM t0 INTERSECT SELECT * FROM t0") } @@ -367,9 +366,7 @@ class LogicalPlanToSQLSuite extends SQLBuilderTest with SQLTestUtils { checkHiveQl("SELECT * FROM parquet_t0 TABLESAMPLE(0.1 PERCENT) WHERE 1=0") } - // TODO Enable this - // Query plans transformed by DistinctAggregationRewriter are not recognized yet - ignore("multi-distinct columns") { + test("multi-distinct columns") { checkHiveQl("SELECT a, COUNT(DISTINCT b), COUNT(DISTINCT c), SUM(d) FROM parquet_t2 GROUP BY a") }