From ef0a76eeea30fabb04499908b04124464225f5fd Mon Sep 17 00:00:00 2001 From: Kris Mok Date: Tue, 21 Mar 2023 21:27:49 +0800 Subject: [PATCH] [SPARK-42851][SQL] Guard EquivalentExpressions.addExpr() with supportedExpression() ### What changes were proposed in this pull request? In `EquivalentExpressions.addExpr()`, add a guard `supportedExpression()` to make it consistent with `addExprTree()` and `getExprState()`. ### Why are the changes needed? This fixes a regression caused by https://github.com/apache/spark/pull/39010 which added the `supportedExpression()` to `addExprTree()` and `getExprState()` but not `addExpr()`. One example of a use case affected by the inconsistency is the `PhysicalAggregation` pattern in physical planning. There, it calls `addExpr()` to deduplicate the aggregate expressions, and then calls `getExprState()` to deduplicate the result expressions. Guarding inconsistently will cause the aggregate and result expressions go out of sync, eventually resulting in query execution error (or whole-stage codegen error). ### Does this PR introduce _any_ user-facing change? This fixes a regression affecting Spark 3.3.2+, where it may manifest as an error running aggregate operators with higher-order functions. Example running the SQL command: ```sql select max(transform(array(id), x -> x)), max(transform(array(id), x -> x)) from range(2) ``` example error message before the fix: ``` java.lang.IllegalStateException: Couldn't find max(transform(array(id#0L), lambdafunction(lambda x#2L, lambda x#2L, false)))#4 in [max(transform(array(id#0L), lambdafunction(lambda x#1L, lambda x#1L, false)))#3] ``` after the fix this error is gone. ### How was this patch tested? Added new test cases to `SubexpressionEliminationSuite` for the immediate issue, and to `DataFrameAggregateSuite` for an example of user-visible symptom. Closes #40473 from rednaxelafx/spark-42851. Authored-by: Kris Mok Signed-off-by: Wenchen Fan --- .../expressions/EquivalentExpressions.scala | 6 +++++- .../SubexpressionEliminationSuite.scala | 18 +++++++++++++++++- .../spark/sql/DataFrameAggregateSuite.scala | 7 +++++++ 3 files changed, 29 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala index 3ffd9f9d88750..f47391c049298 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala @@ -40,7 +40,11 @@ class EquivalentExpressions { * Returns true if there was already a matching expression. */ def addExpr(expr: Expression): Boolean = { - updateExprInMap(expr, equivalenceMap) + if (supportedExpression(expr)) { + updateExprInMap(expr, equivalenceMap) + } else { + false + } } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala index b16629f59aa2d..44d8ea3a112e1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala @@ -21,7 +21,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.plans.logical.LocalRelation import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.{BinaryType, DataType, IntegerType} +import org.apache.spark.sql.types.{BinaryType, DataType, IntegerType, ObjectType} class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHelper { test("Semantic equals and hash") { @@ -449,6 +449,22 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel assert(e2.getCommonSubexpressions.size == 1) assert(e2.getCommonSubexpressions.head == add) } + + test("SPARK-42851: Handle supportExpression consistently across add and get") { + val expr = { + val function = (lambda: Expression) => Add(lambda, Literal(1)) + val elementType = IntegerType + val colClass = classOf[Array[Int]] + val inputType = ObjectType(colClass) + val inputObject = BoundReference(0, inputType, nullable = true) + objects.MapObjects(function, inputObject, elementType, true, Option(colClass)) + } + val equivalence = new EquivalentExpressions + equivalence.addExpr(expr) + val hasMatching = equivalence.addExpr(expr) + val cseState = equivalence.getExprState(expr) + assert(hasMatching == cseState.isDefined) + } } case class CodegenFallbackExpression(child: Expression) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 737d31cc6e913..2ba9039166f48 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -1538,6 +1538,13 @@ class DataFrameAggregateSuite extends QueryTest ) checkAnswer(res, Row(1, 1, 1) :: Row(4, 1, 2) :: Nil) } + + test("SPARK-42851: common subexpression should consistently handle aggregate and result exprs") { + val res = sql( + "select max(transform(array(id), x -> x)), max(transform(array(id), x -> x)) from range(2)" + ) + checkAnswer(res, Row(Array(1), Array(1))) + } } case class B(c: Option[Double])