Skip to content

Commit

Permalink
[SPARK-31620][SQL] Fix reference binding failure in case of an final …
Browse files Browse the repository at this point in the history
…agg contains subquery

Instead of using `child.output` directly, we should use `inputAggBufferAttributes` from the current agg expression  for `Final` and `PartialMerge` aggregates to bind references for their `mergeExpression`.

When planning aggregates, the partial aggregate uses agg fucs' `inputAggBufferAttributes` as its output, see https://github.com/apache/spark/blob/v3.0.0-rc1/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala#L105

For final `HashAggregateExec`, we need to bind the `DeclarativeAggregate.mergeExpressions` with the output of the partial aggregate operator, see https://github.com/apache/spark/blob/v3.0.0-rc1/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala#L348

This is usually fine. However, if we copy the agg func somehow after agg planning, like `PlanSubqueries`, the `DeclarativeAggregate` will be replaced by a new instance with new `inputAggBufferAttributes` and `mergeExpressions`. Then we can't bind the `mergeExpressions` with the output of the partial aggregate operator, as it uses the `inputAggBufferAttributes` of the original `DeclarativeAggregate` before copy.

Note that, `ImperativeAggregate` doesn't have this problem, as we don't need to bind its `mergeExpressions`. It has a different mechanism to access buffer values, via `mutableAggBufferOffset` and `inputAggBufferOffset`.

Yes, user hit error previously but run query successfully after this change.

Added a regression test.

Closes apache#28496 from Ngone51/spark-31620.

Authored-by: yi.wu <yi.wu@databricks.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
Ngone51 authored and Nick Nicolini committed Jun 11, 2020
1 parent 3c41f4b commit 359e1a0
Show file tree
Hide file tree
Showing 5 changed files with 231 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
package org.apache.spark.sql.execution.aggregate

import org.apache.spark.sql.catalyst.expressions.{Attribute, NamedExpression}
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Final, PartialMerge}
import org.apache.spark.sql.execution.{ExplainUtils, UnaryExecNode}

