Skip to content

Commit

Permalink
[SPARK-19851][SQL] Add support for EVERY and ANY (SOME) aggregates
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?

Implements Every, Some, Any aggregates in SQL. These new aggregate expressions are analyzed in normal way and rewritten to equivalent existing aggregate expressions in the optimizer.

Every(x) => Min(x)  where x is boolean.
Some(x) => Max(x) where x is boolean.

Any is a synonym for Some.
SQL
```
explain extended select every(v) from test_agg group by k;
```
Plan :
```
== Parsed Logical Plan ==
'Aggregate ['k], [unresolvedalias('every('v), None)]
+- 'UnresolvedRelation `test_agg`

== Analyzed Logical Plan ==
every(v): boolean
Aggregate [k#0], [every(v#1) AS every(v)#5]
+- SubqueryAlias `test_agg`
   +- Project [k#0, v#1]
      +- SubqueryAlias `test_agg`
         +- LocalRelation [k#0, v#1]

== Optimized Logical Plan ==
Aggregate [k#0], [min(v#1) AS every(v)#5]
+- LocalRelation [k#0, v#1]

== Physical Plan ==
*(2) HashAggregate(keys=[k#0], functions=[min(v#1)], output=[every(v)#5])
+- Exchange hashpartitioning(k#0, 200)
   +- *(1) HashAggregate(keys=[k#0], functions=[partial_min(v#1)], output=[k#0, min#7])
      +- LocalTableScan [k#0, v#1]
Time taken: 0.512 seconds, Fetched 1 row(s)
```

## How was this patch tested?
Added tests in SQLQueryTestSuite, DataframeAggregateSuite

Closes apache#22809 from dilipbiswal/SPARK-19851-specific-rewrite.

