diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 38f5c02910f79..af6166bcb8692 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -300,6 +300,9 @@ object FunctionRegistry { expression[CollectList]("collect_list"), expression[CollectSet]("collect_set"), expression[CountMinSketchAgg]("count_min_sketch"), + expression[EveryAgg]("every"), + expression[AnyAgg]("any"), + expression[SomeAgg]("some"), // string functions expression[Ascii]("ascii"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index c215735ab1c98..ccc5b9043a0aa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -21,6 +21,7 @@ import java.util.Locale import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion} +import org.apache.spark.sql.catalyst.expressions.aggregate.DeclarativeAggregate import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.trees.TreeNode @@ -282,6 +283,31 @@ trait RuntimeReplaceable extends UnaryExpression with Unevaluable { override lazy val canonicalized: Expression = child.canonicalized } +/** + * An aggregate expression that gets rewritten (currently by the optimizer) into a + * different aggregate expression for evaluation. This is mainly used to provide compatibility + * with other databases. For example, we use this to support every, any/some aggregates by rewriting + * them with Min and Max respectively. + */ +trait UnevaluableAggregate extends DeclarativeAggregate { + + override def nullable: Boolean = true + + override lazy val aggBufferAttributes = + throw new UnsupportedOperationException(s"Cannot evaluate aggBufferAttributes: $this") + + override lazy val initialValues: Seq[Expression] = + throw new UnsupportedOperationException(s"Cannot evaluate initialValues: $this") + + override lazy val updateExpressions: Seq[Expression] = + throw new UnsupportedOperationException(s"Cannot evaluate updateExpressions: $this") + + override lazy val mergeExpressions: Seq[Expression] = + throw new UnsupportedOperationException(s"Cannot evaluate mergeExpressions: $this") + + override lazy val evaluateExpression: Expression = + throw new UnsupportedOperationException(s"Cannot evaluate evaluateExpression: $this") +} /** * Expressions that don't have SQL representation should extend this trait. Examples are diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/UnevaluableAggs.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/UnevaluableAggs.scala new file mode 100644 index 0000000000000..fc33ef919498b --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/UnevaluableAggs.scala @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.aggregate + +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.types._ + +abstract class UnevaluableBooleanAggBase(arg: Expression) + extends UnevaluableAggregate with ImplicitCastInputTypes { + + override def children: Seq[Expression] = arg :: Nil + + override def dataType: DataType = BooleanType + + override def inputTypes: Seq[AbstractDataType] = Seq(BooleanType) + + override def checkInputDataTypes(): TypeCheckResult = { + arg.dataType match { + case dt if dt != BooleanType => + TypeCheckResult.TypeCheckFailure(s"Input to function '$prettyName' should have been " + + s"${BooleanType.simpleString}, but it's [${arg.dataType.catalogString}].") + case _ => TypeCheckResult.TypeCheckSuccess + } + } +} + +@ExpressionDescription( + usage = "_FUNC_(expr) - Returns true if all values of `expr` are true.", + since = "3.0.0") +case class EveryAgg(arg: Expression) extends UnevaluableBooleanAggBase(arg) { + override def nodeName: String = "Every" +} + +@ExpressionDescription( + usage = "_FUNC_(expr) - Returns true if at least one value of `expr` is true.", + since = "3.0.0") +case class AnyAgg(arg: Expression) extends UnevaluableBooleanAggBase(arg) { + override def nodeName: String = "Any" +} + +@ExpressionDescription( + usage = "_FUNC_(expr) - Returns true if at least one value of `expr` is true.", + since = "3.0.0") +case class SomeAgg(arg: Expression) extends UnevaluableBooleanAggBase(arg) { + override def nodeName: String = "Some" +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala index af0837e36e8ad..fe196ec7c9d54 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala @@ -21,6 +21,7 @@ import scala.collection.mutable import org.apache.spark.sql.catalyst.catalog.SessionCatalog import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.catalyst.util.DateTimeUtils @@ -28,13 +29,24 @@ import org.apache.spark.sql.types._ /** - * Finds all [[RuntimeReplaceable]] expressions and replace them with the expressions that can - * be evaluated. This is mainly used to provide compatibility with other databases. - * For example, we use this to support "nvl" by replacing it with "coalesce". + * Finds all the expressions that are unevaluable and replace/rewrite them with semantically + * equivalent expressions that can be evaluated. Currently we replace two kinds of expressions: + * 1) [[RuntimeReplaceable]] expressions + * 2) [[UnevaluableAggregate]] expressions such as Every, Some, Any + * This is mainly used to provide compatibility with other databases. + * Few examples are: + * we use this to support "nvl" by replacing it with "coalesce". + * we use this to replace Every and Any with Min and Max respectively. + * + * TODO: In future, explore an option to replace aggregate functions similar to + * how RruntimeReplaceable does. */ object ReplaceExpressions extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { case e: RuntimeReplaceable => e.child + case SomeAgg(arg) => Max(arg) + case AnyAgg(arg) => Max(arg) + case EveryAgg(arg) => Min(arg) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala index 8eec14842c7e7..3eb3fe66cebc5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala @@ -144,6 +144,9 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { assertSuccess(Sum('stringField)) assertSuccess(Average('stringField)) assertSuccess(Min('arrayField)) + assertSuccess(new EveryAgg('booleanField)) + assertSuccess(new AnyAgg('booleanField)) + assertSuccess(new SomeAgg('booleanField)) assertError(Min('mapField), "min does not support ordering on type") assertError(Max('mapField), "max does not support ordering on type") diff --git a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql index 433db71527437..ec263ea70bd4a 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql @@ -80,3 +80,69 @@ SELECT 1 FROM range(10) HAVING true; SELECT 1 FROM range(10) HAVING MAX(id) > 0; SELECT id FROM range(10) HAVING id > 0; + +-- Test data +CREATE OR REPLACE TEMPORARY VIEW test_agg AS SELECT * FROM VALUES + (1, true), (1, false), + (2, true), + (3, false), (3, null), + (4, null), (4, null), + (5, null), (5, true), (5, false) AS test_agg(k, v); + +-- empty table +SELECT every(v), some(v), any(v) FROM test_agg WHERE 1 = 0; + +-- all null values +SELECT every(v), some(v), any(v) FROM test_agg WHERE k = 4; + +-- aggregates are null Filtering +SELECT every(v), some(v), any(v) FROM test_agg WHERE k = 5; + +-- group by +SELECT k, every(v), some(v), any(v) FROM test_agg GROUP BY k; + +-- having +SELECT k, every(v) FROM test_agg GROUP BY k HAVING every(v) = false; +SELECT k, every(v) FROM test_agg GROUP BY k HAVING every(v) IS NULL; + +-- basic subquery path to make sure rewrite happens in both parent and child plans. +SELECT k, + Every(v) AS every +FROM test_agg +WHERE k = 2 + AND v IN (SELECT Any(v) + FROM test_agg + WHERE k = 1) +GROUP BY k; + +-- basic subquery path to make sure rewrite happens in both parent and child plans. +SELECT k, + Every(v) AS every +FROM test_agg +WHERE k = 2 + AND v IN (SELECT Every(v) + FROM test_agg + WHERE k = 1) +GROUP BY k; + +-- input type checking Int +SELECT every(1); + +-- input type checking Short +SELECT some(1S); + +-- input type checking Long +SELECT any(1L); + +-- input type checking String +SELECT every("true"); + +-- every/some/any aggregates are supported as windows expression. +SELECT k, v, every(v) OVER (PARTITION BY k ORDER BY v) FROM test_agg; +SELECT k, v, some(v) OVER (PARTITION BY k ORDER BY v) FROM test_agg; +SELECT k, v, any(v) OVER (PARTITION BY k ORDER BY v) FROM test_agg; + +-- simple explain of queries having every/some/any agregates. Optimized +-- plan should show the rewritten aggregate expression. +EXPLAIN EXTENDED SELECT k, every(v), some(v), any(v) FROM test_agg GROUP BY k; + diff --git a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out index f9d1ee8a6bcdb..9a8d025331b67 100644 --- a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 30 +-- Number of queries: 47 -- !query 0 @@ -275,3 +275,215 @@ struct<> -- !query 29 output org.apache.spark.sql.AnalysisException grouping expressions sequence is empty, and '`id`' is not an aggregate function. Wrap '()' in windowing function(s) or wrap '`id`' in first() (or first_value) if you don't care which value you get.; + + +-- !query 30 +CREATE OR REPLACE TEMPORARY VIEW test_agg AS SELECT * FROM VALUES + (1, true), (1, false), + (2, true), + (3, false), (3, null), + (4, null), (4, null), + (5, null), (5, true), (5, false) AS test_agg(k, v) +-- !query 30 schema +struct<> +-- !query 30 output + + + +-- !query 31 +SELECT every(v), some(v), any(v) FROM test_agg WHERE 1 = 0 +-- !query 31 schema +struct +-- !query 31 output +NULL NULL NULL + + +-- !query 32 +SELECT every(v), some(v), any(v) FROM test_agg WHERE k = 4 +-- !query 32 schema +struct +-- !query 32 output +NULL NULL NULL + + +-- !query 33 +SELECT every(v), some(v), any(v) FROM test_agg WHERE k = 5 +-- !query 33 schema +struct +-- !query 33 output +false true true + + +-- !query 34 +SELECT k, every(v), some(v), any(v) FROM test_agg GROUP BY k +-- !query 34 schema +struct +-- !query 34 output +1 false true true +2 true true true +3 false false false +4 NULL NULL NULL +5 false true true + + +-- !query 35 +SELECT k, every(v) FROM test_agg GROUP BY k HAVING every(v) = false +-- !query 35 schema +struct +-- !query 35 output +1 false +3 false +5 false + + +-- !query 36 +SELECT k, every(v) FROM test_agg GROUP BY k HAVING every(v) IS NULL +-- !query 36 schema +struct +-- !query 36 output +4 NULL + + +-- !query 37 +SELECT k, + Every(v) AS every +FROM test_agg +WHERE k = 2 + AND v IN (SELECT Any(v) + FROM test_agg + WHERE k = 1) +GROUP BY k +-- !query 37 schema +struct +-- !query 37 output +2 true + + +-- !query 38 +SELECT k, + Every(v) AS every +FROM test_agg +WHERE k = 2 + AND v IN (SELECT Every(v) + FROM test_agg + WHERE k = 1) +GROUP BY k +-- !query 38 schema +struct +-- !query 38 output + + + +-- !query 39 +SELECT every(1) +-- !query 39 schema +struct<> +-- !query 39 output +org.apache.spark.sql.AnalysisException +cannot resolve 'every(1)' due to data type mismatch: Input to function 'every' should have been boolean, but it's [int].; line 1 pos 7 + + +-- !query 40 +SELECT some(1S) +-- !query 40 schema +struct<> +-- !query 40 output +org.apache.spark.sql.AnalysisException +cannot resolve 'some(1S)' due to data type mismatch: Input to function 'some' should have been boolean, but it's [smallint].; line 1 pos 7 + + +-- !query 41 +SELECT any(1L) +-- !query 41 schema +struct<> +-- !query 41 output +org.apache.spark.sql.AnalysisException +cannot resolve 'any(1L)' due to data type mismatch: Input to function 'any' should have been boolean, but it's [bigint].; line 1 pos 7 + + +-- !query 42 +SELECT every("true") +-- !query 42 schema +struct<> +-- !query 42 output +org.apache.spark.sql.AnalysisException +cannot resolve 'every('true')' due to data type mismatch: Input to function 'every' should have been boolean, but it's [string].; line 1 pos 7 + + +-- !query 43 +SELECT k, v, every(v) OVER (PARTITION BY k ORDER BY v) FROM test_agg +-- !query 43 schema +struct +-- !query 43 output +1 false false +1 true false +2 true true +3 NULL NULL +3 false false +4 NULL NULL +4 NULL NULL +5 NULL NULL +5 false false +5 true false + + +-- !query 44 +SELECT k, v, some(v) OVER (PARTITION BY k ORDER BY v) FROM test_agg +-- !query 44 schema +struct +-- !query 44 output +1 false false +1 true true +2 true true +3 NULL NULL +3 false false +4 NULL NULL +4 NULL NULL +5 NULL NULL +5 false false +5 true true + + +-- !query 45 +SELECT k, v, any(v) OVER (PARTITION BY k ORDER BY v) FROM test_agg +-- !query 45 schema +struct +-- !query 45 output +1 false false +1 true true +2 true true +3 NULL NULL +3 false false +4 NULL NULL +4 NULL NULL +5 NULL NULL +5 false false +5 true true + + +-- !query 46 +EXPLAIN EXTENDED SELECT k, every(v), some(v), any(v) FROM test_agg GROUP BY k +-- !query 46 schema +struct +-- !query 46 output +== Parsed Logical Plan == +'Aggregate ['k], ['k, unresolvedalias('every('v), None), unresolvedalias('some('v), None), unresolvedalias('any('v), None)] ++- 'UnresolvedRelation `test_agg` + +== Analyzed Logical Plan == +k: int, every(v): boolean, some(v): boolean, any(v): boolean +Aggregate [k#x], [k#x, every(v#x) AS every(v)#x, some(v#x) AS some(v)#x, any(v#x) AS any(v)#x] ++- SubqueryAlias `test_agg` + +- Project [k#x, v#x] + +- SubqueryAlias `test_agg` + +- LocalRelation [k#x, v#x] + +== Optimized Logical Plan == +Aggregate [k#x], [k#x, min(v#x) AS every(v)#x, max(v#x) AS some(v)#x, max(v#x) AS any(v)#x] ++- LocalRelation [k#x, v#x] + +== Physical Plan == +*HashAggregate(keys=[k#x], functions=[min(v#x), max(v#x)], output=[k#x, every(v)#x, some(v)#x, any(v)#x]) ++- Exchange hashpartitioning(k#x, 200) + +- *HashAggregate(keys=[k#x], functions=[partial_min(v#x), partial_max(v#x)], output=[k#x, min#x, max#x]) + +- LocalTableScan [k#x, v#x]