From 359e1a018c4381ac0b4166a3084e1558d2836ed7 Mon Sep 17 00:00:00 2001 From: "yi.wu" Date: Fri, 15 May 2020 15:36:28 +0000 Subject: [PATCH] [SPARK-31620][SQL] Fix reference binding failure in case of an final agg contains subquery Instead of using `child.output` directly, we should use `inputAggBufferAttributes` from the current agg expression for `Final` and `PartialMerge` aggregates to bind references for their `mergeExpression`. When planning aggregates, the partial aggregate uses agg fucs' `inputAggBufferAttributes` as its output, see https://github.com/apache/spark/blob/v3.0.0-rc1/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala#L105 For final `HashAggregateExec`, we need to bind the `DeclarativeAggregate.mergeExpressions` with the output of the partial aggregate operator, see https://github.com/apache/spark/blob/v3.0.0-rc1/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala#L348 This is usually fine. However, if we copy the agg func somehow after agg planning, like `PlanSubqueries`, the `DeclarativeAggregate` will be replaced by a new instance with new `inputAggBufferAttributes` and `mergeExpressions`. Then we can't bind the `mergeExpressions` with the output of the partial aggregate operator, as it uses the `inputAggBufferAttributes` of the original `DeclarativeAggregate` before copy. Note that, `ImperativeAggregate` doesn't have this problem, as we don't need to bind its `mergeExpressions`. It has a different mechanism to access buffer values, via `mutableAggBufferOffset` and `inputAggBufferOffset`. Yes, user hit error previously but run query successfully after this change. Added a regression test. Closes #28496 from Ngone51/spark-31620. Authored-by: yi.wu Signed-off-by: Wenchen Fan --- .../aggregate/BaseAggregateExec.scala | 26 ++- .../aggregate/HashAggregateExec.scala | 6 +- .../aggregate/ObjectHashAggregateExec.scala | 2 +- .../aggregate/SortAggregateExec.scala | 2 +- .../spark/sql/DataFrameAggregateSuite.scala | 201 ++++++++++++++++++ 5 files changed, 231 insertions(+), 6 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/BaseAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/BaseAggregateExec.scala index 0eaa0f53fdac..74c00ccbaf25 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/BaseAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/BaseAggregateExec.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution.aggregate import org.apache.spark.sql.catalyst.expressions.{Attribute, NamedExpression} -import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Final, PartialMerge} import org.apache.spark.sql.execution.{ExplainUtils, UnaryExecNode} /** @@ -45,4 +45,28 @@ trait BaseAggregateExec extends UnaryExecNode { |Results: $resultString """.stripMargin } + + protected def inputAttributes: Seq[Attribute] = { + val modes = aggregateExpressions.map(_.mode).distinct + if (modes.contains(Final) || modes.contains(PartialMerge)) { + // SPARK-31620: when planning aggregates, the partial aggregate uses aggregate function's + // `inputAggBufferAttributes` as its output. And Final and PartialMerge aggregate rely on the + // output to bind references for `DeclarativeAggregate.mergeExpressions`. But if we copy the + // aggregate function somehow after aggregate planning, like `PlanSubqueries`, the + // `DeclarativeAggregate` will be replaced by a new instance with new + // `inputAggBufferAttributes` and `mergeExpressions`. Then Final and PartialMerge aggregate + // can't bind the `mergeExpressions` with the output of the partial aggregate, as they use + // the `inputAggBufferAttributes` of the original `DeclarativeAggregate` before copy. Instead, + // we shall use `inputAggBufferAttributes` after copy to match the new `mergeExpressions`. + val aggAttrs = aggregateExpressions + // there're exactly four cases needs `inputAggBufferAttributes` from child according to the + // agg planning in `AggUtils`: Partial -> Final, PartialMerge -> Final, + // Partial -> PartialMerge, PartialMerge -> PartialMerge. + .filter(a => a.mode == Final || a.mode == PartialMerge).map(_.aggregateFunction) + .flatMap(_.inputAggBufferAttributes) + child.output.dropRight(aggAttrs.length) ++ aggAttrs + } else { + child.output + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 12a7a75e4327..617d69bfa75e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -121,7 +121,7 @@ case class HashAggregateExec( resultExpressions, (expressions, inputSchema) => newMutableProjection(expressions, inputSchema, subexpressionEliminationEnabled), - child.output, + inputAttributes, iter, testFallbackStartsAt, numOutputRows, @@ -254,7 +254,7 @@ case class HashAggregateExec( private def doConsumeWithoutKeys(ctx: CodegenContext, input: Seq[ExprCode]): String = { // only have DeclarativeAggregate val functions = aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate]) - val inputAttrs = functions.flatMap(_.aggBufferAttributes) ++ child.output + val inputAttrs = functions.flatMap(_.aggBufferAttributes) ++ inputAttributes val updateExpr = aggregateExpressions.flatMap { e => e.mode match { case Partial | Complete => @@ -817,7 +817,7 @@ case class HashAggregateExec( } } - val inputAttr = aggregateBufferAttributes ++ child.output + val inputAttr = aggregateBufferAttributes ++ inputAttributes // Here we set `currentVars(0)` to `currentVars(numBufferSlots)` to null, so that when // generating code for buffer columns, we use `INPUT_ROW`(will be the buffer row), while // generating input columns, we use `currentVars`. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala index fc615e3e81ed..10b9f17f6d82 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala @@ -121,7 +121,7 @@ case class ObjectHashAggregateExec( resultExpressions, (expressions, inputSchema) => newMutableProjection(expressions, inputSchema, subexpressionEliminationEnabled), - child.output, + inputAttributes, iter, fallbackCountThreshold, numOutputRows) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala index 34c3d8eb238a..be4bdc355ad6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala @@ -86,7 +86,7 @@ case class SortAggregateExec( val outputIter = new SortBasedAggregationIterator( partIndex, groupingExpressions, - child.output, + inputAttributes, iter, aggregateExpressions, aggregateAttributes, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 73259a0ed3b5..1a6925579acd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -772,4 +772,205 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { Row(Seq(0.0f, 0.0f), Row(0.0d, Double.NaN), Seq(Row(0.0d, Double.NaN)), 2) ) } + + test("SPARK-27581: DataFrame countDistinct(\"*\") shouldn't fail with AnalysisException") { + val df = sql("select id % 100 from range(100000)") + val distinctCount1 = df.select(expr("count(distinct(*))")) + val distinctCount2 = df.select(countDistinct("*")) + checkAnswer(distinctCount1, distinctCount2) + + val countAndDistinct = df.select(count("*"), countDistinct("*")) + checkAnswer(countAndDistinct, Row(100000, 100)) + } + + test("max_by") { + val yearOfMaxEarnings = + sql("SELECT course, max_by(year, earnings) FROM courseSales GROUP BY course") + checkAnswer(yearOfMaxEarnings, Row("dotNET", 2013) :: Row("Java", 2013) :: Nil) + + checkAnswer( + sql("SELECT max_by(x, y) FROM VALUES (('a', 10)), (('b', 50)), (('c', 20)) AS tab(x, y)"), + Row("b") :: Nil + ) + + checkAnswer( + sql("SELECT max_by(x, y) FROM VALUES (('a', 10)), (('b', null)), (('c', 20)) AS tab(x, y)"), + Row("c") :: Nil + ) + + checkAnswer( + sql("SELECT max_by(x, y) FROM VALUES (('a', null)), (('b', null)), (('c', 20)) AS tab(x, y)"), + Row("c") :: Nil + ) + + checkAnswer( + sql("SELECT max_by(x, y) FROM VALUES (('a', 10)), (('b', 50)), (('c', null)) AS tab(x, y)"), + Row("b") :: Nil + ) + + checkAnswer( + sql("SELECT max_by(x, y) FROM VALUES (('a', null)), (('b', null)) AS tab(x, y)"), + Row(null) :: Nil + ) + + // structs as ordering value. + checkAnswer( + sql("select max_by(x, y) FROM VALUES (('a', (10, 20))), (('b', (10, 50))), " + + "(('c', (10, 60))) AS tab(x, y)"), + Row("c") :: Nil + ) + + checkAnswer( + sql("select max_by(x, y) FROM VALUES (('a', (10, 20))), (('b', (10, 50))), " + + "(('c', null)) AS tab(x, y)"), + Row("b") :: Nil + ) + + withTempView("tempView") { + val dfWithMap = Seq((0, "a"), (1, "b"), (2, "c")) + .toDF("x", "y") + .select($"x", map($"x", $"y").as("y")) + .createOrReplaceTempView("tempView") + val error = intercept[AnalysisException] { + sql("SELECT max_by(x, y) FROM tempView").show + } + assert( + error.message.contains("function max_by does not support ordering on type map")) + } + } + + test("min_by") { + val yearOfMinEarnings = + sql("SELECT course, min_by(year, earnings) FROM courseSales GROUP BY course") + checkAnswer(yearOfMinEarnings, Row("dotNET", 2012) :: Row("Java", 2012) :: Nil) + + checkAnswer( + sql("SELECT min_by(x, y) FROM VALUES (('a', 10)), (('b', 50)), (('c', 20)) AS tab(x, y)"), + Row("a") :: Nil + ) + + checkAnswer( + sql("SELECT min_by(x, y) FROM VALUES (('a', 10)), (('b', null)), (('c', 20)) AS tab(x, y)"), + Row("a") :: Nil + ) + + checkAnswer( + sql("SELECT min_by(x, y) FROM VALUES (('a', null)), (('b', null)), (('c', 20)) AS tab(x, y)"), + Row("c") :: Nil + ) + + checkAnswer( + sql("SELECT min_by(x, y) FROM VALUES (('a', 10)), (('b', 50)), (('c', null)) AS tab(x, y)"), + Row("a") :: Nil + ) + + checkAnswer( + sql("SELECT min_by(x, y) FROM VALUES (('a', null)), (('b', null)) AS tab(x, y)"), + Row(null) :: Nil + ) + + // structs as ordering value. + checkAnswer( + sql("select min_by(x, y) FROM VALUES (('a', (10, 20))), (('b', (10, 50))), " + + "(('c', (10, 60))) AS tab(x, y)"), + Row("a") :: Nil + ) + + checkAnswer( + sql("select min_by(x, y) FROM VALUES (('a', null)), (('b', (10, 50))), " + + "(('c', (10, 60))) AS tab(x, y)"), + Row("b") :: Nil + ) + + withTempView("tempView") { + val dfWithMap = Seq((0, "a"), (1, "b"), (2, "c")) + .toDF("x", "y") + .select($"x", map($"x", $"y").as("y")) + .createOrReplaceTempView("tempView") + val error = intercept[AnalysisException] { + sql("SELECT min_by(x, y) FROM tempView").show + } + assert( + error.message.contains("function min_by does not support ordering on type map")) + } + } + + test("count_if") { + withTempView("tempView") { + Seq(("a", None), ("a", Some(1)), ("a", Some(2)), ("a", Some(3)), + ("b", None), ("b", Some(4)), ("b", Some(5)), ("b", Some(6))) + .toDF("x", "y") + .createOrReplaceTempView("tempView") + + checkAnswer( + sql("SELECT COUNT_IF(NULL), COUNT_IF(y % 2 = 0), COUNT_IF(y % 2 <> 0), " + + "COUNT_IF(y IS NULL) FROM tempView"), + Row(0L, 3L, 3L, 2L)) + + checkAnswer( + sql("SELECT x, COUNT_IF(NULL), COUNT_IF(y % 2 = 0), COUNT_IF(y % 2 <> 0), " + + "COUNT_IF(y IS NULL) FROM tempView GROUP BY x"), + Row("a", 0L, 1L, 2L, 1L) :: Row("b", 0L, 2L, 1L, 1L) :: Nil) + + checkAnswer( + sql("SELECT x FROM tempView GROUP BY x HAVING COUNT_IF(y % 2 = 0) = 1"), + Row("a")) + + checkAnswer( + sql("SELECT x FROM tempView GROUP BY x HAVING COUNT_IF(y % 2 = 0) = 2"), + Row("b")) + + checkAnswer( + sql("SELECT x FROM tempView GROUP BY x HAVING COUNT_IF(y IS NULL) > 0"), + Row("a") :: Row("b") :: Nil) + + checkAnswer( + sql("SELECT x FROM tempView GROUP BY x HAVING COUNT_IF(NULL) > 0"), + Nil) + + val error = intercept[AnalysisException] { + sql("SELECT COUNT_IF(x) FROM tempView") + } + assert(error.message.contains("function count_if requires boolean type")) + } + } + + Seq(true, false).foreach { value => + test(s"SPARK-31620: agg with subquery (whole-stage-codegen = $value)") { + withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> value.toString) { + withTempView("t1", "t2") { + sql("create temporary view t1 as select * from values (1, 2) as t1(a, b)") + sql("create temporary view t2 as select * from values (3, 4) as t2(c, d)") + + // test without grouping keys + checkAnswer(sql("select sum(if(c > (select a from t1), d, 0)) as csum from t2"), + Row(4) :: Nil) + + // test with grouping keys + checkAnswer(sql("select c, sum(if(c > (select a from t1), d, 0)) as csum from " + + "t2 group by c"), Row(3, 4) :: Nil) + + // test with distinct + checkAnswer(sql("select avg(distinct(d)), sum(distinct(if(c > (select a from t1)," + + " d, 0))) as csum from t2 group by c"), Row(4, 4) :: Nil) + + // test subquery with agg + checkAnswer(sql("select sum(distinct(if(c > (select sum(distinct(a)) from t1)," + + " d, 0))) as csum from t2 group by c"), Row(4) :: Nil) + + // test SortAggregateExec + var df = sql("select max(if(c > (select a from t1), 'str1', 'str2')) as csum from t2") + assert(df.queryExecution.executedPlan + .find { case _: SortAggregateExec => true }.isDefined) + checkAnswer(df, Row("str1") :: Nil) + + // test ObjectHashAggregateExec + df = sql("select collect_list(d), sum(if(c > (select a from t1), d, 0)) as csum from t2") + assert(df.queryExecution.executedPlan + .find { case _: ObjectHashAggregateExec => true }.isDefined) + checkAnswer(df, Row(Array(4), 4) :: Nil) + } + } + } + } }