Skip to content

Commit

Permalink
[SPARK-43780][SQL] Support correlated references in join predicates f…
Browse files Browse the repository at this point in the history
…or scalar and lateral subqueries

### What changes were proposed in this pull request?

This PR adds support to subqueries that involve joins with correlated references in join predicates, e.g.

```
select * from t0 join lateral (select * from t1 join t2 on t1a = t2a and t1a = t0a);
```

(full example in https://issues.apache.org/jira/browse/SPARK-43780)

Currently we only handle scalar and lateral subqueries.

### Why are the changes needed?

This is a valid SQL that is not yet supported by Spark SQL.

### Does this PR introduce _any_ user-facing change?

Yes, previously unsupported queries become supported.

### How was this patch tested?

Query and unit tests

Closes apache#41301 from agubichev/spark-43780-corr-predicate.

Authored-by: Andrey Gubichev <andrey.gubichev@databricks.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
agubichev authored and vpolet committed Aug 24, 2023
1 parent af96ba6 commit 0ae4713
Show file tree
Hide file tree
Showing 13 changed files with 605 additions and 11 deletions.
Expand Up @@ -1173,6 +1173,7 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB
def canHostOuter(plan: LogicalPlan): Boolean = plan match {
case _: Filter => true
case _: Project => usingDecorrelateInnerQueryFramework
case _: Join => usingDecorrelateInnerQueryFramework
case _ => false
}

Expand Down
Expand Up @@ -804,18 +804,88 @@ object DecorrelateInnerQuery extends PredicateHelper {
(d.copy(child = newChild), joinCond, outerReferenceMap)

case j @ Join(left, right, joinType, condition, _) =>
val outerReferences = collectOuterReferences(j.expressions)
// Join condition containing outer references is not supported.
assert(outerReferences.isEmpty, s"Correlated column is not allowed in join: $j")
val newOuterReferences = parentOuterReferences ++ outerReferences
val shouldPushToLeft = joinType match {
// Given 'condition', computes the tuple of
// (correlated, uncorrelated, equalityCond, predicates, equivalences).
// 'correlated' and 'uncorrelated' are the conjuncts with (resp. without)
// outer (correlated) references. Furthermore, correlated conjuncts are split
// into 'equalityCond' (those that are equalities) and all rest ('predicates').
// 'equivalences' track equivalent attributes given 'equalityCond'.
// The split is only performed if 'shouldDecorrelatePredicates' is true.
// The input parameter 'isInnerJoin' is set to true for INNER joins and helps
// determine whether some predicates can be lifted up from the join (this is only
// valid for inner joins).
// Example: For a 'condition' A = outer(X) AND B > outer(Y) AND C = D, the output
// would be:
// correlated = (A = outer(X), B > outer(Y))
// uncorrelated = (C = D)
// equalityCond = (A = outer(X))
// predicates = (B > outer(Y))
// equivalences: (A -> outer(X))
def splitCorrelatedPredicate(
condition: Option[Expression],
isInnerJoin: Boolean,
shouldDecorrelatePredicates: Boolean):
(Seq[Expression], Seq[Expression], Seq[Expression],
Seq[Expression], AttributeMap[Attribute]) = {
// Similar to Filters above, we split the join condition (if present) into correlated
// and uncorrelated predicates, and separately handle joins under set and aggregation
// operations.
if (shouldDecorrelatePredicates) {
val conditions =
if (condition.isDefined) splitConjunctivePredicates(condition.get)
else Seq.empty[Expression]
val (correlated, uncorrelated) = conditions.partition(containsOuter)
var equivalences =
if (underSetOp) AttributeMap.empty[Attribute]
else collectEquivalentOuterReferences(correlated)
var (equalityCond, predicates) =
if (underSetOp) (Seq.empty[Expression], correlated)
else correlated.partition(canPullUpOverAgg)
// Fully preserve the join predicate for non-inner joins.
if (!isInnerJoin) {
predicates = correlated
equalityCond = Seq.empty[Expression]
equivalences = AttributeMap.empty[Attribute]
}
(correlated, uncorrelated, equalityCond, predicates, equivalences)
} else {
(Seq.empty[Expression],
if (condition.isEmpty) Seq.empty[Expression] else Seq(condition.get),
Seq.empty[Expression],
Seq.empty[Expression],
AttributeMap.empty[Attribute])
}
}

val shouldDecorrelatePredicates =
SQLConf.get.getConf(SQLConf.DECORRELATE_JOIN_PREDICATE_ENABLED)
if (!shouldDecorrelatePredicates) {
val outerReferences = collectOuterReferences(j.expressions)
// Join condition containing outer references is not supported.
assert(outerReferences.isEmpty, s"Correlated column is not allowed in join: $j")
}
val (correlated, uncorrelated, equalityCond, predicates, equivalences) =
splitCorrelatedPredicate(condition, joinType == Inner, shouldDecorrelatePredicates)
val outerReferences = collectOuterReferences(j.expressions) ++
collectOuterReferences(predicates)
val newOuterReferences =
parentOuterReferences ++ outerReferences -- equivalences.keySet
var shouldPushToLeft = joinType match {
case LeftOuter | LeftSemiOrAnti(_) | FullOuter => true
case _ => hasOuterReferences(left)
}
val shouldPushToRight = joinType match {
case RightOuter | FullOuter => true
case _ => hasOuterReferences(right)
}
if (shouldDecorrelatePredicates && !shouldPushToLeft && !shouldPushToRight
&& !predicates.isEmpty) {
// Neither left nor right children of the join have correlations, but the join
// predicate does, and the correlations can not be replaced via equivalences.
// Introduce a domain join on the left side of the join
// (chosen arbitrarily) to provide values for the correlated attribute reference.
shouldPushToLeft = true;
}
val (newLeft, leftJoinCond, leftOuterReferenceMap) = if (shouldPushToLeft) {
decorrelate(left, newOuterReferences, aggregated, underSetOp)
} else {
Expand All @@ -826,8 +896,13 @@ object DecorrelateInnerQuery extends PredicateHelper {
} else {
(right, Nil, AttributeMap.empty[Attribute])
}
val newOuterReferenceMap = leftOuterReferenceMap ++ rightOuterReferenceMap
val newJoinCond = leftJoinCond ++ rightJoinCond
val newOuterReferenceMap = leftOuterReferenceMap ++ rightOuterReferenceMap ++
equivalences
val newCorrelated =
if (shouldDecorrelatePredicates) {
replaceOuterReferences(correlated, newOuterReferenceMap)
} else Seq.empty[Expression]
val newJoinCond = leftJoinCond ++ rightJoinCond ++ equalityCond
// If we push the dependent join to both sides, we can augment the join condition
// such that both sides are matched on the domain attributes. For example,
// - Left Map: {outer(c1) = c1}
Expand All @@ -836,7 +911,8 @@ object DecorrelateInnerQuery extends PredicateHelper {
val augmentedConditions = leftOuterReferenceMap.flatMap {
case (outer, inner) => rightOuterReferenceMap.get(outer).map(EqualNullSafe(inner, _))
}
val newCondition = (condition ++ augmentedConditions).reduceOption(And)
val newCondition = (newCorrelated ++ uncorrelated
++ augmentedConditions).reduceOption(And)
val newJoin = j.copy(left = newLeft, right = newRight, condition = newCondition)
(newJoin, newJoinCond, newOuterReferenceMap)

Expand Down
Expand Up @@ -4370,6 +4370,16 @@ object SQLConf {
.checkValue(_ >= 0, "The threshold of cached local relations must not be negative")
.createWithDefault(64 * 1024 * 1024)

val DECORRELATE_JOIN_PREDICATE_ENABLED =
buildConf("spark.sql.optimizer.decorrelateJoinPredicate.enabled")
.internal()
.doc("Decorrelate scalar and lateral subqueries with correlated references in join " +
"predicates. This configuration is only effective when " +
"'${DECORRELATE_INNER_QUERY_ENABLED.key}' is true.")
.version("4.0.0")
.booleanConf
.createWithDefault(true)

/**
* Holds information about keys that have been deprecated.
*
Expand Down
Expand Up @@ -35,10 +35,13 @@ class DecorrelateInnerQuerySuite extends PlanTest {
val a3 = AttributeReference("a3", IntegerType)()
val b3 = AttributeReference("b3", IntegerType)()
val c3 = AttributeReference("c3", IntegerType)()
val a4 = AttributeReference("a4", IntegerType)()
val b4 = AttributeReference("b4", IntegerType)()
val t0 = OneRowRelation()
val testRelation = LocalRelation(a, b, c)
val testRelation2 = LocalRelation(x, y, z)
val testRelation3 = LocalRelation(a3, b3, c3)
val testRelation4 = LocalRelation(a4, b4)

private def hasOuterReferences(plan: LogicalPlan): Boolean = {
plan.exists(_.expressions.exists(SubExprUtils.containsOuter))
Expand Down Expand Up @@ -198,12 +201,15 @@ class DecorrelateInnerQuerySuite extends PlanTest {
val innerPlan =
Join(
testRelation.as("t1"),
Filter(OuterReference(y) === 3, testRelation),
Filter(OuterReference(y) === b3, testRelation3),
Inner,
Some(OuterReference(x) === a),
JoinHint.NONE)
val error = intercept[AssertionError] { DecorrelateInnerQuery(innerPlan, outerPlan.select()) }
assert(error.getMessage.contains("Correlated column is not allowed in join"))
val correctAnswer =
Join(
testRelation.as("t1"), testRelation3,
Inner, Some(a === a), JoinHint.NONE)
check(innerPlan, outerPlan, correctAnswer, Seq(b3 === y, x === a))
}

test("correlated values in project") {
Expand Down Expand Up @@ -454,4 +460,125 @@ class DecorrelateInnerQuerySuite extends PlanTest {
DomainJoin(Seq(x), testRelation))))
check(innerPlan, outerPlan, correctAnswer, Seq(x <=> x))
}

test("SPARK-43780: aggregation in subquery with correlated equi-join") {
// Join in the subquery is on equi-predicates, so all the correlated references can be
// substituted by equivalent ones from the outer query, and domain join is not needed.
val outerPlan = testRelation
val innerPlan =
Aggregate(
Seq.empty[Expression], Seq(Alias(count(Literal(1)), "a")()),
Project(Seq(x, y, a3, b3),
Join(testRelation2, testRelation3, Inner,
Some(And(x === a3, y === OuterReference(a))), JoinHint.NONE)))

val correctAnswer =
Aggregate(
Seq(y), Seq(Alias(count(Literal(1)), "a")(), y),
Project(Seq(x, y, a3, b3),
Join(testRelation2, testRelation3, Inner, Some(And(y === y, x === a3)), JoinHint.NONE)))
check(innerPlan, outerPlan, correctAnswer, Seq(y === a))
}

test("SPARK-43780: aggregation in subquery with correlated non-equi-join") {
// Join in the subquery is on non-equi-predicate, so we introduce a DomainJoin.
val outerPlan = testRelation
val innerPlan =
Aggregate(
Seq.empty[Expression], Seq(Alias(count(Literal(1)), "a")()),
Project(Seq(x, y, a3, b3),
Join(testRelation2, testRelation3, Inner,
Some(And(x === a3, y > OuterReference(a))), JoinHint.NONE)))
val correctAnswer =
Aggregate(
Seq(a), Seq(Alias(count(Literal(1)), "a")(), a),
Project(Seq(x, y, a3, b3, a),
Join(
DomainJoin(Seq(a), testRelation2),
testRelation3, Inner, Some(And(x === a3, y > a)), JoinHint.NONE)))
check(innerPlan, outerPlan, correctAnswer, Seq(a <=> a))
}

test("SPARK-43780: aggregation in subquery with correlated left join") {
// Join in the subquery is on equi-predicates, so all the correlated references can be
// substituted by equivalent ones from the outer query, and domain join is not needed.
val outerPlan = testRelation
val innerPlan =
Aggregate(
Seq.empty[Expression], Seq(Alias(count(Literal(1)), "a")()),
Project(Seq(x, y, a3, b3),
Join(testRelation2, testRelation3, LeftOuter,
Some(And(x === a3, y === OuterReference(a))), JoinHint.NONE)))

val correctAnswer =
Aggregate(
Seq(a), Seq(Alias(count(Literal(1)), "a")(), a),
Project(Seq(x, y, a3, b3, a),
Join(DomainJoin(Seq(a), testRelation2), testRelation3, LeftOuter,
Some(And(y === a, x === a3)), JoinHint.NONE)))
check(innerPlan, outerPlan, correctAnswer, Seq(a <=> a))
}

test("SPARK-43780: aggregation in subquery with correlated left join, " +
"correlation over right side") {
// Same as above, but the join predicate connects the outer reference and the column from the
// right (optional) side of the left join. Domain join is still not needed.
val outerPlan = testRelation
val innerPlan =
Aggregate(
Seq.empty[Expression], Seq(Alias(count(Literal(1)), "a")()),
Project(Seq(x, y, a3, b3),
Join(testRelation2, testRelation3, LeftOuter,
Some(And(x === a3, b3 === OuterReference(b))), JoinHint.NONE)))

val correctAnswer =
Aggregate(
Seq(b), Seq(Alias(count(Literal(1)), "a")(), b),
Project(Seq(x, y, a3, b3, b),
Join(DomainJoin(Seq(b), testRelation2), testRelation3, LeftOuter,
Some(And(b === b3, x === a3)), JoinHint.NONE)))
check(innerPlan, outerPlan, correctAnswer, Seq(b <=> b))
}

test("SPARK-43780: correlated left join preserves the join predicates") {
// Left outer join preserves both predicates after being decorrelated.
val outerPlan = testRelation
val innerPlan =
Filter(
IsNotNull(c3),
Project(Seq(x, y, a3, b3, c3),
Join(testRelation2, testRelation3, LeftOuter,
Some(And(x === a3, b3 === OuterReference(b))), JoinHint.NONE)))

val correctAnswer =
Filter(
IsNotNull(c3),
Project(Seq(x, y, a3, b3, c3, b),
Join(DomainJoin(Seq(b), testRelation2), testRelation3, LeftOuter,
Some(And(x === a3, b === b3)), JoinHint.NONE)))
check(innerPlan, outerPlan, correctAnswer, Seq(b <=> b))
}

test("SPARK-43780: union all in subquery with correlated join") {
val outerPlan = testRelation
val innerPlan =
Union(
Seq(Project(Seq(x, b3),
Join(testRelation2, testRelation3, Inner,
Some(And(x === a3, y === OuterReference(a))), JoinHint.NONE)),
Project(Seq(a4, b4),
testRelation4)))
val correctAnswer =
Union(
Seq(Project(Seq(x, b3, a),
Project(Seq(x, b3, a),
Join(
DomainJoin(Seq(a), testRelation2),
testRelation3, Inner,
Some(And(x === a3, y === a)), JoinHint.NONE))),
Project(Seq(a4, b4, a),
DomainJoin(Seq(a),
Project(Seq(a4, b4), testRelation4)))))
check(innerPlan, outerPlan, correctAnswer, Seq(a <=> a))
}
}
Expand Up @@ -795,6 +795,72 @@ Project [c1#x, c2#x]
+- LocalRelation [col1#x, col2#x]


-- !query
SELECT * FROM t1 JOIN lateral (SELECT * FROM t2 JOIN t4 ON t2.c1 = t4.c1 AND t2.c1 = t1.c1)
-- !query analysis
Project [c1#x, c2#x, c1#x, c2#x, c1#x, c2#x]
+- LateralJoin lateral-subquery#x [c1#x], Inner
: +- SubqueryAlias __auto_generated_subquery_name
: +- Project [c1#x, c2#x, c1#x, c2#x]
: +- Join Inner, ((c1#x = c1#x) AND (c1#x = outer(c1#x)))
: :- SubqueryAlias spark_catalog.default.t2
: : +- View (`spark_catalog`.`default`.`t2`, [c1#x,c2#x])
: : +- Project [cast(col1#x as int) AS c1#x, cast(col2#x as int) AS c2#x]
: : +- LocalRelation [col1#x, col2#x]
: +- SubqueryAlias spark_catalog.default.t4
: +- View (`spark_catalog`.`default`.`t4`, [c1#x,c2#x])
: +- Project [cast(col1#x as int) AS c1#x, cast(col2#x as int) AS c2#x]
: +- LocalRelation [col1#x, col2#x]
+- SubqueryAlias spark_catalog.default.t1
+- View (`spark_catalog`.`default`.`t1`, [c1#x,c2#x])
+- Project [cast(col1#x as int) AS c1#x, cast(col2#x as int) AS c2#x]
+- LocalRelation [col1#x, col2#x]


-- !query
SELECT * FROM t1 JOIN lateral (SELECT * FROM t2 JOIN t4 ON t2.c1 != t4.c1 AND t2.c1 != t1.c1)
-- !query analysis
Project [c1#x, c2#x, c1#x, c2#x, c1#x, c2#x]
+- LateralJoin lateral-subquery#x [c1#x], Inner
: +- SubqueryAlias __auto_generated_subquery_name
: +- Project [c1#x, c2#x, c1#x, c2#x]
: +- Join Inner, (NOT (c1#x = c1#x) AND NOT (c1#x = outer(c1#x)))
: :- SubqueryAlias spark_catalog.default.t2
: : +- View (`spark_catalog`.`default`.`t2`, [c1#x,c2#x])
: : +- Project [cast(col1#x as int) AS c1#x, cast(col2#x as int) AS c2#x]
: : +- LocalRelation [col1#x, col2#x]
: +- SubqueryAlias spark_catalog.default.t4
: +- View (`spark_catalog`.`default`.`t4`, [c1#x,c2#x])
: +- Project [cast(col1#x as int) AS c1#x, cast(col2#x as int) AS c2#x]
: +- LocalRelation [col1#x, col2#x]
+- SubqueryAlias spark_catalog.default.t1
+- View (`spark_catalog`.`default`.`t1`, [c1#x,c2#x])
+- Project [cast(col1#x as int) AS c1#x, cast(col2#x as int) AS c2#x]
+- LocalRelation [col1#x, col2#x]


-- !query
SELECT * FROM t1 LEFT JOIN lateral (SELECT * FROM t4 LEFT JOIN t2 ON t2.c1 = t4.c1 AND t2.c1 = t1.c1)
-- !query analysis
Project [c1#x, c2#x, c1#x, c2#x, c1#x, c2#x]
+- LateralJoin lateral-subquery#x [c1#x], LeftOuter
: +- SubqueryAlias __auto_generated_subquery_name
: +- Project [c1#x, c2#x, c1#x, c2#x]
: +- Join LeftOuter, ((c1#x = c1#x) AND (c1#x = outer(c1#x)))
: :- SubqueryAlias spark_catalog.default.t4
: : +- View (`spark_catalog`.`default`.`t4`, [c1#x,c2#x])
: : +- Project [cast(col1#x as int) AS c1#x, cast(col2#x as int) AS c2#x]
: : +- LocalRelation [col1#x, col2#x]
: +- SubqueryAlias spark_catalog.default.t2
: +- View (`spark_catalog`.`default`.`t2`, [c1#x,c2#x])
: +- Project [cast(col1#x as int) AS c1#x, cast(col2#x as int) AS c2#x]
: +- LocalRelation [col1#x, col2#x]
+- SubqueryAlias spark_catalog.default.t1
+- View (`spark_catalog`.`default`.`t1`, [c1#x,c2#x])
+- Project [cast(col1#x as int) AS c1#x, cast(col2#x as int) AS c2#x]
+- LocalRelation [col1#x, col2#x]


-- !query
SELECT * FROM t1, LATERAL (SELECT COUNT(*) cnt FROM t2 WHERE c1 = t1.c1)
-- !query analysis
Expand Down

0 comments on commit 0ae4713

Please sign in to comment.