From e545811346189cb9770bb54dc31ba93057cdc68e Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Sun, 28 Oct 2018 09:38:38 +0800 Subject: [PATCH] [SPARK-19851][SQL] Add support for EVERY and ANY (SOME) aggregates ## What changes were proposed in this pull request? Implements Every, Some, Any aggregates in SQL. These new aggregate expressions are analyzed in normal way and rewritten to equivalent existing aggregate expressions in the optimizer. Every(x) => Min(x) where x is boolean. Some(x) => Max(x) where x is boolean. Any is a synonym for Some. SQL ``` explain extended select every(v) from test_agg group by k; ``` Plan : ``` == Parsed Logical Plan == 'Aggregate ['k], [unresolvedalias('every('v), None)] +- 'UnresolvedRelation `test_agg` == Analyzed Logical Plan == every(v): boolean Aggregate [k#0], [every(v#1) AS every(v)#5] +- SubqueryAlias `test_agg` +- Project [k#0, v#1] +- SubqueryAlias `test_agg` +- LocalRelation [k#0, v#1] == Optimized Logical Plan == Aggregate [k#0], [min(v#1) AS every(v)#5] +- LocalRelation [k#0, v#1] == Physical Plan == *(2) HashAggregate(keys=[k#0], functions=[min(v#1)], output=[every(v)#5]) +- Exchange hashpartitioning(k#0, 200) +- *(1) HashAggregate(keys=[k#0], functions=[partial_min(v#1)], output=[k#0, min#7]) +- LocalTableScan [k#0, v#1] Time taken: 0.512 seconds, Fetched 1 row(s) ``` ## How was this patch tested? Added tests in SQLQueryTestSuite, DataframeAggregateSuite Closes #22809 from dilipbiswal/SPARK-19851-specific-rewrite. Authored-by: Dilip Biswal Signed-off-by: Wenchen Fan --- .../catalyst/analysis/FunctionRegistry.scala | 3 + .../sql/catalyst/expressions/Expression.scala | 26 +++ .../aggregate/UnevaluableAggs.scala | 62 +++++ .../catalyst/optimizer/finishAnalysis.scala | 18 +- .../ExpressionTypeCheckingSuite.scala | 3 + .../resources/sql-tests/inputs/group-by.sql | 66 ++++++ .../sql-tests/results/group-by.sql.out | 214 +++++++++++++++++- 7 files changed, 388 insertions(+), 4 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/UnevaluableAggs.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 38f5c02910f79..af6166bcb8692 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -300,6 +300,9 @@ object FunctionRegistry { expression[CollectList]("collect_list"), expression[CollectSet]("collect_set"), expression[CountMinSketchAgg]("count_min_sketch"), + expression[EveryAgg]("every"), + expression[AnyAgg]("any"), + expression[SomeAgg]("some"), // string functions expression[Ascii]("ascii"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index c215735ab1c98..ccc5b9043a0aa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -21,6 +21,7 @@ import java.util.Locale import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion} +import org.apache.spark.sql.catalyst.expressions.aggregate.DeclarativeAggregate import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.trees.TreeNode @@ -282,6 +283,31 @@ trait RuntimeReplaceable extends UnaryExpression with Unevaluable { override lazy val canonicalized: Expression = child.canonicalized } +/** + * An aggregate expression that gets rewritten (currently by the optimizer) into a + * different aggregate expression for evaluation. This is mainly used to provide compatibility + * with other databases. For example, we use this to support every, any/some aggregates by rewriting + * them with Min and Max respectively. + */ +trait UnevaluableAggregate extends DeclarativeAggregate { + + override def nullable: Boolean = true + + override lazy val aggBufferAttributes = + throw new UnsupportedOperationException(s"Cannot evaluate aggBufferAttributes: $this") + + override lazy val initialValues: Seq[Expression] = + throw new UnsupportedOperationException(s"Cannot evaluate initialValues: $this") + + override lazy val updateExpressions: Seq[Expression] = + throw new UnsupportedOperationException(s"Cannot evaluate updateExpressions: $this") + + override lazy val mergeExpressions: Seq[Expression] = + throw new UnsupportedOperationException(s"Cannot evaluate mergeExpressions: $this") + + override lazy val evaluateExpression: Expression = + throw new UnsupportedOperationException(s"Cannot evaluate evaluateExpression: $this") +} /** * Expressions that don't have SQL representation should extend this trait. Examples are diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/UnevaluableAggs.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/UnevaluableAggs.scala new file mode 100644 index 0000000000000..fc33ef919498b --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/UnevaluableAggs.scala @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.aggregate + +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.types._ + +abstract class UnevaluableBooleanAggBase(arg: Expression) + extends UnevaluableAggregate with ImplicitCastInputTypes { + + override def children: Seq[Expression] = arg :: Nil + + override def dataType: DataType = BooleanType + + override def inputTypes: Seq[AbstractDataType] = Seq(BooleanType) + + override def checkInputDataTypes(): TypeCheckResult = { + arg.dataType match { + case dt if dt != BooleanType => + TypeCheckResult.TypeCheckFailure(s"Input to function '$prettyName' should have been " + + s"${BooleanType.simpleString}, but it's [${arg.dataType.catalogString}].") + case _ => TypeCheckResult.TypeCheckSuccess + } + } +} + +@ExpressionDescription( + usage = "_FUNC_(expr) - Returns true if all values of `expr` are true.", + since = "3.0.0") +case class EveryAgg(arg: Expression) extends UnevaluableBooleanAggBase(arg) { + override def nodeName: String = "Every" +} + +@ExpressionDescription( + usage = "_FUNC_(expr) - Returns true if at least one value of `expr` is true.", + since = "3.0.0") +case class AnyAgg(arg: Expression) extends UnevaluableBooleanAggBase(arg) { + override def nodeName: String = "Any" +} + +@ExpressionDescription( + usage = "_FUNC_(expr) - Returns true if at least one value of `expr` is true.", + since = "3.0.0") +case class SomeAgg(arg: Expression) extends UnevaluableBooleanAggBase(arg) { + override def nodeName: String = "Some" +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala index af0837e36e8ad..fe196ec7c9d54 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala @@ -21,6 +21,7 @@ import scala.collection.mutable import org.apache.spark.sql.catalyst.catalog.SessionCatalog import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.catalyst.util.DateTimeUtils @@ -28,13 +29,24 @@ import org.apache.spark.sql.types._ /** - * Finds all [[RuntimeReplaceable]] expressions and replace them with the expressions that can - * be evaluated. This is mainly used to provide compatibility with other databases. - * For example, we use this to support "nvl" by replacing it with "coalesce". + * Finds all the expressions that are unevaluable and replace/rewrite them with semantically + * equivalent expressions that can be evaluated. Currently we replace two kinds of expressions: + * 1) [[RuntimeReplaceable]] expressions + * 2) [[UnevaluableAggregate]] expressions such as Every, Some, Any + * This is mainly used to provide compatibility with other databases. + * Few examples are: + * we use this to support "nvl" by replacing it with "coalesce". + * we use this to replace Every and Any with Min and Max respectively. + * + * TODO: In future, explore an option to replace aggregate functions similar to + * how RruntimeReplaceable does. */ object ReplaceExpressions extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { case e: RuntimeReplaceable => e.child + case SomeAgg(arg) => Max(arg) + case AnyAgg(arg) => Max(arg) + case EveryAgg(arg) => Min(arg) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala index 8eec14842c7e7..3eb3fe66cebc5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala @@ -144,6 +144,9 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { assertSuccess(Sum('stringField)) assertSuccess(Average('stringField)) assertSuccess(Min('arrayField)) + assertSuccess(new EveryAgg('booleanField)) + assertSuccess(new AnyAgg('booleanField)) + assertSuccess(new SomeAgg('booleanField)) assertError(Min('mapField), "min does not support ordering on type") assertError(Max('mapField), "max does not support ordering on type") diff --git a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql index 433db71527437..ec263ea70bd4a 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql @@ -80,3 +80,69 @@ SELECT 1 FROM range(10) HAVING true; SELECT 1 FROM range(10) HAVING MAX(id) > 0; SELECT id FROM range(10) HAVING id > 0; + +-- Test data +CREATE OR REPLACE TEMPORARY VIEW test_agg AS SELECT * FROM VALUES + (1, true), (1, false), + (2, true), + (3, false), (3, null), + (4, null), (4, null), + (5, null), (5, true), (5, false) AS test_agg(k, v); + +-- empty table +SELECT every(v), some(v), any(v) FROM test_agg WHERE 1 = 0; + +-- all null values +SELECT every(v), some(v), any(v) FROM test_agg WHERE k = 4; + +-- aggregates are null Filtering +SELECT every(v), some(v), any(v) FROM test_agg WHERE k = 5; + +-- group by +SELECT k, every(v), some(v), any(v) FROM test_agg GROUP BY k; + +-- having +SELECT k, every(v) FROM test_agg GROUP BY k HAVING every(v) = false; +SELECT k, every(v) FROM test_agg GROUP BY k HAVING every(v) IS NULL; + +-- basic subquery path to make sure rewrite happens in both parent and child plans. +SELECT k, + Every(v) AS every +FROM test_agg +WHERE k = 2 + AND v IN (SELECT Any(v) + FROM test_agg + WHERE k = 1) +GROUP BY k; + +-- basic subquery path to make sure rewrite happens in both parent and child plans. +SELECT k, + Every(v) AS every +FROM test_agg +WHERE k = 2 + AND v IN (SELECT Every(v) + FROM test_agg + WHERE k = 1) +GROUP BY k; + +-- input type checking Int +SELECT every(1); + +-- input type checking Short +SELECT some(1S); + +-- input type checking Long +SELECT any(1L); + +-- input type checking String +SELECT every("true"); + +-- every/some/any aggregates are supported as windows expression. +SELECT k, v, every(v) OVER (PARTITION BY k ORDER BY v) FROM test_agg; +SELECT k, v, some(v) OVER (PARTITION BY k ORDER BY v) FROM test_agg; +SELECT k, v, any(v) OVER (PARTITION BY k ORDER BY v) FROM test_agg; + +-- simple explain of queries having every/some/any agregates. Optimized +-- plan should show the rewritten aggregate expression. +EXPLAIN EXTENDED SELECT k, every(v), some(v), any(v) FROM test_agg GROUP BY k; + diff --git a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out index f9d1ee8a6bcdb..9a8d025331b67 100644 --- a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 30 +-- Number of queries: 47 -- !query 0 @@ -275,3 +275,215 @@ struct<> -- !query 29 output org.apache.spark.sql.AnalysisException grouping expressions sequence is empty, and '`id`' is not an aggregate function. Wrap '()' in windowing function(s) or wrap '`id`' in first() (or first_value) if you don't care which value you get.; + + +-- !query 30 +CREATE OR REPLACE TEMPORARY VIEW test_agg AS SELECT * FROM VALUES + (1, true), (1, false), + (2, true), + (3, false), (3, null), + (4, null), (4, null), + (5, null), (5, true), (5, false) AS test_agg(k, v) +-- !query 30 schema +struct<> +-- !query 30 output + + + +-- !query 31 +SELECT every(v), some(v), any(v) FROM test_agg WHERE 1 = 0 +-- !query 31 schema +struct +-- !query 31 output +NULL NULL NULL + + +-- !query 32 +SELECT every(v), some(v), any(v) FROM test_agg WHERE k = 4 +-- !query 32 schema +struct +-- !query 32 output +NULL NULL NULL + + +-- !query 33 +SELECT every(v), some(v), any(v) FROM test_agg WHERE k = 5 +-- !query 33 schema +struct +-- !query 33 output +false true true + + +-- !query 34 +SELECT k, every(v), some(v), any(v) FROM test_agg GROUP BY k +-- !query 34 schema +struct +-- !query 34 output +1 false true true +2 true true true +3 false false false +4 NULL NULL NULL +5 false true true + + +-- !query 35 +SELECT k, every(v) FROM test_agg GROUP BY k HAVING every(v) = false +-- !query 35 schema +struct +-- !query 35 output +1 false +3 false +5 false + + +-- !query 36 +SELECT k, every(v) FROM test_agg GROUP BY k HAVING every(v) IS NULL +-- !query 36 schema +struct +-- !query 36 output +4 NULL + + +-- !query 37 +SELECT k, + Every(v) AS every +FROM test_agg +WHERE k = 2 + AND v IN (SELECT Any(v) + FROM test_agg + WHERE k = 1) +GROUP BY k +-- !query 37 schema +struct +-- !query 37 output +2 true + + +-- !query 38 +SELECT k, + Every(v) AS every +FROM test_agg +WHERE k = 2 + AND v IN (SELECT Every(v) + FROM test_agg + WHERE k = 1) +GROUP BY k +-- !query 38 schema +struct +-- !query 38 output + + + +-- !query 39 +SELECT every(1) +-- !query 39 schema +struct<> +-- !query 39 output +org.apache.spark.sql.AnalysisException +cannot resolve 'every(1)' due to data type mismatch: Input to function 'every' should have been boolean, but it's [int].; line 1 pos 7 + + +-- !query 40 +SELECT some(1S) +-- !query 40 schema +struct<> +-- !query 40 output +org.apache.spark.sql.AnalysisException +cannot resolve 'some(1S)' due to data type mismatch: Input to function 'some' should have been boolean, but it's [smallint].; line 1 pos 7 + + +-- !query 41 +SELECT any(1L) +-- !query 41 schema +struct<> +-- !query 41 output +org.apache.spark.sql.AnalysisException +cannot resolve 'any(1L)' due to data type mismatch: Input to function 'any' should have been boolean, but it's [bigint].; line 1 pos 7 + + +-- !query 42 +SELECT every("true") +-- !query 42 schema +struct<> +-- !query 42 output +org.apache.spark.sql.AnalysisException +cannot resolve 'every('true')' due to data type mismatch: Input to function 'every' should have been boolean, but it's [string].; line 1 pos 7 + + +-- !query 43 +SELECT k, v, every(v) OVER (PARTITION BY k ORDER BY v) FROM test_agg +-- !query 43 schema +struct +-- !query 43 output +1 false false +1 true false +2 true true +3 NULL NULL +3 false false +4 NULL NULL +4 NULL NULL +5 NULL NULL +5 false false +5 true false + + +-- !query 44 +SELECT k, v, some(v) OVER (PARTITION BY k ORDER BY v) FROM test_agg +-- !query 44 schema +struct +-- !query 44 output +1 false false +1 true true +2 true true +3 NULL NULL +3 false false +4 NULL NULL +4 NULL NULL +5 NULL NULL +5 false false +5 true true + + +-- !query 45 +SELECT k, v, any(v) OVER (PARTITION BY k ORDER BY v) FROM test_agg +-- !query 45 schema +struct +-- !query 45 output +1 false false +1 true true +2 true true +3 NULL NULL +3 false false +4 NULL NULL +4 NULL NULL +5 NULL NULL +5 false false +5 true true + + +-- !query 46 +EXPLAIN EXTENDED SELECT k, every(v), some(v), any(v) FROM test_agg GROUP BY k +-- !query 46 schema +struct +-- !query 46 output +== Parsed Logical Plan == +'Aggregate ['k], ['k, unresolvedalias('every('v), None), unresolvedalias('some('v), None), unresolvedalias('any('v), None)] ++- 'UnresolvedRelation `test_agg` + +== Analyzed Logical Plan == +k: int, every(v): boolean, some(v): boolean, any(v): boolean +Aggregate [k#x], [k#x, every(v#x) AS every(v)#x, some(v#x) AS some(v)#x, any(v#x) AS any(v)#x] ++- SubqueryAlias `test_agg` + +- Project [k#x, v#x] + +- SubqueryAlias `test_agg` + +- LocalRelation [k#x, v#x] + +== Optimized Logical Plan == +Aggregate [k#x], [k#x, min(v#x) AS every(v)#x, max(v#x) AS some(v)#x, max(v#x) AS any(v)#x] ++- LocalRelation [k#x, v#x] + +== Physical Plan == +*HashAggregate(keys=[k#x], functions=[min(v#x), max(v#x)], output=[k#x, every(v)#x, some(v)#x, any(v)#x]) ++- Exchange hashpartitioning(k#x, 200) + +- *HashAggregate(keys=[k#x], functions=[partial_min(v#x), partial_max(v#x)], output=[k#x, min#x, max#x]) + +- LocalTableScan [k#x, v#x]