diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 48c38a9bd4cf3..c7346809f3fd7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -1173,6 +1173,7 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB def canHostOuter(plan: LogicalPlan): Boolean = plan match { case _: Filter => true case _: Project => usingDecorrelateInnerQueryFramework + case _: Join => usingDecorrelateInnerQueryFramework case _ => false } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/DecorrelateInnerQuery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/DecorrelateInnerQuery.scala index 86fa78e96a5f6..a3e264579f4de 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/DecorrelateInnerQuery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/DecorrelateInnerQuery.scala @@ -804,11 +804,73 @@ object DecorrelateInnerQuery extends PredicateHelper { (d.copy(child = newChild), joinCond, outerReferenceMap) case j @ Join(left, right, joinType, condition, _) => - val outerReferences = collectOuterReferences(j.expressions) - // Join condition containing outer references is not supported. - assert(outerReferences.isEmpty, s"Correlated column is not allowed in join: $j") - val newOuterReferences = parentOuterReferences ++ outerReferences - val shouldPushToLeft = joinType match { + // Given 'condition', computes the tuple of + // (correlated, uncorrelated, equalityCond, predicates, equivalences). + // 'correlated' and 'uncorrelated' are the conjuncts with (resp. without) + // outer (correlated) references. Furthermore, correlated conjuncts are split + // into 'equalityCond' (those that are equalities) and all rest ('predicates'). + // 'equivalences' track equivalent attributes given 'equalityCond'. + // The split is only performed if 'shouldDecorrelatePredicates' is true. + // The input parameter 'isInnerJoin' is set to true for INNER joins and helps + // determine whether some predicates can be lifted up from the join (this is only + // valid for inner joins). + // Example: For a 'condition' A = outer(X) AND B > outer(Y) AND C = D, the output + // would be: + // correlated = (A = outer(X), B > outer(Y)) + // uncorrelated = (C = D) + // equalityCond = (A = outer(X)) + // predicates = (B > outer(Y)) + // equivalences: (A -> outer(X)) + def splitCorrelatedPredicate( + condition: Option[Expression], + isInnerJoin: Boolean, + shouldDecorrelatePredicates: Boolean): + (Seq[Expression], Seq[Expression], Seq[Expression], + Seq[Expression], AttributeMap[Attribute]) = { + // Similar to Filters above, we split the join condition (if present) into correlated + // and uncorrelated predicates, and separately handle joins under set and aggregation + // operations. + if (shouldDecorrelatePredicates) { + val conditions = + if (condition.isDefined) splitConjunctivePredicates(condition.get) + else Seq.empty[Expression] + val (correlated, uncorrelated) = conditions.partition(containsOuter) + var equivalences = + if (underSetOp) AttributeMap.empty[Attribute] + else collectEquivalentOuterReferences(correlated) + var (equalityCond, predicates) = + if (underSetOp) (Seq.empty[Expression], correlated) + else correlated.partition(canPullUpOverAgg) + // Fully preserve the join predicate for non-inner joins. + if (!isInnerJoin) { + predicates = correlated + equalityCond = Seq.empty[Expression] + equivalences = AttributeMap.empty[Attribute] + } + (correlated, uncorrelated, equalityCond, predicates, equivalences) + } else { + (Seq.empty[Expression], + if (condition.isEmpty) Seq.empty[Expression] else Seq(condition.get), + Seq.empty[Expression], + Seq.empty[Expression], + AttributeMap.empty[Attribute]) + } + } + + val shouldDecorrelatePredicates = + SQLConf.get.getConf(SQLConf.DECORRELATE_JOIN_PREDICATE_ENABLED) + if (!shouldDecorrelatePredicates) { + val outerReferences = collectOuterReferences(j.expressions) + // Join condition containing outer references is not supported. + assert(outerReferences.isEmpty, s"Correlated column is not allowed in join: $j") + } + val (correlated, uncorrelated, equalityCond, predicates, equivalences) = + splitCorrelatedPredicate(condition, joinType == Inner, shouldDecorrelatePredicates) + val outerReferences = collectOuterReferences(j.expressions) ++ + collectOuterReferences(predicates) + val newOuterReferences = + parentOuterReferences ++ outerReferences -- equivalences.keySet + var shouldPushToLeft = joinType match { case LeftOuter | LeftSemiOrAnti(_) | FullOuter => true case _ => hasOuterReferences(left) } @@ -816,6 +878,14 @@ object DecorrelateInnerQuery extends PredicateHelper { case RightOuter | FullOuter => true case _ => hasOuterReferences(right) } + if (shouldDecorrelatePredicates && !shouldPushToLeft && !shouldPushToRight + && !predicates.isEmpty) { + // Neither left nor right children of the join have correlations, but the join + // predicate does, and the correlations can not be replaced via equivalences. + // Introduce a domain join on the left side of the join + // (chosen arbitrarily) to provide values for the correlated attribute reference. + shouldPushToLeft = true; + } val (newLeft, leftJoinCond, leftOuterReferenceMap) = if (shouldPushToLeft) { decorrelate(left, newOuterReferences, aggregated, underSetOp) } else { @@ -826,8 +896,13 @@ object DecorrelateInnerQuery extends PredicateHelper { } else { (right, Nil, AttributeMap.empty[Attribute]) } - val newOuterReferenceMap = leftOuterReferenceMap ++ rightOuterReferenceMap - val newJoinCond = leftJoinCond ++ rightJoinCond + val newOuterReferenceMap = leftOuterReferenceMap ++ rightOuterReferenceMap ++ + equivalences + val newCorrelated = + if (shouldDecorrelatePredicates) { + replaceOuterReferences(correlated, newOuterReferenceMap) + } else Seq.empty[Expression] + val newJoinCond = leftJoinCond ++ rightJoinCond ++ equalityCond // If we push the dependent join to both sides, we can augment the join condition // such that both sides are matched on the domain attributes. For example, // - Left Map: {outer(c1) = c1} @@ -836,7 +911,8 @@ object DecorrelateInnerQuery extends PredicateHelper { val augmentedConditions = leftOuterReferenceMap.flatMap { case (outer, inner) => rightOuterReferenceMap.get(outer).map(EqualNullSafe(inner, _)) } - val newCondition = (condition ++ augmentedConditions).reduceOption(And) + val newCondition = (newCorrelated ++ uncorrelated + ++ augmentedConditions).reduceOption(And) val newJoin = j.copy(left = newLeft, right = newRight, condition = newCondition) (newJoin, newJoinCond, newOuterReferenceMap) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index e4f335a9a08f0..ced3f3458c082 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -4370,6 +4370,16 @@ object SQLConf { .checkValue(_ >= 0, "The threshold of cached local relations must not be negative") .createWithDefault(64 * 1024 * 1024) + val DECORRELATE_JOIN_PREDICATE_ENABLED = + buildConf("spark.sql.optimizer.decorrelateJoinPredicate.enabled") + .internal() + .doc("Decorrelate scalar and lateral subqueries with correlated references in join " + + "predicates. This configuration is only effective when " + + "'${DECORRELATE_INNER_QUERY_ENABLED.key}' is true.") + .version("4.0.0") + .booleanConf + .createWithDefault(true) + /** * Holds information about keys that have been deprecated. * diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/DecorrelateInnerQuerySuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/DecorrelateInnerQuerySuite.scala index 304f7de4c6ab9..21ac8849fe224 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/DecorrelateInnerQuerySuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/DecorrelateInnerQuerySuite.scala @@ -35,10 +35,13 @@ class DecorrelateInnerQuerySuite extends PlanTest { val a3 = AttributeReference("a3", IntegerType)() val b3 = AttributeReference("b3", IntegerType)() val c3 = AttributeReference("c3", IntegerType)() + val a4 = AttributeReference("a4", IntegerType)() + val b4 = AttributeReference("b4", IntegerType)() val t0 = OneRowRelation() val testRelation = LocalRelation(a, b, c) val testRelation2 = LocalRelation(x, y, z) val testRelation3 = LocalRelation(a3, b3, c3) + val testRelation4 = LocalRelation(a4, b4) private def hasOuterReferences(plan: LogicalPlan): Boolean = { plan.exists(_.expressions.exists(SubExprUtils.containsOuter)) @@ -198,12 +201,15 @@ class DecorrelateInnerQuerySuite extends PlanTest { val innerPlan = Join( testRelation.as("t1"), - Filter(OuterReference(y) === 3, testRelation), + Filter(OuterReference(y) === b3, testRelation3), Inner, Some(OuterReference(x) === a), JoinHint.NONE) - val error = intercept[AssertionError] { DecorrelateInnerQuery(innerPlan, outerPlan.select()) } - assert(error.getMessage.contains("Correlated column is not allowed in join")) + val correctAnswer = + Join( + testRelation.as("t1"), testRelation3, + Inner, Some(a === a), JoinHint.NONE) + check(innerPlan, outerPlan, correctAnswer, Seq(b3 === y, x === a)) } test("correlated values in project") { @@ -454,4 +460,125 @@ class DecorrelateInnerQuerySuite extends PlanTest { DomainJoin(Seq(x), testRelation)))) check(innerPlan, outerPlan, correctAnswer, Seq(x <=> x)) } + + test("SPARK-43780: aggregation in subquery with correlated equi-join") { + // Join in the subquery is on equi-predicates, so all the correlated references can be + // substituted by equivalent ones from the outer query, and domain join is not needed. + val outerPlan = testRelation + val innerPlan = + Aggregate( + Seq.empty[Expression], Seq(Alias(count(Literal(1)), "a")()), + Project(Seq(x, y, a3, b3), + Join(testRelation2, testRelation3, Inner, + Some(And(x === a3, y === OuterReference(a))), JoinHint.NONE))) + + val correctAnswer = + Aggregate( + Seq(y), Seq(Alias(count(Literal(1)), "a")(), y), + Project(Seq(x, y, a3, b3), + Join(testRelation2, testRelation3, Inner, Some(And(y === y, x === a3)), JoinHint.NONE))) + check(innerPlan, outerPlan, correctAnswer, Seq(y === a)) + } + + test("SPARK-43780: aggregation in subquery with correlated non-equi-join") { + // Join in the subquery is on non-equi-predicate, so we introduce a DomainJoin. + val outerPlan = testRelation + val innerPlan = + Aggregate( + Seq.empty[Expression], Seq(Alias(count(Literal(1)), "a")()), + Project(Seq(x, y, a3, b3), + Join(testRelation2, testRelation3, Inner, + Some(And(x === a3, y > OuterReference(a))), JoinHint.NONE))) + val correctAnswer = + Aggregate( + Seq(a), Seq(Alias(count(Literal(1)), "a")(), a), + Project(Seq(x, y, a3, b3, a), + Join( + DomainJoin(Seq(a), testRelation2), + testRelation3, Inner, Some(And(x === a3, y > a)), JoinHint.NONE))) + check(innerPlan, outerPlan, correctAnswer, Seq(a <=> a)) + } + + test("SPARK-43780: aggregation in subquery with correlated left join") { + // Join in the subquery is on equi-predicates, so all the correlated references can be + // substituted by equivalent ones from the outer query, and domain join is not needed. + val outerPlan = testRelation + val innerPlan = + Aggregate( + Seq.empty[Expression], Seq(Alias(count(Literal(1)), "a")()), + Project(Seq(x, y, a3, b3), + Join(testRelation2, testRelation3, LeftOuter, + Some(And(x === a3, y === OuterReference(a))), JoinHint.NONE))) + + val correctAnswer = + Aggregate( + Seq(a), Seq(Alias(count(Literal(1)), "a")(), a), + Project(Seq(x, y, a3, b3, a), + Join(DomainJoin(Seq(a), testRelation2), testRelation3, LeftOuter, + Some(And(y === a, x === a3)), JoinHint.NONE))) + check(innerPlan, outerPlan, correctAnswer, Seq(a <=> a)) + } + + test("SPARK-43780: aggregation in subquery with correlated left join, " + + "correlation over right side") { + // Same as above, but the join predicate connects the outer reference and the column from the + // right (optional) side of the left join. Domain join is still not needed. + val outerPlan = testRelation + val innerPlan = + Aggregate( + Seq.empty[Expression], Seq(Alias(count(Literal(1)), "a")()), + Project(Seq(x, y, a3, b3), + Join(testRelation2, testRelation3, LeftOuter, + Some(And(x === a3, b3 === OuterReference(b))), JoinHint.NONE))) + + val correctAnswer = + Aggregate( + Seq(b), Seq(Alias(count(Literal(1)), "a")(), b), + Project(Seq(x, y, a3, b3, b), + Join(DomainJoin(Seq(b), testRelation2), testRelation3, LeftOuter, + Some(And(b === b3, x === a3)), JoinHint.NONE))) + check(innerPlan, outerPlan, correctAnswer, Seq(b <=> b)) + } + + test("SPARK-43780: correlated left join preserves the join predicates") { + // Left outer join preserves both predicates after being decorrelated. + val outerPlan = testRelation + val innerPlan = + Filter( + IsNotNull(c3), + Project(Seq(x, y, a3, b3, c3), + Join(testRelation2, testRelation3, LeftOuter, + Some(And(x === a3, b3 === OuterReference(b))), JoinHint.NONE))) + + val correctAnswer = + Filter( + IsNotNull(c3), + Project(Seq(x, y, a3, b3, c3, b), + Join(DomainJoin(Seq(b), testRelation2), testRelation3, LeftOuter, + Some(And(x === a3, b === b3)), JoinHint.NONE))) + check(innerPlan, outerPlan, correctAnswer, Seq(b <=> b)) + } + + test("SPARK-43780: union all in subquery with correlated join") { + val outerPlan = testRelation + val innerPlan = + Union( + Seq(Project(Seq(x, b3), + Join(testRelation2, testRelation3, Inner, + Some(And(x === a3, y === OuterReference(a))), JoinHint.NONE)), + Project(Seq(a4, b4), + testRelation4))) + val correctAnswer = + Union( + Seq(Project(Seq(x, b3, a), + Project(Seq(x, b3, a), + Join( + DomainJoin(Seq(a), testRelation2), + testRelation3, Inner, + Some(And(x === a3, y === a)), JoinHint.NONE))), + Project(Seq(a4, b4, a), + DomainJoin(Seq(a), + Project(Seq(a4, b4), testRelation4))))) + check(innerPlan, outerPlan, correctAnswer, Seq(a <=> a)) + } } diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/join-lateral.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/join-lateral.sql.out index 5225996c16b73..2d1eebc65c66e 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/join-lateral.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/join-lateral.sql.out @@ -795,6 +795,72 @@ Project [c1#x, c2#x] +- LocalRelation [col1#x, col2#x] +-- !query +SELECT * FROM t1 JOIN lateral (SELECT * FROM t2 JOIN t4 ON t2.c1 = t4.c1 AND t2.c1 = t1.c1) +-- !query analysis +Project [c1#x, c2#x, c1#x, c2#x, c1#x, c2#x] ++- LateralJoin lateral-subquery#x [c1#x], Inner + : +- SubqueryAlias __auto_generated_subquery_name + : +- Project [c1#x, c2#x, c1#x, c2#x] + : +- Join Inner, ((c1#x = c1#x) AND (c1#x = outer(c1#x))) + : :- SubqueryAlias spark_catalog.default.t2 + : : +- View (`spark_catalog`.`default`.`t2`, [c1#x,c2#x]) + : : +- Project [cast(col1#x as int) AS c1#x, cast(col2#x as int) AS c2#x] + : : +- LocalRelation [col1#x, col2#x] + : +- SubqueryAlias spark_catalog.default.t4 + : +- View (`spark_catalog`.`default`.`t4`, [c1#x,c2#x]) + : +- Project [cast(col1#x as int) AS c1#x, cast(col2#x as int) AS c2#x] + : +- LocalRelation [col1#x, col2#x] + +- SubqueryAlias spark_catalog.default.t1 + +- View (`spark_catalog`.`default`.`t1`, [c1#x,c2#x]) + +- Project [cast(col1#x as int) AS c1#x, cast(col2#x as int) AS c2#x] + +- LocalRelation [col1#x, col2#x] + + +-- !query +SELECT * FROM t1 JOIN lateral (SELECT * FROM t2 JOIN t4 ON t2.c1 != t4.c1 AND t2.c1 != t1.c1) +-- !query analysis +Project [c1#x, c2#x, c1#x, c2#x, c1#x, c2#x] ++- LateralJoin lateral-subquery#x [c1#x], Inner + : +- SubqueryAlias __auto_generated_subquery_name + : +- Project [c1#x, c2#x, c1#x, c2#x] + : +- Join Inner, (NOT (c1#x = c1#x) AND NOT (c1#x = outer(c1#x))) + : :- SubqueryAlias spark_catalog.default.t2 + : : +- View (`spark_catalog`.`default`.`t2`, [c1#x,c2#x]) + : : +- Project [cast(col1#x as int) AS c1#x, cast(col2#x as int) AS c2#x] + : : +- LocalRelation [col1#x, col2#x] + : +- SubqueryAlias spark_catalog.default.t4 + : +- View (`spark_catalog`.`default`.`t4`, [c1#x,c2#x]) + : +- Project [cast(col1#x as int) AS c1#x, cast(col2#x as int) AS c2#x] + : +- LocalRelation [col1#x, col2#x] + +- SubqueryAlias spark_catalog.default.t1 + +- View (`spark_catalog`.`default`.`t1`, [c1#x,c2#x]) + +- Project [cast(col1#x as int) AS c1#x, cast(col2#x as int) AS c2#x] + +- LocalRelation [col1#x, col2#x] + + +-- !query +SELECT * FROM t1 LEFT JOIN lateral (SELECT * FROM t4 LEFT JOIN t2 ON t2.c1 = t4.c1 AND t2.c1 = t1.c1) +-- !query analysis +Project [c1#x, c2#x, c1#x, c2#x, c1#x, c2#x] ++- LateralJoin lateral-subquery#x [c1#x], LeftOuter + : +- SubqueryAlias __auto_generated_subquery_name + : +- Project [c1#x, c2#x, c1#x, c2#x] + : +- Join LeftOuter, ((c1#x = c1#x) AND (c1#x = outer(c1#x))) + : :- SubqueryAlias spark_catalog.default.t4 + : : +- View (`spark_catalog`.`default`.`t4`, [c1#x,c2#x]) + : : +- Project [cast(col1#x as int) AS c1#x, cast(col2#x as int) AS c2#x] + : : +- LocalRelation [col1#x, col2#x] + : +- SubqueryAlias spark_catalog.default.t2 + : +- View (`spark_catalog`.`default`.`t2`, [c1#x,c2#x]) + : +- Project [cast(col1#x as int) AS c1#x, cast(col2#x as int) AS c2#x] + : +- LocalRelation [col1#x, col2#x] + +- SubqueryAlias spark_catalog.default.t1 + +- View (`spark_catalog`.`default`.`t1`, [c1#x,c2#x]) + +- Project [cast(col1#x as int) AS c1#x, cast(col2#x as int) AS c2#x] + +- LocalRelation [col1#x, col2#x] + + -- !query SELECT * FROM t1, LATERAL (SELECT COUNT(*) cnt FROM t2 WHERE c1 = t1.c1) -- !query analysis diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/subquery/scalar-subquery/scalar-subquery-predicate.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/subquery/scalar-subquery/scalar-subquery-predicate.sql.out index a55b1e717be15..76c9bec5fb8a5 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/subquery/scalar-subquery/scalar-subquery-predicate.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/subquery/scalar-subquery/scalar-subquery-predicate.sql.out @@ -1258,3 +1258,84 @@ Project [id#xL] : +- Range (1, 2, step=1, splits=None) +- SubqueryAlias t1 +- Range (1, 3, step=1, splits=None) + + +-- !query +SELECT * FROM t0 WHERE t0a < +(SELECT sum(t1c) FROM + (SELECT t1c + FROM t1 JOIN t2 ON (t1a = t0a AND t2b = t1b)) +) +-- !query analysis +Project [t0a#x, t0b#x] ++- Filter (cast(t0a#x as bigint) < scalar-subquery#x [t0a#x]) + : +- Aggregate [sum(t1c#x) AS sum(t1c)#xL] + : +- SubqueryAlias __auto_generated_subquery_name + : +- Project [t1c#x] + : +- Join Inner, ((t1a#x = outer(t0a#x)) AND (t2b#x = t1b#x)) + : :- SubqueryAlias t1 + : : +- View (`t1`, [t1a#x,t1b#x,t1c#x]) + : : +- Project [cast(col1#x as int) AS t1a#x, cast(col2#x as int) AS t1b#x, cast(col3#x as int) AS t1c#x] + : : +- LocalRelation [col1#x, col2#x, col3#x] + : +- SubqueryAlias t2 + : +- View (`t2`, [t2a#x,t2b#x,t2c#x]) + : +- Project [cast(col1#x as int) AS t2a#x, cast(col2#x as int) AS t2b#x, cast(col3#x as int) AS t2c#x] + : +- LocalRelation [col1#x, col2#x, col3#x] + +- SubqueryAlias t0 + +- View (`t0`, [t0a#x,t0b#x]) + +- Project [cast(col1#x as int) AS t0a#x, cast(col2#x as int) AS t0b#x] + +- LocalRelation [col1#x, col2#x] + + +-- !query +SELECT * FROM t0 WHERE t0a < +(SELECT sum(t1c) FROM + (SELECT t1c + FROM t1 JOIN t2 ON (t1a < t0a AND t2b >= t1b)) +) +-- !query analysis +Project [t0a#x, t0b#x] ++- Filter (cast(t0a#x as bigint) < scalar-subquery#x [t0a#x]) + : +- Aggregate [sum(t1c#x) AS sum(t1c)#xL] + : +- SubqueryAlias __auto_generated_subquery_name + : +- Project [t1c#x] + : +- Join Inner, ((t1a#x < outer(t0a#x)) AND (t2b#x >= t1b#x)) + : :- SubqueryAlias t1 + : : +- View (`t1`, [t1a#x,t1b#x,t1c#x]) + : : +- Project [cast(col1#x as int) AS t1a#x, cast(col2#x as int) AS t1b#x, cast(col3#x as int) AS t1c#x] + : : +- LocalRelation [col1#x, col2#x, col3#x] + : +- SubqueryAlias t2 + : +- View (`t2`, [t2a#x,t2b#x,t2c#x]) + : +- Project [cast(col1#x as int) AS t2a#x, cast(col2#x as int) AS t2b#x, cast(col3#x as int) AS t2c#x] + : +- LocalRelation [col1#x, col2#x, col3#x] + +- SubqueryAlias t0 + +- View (`t0`, [t0a#x,t0b#x]) + +- Project [cast(col1#x as int) AS t0a#x, cast(col2#x as int) AS t0b#x] + +- LocalRelation [col1#x, col2#x] + + +-- !query +SELECT * FROM t0 WHERE t0a < +(SELECT sum(t1c) FROM + (SELECT t1c + FROM t1 LEFT JOIN t2 ON (t1a = t0a AND t2b = t0b)) +) +-- !query analysis +Project [t0a#x, t0b#x] ++- Filter (cast(t0a#x as bigint) < scalar-subquery#x [t0a#x && t0b#x]) + : +- Aggregate [sum(t1c#x) AS sum(t1c)#xL] + : +- SubqueryAlias __auto_generated_subquery_name + : +- Project [t1c#x] + : +- Join LeftOuter, ((t1a#x = outer(t0a#x)) AND (t2b#x = outer(t0b#x))) + : :- SubqueryAlias t1 + : : +- View (`t1`, [t1a#x,t1b#x,t1c#x]) + : : +- Project [cast(col1#x as int) AS t1a#x, cast(col2#x as int) AS t1b#x, cast(col3#x as int) AS t1c#x] + : : +- LocalRelation [col1#x, col2#x, col3#x] + : +- SubqueryAlias t2 + : +- View (`t2`, [t2a#x,t2b#x,t2c#x]) + : +- Project [cast(col1#x as int) AS t2a#x, cast(col2#x as int) AS t2b#x, cast(col3#x as int) AS t2c#x] + : +- LocalRelation [col1#x, col2#x, col3#x] + +- SubqueryAlias t0 + +- View (`t0`, [t0a#x,t0b#x]) + +- Project [cast(col1#x as int) AS t0a#x, cast(col2#x as int) AS t0b#x] + +- LocalRelation [col1#x, col2#x] diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/subquery/scalar-subquery/scalar-subquery-set-op.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/subquery/scalar-subquery/scalar-subquery-set-op.sql.out index 790d9da94e149..3f9eeb2cd5922 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/subquery/scalar-subquery/scalar-subquery-set-op.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/subquery/scalar-subquery/scalar-subquery-set-op.sql.out @@ -1802,3 +1802,83 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException "fragment" : "SELECT sum(t0a) as d\n FROM t1" } ] } + + +-- !query +SELECT t0a, (SELECT sum(t1b) FROM + (SELECT t1b + FROM t1 join t2 ON (t1a = t0a and t1b = t2b) + UNION ALL + SELECT t2b + FROM t1 join t2 ON (t2a = t0a and t1a = t2a)) +) +FROM t0 +-- !query analysis +Project [t0a#x, scalar-subquery#x [t0a#x && t0a#x] AS scalarsubquery(t0a, t0a)#xL] +: +- Aggregate [sum(t1b#x) AS sum(t1b)#xL] +: +- SubqueryAlias __auto_generated_subquery_name +: +- Union false, false +: :- Project [t1b#x] +: : +- Join Inner, ((t1a#x = outer(t0a#x)) AND (t1b#x = t2b#x)) +: : :- SubqueryAlias t1 +: : : +- View (`t1`, [t1a#x,t1b#x,t1c#x]) +: : : +- Project [cast(col1#x as int) AS t1a#x, cast(col2#x as int) AS t1b#x, cast(col3#x as int) AS t1c#x] +: : : +- LocalRelation [col1#x, col2#x, col3#x] +: : +- SubqueryAlias t2 +: : +- View (`t2`, [t2a#x,t2b#x,t2c#x]) +: : +- Project [cast(col1#x as int) AS t2a#x, cast(col2#x as int) AS t2b#x, cast(col3#x as int) AS t2c#x] +: : +- LocalRelation [col1#x, col2#x, col3#x] +: +- Project [t2b#x] +: +- Join Inner, ((t2a#x = outer(t0a#x)) AND (t1a#x = t2a#x)) +: :- SubqueryAlias t1 +: : +- View (`t1`, [t1a#x,t1b#x,t1c#x]) +: : +- Project [cast(col1#x as int) AS t1a#x, cast(col2#x as int) AS t1b#x, cast(col3#x as int) AS t1c#x] +: : +- LocalRelation [col1#x, col2#x, col3#x] +: +- SubqueryAlias t2 +: +- View (`t2`, [t2a#x,t2b#x,t2c#x]) +: +- Project [cast(col1#x as int) AS t2a#x, cast(col2#x as int) AS t2b#x, cast(col3#x as int) AS t2c#x] +: +- LocalRelation [col1#x, col2#x, col3#x] ++- SubqueryAlias t0 + +- View (`t0`, [t0a#x,t0b#x]) + +- Project [cast(col1#x as int) AS t0a#x, cast(col2#x as int) AS t0b#x] + +- LocalRelation [col1#x, col2#x] + + +-- !query +SELECT t0a, (SELECT sum(t1b) FROM + (SELECT t1b + FROM t1 left join t2 ON (t1a = t0a and t1b = t2b) + UNION ALL + SELECT t2b + FROM t1 join t2 ON (t2a = t0a + 1 and t1a = t2a)) +) +FROM t0 +-- !query analysis +Project [t0a#x, scalar-subquery#x [t0a#x && t0a#x] AS scalarsubquery(t0a, t0a)#xL] +: +- Aggregate [sum(t1b#x) AS sum(t1b)#xL] +: +- SubqueryAlias __auto_generated_subquery_name +: +- Union false, false +: :- Project [t1b#x] +: : +- Join LeftOuter, ((t1a#x = outer(t0a#x)) AND (t1b#x = t2b#x)) +: : :- SubqueryAlias t1 +: : : +- View (`t1`, [t1a#x,t1b#x,t1c#x]) +: : : +- Project [cast(col1#x as int) AS t1a#x, cast(col2#x as int) AS t1b#x, cast(col3#x as int) AS t1c#x] +: : : +- LocalRelation [col1#x, col2#x, col3#x] +: : +- SubqueryAlias t2 +: : +- View (`t2`, [t2a#x,t2b#x,t2c#x]) +: : +- Project [cast(col1#x as int) AS t2a#x, cast(col2#x as int) AS t2b#x, cast(col3#x as int) AS t2c#x] +: : +- LocalRelation [col1#x, col2#x, col3#x] +: +- Project [t2b#x] +: +- Join Inner, ((t2a#x = (outer(t0a#x) + 1)) AND (t1a#x = t2a#x)) +: :- SubqueryAlias t1 +: : +- View (`t1`, [t1a#x,t1b#x,t1c#x]) +: : +- Project [cast(col1#x as int) AS t1a#x, cast(col2#x as int) AS t1b#x, cast(col3#x as int) AS t1c#x] +: : +- LocalRelation [col1#x, col2#x, col3#x] +: +- SubqueryAlias t2 +: +- View (`t2`, [t2a#x,t2b#x,t2c#x]) +: +- Project [cast(col1#x as int) AS t2a#x, cast(col2#x as int) AS t2b#x, cast(col3#x as int) AS t2c#x] +: +- LocalRelation [col1#x, col2#x, col3#x] ++- SubqueryAlias t0 + +- View (`t0`, [t0a#x,t0b#x]) + +- Project [cast(col1#x as int) AS t0a#x, cast(col2#x as int) AS t0b#x] + +- LocalRelation [col1#x, col2#x] diff --git a/sql/core/src/test/resources/sql-tests/inputs/join-lateral.sql b/sql/core/src/test/resources/sql-tests/inputs/join-lateral.sql index 29ff29d6630b9..2787a86597567 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/join-lateral.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/join-lateral.sql @@ -101,6 +101,11 @@ SELECT * FROM t1 WHERE c1 = (SELECT MIN(a) FROM t2, LATERAL (SELECT c1 AS a)); -- lateral join inside correlated subquery SELECT * FROM t1 WHERE c1 = (SELECT MIN(a) FROM t2, LATERAL (SELECT c1 AS a) WHERE c1 = t1.c1); +-- join condition has a correlated reference to the left side of the lateral join +SELECT * FROM t1 JOIN lateral (SELECT * FROM t2 JOIN t4 ON t2.c1 = t4.c1 AND t2.c1 = t1.c1); +SELECT * FROM t1 JOIN lateral (SELECT * FROM t2 JOIN t4 ON t2.c1 != t4.c1 AND t2.c1 != t1.c1); +SELECT * FROM t1 LEFT JOIN lateral (SELECT * FROM t4 LEFT JOIN t2 ON t2.c1 = t4.c1 AND t2.c1 = t1.c1); + -- COUNT bug with a single aggregate expression SELECT * FROM t1, LATERAL (SELECT COUNT(*) cnt FROM t2 WHERE c1 = t1.c1); diff --git a/sql/core/src/test/resources/sql-tests/inputs/subquery/scalar-subquery/scalar-subquery-predicate.sql b/sql/core/src/test/resources/sql-tests/inputs/subquery/scalar-subquery/scalar-subquery-predicate.sql index e015d57754999..a49f30773ca22 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/subquery/scalar-subquery/scalar-subquery-predicate.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/subquery/scalar-subquery/scalar-subquery-predicate.sql @@ -405,3 +405,23 @@ from range(1, 3) t1 where (select t2.id c from range (1, 2) t2 where t1.id = t2.id ) is not null; + +-- Correlated references in join predicates +SELECT * FROM t0 WHERE t0a < +(SELECT sum(t1c) FROM + (SELECT t1c + FROM t1 JOIN t2 ON (t1a = t0a AND t2b = t1b)) +); + +SELECT * FROM t0 WHERE t0a < +(SELECT sum(t1c) FROM + (SELECT t1c + FROM t1 JOIN t2 ON (t1a < t0a AND t2b >= t1b)) +); + +SELECT * FROM t0 WHERE t0a < +(SELECT sum(t1c) FROM + (SELECT t1c + FROM t1 LEFT JOIN t2 ON (t1a = t0a AND t2b = t0b)) +); + diff --git a/sql/core/src/test/resources/sql-tests/inputs/subquery/scalar-subquery/scalar-subquery-set-op.sql b/sql/core/src/test/resources/sql-tests/inputs/subquery/scalar-subquery/scalar-subquery-set-op.sql index 8f03f7e41004b..39e456611c03b 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/subquery/scalar-subquery/scalar-subquery-set-op.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/subquery/scalar-subquery/scalar-subquery-set-op.sql @@ -619,3 +619,23 @@ SELECT t0a, (SELECT sum(d) FROM FROM t2) ) FROM t0; + +-- Correlated references in join predicates +SELECT t0a, (SELECT sum(t1b) FROM + (SELECT t1b + FROM t1 join t2 ON (t1a = t0a and t1b = t2b) + UNION ALL + SELECT t2b + FROM t1 join t2 ON (t2a = t0a and t1a = t2a)) +) +FROM t0; + + +SELECT t0a, (SELECT sum(t1b) FROM + (SELECT t1b + FROM t1 left join t2 ON (t1a = t0a and t1b = t2b) + UNION ALL + SELECT t2b + FROM t1 join t2 ON (t2a = t0a + 1 and t1a = t2a)) +) +FROM t0; diff --git a/sql/core/src/test/resources/sql-tests/results/join-lateral.sql.out b/sql/core/src/test/resources/sql-tests/results/join-lateral.sql.out index 0bb83be0f03dd..33f084f3d869e 100644 --- a/sql/core/src/test/resources/sql-tests/results/join-lateral.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/join-lateral.sql.out @@ -572,6 +572,45 @@ struct 0 1 +-- !query +SELECT * FROM t1 JOIN lateral (SELECT * FROM t2 JOIN t4 ON t2.c1 = t4.c1 AND t2.c1 = t1.c1) +-- !query schema +struct +-- !query output +0 1 0 2 0 1 +0 1 0 2 0 2 +0 1 0 3 0 1 +0 1 0 3 0 2 + + +-- !query +SELECT * FROM t1 JOIN lateral (SELECT * FROM t2 JOIN t4 ON t2.c1 != t4.c1 AND t2.c1 != t1.c1) +-- !query schema +struct +-- !query output +1 2 0 2 1 1 +1 2 0 2 1 3 +1 2 0 3 1 1 +1 2 0 3 1 3 + + +-- !query +SELECT * FROM t1 LEFT JOIN lateral (SELECT * FROM t4 LEFT JOIN t2 ON t2.c1 = t4.c1 AND t2.c1 = t1.c1) +-- !query schema +struct +-- !query output +0 1 0 1 0 2 +0 1 0 1 0 3 +0 1 0 2 0 2 +0 1 0 2 0 3 +0 1 1 1 NULL NULL +0 1 1 3 NULL NULL +1 2 0 1 NULL NULL +1 2 0 2 NULL NULL +1 2 1 1 NULL NULL +1 2 1 3 NULL NULL + + -- !query SELECT * FROM t1, LATERAL (SELECT COUNT(*) cnt FROM t2 WHERE c1 = t1.c1) -- !query schema diff --git a/sql/core/src/test/resources/sql-tests/results/subquery/scalar-subquery/scalar-subquery-predicate.sql.out b/sql/core/src/test/resources/sql-tests/results/subquery/scalar-subquery/scalar-subquery-predicate.sql.out index ef5d941dc9791..302c5e6dd7e30 100644 --- a/sql/core/src/test/resources/sql-tests/results/subquery/scalar-subquery/scalar-subquery-predicate.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/subquery/scalar-subquery/scalar-subquery-predicate.sql.out @@ -660,3 +660,40 @@ where (select t2.id c struct -- !query output 1 + + +-- !query +SELECT * FROM t0 WHERE t0a < +(SELECT sum(t1c) FROM + (SELECT t1c + FROM t1 JOIN t2 ON (t1a = t0a AND t2b = t1b)) +) +-- !query schema +struct +-- !query output +1 1 + + +-- !query +SELECT * FROM t0 WHERE t0a < +(SELECT sum(t1c) FROM + (SELECT t1c + FROM t1 JOIN t2 ON (t1a < t0a AND t2b >= t1b)) +) +-- !query schema +struct +-- !query output +2 0 + + +-- !query +SELECT * FROM t0 WHERE t0a < +(SELECT sum(t1c) FROM + (SELECT t1c + FROM t1 LEFT JOIN t2 ON (t1a = t0a AND t2b = t0b)) +) +-- !query schema +struct +-- !query output +1 1 +2 0 diff --git a/sql/core/src/test/resources/sql-tests/results/subquery/scalar-subquery/scalar-subquery-set-op.sql.out b/sql/core/src/test/resources/sql-tests/results/subquery/scalar-subquery/scalar-subquery-set-op.sql.out index 2799728d48a6a..33a57a73be08e 100644 --- a/sql/core/src/test/resources/sql-tests/results/subquery/scalar-subquery/scalar-subquery-set-op.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/subquery/scalar-subquery/scalar-subquery-set-op.sql.out @@ -1041,3 +1041,35 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException "fragment" : "SELECT sum(t0a) as d\n FROM t1" } ] } + + +-- !query +SELECT t0a, (SELECT sum(t1b) FROM + (SELECT t1b + FROM t1 join t2 ON (t1a = t0a and t1b = t2b) + UNION ALL + SELECT t2b + FROM t1 join t2 ON (t2a = t0a and t1a = t2a)) +) +FROM t0 +-- !query schema +struct +-- !query output +1 2 +2 NULL + + +-- !query +SELECT t0a, (SELECT sum(t1b) FROM + (SELECT t1b + FROM t1 left join t2 ON (t1a = t0a and t1b = t2b) + UNION ALL + SELECT t2b + FROM t1 join t2 ON (t2a = t0a + 1 and t1a = t2a)) +) +FROM t0 +-- !query schema +struct +-- !query output +1 1 +2 1