diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 8faf0eda548ea..ed6e17a8eb465 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1011,24 +1011,24 @@ class Analyzer( private def pullOutCorrelatedPredicates(sub: LogicalPlan): (LogicalPlan, Seq[Expression]) = { val predicateMap = scala.collection.mutable.Map.empty[LogicalPlan, Seq[Expression]] - /** Make sure a plans' subtree does not contain a tagged predicate. */ - def failOnOuterReferenceInSubTree(p: LogicalPlan, msg: String): Unit = { - if (p.collect(predicateMap).nonEmpty) { - failAnalysis(s"Accessing outer query column is not allowed in $msg: $p") + // Make sure a plan's subtree does not contain outer references + def failOnOuterReferenceInSubTree(p: LogicalPlan): Unit = { + if (p.collectFirst(predicateMap).nonEmpty) { + failAnalysis(s"Accessing outer query column is not allowed in:\n$p") } } - /** Helper function for locating outer references. */ + // Helper function for locating outer references. def containsOuter(e: Expression): Boolean = { e.find(_.isInstanceOf[OuterReference]).isDefined } - /** Make sure a plans' expressions do not contain a tagged predicate. */ + // Make sure a plan's expressions do not contain outer references def failOnOuterReference(p: LogicalPlan): Unit = { if (p.expressions.exists(containsOuter)) { failAnalysis( "Expressions referencing the outer query are not supported outside of WHERE/HAVING " + - s"clauses: $p") + s"clauses:\n$p") } } @@ -1077,10 +1077,51 @@ class Analyzer( // Simplify the predicates before pulling them out. val transformed = BooleanSimplification(sub) transformUp { - // WARNING: - // Only Filter can host correlated expressions at this time - // Anyone adding a new "case" below needs to add the call to - // "failOnOuterReference" to disallow correlated expressions in it. + + // Whitelist operators allowed in a correlated subquery + // There are 4 categories: + // 1. Operators that are allowed anywhere in a correlated subquery, and, + // by definition of the operators, they either do not contain + // any columns or cannot host outer references. + // 2. Operators that are allowed anywhere in a correlated subquery + // so long as they do not host outer references. + // 3. Operators that need special handlings. These operators are + // Project, Filter, Join, Aggregate, and Generate. + // + // Any operators that are not in the above list are allowed + // in a correlated subquery only if they are not on a correlation path. + // In other word, these operators are allowed only under a correlation point. + // + // A correlation path is defined as the sub-tree of all the operators that + // are on the path from the operator hosting the correlated expressions + // up to the operator producing the correlated values. + + // Category 1: + // BroadcastHint, Distinct, LeafNode, Repartition, and SubqueryAlias + case p: BroadcastHint => + p + case p: Distinct => + p + case p: LeafNode => + p + case p: Repartition => + p + case p: SubqueryAlias => + p + + // Category 2: + // These operators can be anywhere in a correlated subquery. + // so long as they do not host outer references in the operators. + case p: Sort => + failOnOuterReference(p) + p + case p: RedistributeData => + failOnOuterReference(p) + p + + // Category 3: + // Filter is one of the two operators allowed to host correlated expressions. + // The other operator is Join. Filter can be anywhere in a correlated subquery. case f @ Filter(cond, child) => // Find all predicates with an outer reference. val (correlated, local) = splitConjunctivePredicates(cond).partition(containsOuter) @@ -1102,14 +1143,24 @@ class Analyzer( predicateMap += child -> xs child } + + // Project cannot host any correlated expressions + // but can be anywhere in a correlated subquery. case p @ Project(expressions, child) => failOnOuterReference(p) + val referencesToAdd = missingReferences(p) if (referencesToAdd.nonEmpty) { Project(expressions ++ referencesToAdd, child) } else { p } + + // Aggregate cannot host any correlated expressions + // It can be on a correlation path if the correlation contains + // only equality correlated predicates. + // It cannot be on a correlation path if the correlation has + // non-equality correlated predicates. case a @ Aggregate(grouping, expressions, child) => failOnOuterReference(a) failOnNonEqualCorrelatedPredicate(foundNonEqualCorrelatedPred, a) @@ -1120,48 +1171,55 @@ class Analyzer( } else { a } - case w : Window => - failOnOuterReference(w) - failOnNonEqualCorrelatedPredicate(foundNonEqualCorrelatedPred, w) - w - case j @ Join(left, _, RightOuter, _) => - failOnOuterReference(j) - failOnOuterReferenceInSubTree(left, "a RIGHT OUTER JOIN") - j - // SPARK-18578: Do not allow any correlated predicate - // in a Full (Outer) Join operator and its descendants - case j @ Join(_, _, FullOuter, _) => - failOnOuterReferenceInSubTree(j, "a FULL OUTER JOIN") - j - case j @ Join(_, right, jt, _) if !jt.isInstanceOf[InnerLike] => - failOnOuterReference(j) - failOnOuterReferenceInSubTree(right, "a LEFT (OUTER) JOIN") + + // Join can host correlated expressions. + case j @ Join(left, right, joinType, _) => + joinType match { + // Inner join, like Filter, can be anywhere. + case _: InnerLike => + failOnOuterReference(j) + + // Left outer join's right operand cannot be on a correlation path. + // LeftAnti and ExistenceJoin are special cases of LeftOuter. + // Note that ExistenceJoin cannot be expressed externally in both SQL and DataFrame + // so it should not show up here in Analysis phase. This is just a safety net. + // + // LeftSemi does not allow output from the right operand. + // Any correlated references in the subplan + // of the right operand cannot be pulled up. + case LeftOuter | LeftSemi | LeftAnti | ExistenceJoin(_) => + failOnOuterReference(j) + failOnOuterReferenceInSubTree(right) + + // Likewise, Right outer join's left operand cannot be on a correlation path. + case RightOuter => + failOnOuterReference(j) + failOnOuterReferenceInSubTree(left) + + // Any other join types not explicitly listed above, + // including Full outer join, are treated as Category 4. + case _ => + failOnOuterReferenceInSubTree(j) + } j - case u: Union => - failOnOuterReferenceInSubTree(u, "a UNION") - u - case s: SetOperation => - failOnOuterReferenceInSubTree(s.right, "an INTERSECT/EXCEPT") - s - case e: Expand => - failOnOuterReferenceInSubTree(e, "an EXPAND") - e - case l : LocalLimit => - failOnOuterReferenceInSubTree(l, "a LIMIT") - l - // Since LIMIT is represented as GlobalLimit(, (LocalLimit (, child)) - // and we are walking bottom up, we will fail on LocalLimit before - // reaching GlobalLimit. - // The code below is just a safety net. - case g : GlobalLimit => - failOnOuterReferenceInSubTree(g, "a LIMIT") - g - case s : Sample => - failOnOuterReferenceInSubTree(s, "a TABLESAMPLE") - s - case p => + + // Generator with join=true, i.e., expressed with + // LATERAL VIEW [OUTER], similar to inner join, + // allows to have correlation under it + // but must not host any outer references. + // Note: + // Generator with join=false is treated as Category 4. + case p @ Generate(generator, true, _, _, _, _) => failOnOuterReference(p) p + + // Category 4: Any other operators not in the above 3 categories + // cannot be on a correlation path, that is they are allowed only + // under a correlation point but they and their descendant operators + // are not allowed to have any correlated expressions. + case p => + failOnOuterReferenceInSubTree(p) + p } (transformed, predicateMap.values.flatten.toSeq) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 37f0c8ed19d37..75d9997582aa6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -932,7 +932,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper { split(joinCondition.map(splitConjunctivePredicates).getOrElse(Nil), left, right) joinType match { - case _: InnerLike | LeftSemi => + case _: InnerLike | LeftSemi => // push down the single side only join filter for both sides sub queries val newLeft = leftJoinConditions. reduceLeftOption(And).map(Filter(_, left)).getOrElse(left) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index 8c1faea2394c6..96aff37a4b4f9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -542,7 +542,7 @@ class AnalysisErrorSuite extends AnalysisTest { Filter(EqualTo(OuterReference(a), b), LocalRelation(b))) ), LocalRelation(a)) - assertAnalysisError(plan4, "Accessing outer query column is not allowed in a LIMIT" :: Nil) + assertAnalysisError(plan4, "Accessing outer query column is not allowed in" :: Nil) val plan5 = Filter( Exists( @@ -551,6 +551,6 @@ class AnalysisErrorSuite extends AnalysisTest { ), LocalRelation(a)) assertAnalysisError(plan5, - "Accessing outer query column is not allowed in a TABLESAMPLE" :: Nil) + "Accessing outer query column is not allowed in" :: Nil) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala index 73a53944964fd..0f2f520006e35 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala @@ -789,4 +789,22 @@ class SubquerySuite extends QueryTest with SharedSQLContext { } } } + + // Generate operator + test("Correlated subqueries in LATERAL VIEW") { + withTempView("t1", "t2") { + Seq((1, 1), (2, 0)).toDF("c1", "c2").createOrReplaceTempView("t1") + Seq[(Int, Array[Int])]((1, Array(1, 2)), (2, Array(-1, -3))) + .toDF("c1", "arr_c2").createTempView("t2") + checkAnswer( + sql( + """ + | select c2 + | from t1 + | where exists (select * + | from t2 lateral view explode(arr_c2) q as c2 + where t1.c1 = t2.c1)""".stripMargin), + Row(1) :: Row(0) :: Nil) + } + } }