Skip to content

Commit

Permalink
[SPARK-41162][SQL][3.3] Fix anti- and semi-join for self-join with ag…
Browse files Browse the repository at this point in the history
…gregations

### What changes were proposed in this pull request?
Backport apache#39131 to branch-3.3.

Rule `PushDownLeftSemiAntiJoin` should not push an anti-join below an `Aggregate` when the join condition references an attribute that exists in its right plan and its left plan's child. This usually happens when the anti-join / semi-join is a self-join while `DeduplicateRelations` cannot deduplicate those attributes (in this example due to the projection of `value` to `id`).

This behaviour already exists for `Project` and `Union`, but `Aggregate` lacks this safety guard.

### Why are the changes needed?
Without this change, the optimizer creates an incorrect plan.

This example fails with `distinct()` (an aggregation), and succeeds without `distinct()`, but both queries are identical:
```scala
val ids = Seq(1, 2, 3).toDF("id").distinct()
val result = ids.withColumn("id", $"id" + 1).join(ids, Seq("id"), "left_anti").collect()
assert(result.length == 1)
```
With `distinct()`, rule `PushDownLeftSemiAntiJoin` creates a join condition `(value#907 + 1) = value#907`, which can never be true. This effectively removes the anti-join.

**Before this PR:**
The anti-join is fully removed from the plan.
```
== Physical Plan ==
AdaptiveSparkPlan (16)
+- == Final Plan ==
   LocalTableScan (1)

(16) AdaptiveSparkPlan
Output [1]: [id#900]
Arguments: isFinalPlan=true
```

This is caused by `PushDownLeftSemiAntiJoin` adding join condition `(value#907 + 1) = value#907`, which is wrong as because `id#910` in `(id#910 + 1) AS id#912` exists in the right child of the join as well as in the left grandchild:
```
=== Applying Rule org.apache.spark.sql.catalyst.optimizer.PushDownLeftSemiAntiJoin ===
!Join LeftAnti, (id#912 = id#910)                  Aggregate [id#910], [(id#910 + 1) AS id#912]
!:- Aggregate [id#910], [(id#910 + 1) AS id#912]   +- Project [value#907 AS id#910]
!:  +- Project [value#907 AS id#910]                  +- Join LeftAnti, ((value#907 + 1) = value#907)
!:     +- LocalRelation [value#907]                      :- LocalRelation [value#907]
!+- Aggregate [id#910], [id#910]                         +- Aggregate [id#910], [id#910]
!   +- Project [value#914 AS id#910]                        +- Project [value#914 AS id#910]
!      +- LocalRelation [value#914]                            +- LocalRelation [value#914]
```

The right child of the join and in the left grandchild would become the children of the pushed-down join, which creates an invalid join condition.

**After this PR:**
Join condition `(id#910 + 1) AS id#912` is understood to become ambiguous as both sides of the prospect join contain `id#910`. Hence, the join is not pushed down. The rule is then not applied any more.

The final plan contains the anti-join:
```
== Physical Plan ==
AdaptiveSparkPlan (24)
+- == Final Plan ==
   * BroadcastHashJoin LeftSemi BuildRight (14)
   :- * HashAggregate (7)
   :  +- AQEShuffleRead (6)
   :     +- ShuffleQueryStage (5), Statistics(sizeInBytes=48.0 B, rowCount=3)
   :        +- Exchange (4)
   :           +- * HashAggregate (3)
   :              +- * Project (2)
   :                 +- * LocalTableScan (1)
   +- BroadcastQueryStage (13), Statistics(sizeInBytes=1024.0 KiB, rowCount=3)
      +- BroadcastExchange (12)
         +- * HashAggregate (11)
            +- AQEShuffleRead (10)
               +- ShuffleQueryStage (9), Statistics(sizeInBytes=48.0 B, rowCount=3)
                  +- ReusedExchange (8)

(8) ReusedExchange [Reuses operator id: 4]
Output [1]: [id#898]

(24) AdaptiveSparkPlan
Output [1]: [id#900]
Arguments: isFinalPlan=true
```

### Does this PR introduce _any_ user-facing change?
It fixes correctness.

### How was this patch tested?
Unit tests in `DataFrameJoinSuite` and `LeftSemiAntiJoinPushDownSuite`.

Closes apache#39409 from EnricoMi/branch-antijoin-selfjoin-fix-3.3.

