Skip to content

Commit

Permalink
[SPARK-32861][SQL] GenerateExec should require column ordering
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
This PR updates the `RemoveRedundantProjects` rule to make `GenerateExec` require column ordering.

### Why are the changes needed?
`GenerateExec` was originally considered as a node that does not require column ordering. However, `GenerateExec` binds its input rows directly with its `requiredChildOutput` without using the child's output schema.
In `doExecute()`:
```scala
val proj = UnsafeProjection.create(output, output)
```
In `doConsume()`:
```scala
val values = if (requiredChildOutput.nonEmpty) {
  input
} else {
  Seq.empty
}
```
In this case, changing input column ordering will result in `GenerateExec` binding the wrong schema to the input columns. For example, if we do not require child columns to be ordered, the `requiredChildOutput` [a, b, c] will directly bind to the schema of the input columns [c, b, a], which is incorrect:
```
GenerateExec explode(array(a, b, c)), [a, b, c], false, [d]
  HashAggregate(keys=[a, b, c], functions=[], output=[c, b, a])
    ...
```

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
Unit test

Closes apache#29734 from allisonwang-db/generator.

Authored-by: allisonwang-db <66282705+allisonwang-db@users.noreply.github.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
allisonwang-db authored and cloud-fan committed Sep 16, 2020
1 parent 6051755 commit 2e3aa2f
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,9 @@ case class RemoveRedundantProjects(conf: SQLConf) extends Rule[SparkPlan] {
val keepOrdering = a.aggregateExpressions
.exists(ae => ae.mode.equals(Final) || ae.mode.equals(PartialMerge))
a.mapChildren(removeProject(_, keepOrdering))
case g: GenerateExec => g.mapChildren(removeProject(_, false))
// GenerateExec requires column ordering since it binds input rows directly with its
// requiredChildOutput without using child's output schema.
case g: GenerateExec => g.mapChildren(removeProject(_, true))
// JoinExec ordering requirement will inherit from its parent. If there is no ProjectExec in
// its ancestors, JoinExec should require output columns to be ordered.
case o => o.mapChildren(removeProject(_, requireOrdering))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,21 @@
package org.apache.spark.sql.execution

import org.apache.spark.sql.{DataFrame, QueryTest, Row}
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanHelper, DisableAdaptiveExecutionSuite, EnableAdaptiveExecutionSuite}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.{SharedSparkSession, SQLTestUtils}
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.Utils

class RemoveRedundantProjectsSuite extends QueryTest with SharedSparkSession with SQLTestUtils {
abstract class RemoveRedundantProjectsSuiteBase
extends QueryTest
with SharedSparkSession
with AdaptiveSparkPlanHelper {

private def assertProjectExecCount(df: DataFrame, expected: Int): Unit = {
withClue(df.queryExecution) {
val plan = df.queryExecution.executedPlan
val actual = plan.collectWithSubqueries { case p: ProjectExec => p }.size
val actual = collectWithSubqueries(plan) { case p: ProjectExec => p }.size
assert(actual == expected)
}
}
Expand Down Expand Up @@ -115,9 +119,41 @@ class RemoveRedundantProjectsSuite extends QueryTest with SharedSparkSession wit
assertProjectExec(query, 1, 2)
}

test("generate") {
val query = "select a, key, explode(d) from testView where a > 10"
assertProjectExec(query, 0, 1)
test("generate should require column ordering") {
withTempView("testData") {
spark.range(0, 10, 1)
.selectExpr("id as key", "id * 2 as a", "id * 3 as b")
.createOrReplaceTempView("testData")

val data = sql("select key, a, b, count(*) from testData group by key, a, b limit 2")
val df = data.selectExpr("a", "b", "key", "explode(array(key, a, b)) as d").filter("d > 0")
df.collect()
val plan = df.queryExecution.executedPlan
val numProjects = collectWithSubqueries(plan) { case p: ProjectExec => p }.length

// Create a new plan that reverse the GenerateExec output and add a new ProjectExec between
// GenerateExec and its child. This is to test if the ProjectExec is removed, the output of
// the query will be incorrect.
val newPlan = stripAQEPlan(plan) transform {
case g @ GenerateExec(_, requiredChildOutput, _, _, child) =>
g.copy(requiredChildOutput = requiredChildOutput.reverse,
child = ProjectExec(requiredChildOutput.reverse, child))
}

// Re-apply remove redundant project rule.
val rule = RemoveRedundantProjects(spark.sessionState.conf)
val newExecutedPlan = rule.apply(newPlan)
// The manually added ProjectExec node shouldn't be removed.
assert(collectWithSubqueries(newExecutedPlan) {
case p: ProjectExec => p
}.size == numProjects + 1)

// Check the original plan's output and the new plan's output are the same.
val expectedRows = plan.executeCollect()
val actualRows = newExecutedPlan.executeCollect()
assert(expectedRows.length == actualRows.length)
expectedRows.zip(actualRows).foreach { case (expected, actual) => assert(expected == actual) }
}
}

test("subquery") {
Expand All @@ -131,3 +167,9 @@ class RemoveRedundantProjectsSuite extends QueryTest with SharedSparkSession wit
}
}
}

class RemoveRedundantProjectsSuite extends RemoveRedundantProjectsSuiteBase
with DisableAdaptiveExecutionSuite

class RemoveRedundantProjectsSuiteAE extends RemoveRedundantProjectsSuiteBase
with EnableAdaptiveExecutionSuite

0 comments on commit 2e3aa2f

Please sign in to comment.