/**
Expand All @@ -45,4 +45,28 @@ trait BaseAggregateExec extends UnaryExecNode {
|Results: $resultString
""".stripMargin
}

protected def inputAttributes: Seq[Attribute] = {
val modes = aggregateExpressions.map(_.mode).distinct
if (modes.contains(Final) || modes.contains(PartialMerge)) {
// SPARK-31620: when planning aggregates, the partial aggregate uses aggregate function's
// `inputAggBufferAttributes` as its output. And Final and PartialMerge aggregate rely on the
// output to bind references for `DeclarativeAggregate.mergeExpressions`. But if we copy the
// aggregate function somehow after aggregate planning, like `PlanSubqueries`, the
// `DeclarativeAggregate` will be replaced by a new instance with new
// `inputAggBufferAttributes` and `mergeExpressions`. Then Final and PartialMerge aggregate
// can't bind the `mergeExpressions` with the output of the partial aggregate, as they use
// the `inputAggBufferAttributes` of the original `DeclarativeAggregate` before copy. Instead,
// we shall use `inputAggBufferAttributes` after copy to match the new `mergeExpressions`.
val aggAttrs = aggregateExpressions
// there're exactly four cases needs `inputAggBufferAttributes` from child according to the
// agg planning in `AggUtils`: Partial -> Final, PartialMerge -> Final,
// Partial -> PartialMerge, PartialMerge -> PartialMerge.
.filter(a => a.mode == Final || a.mode == PartialMerge).map(_.aggregateFunction)
.flatMap(_.inputAggBufferAttributes)
child.output.dropRight(aggAttrs.length) ++ aggAttrs
} else {
child.output
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ case class HashAggregateExec(
resultExpressions,
(expressions, inputSchema) =>
newMutableProjection(expressions, inputSchema, subexpressionEliminationEnabled),
child.output,
inputAttributes,
iter,
testFallbackStartsAt,
numOutputRows,
Expand Down Expand Up @@ -254,7 +254,7 @@ case class HashAggregateExec(
private def doConsumeWithoutKeys(ctx: CodegenContext, input: Seq[ExprCode]): String = {
// only have DeclarativeAggregate
val functions = aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate])
val inputAttrs = functions.flatMap(_.aggBufferAttributes) ++ child.output
val inputAttrs = functions.flatMap(_.aggBufferAttributes) ++ inputAttributes
val updateExpr = aggregateExpressions.flatMap { e =>
e.mode match {
case Partial | Complete =>
Expand Down Expand Up @@ -817,7 +817,7 @@ case class HashAggregateExec(
}
}

val inputAttr = aggregateBufferAttributes ++ child.output
val inputAttr = aggregateBufferAttributes ++ inputAttributes
// Here we set `currentVars(0)` to `currentVars(numBufferSlots)` to null, so that when
// generating code for buffer columns, we use `INPUT_ROW`(will be the buffer row), while
// generating input columns, we use `currentVars`.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ case class ObjectHashAggregateExec(
resultExpressions,
(expressions, inputSchema) =>
newMutableProjection(expressions, inputSchema, subexpressionEliminationEnabled),
child.output,
inputAttributes,
iter,
fallbackCountThreshold,
numOutputRows)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ case class SortAggregateExec(
val outputIter = new SortBasedAggregationIterator(
partIndex,
groupingExpressions,
child.output,
inputAttributes,
iter,
aggregateExpressions,
aggregateAttributes,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -772,4 +772,205 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext {
Row(Seq(0.0f, 0.0f), Row(0.0d, Double.NaN), Seq(Row(0.0d, Double.NaN)), 2)
)
}

test("SPARK-27581: DataFrame countDistinct(\"*\") shouldn't fail with AnalysisException") {
val df = sql("select id % 100 from range(100000)")
val distinctCount1 = df.select(expr("count(distinct(*))"))
val distinctCount2 = df.select(countDistinct("*"))
checkAnswer(distinctCount1, distinctCount2)

val countAndDistinct = df.select(count("*"), countDistinct("*"))
checkAnswer(countAndDistinct, Row(100000, 100))
}

test("max_by") {
val yearOfMaxEarnings =
sql("SELECT course, max_by(year, earnings) FROM courseSales GROUP BY course")
checkAnswer(yearOfMaxEarnings, Row("dotNET", 2013) :: Row("Java", 2013) :: Nil)

checkAnswer(
sql("SELECT max_by(x, y) FROM VALUES (('a', 10)), (('b', 50)), (('c', 20)) AS tab(x, y)"),
Row("b") :: Nil
)

checkAnswer(
sql("SELECT max_by(x, y) FROM VALUES (('a', 10)), (('b', null)), (('c', 20)) AS tab(x, y)"),
Row("c") :: Nil
)

checkAnswer(
sql("SELECT max_by(x, y) FROM VALUES (('a', null)), (('b', null)), (('c', 20)) AS tab(x, y)"),
Row("c") :: Nil
)

checkAnswer(
sql("SELECT max_by(x, y) FROM VALUES (('a', 10)), (('b', 50)), (('c', null)) AS tab(x, y)"),
Row("b") :: Nil
)

checkAnswer(
sql("SELECT max_by(x, y) FROM VALUES (('a', null)), (('b', null)) AS tab(x, y)"),
Row(null) :: Nil
)

// structs as ordering value.
checkAnswer(
sql("select max_by(x, y) FROM VALUES (('a', (10, 20))), (('b', (10, 50))), " +
"(('c', (10, 60))) AS tab(x, y)"),
Row("c") :: Nil
)

checkAnswer(
sql("select max_by(x, y) FROM VALUES (('a', (10, 20))), (('b', (10, 50))), " +
"(('c', null)) AS tab(x, y)"),
Row("b") :: Nil
)

withTempView("tempView") {
val dfWithMap = Seq((0, "a"), (1, "b"), (2, "c"))
.toDF("x", "y")
.select($"x", map($"x", $"y").as("y"))
.createOrReplaceTempView("tempView")
val error = intercept[AnalysisException] {
sql("SELECT max_by(x, y) FROM tempView").show
}
assert(
error.message.contains("function max_by does not support ordering on type map<int,string>"))
}
}

test("min_by") {
val yearOfMinEarnings =
sql("SELECT course, min_by(year, earnings) FROM courseSales GROUP BY course")
checkAnswer(yearOfMinEarnings, Row("dotNET", 2012) :: Row("Java", 2012) :: Nil)

checkAnswer(
sql("SELECT min_by(x, y) FROM VALUES (('a', 10)), (('b', 50)), (('c', 20)) AS tab(x, y)"),
Row("a") :: Nil
)

checkAnswer(
sql("SELECT min_by(x, y) FROM VALUES (('a', 10)), (('b', null)), (('c', 20)) AS tab(x, y)"),
Row("a") :: Nil
)

checkAnswer(
sql("SELECT min_by(x, y) FROM VALUES (('a', null)), (('b', null)), (('c', 20)) AS tab(x, y)"),
Row("c") :: Nil
)

checkAnswer(
sql("SELECT min_by(x, y) FROM VALUES (('a', 10)), (('b', 50)), (('c', null)) AS tab(x, y)"),
Row("a") :: Nil
)

checkAnswer(
sql("SELECT min_by(x, y) FROM VALUES (('a', null)), (('b', null)) AS tab(x, y)"),
Row(null) :: Nil
)

// structs as ordering value.
checkAnswer(
sql("select min_by(x, y) FROM VALUES (('a', (10, 20))), (('b', (10, 50))), " +
"(('c', (10, 60))) AS tab(x, y)"),
Row("a") :: Nil
)

checkAnswer(
sql("select min_by(x, y) FROM VALUES (('a', null)), (('b', (10, 50))), " +
"(('c', (10, 60))) AS tab(x, y)"),
Row("b") :: Nil
)

withTempView("tempView") {
val dfWithMap = Seq((0, "a"), (1, "b"), (2, "c"))
.toDF("x", "y")
.select($"x", map($"x", $"y").as("y"))
.createOrReplaceTempView("tempView")
val error = intercept[AnalysisException] {
sql("SELECT min_by(x, y) FROM tempView").show
}
assert(
error.message.contains("function min_by does not support ordering on type map<int,string>"))
}
}

test("count_if") {
withTempView("tempView") {
Seq(("a", None), ("a", Some(1)), ("a", Some(2)), ("a", Some(3)),
("b", None), ("b", Some(4)), ("b", Some(5)), ("b", Some(6)))
.toDF("x", "y")
.createOrReplaceTempView("tempView")

checkAnswer(
sql("SELECT COUNT_IF(NULL), COUNT_IF(y % 2 = 0), COUNT_IF(y % 2 <> 0), " +
"COUNT_IF(y IS NULL) FROM tempView"),
Row(0L, 3L, 3L, 2L))

checkAnswer(
sql("SELECT x, COUNT_IF(NULL), COUNT_IF(y % 2 = 0), COUNT_IF(y % 2 <> 0), " +
"COUNT_IF(y IS NULL) FROM tempView GROUP BY x"),
Row("a", 0L, 1L, 2L, 1L) :: Row("b", 0L, 2L, 1L, 1L) :: Nil)

checkAnswer(
sql("SELECT x FROM tempView GROUP BY x HAVING COUNT_IF(y % 2 = 0) = 1"),
Row("a"))

checkAnswer(
sql("SELECT x FROM tempView GROUP BY x HAVING COUNT_IF(y % 2 = 0) = 2"),
Row("b"))

checkAnswer(
sql("SELECT x FROM tempView GROUP BY x HAVING COUNT_IF(y IS NULL) > 0"),
Row("a") :: Row("b") :: Nil)

checkAnswer(
sql("SELECT x FROM tempView GROUP BY x HAVING COUNT_IF(NULL) > 0"),
Nil)

val error = intercept[AnalysisException] {
sql("SELECT COUNT_IF(x) FROM tempView")
}
assert(error.message.contains("function count_if requires boolean type"))
}
}

Seq(true, false).foreach { value =>
test(s"SPARK-31620: agg with subquery (whole-stage-codegen = $value)") {
withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> value.toString) {
withTempView("t1", "t2") {
sql("create temporary view t1 as select * from values (1, 2) as t1(a, b)")
sql("create temporary view t2 as select * from values (3, 4) as t2(c, d)")

// test without grouping keys
checkAnswer(sql("select sum(if(c > (select a from t1), d, 0)) as csum from t2"),
Row(4) :: Nil)

// test with grouping keys
checkAnswer(sql("select c, sum(if(c > (select a from t1), d, 0)) as csum from " +
"t2 group by c"), Row(3, 4) :: Nil)

// test with distinct
checkAnswer(sql("select avg(distinct(d)), sum(distinct(if(c > (select a from t1)," +
" d, 0))) as csum from t2 group by c"), Row(4, 4) :: Nil)

// test subquery with agg
checkAnswer(sql("select sum(distinct(if(c > (select sum(distinct(a)) from t1)," +
" d, 0))) as csum from t2 group by c"), Row(4) :: Nil)

// test SortAggregateExec
var df = sql("select max(if(c > (select a from t1), 'str1', 'str2')) as csum from t2")
assert(df.queryExecution.executedPlan
.find { case _: SortAggregateExec => true }.isDefined)
checkAnswer(df, Row("str1") :: Nil)

// test ObjectHashAggregateExec
df = sql("select collect_list(d), sum(if(c > (select a from t1), d, 0)) as csum from t2")
assert(df.queryExecution.executedPlan
.find { case _: ObjectHashAggregateExec => true }.isDefined)
checkAnswer(df, Row(Array(4), 4) :: Nil)
}
}
}
}
}

0 comments on commit 359e1a0

Please sign in to comment.