Authored-by: Dilip Biswal <dbiswal@us.ibm.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
dilipbiswal authored and cloud-fan committed Oct 28, 2018
1 parent 41e1416 commit e545811
Show file tree
Hide file tree
Showing 7 changed files with 388 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,9 @@ object FunctionRegistry {
expression[CollectList]("collect_list"),
expression[CollectSet]("collect_set"),
expression[CountMinSketchAgg]("count_min_sketch"),
expression[EveryAgg]("every"),
expression[AnyAgg]("any"),
expression[SomeAgg]("some"),

// string functions
expression[Ascii]("ascii"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import java.util.Locale

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion}
import org.apache.spark.sql.catalyst.expressions.aggregate.DeclarativeAggregate
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.trees.TreeNode
Expand Down Expand Up @@ -282,6 +283,31 @@ trait RuntimeReplaceable extends UnaryExpression with Unevaluable {
override lazy val canonicalized: Expression = child.canonicalized
}

/**
* An aggregate expression that gets rewritten (currently by the optimizer) into a
* different aggregate expression for evaluation. This is mainly used to provide compatibility
* with other databases. For example, we use this to support every, any/some aggregates by rewriting
* them with Min and Max respectively.
*/
trait UnevaluableAggregate extends DeclarativeAggregate {

override def nullable: Boolean = true

override lazy val aggBufferAttributes =
throw new UnsupportedOperationException(s"Cannot evaluate aggBufferAttributes: $this")

override lazy val initialValues: Seq[Expression] =
throw new UnsupportedOperationException(s"Cannot evaluate initialValues: $this")

override lazy val updateExpressions: Seq[Expression] =
throw new UnsupportedOperationException(s"Cannot evaluate updateExpressions: $this")

override lazy val mergeExpressions: Seq[Expression] =
throw new UnsupportedOperationException(s"Cannot evaluate mergeExpressions: $this")

override lazy val evaluateExpression: Expression =
throw new UnsupportedOperationException(s"Cannot evaluate evaluateExpression: $this")
}

/**
* Expressions that don't have SQL representation should extend this trait. Examples are
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.expressions.aggregate

import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types._

abstract class UnevaluableBooleanAggBase(arg: Expression)
extends UnevaluableAggregate with ImplicitCastInputTypes {

override def children: Seq[Expression] = arg :: Nil

override def dataType: DataType = BooleanType

override def inputTypes: Seq[AbstractDataType] = Seq(BooleanType)

override def checkInputDataTypes(): TypeCheckResult = {
arg.dataType match {
case dt if dt != BooleanType =>
TypeCheckResult.TypeCheckFailure(s"Input to function '$prettyName' should have been " +
s"${BooleanType.simpleString}, but it's [${arg.dataType.catalogString}].")
case _ => TypeCheckResult.TypeCheckSuccess
}
}
}

@ExpressionDescription(
usage = "_FUNC_(expr) - Returns true if all values of `expr` are true.",
since = "3.0.0")
case class EveryAgg(arg: Expression) extends UnevaluableBooleanAggBase(arg) {
override def nodeName: String = "Every"
}

@ExpressionDescription(
usage = "_FUNC_(expr) - Returns true if at least one value of `expr` is true.",
since = "3.0.0")
case class AnyAgg(arg: Expression) extends UnevaluableBooleanAggBase(arg) {
override def nodeName: String = "Any"
}

@ExpressionDescription(
usage = "_FUNC_(expr) - Returns true if at least one value of `expr` is true.",
since = "3.0.0")
case class SomeAgg(arg: Expression) extends UnevaluableBooleanAggBase(arg) {
override def nodeName: String = "Some"
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,20 +21,32 @@ import scala.collection.mutable

import org.apache.spark.sql.catalyst.catalog.SessionCatalog
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.types._


/**
* Finds all [[RuntimeReplaceable]] expressions and replace them with the expressions that can
* be evaluated. This is mainly used to provide compatibility with other databases.
* For example, we use this to support "nvl" by replacing it with "coalesce".
* Finds all the expressions that are unevaluable and replace/rewrite them with semantically
* equivalent expressions that can be evaluated. Currently we replace two kinds of expressions:
* 1) [[RuntimeReplaceable]] expressions
* 2) [[UnevaluableAggregate]] expressions such as Every, Some, Any
* This is mainly used to provide compatibility with other databases.
* Few examples are:
* we use this to support "nvl" by replacing it with "coalesce".
* we use this to replace Every and Any with Min and Max respectively.
*
* TODO: In future, explore an option to replace aggregate functions similar to
* how RruntimeReplaceable does.
*/
object ReplaceExpressions extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
case e: RuntimeReplaceable => e.child
case SomeAgg(arg) => Max(arg)
case AnyAgg(arg) => Max(arg)
case EveryAgg(arg) => Min(arg)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,9 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
assertSuccess(Sum('stringField))
assertSuccess(Average('stringField))
assertSuccess(Min('arrayField))
assertSuccess(new EveryAgg('booleanField))
assertSuccess(new AnyAgg('booleanField))
assertSuccess(new SomeAgg('booleanField))

assertError(Min('mapField), "min does not support ordering on type")
assertError(Max('mapField), "max does not support ordering on type")
Expand Down
66 changes: 66 additions & 0 deletions sql/core/src/test/resources/sql-tests/inputs/group-by.sql
Original file line number Diff line number Diff line change
Expand Up @@ -80,3 +80,69 @@ SELECT 1 FROM range(10) HAVING true;
SELECT 1 FROM range(10) HAVING MAX(id) > 0;

SELECT id FROM range(10) HAVING id > 0;

-- Test data
CREATE OR REPLACE TEMPORARY VIEW test_agg AS SELECT * FROM VALUES
(1, true), (1, false),
(2, true),
(3, false), (3, null),
(4, null), (4, null),
(5, null), (5, true), (5, false) AS test_agg(k, v);

-- empty table
SELECT every(v), some(v), any(v) FROM test_agg WHERE 1 = 0;

-- all null values
SELECT every(v), some(v), any(v) FROM test_agg WHERE k = 4;

-- aggregates are null Filtering
SELECT every(v), some(v), any(v) FROM test_agg WHERE k = 5;

-- group by
SELECT k, every(v), some(v), any(v) FROM test_agg GROUP BY k;

-- having
SELECT k, every(v) FROM test_agg GROUP BY k HAVING every(v) = false;
SELECT k, every(v) FROM test_agg GROUP BY k HAVING every(v) IS NULL;

-- basic subquery path to make sure rewrite happens in both parent and child plans.
SELECT k,
Every(v) AS every
FROM test_agg
WHERE k = 2
AND v IN (SELECT Any(v)
FROM test_agg
WHERE k = 1)
GROUP BY k;

-- basic subquery path to make sure rewrite happens in both parent and child plans.
SELECT k,
Every(v) AS every
FROM test_agg
WHERE k = 2
AND v IN (SELECT Every(v)
FROM test_agg
WHERE k = 1)
GROUP BY k;

-- input type checking Int
SELECT every(1);

-- input type checking Short
SELECT some(1S);

-- input type checking Long
SELECT any(1L);

-- input type checking String
SELECT every("true");

-- every/some/any aggregates are supported as windows expression.
SELECT k, v, every(v) OVER (PARTITION BY k ORDER BY v) FROM test_agg;
SELECT k, v, some(v) OVER (PARTITION BY k ORDER BY v) FROM test_agg;
SELECT k, v, any(v) OVER (PARTITION BY k ORDER BY v) FROM test_agg;

-- simple explain of queries having every/some/any agregates. Optimized
-- plan should show the rewritten aggregate expression.
EXPLAIN EXTENDED SELECT k, every(v), some(v), any(v) FROM test_agg GROUP BY k;

Loading

0 comments on commit e545811

Please sign in to comment.