Authored-by: Enrico Minack <github@enrico.minack.dev>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
EnricoMi authored and cloud-fan committed Jan 6, 2023
1 parent 977e445 commit b97f79d
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,10 @@ object PushDownLeftSemiAntiJoin extends Rule[LogicalPlan]
}

// LeftSemi/LeftAnti over Aggregate, only push down if join can be planned as broadcast join.
case join @ Join(agg: Aggregate, rightOp, LeftSemiOrAnti(_), _, _)
case join @ Join(agg: Aggregate, rightOp, LeftSemiOrAnti(_), joinCond, _)
if agg.aggregateExpressions.forall(_.deterministic) && agg.groupingExpressions.nonEmpty &&
!agg.aggregateExpressions.exists(ScalarSubquery.hasCorrelatedScalarSubquery) &&
canPushThroughCondition(agg.children, joinCond, rightOp) &&
canPlanAsBroadcastHashJoin(join, conf) =>
val aliasMap = getAliasMap(agg)
val canPushDownPredicate = (predicate: Expression) => {
Expand Down Expand Up @@ -105,11 +106,11 @@ object PushDownLeftSemiAntiJoin extends Rule[LogicalPlan]
}

/**
* Check if we can safely push a join through a project or union by making sure that attributes
* referred in join condition do not contain the same attributes as the plan they are moved
* into. This can happen when both sides of join refers to the same source (self join). This
* function makes sure that the join condition refers to attributes that are not ambiguous (i.e
* present in both the legs of the join) or else the resultant plan will be invalid.
* Check if we can safely push a join through a project, aggregate, or union by making sure that
* attributes referred in join condition do not contain the same attributes as the plan they are
* moved into. This can happen when both sides of join refers to the same source (self join).
* This function makes sure that the join condition refers to attributes that are not ambiguous
* (i.e present in both the legs of the join) or else the resultant plan will be invalid.
*/
private def canPushThroughCondition(
plans: Seq[LogicalPlan],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.IntegerType

class LeftSemiPushdownSuite extends PlanTest {
class LeftSemiAntiJoinPushDownSuite extends PlanTest {

object Optimize extends RuleExecutor[LogicalPlan] {
val batches =
Expand All @@ -46,7 +46,7 @@ class LeftSemiPushdownSuite extends PlanTest {
val testRelation1 = LocalRelation('d.int)
val testRelation2 = LocalRelation('e.int)

test("Project: LeftSemiAnti join pushdown") {
test("Project: LeftSemi join pushdown") {
val originalQuery = testRelation
.select(star())
.join(testRelation1, joinType = LeftSemi, condition = Some('b === 'd))
Expand All @@ -59,7 +59,7 @@ class LeftSemiPushdownSuite extends PlanTest {
comparePlans(optimized, correctAnswer)
}

test("Project: LeftSemiAnti join no pushdown because of non-deterministic proj exprs") {
test("Project: LeftSemi join no pushdown - non-deterministic proj exprs") {
val originalQuery = testRelation
.select(Rand(1), 'b, 'c)
.join(testRelation1, joinType = LeftSemi, condition = Some('b === 'd))
Expand All @@ -68,7 +68,7 @@ class LeftSemiPushdownSuite extends PlanTest {
comparePlans(optimized, originalQuery.analyze)
}

test("Project: LeftSemiAnti join non correlated scalar subq") {
test("Project: LeftSemi join pushdown - non-correlated scalar subq") {
val subq = ScalarSubquery(testRelation.groupBy('b)(sum('c).as("sum")).analyze)
val originalQuery = testRelation
.select(subq.as("sum"))
Expand All @@ -83,7 +83,7 @@ class LeftSemiPushdownSuite extends PlanTest {
comparePlans(optimized, correctAnswer)
}

test("Project: LeftSemiAnti join no pushdown - correlated scalar subq in projection list") {
test("Project: LeftSemi join no pushdown - correlated scalar subq in projection list") {
val testRelation2 = LocalRelation('e.int, 'f.int)
val subqPlan = testRelation2.groupBy('e)(sum('f).as("sum")).where('e === 'a)
val subqExpr = ScalarSubquery(subqPlan)
Expand All @@ -95,7 +95,7 @@ class LeftSemiPushdownSuite extends PlanTest {
comparePlans(optimized, originalQuery.analyze)
}

test("Aggregate: LeftSemiAnti join pushdown") {
test("Aggregate: LeftSemi join pushdown") {
val originalQuery = testRelation
.groupBy('b)('b, sum('c))
.join(testRelation1, joinType = LeftSemi, condition = Some('b === 'd))
Expand All @@ -109,7 +109,7 @@ class LeftSemiPushdownSuite extends PlanTest {
comparePlans(optimized, correctAnswer)
}

test("Aggregate: LeftSemiAnti join no pushdown due to non-deterministic aggr expressions") {
test("Aggregate: LeftSemi join no pushdown - non-deterministic aggr expressions") {
val originalQuery = testRelation
.groupBy('b)('b, Rand(10).as('c))
.join(testRelation1, joinType = LeftSemi, condition = Some('b === 'd))
Expand Down Expand Up @@ -142,7 +142,7 @@ class LeftSemiPushdownSuite extends PlanTest {
comparePlans(optimized, originalQuery.analyze)
}

test("LeftSemiAnti join over aggregate - no pushdown") {
test("Aggregate: LeftSemi join no pushdown") {
val originalQuery = testRelation
.groupBy('b)('b, sum('c).as('sum))
.join(testRelation1, joinType = LeftSemi, condition = Some('b === 'd && 'sum === 'd))
Expand All @@ -151,7 +151,7 @@ class LeftSemiPushdownSuite extends PlanTest {
comparePlans(optimized, originalQuery.analyze)
}

test("Aggregate: LeftSemiAnti join non-correlated scalar subq aggr exprs") {
test("Aggregate: LeftSemi join pushdown - non-correlated scalar subq aggr exprs") {
val subq = ScalarSubquery(testRelation.groupBy('b)(sum('c).as("sum")).analyze)
val originalQuery = testRelation
.groupBy('a) ('a, subq.as("sum"))
Expand All @@ -166,7 +166,7 @@ class LeftSemiPushdownSuite extends PlanTest {
comparePlans(optimized, correctAnswer)
}

test("LeftSemiAnti join over Window") {
test("Window: LeftSemi join pushdown") {
val winExpr = windowExpr(count('b), windowSpec('a :: Nil, 'b.asc :: Nil, UnspecifiedFrame))

val originalQuery = testRelation
Expand All @@ -184,7 +184,7 @@ class LeftSemiPushdownSuite extends PlanTest {
comparePlans(optimized, correctAnswer)
}

test("Window: LeftSemi partial pushdown") {
test("Window: LeftSemi join partial pushdown") {
// Attributes from join condition which does not refer to the window partition spec
// are kept up in the plan as a Filter operator above Window.
val winExpr = windowExpr(count('b), windowSpec('a :: Nil, 'b.asc :: Nil, UnspecifiedFrame))
Expand Down Expand Up @@ -224,7 +224,7 @@ class LeftSemiPushdownSuite extends PlanTest {
comparePlans(optimized, correctAnswer)
}

test("Union: LeftSemiAnti join pushdown") {
test("Union: LeftSemi join pushdown") {
val testRelation2 = LocalRelation('x.int, 'y.int, 'z.int)

val originalQuery = Union(Seq(testRelation, testRelation2))
Expand All @@ -240,7 +240,7 @@ class LeftSemiPushdownSuite extends PlanTest {
comparePlans(optimized, correctAnswer)
}

test("Union: LeftSemiAnti join pushdown in self join scenario") {
test("Union: LeftSemi join pushdown in self join scenario") {
val testRelation2 = LocalRelation('x.int, 'y.int, 'z.int)
val attrX = testRelation2.output.head

Expand All @@ -259,7 +259,7 @@ class LeftSemiPushdownSuite extends PlanTest {
comparePlans(optimized, correctAnswer)
}

test("Unary: LeftSemiAnti join pushdown") {
test("Unary: LeftSemi join pushdown") {
val originalQuery = testRelation
.select(star())
.repartition(1)
Expand All @@ -274,7 +274,7 @@ class LeftSemiPushdownSuite extends PlanTest {
comparePlans(optimized, correctAnswer)
}

test("Unary: LeftSemiAnti join pushdown - empty join condition") {
test("Unary: LeftSemi join pushdown - empty join condition") {
val originalQuery = testRelation
.select(star())
.repartition(1)
Expand All @@ -289,7 +289,7 @@ class LeftSemiPushdownSuite extends PlanTest {
comparePlans(optimized, correctAnswer)
}

test("Unary: LeftSemi join pushdown - partial pushdown") {
test("Unary: LeftSemi join partial pushdown") {
val testRelationWithArrayType = LocalRelation('a.int, 'b.int, 'c_arr.array(IntegerType))
val originalQuery = testRelationWithArrayType
.generate(Explode('c_arr), alias = Some("arr"), outputNames = Seq("out_col"))
Expand All @@ -305,7 +305,7 @@ class LeftSemiPushdownSuite extends PlanTest {
comparePlans(optimized, correctAnswer)
}

test("Unary: LeftAnti join pushdown - no pushdown") {
test("Unary: LeftAnti join no pushdown") {
val testRelationWithArrayType = LocalRelation('a.int, 'b.int, 'c_arr.array(IntegerType))
val originalQuery = testRelationWithArrayType
.generate(Explode('c_arr), alias = Some("arr"), outputNames = Seq("out_col"))
Expand All @@ -315,7 +315,7 @@ class LeftSemiPushdownSuite extends PlanTest {
comparePlans(optimized, originalQuery.analyze)
}

test("Unary: LeftSemiAnti join pushdown - no pushdown") {
test("Unary: LeftSemi join - no pushdown") {
val testRelationWithArrayType = LocalRelation('a.int, 'b.int, 'c_arr.array(IntegerType))
val originalQuery = testRelationWithArrayType
.generate(Explode('c_arr), alias = Some("arr"), outputNames = Seq("out_col"))
Expand All @@ -325,7 +325,7 @@ class LeftSemiPushdownSuite extends PlanTest {
comparePlans(optimized, originalQuery.analyze)
}

test("Unary: LeftSemi join push down through Expand") {
test("Unary: LeftSemi join pushdown through Expand") {
val expand = Expand(Seq(Seq('a, 'b, "null"), Seq('a, "null", 'c)),
Seq('a, 'b, 'c), testRelation)
val originalQuery = expand
Expand Down Expand Up @@ -431,6 +431,25 @@ class LeftSemiPushdownSuite extends PlanTest {
}
}

Seq(LeftSemi, LeftAnti).foreach { case jt =>
test(s"Aggregate: $jt join no pushdown - join condition refers left leg and right leg child") {
val aggregation = testRelation
.select('b.as("id"), 'c)
.groupBy('id)('id, sum('c).as("sum"))

// reference "b" exists in left leg, and the children of the right leg of the join
val originalQuery = aggregation.select(('id + 1).as("id_plus_1"), 'sum)
.join(aggregation, joinType = jt, condition = Some('id === 'id_plus_1))
val optimized = Optimize.execute(originalQuery.analyze)
val correctAnswer = testRelation
.select('b.as("id"), 'c)
.groupBy('id)(('id + 1).as("id_plus_1"), sum('c).as("sum"))
.join(aggregation, joinType = jt, condition = Some('id === 'id_plus_1))
.analyze
comparePlans(optimized, correctAnswer)
}
}

Seq(LeftSemi, LeftAnti).foreach { case outerJT =>
Seq(Inner, LeftOuter, RightOuter, Cross).foreach { case innerJT =>
test(s"$outerJT no pushdown - join condition refers none of the leg - join type $innerJT") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,24 @@ class DataFrameJoinSuite extends QueryTest
}
}

Seq("left_semi", "left_anti").foreach { joinType =>
test(s"SPARK-41162: $joinType self-joined aggregated dataframe") {
// aggregated dataframe
val ids = Seq(1, 2, 3).toDF("id").distinct()

// self-joined via joinType
val result = ids.withColumn("id", $"id" + 1)
.join(ids, usingColumns = Seq("id"), joinType = joinType).collect()

val expected = joinType match {
case "left_semi" => 2
case "left_anti" => 1
case _ => -1 // unsupported test type, test will always fail
}
assert(result.length == expected)
}
}

def extractLeftDeepInnerJoins(plan: LogicalPlan): Seq[LogicalPlan] = plan match {
case j @ Join(left, right, _: InnerLike, _, _) => right +: extractLeftDeepInnerJoins(left)
case Filter(_, child) => extractLeftDeepInnerJoins(child)
Expand Down

0 comments on commit b97f79d

Please sign in to comment.