Skip to content

Commit

Permalink
[SPARK-47070] Fix invalid aggregation after subquery rewrite
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?

**tl;dr**
This PR fixes a bug related to an `exists` variable being lost after an incorrect subquery rewrite when `exists` is not used neither in grouping expressions nor in aggregate functions. We wrap such variable in `first()` agg func to not lose reference to it.

**Motivation**
Imagine we had a plan with a subquery:

```
Aggregate [a1#0] [CASE WHEN a1#0 IN (list#3999 []) THEN Hello ELSE Hi END AS strCol#13]
:    +- LocalRelation <empty>, [b1#3, b2#4, b3#5]
+ LocalRelation <empty>, [a1#0, a2#1, a3#2]
```

During correlated subquery rewrite, the rule `RewritePredicateSubquery` would rewrite expression `a1#0 IN (list#3999 [])` into `exists#12` and replace the subquery with `ExistenceJoin`, like so:

```
Aggregate [a1#0] [CASE WHEN exists#12 THEN Hello ELSE Hi END AS strCol#13]
+- Join ExistenceJoin(exists#12), (a1#0 = b1#3)
     +- LocalRelation <empty>, [a1#0, a2#1, a3#2]
     +- LocalRelation <empty>, [b1#3, b2#4, b3#5]
```

Note that `exists#12` doesn't appear neither in the grouping expressions, nor is part of any aggregate function. This is an invalid aggregation. In particular, aggregate pushdown rule rewrite this plan into:

```
Project [CASE WHEN exists#12 THEN Hello WHEN true THEN Hi END AS strCol#13]
+- AggregatePart [a1#0], true
   +- AggregatePart [a1#0], false
      +- Join ExistenceJoin(exists#12), (a1#0 = b1#3)
         :- AggregatePart [a1#0], false
         :     +- LocalRelation <empty>, [a1#0, a2#1, a3#2]
         +- AggregatePart [b1#3], false
               +- LocalRelation <empty>, [b1#3, b2#4, b3#5]
```

**Solution**
We fix the problem by wrapping such `exists` attributes in `first()` function, which is Spark's executable of `any_value()`. Note that such `exists` is always functionally determined by grouping keys, and thus wrapping it in any aggregate function preserving its unique value is safe.

Specifically, we only wrap `exists` attributes if they are referenced among aggregate expressions, but NOT within an aggregate function or its filter. Note that a new `exists` attribute cannot appear in groupingExpressions.

**Original proposal (NOT used)**
The decision is to fix the bug in the `RewritePredicateSubquery` by enforcing the condition that newly introduced variables, if referenced among agg expressions, must either participate in aggregate functions, or appear in the grouping keys.

With the fix, the plan after `RewritePredicateSubquery` will look like:

```
Aggregate [a1#0, exists#12] [CASE WHEN exists#12 THEN Hello ELSE Hi END AS strCol#13]
+- Join ExistenceJoin(exists#12), (a1#0 = b1#3)
     +- LocalRelation <empty>, [a1#0, a2#1, a3#2]
     +- LocalRelation <empty>, [b1#3, b2#4, b3#5]
```

**NOTE:** It is still possible to manually construct ExistenceJoin (e.g via dsl) and an Aggregate on top of it that violate the condition.

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
Query tests

Closes apache#45133 from anton5798/subquery-exists-agg.

Authored-by: Anton Lykov <anton.lykov@databricks.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
anton5798 authored and pull[bot] committed Apr 11, 2024
1 parent 3d57f34 commit 2427073
Show file tree
Hide file tree
Showing 7 changed files with 736 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -248,24 +248,74 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper {
case u: UnaryNode if u.expressions.exists(
SubqueryExpression.hasInOrCorrelatedExistsSubquery) =>
var newChild = u.child
u.mapExpressions(expr => {
val (newExpr, p) = rewriteExistentialExpr(Seq(expr), newChild)
var introducedAttrs = Seq.empty[Attribute]
val updatedNode = u.mapExpressions(expr => {
val (newExpr, p, newAttrs) = rewriteExistentialExprWithAttrs(Seq(expr), newChild)
newChild = p
introducedAttrs ++= newAttrs
// The newExpr can not be None
newExpr.get
}).withNewChildren(Seq(newChild))
updatedNode match {
case a: Aggregate =>
// If we have introduced new `exists`-attributes that are referenced by
// aggregateExpressions within a non-aggregateFunction expression, we wrap them in
// first() aggregate function. first() is Spark's executable version of any_value()
// aggregate function.
// We do this to keep the aggregation valid, i.e avoid references outside of aggregate
// functions that are not in grouping expressions.
// Note that the same `exists` attr will never appear in groupingExpressions due to
// PullOutGroupingExpressions rule.
// Also note: the value of `exists` is functionally determined by grouping expressions,
// so applying any aggregate function is semantically safe.
val aggFunctionReferences = a.aggregateExpressions.
flatMap(extractAggregateExpressions).
flatMap(_.references).toSet
val nonAggFuncReferences =
a.aggregateExpressions.flatMap(_.references).filterNot(aggFunctionReferences.contains)
val toBeWrappedExistsAttrs = introducedAttrs.filter(nonAggFuncReferences.contains)

// Replace all eligible `exists` by `First(exists)` among aggregateExpressions.
val newAggregateExpressions = a.aggregateExpressions.map { aggExpr =>
aggExpr.transformUp {
case attr: Attribute if toBeWrappedExistsAttrs.contains(attr) =>
new First(attr).toAggregateExpression()
}.asInstanceOf[NamedExpression]
}
a.copy(aggregateExpressions = newAggregateExpressions)
case _ => updatedNode
}
}

/**
* Extract all aggregate expressions from the expression tree routed at `expr`.
*/
private def extractAggregateExpressions(expr: Expression): Seq[AggregateExpression] = {
expr match {
case a: AggregateExpression => Seq(a)
case e: Expression => e.children.flatMap(extractAggregateExpressions)
}
}

/**
* Given a predicate expression and an input plan, it rewrites any embedded existential sub-query
* into an existential join. It returns the rewritten expression together with the updated plan.
* into an existential join. It returns the rewritten expression together with the updated plan,
* as well as the newly introduced attributes.
* Currently, it does not support NOT IN nested inside a NOT expression. This case is blocked in
* the Analyzer.
*/
private def rewriteExistentialExpr(
exprs: Seq[Expression],
plan: LogicalPlan): (Option[Expression], LogicalPlan) = {
exprs: Seq[Expression],
plan: LogicalPlan): (Option[Expression], LogicalPlan) = {
val (newExpr, newPlan, _) = rewriteExistentialExprWithAttrs(exprs, plan)
(newExpr, newPlan)
}

private def rewriteExistentialExprWithAttrs(
exprs: Seq[Expression],
plan: LogicalPlan): (Option[Expression], LogicalPlan, Seq[Attribute]) = {
var newPlan = plan
val introducedAttrs = ArrayBuffer.empty[Attribute]
val newExprs = exprs.map { e =>
e.transformDownWithPruning(_.containsAnyPattern(EXISTS_SUBQUERY, IN_SUBQUERY)) {
case Exists(sub, _, _, conditions, subHint) =>
Expand All @@ -275,6 +325,7 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper {
newPlan =
buildJoin(newPlan, rewriteDomainJoinsIfPresent(newPlan, sub, newCondition),
existenceJoin, newCondition, subHint)
introducedAttrs += exists
exists
case Not(InSubquery(values, ListQuery(sub, _, _, _, conditions, subHint))) =>
val exists = AttributeReference("exists", BooleanType, nullable = false)()
Expand All @@ -299,6 +350,7 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper {
newPlan = Join(newPlan,
rewriteDomainJoinsIfPresent(newPlan, newSub, Some(finalJoinCond)),
ExistenceJoin(exists), Some(finalJoinCond), joinHint)
introducedAttrs += exists
Not(exists)
case InSubquery(values, ListQuery(sub, _, _, _, conditions, subHint)) =>
val exists = AttributeReference("exists", BooleanType, nullable = false)()
Expand All @@ -309,10 +361,11 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper {
val joinHint = JoinHint(None, subHint)
newPlan = Join(newPlan, rewriteDomainJoinsIfPresent(newPlan, newSub, newConditions),
ExistenceJoin(exists), newConditions, joinHint)
introducedAttrs += exists
exists
}
}
(newExprs.reduceOption(And), newPlan)
(newExprs.reduceOption(And), newPlan, introducedAttrs.toSeq)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -374,3 +374,134 @@ Project [emp_name#x]
+- Project [id#x, emp_name#x, hiredate#x, salary#x, dept_id#x]
+- SubqueryAlias EMP
+- LocalRelation [id#x, emp_name#x, hiredate#x, salary#x, dept_id#x]


-- !query
SELECT
emp.dept_id,
EXISTS (SELECT dept.dept_id FROM dept)
FROM emp
GROUP BY emp.dept_id ORDER BY emp.dept_id
-- !query analysis
Sort [dept_id#x ASC NULLS FIRST], true
+- Aggregate [dept_id#x], [dept_id#x, exists#x [] AS exists()#x]
: +- Project [dept_id#x]
: +- SubqueryAlias dept
: +- View (`DEPT`, [dept_id#x, dept_name#x, state#x])
: +- Project [cast(dept_id#x as int) AS dept_id#x, cast(dept_name#x as string) AS dept_name#x, cast(state#x as string) AS state#x]
: +- Project [dept_id#x, dept_name#x, state#x]
: +- SubqueryAlias DEPT
: +- LocalRelation [dept_id#x, dept_name#x, state#x]
+- SubqueryAlias emp
+- View (`EMP`, [id#x, emp_name#x, hiredate#x, salary#x, dept_id#x])
+- Project [cast(id#x as int) AS id#x, cast(emp_name#x as string) AS emp_name#x, cast(hiredate#x as date) AS hiredate#x, cast(salary#x as double) AS salary#x, cast(dept_id#x as int) AS dept_id#x]
+- Project [id#x, emp_name#x, hiredate#x, salary#x, dept_id#x]
+- SubqueryAlias EMP
+- LocalRelation [id#x, emp_name#x, hiredate#x, salary#x, dept_id#x]


-- !query
SELECT
emp.dept_id,
EXISTS (SELECT dept.dept_id FROM dept)
FROM emp
GROUP BY emp.dept_id ORDER BY emp.dept_id
-- !query analysis
Sort [dept_id#x ASC NULLS FIRST], true
+- Aggregate [dept_id#x], [dept_id#x, exists#x [] AS exists()#x]
: +- Project [dept_id#x]
: +- SubqueryAlias dept
: +- View (`DEPT`, [dept_id#x, dept_name#x, state#x])
: +- Project [cast(dept_id#x as int) AS dept_id#x, cast(dept_name#x as string) AS dept_name#x, cast(state#x as string) AS state#x]
: +- Project [dept_id#x, dept_name#x, state#x]
: +- SubqueryAlias DEPT
: +- LocalRelation [dept_id#x, dept_name#x, state#x]
+- SubqueryAlias emp
+- View (`EMP`, [id#x, emp_name#x, hiredate#x, salary#x, dept_id#x])
+- Project [cast(id#x as int) AS id#x, cast(emp_name#x as string) AS emp_name#x, cast(hiredate#x as date) AS hiredate#x, cast(salary#x as double) AS salary#x, cast(dept_id#x as int) AS dept_id#x]
+- Project [id#x, emp_name#x, hiredate#x, salary#x, dept_id#x]
+- SubqueryAlias EMP
+- LocalRelation [id#x, emp_name#x, hiredate#x, salary#x, dept_id#x]


-- !query
SELECT
emp.dept_id,
NOT EXISTS (SELECT dept.dept_id FROM dept)
FROM emp
GROUP BY emp.dept_id ORDER BY emp.dept_id
-- !query analysis
Sort [dept_id#x ASC NULLS FIRST], true
+- Aggregate [dept_id#x], [dept_id#x, NOT exists#x [] AS (NOT exists())#x]
: +- Project [dept_id#x]
: +- SubqueryAlias dept
: +- View (`DEPT`, [dept_id#x, dept_name#x, state#x])
: +- Project [cast(dept_id#x as int) AS dept_id#x, cast(dept_name#x as string) AS dept_name#x, cast(state#x as string) AS state#x]
: +- Project [dept_id#x, dept_name#x, state#x]
: +- SubqueryAlias DEPT
: +- LocalRelation [dept_id#x, dept_name#x, state#x]
+- SubqueryAlias emp
+- View (`EMP`, [id#x, emp_name#x, hiredate#x, salary#x, dept_id#x])
+- Project [cast(id#x as int) AS id#x, cast(emp_name#x as string) AS emp_name#x, cast(hiredate#x as date) AS hiredate#x, cast(salary#x as double) AS salary#x, cast(dept_id#x as int) AS dept_id#x]
+- Project [id#x, emp_name#x, hiredate#x, salary#x, dept_id#x]
+- SubqueryAlias EMP
+- LocalRelation [id#x, emp_name#x, hiredate#x, salary#x, dept_id#x]


-- !query
SELECT
emp.dept_id,
SUM(
CASE WHEN EXISTS (SELECT dept.dept_id FROM dept WHERE dept.dept_id = emp.dept_id) THEN 1
ELSE 0 END)
FROM emp
GROUP BY emp.dept_id ORDER BY emp.dept_id
-- !query analysis
Sort [dept_id#x ASC NULLS FIRST], true
+- Aggregate [dept_id#x], [dept_id#x, sum(CASE WHEN exists#x [dept_id#x] THEN 1 ELSE 0 END) AS sum(CASE WHEN exists(dept_id) THEN 1 ELSE 0 END)#xL]
: +- Project [dept_id#x]
: +- Filter (dept_id#x = outer(dept_id#x))
: +- SubqueryAlias dept
: +- View (`DEPT`, [dept_id#x, dept_name#x, state#x])
: +- Project [cast(dept_id#x as int) AS dept_id#x, cast(dept_name#x as string) AS dept_name#x, cast(state#x as string) AS state#x]
: +- Project [dept_id#x, dept_name#x, state#x]
: +- SubqueryAlias DEPT
: +- LocalRelation [dept_id#x, dept_name#x, state#x]
+- SubqueryAlias emp
+- View (`EMP`, [id#x, emp_name#x, hiredate#x, salary#x, dept_id#x])
+- Project [cast(id#x as int) AS id#x, cast(emp_name#x as string) AS emp_name#x, cast(hiredate#x as date) AS hiredate#x, cast(salary#x as double) AS salary#x, cast(dept_id#x as int) AS dept_id#x]
+- Project [id#x, emp_name#x, hiredate#x, salary#x, dept_id#x]
+- SubqueryAlias EMP
+- LocalRelation [id#x, emp_name#x, hiredate#x, salary#x, dept_id#x]


-- !query
SELECT
cast(EXISTS (SELECT id FROM dept where dept.dept_id = emp.dept_id) AS int)
FROM emp
GROUP BY
cast(EXISTS (SELECT id FROM dept where dept.dept_id = emp.dept_id) AS int)
-- !query analysis
Aggregate [cast(exists#x [id#x && dept_id#x] as int)], [cast(exists#x [id#x && dept_id#x] as int) AS CAST(exists(id, dept_id) AS INT)#x]
: :- Project [outer(id#x)]
: : +- Filter (dept_id#x = outer(dept_id#x))
: : +- SubqueryAlias dept
: : +- View (`DEPT`, [dept_id#x, dept_name#x, state#x])
: : +- Project [cast(dept_id#x as int) AS dept_id#x, cast(dept_name#x as string) AS dept_name#x, cast(state#x as string) AS state#x]
: : +- Project [dept_id#x, dept_name#x, state#x]
: : +- SubqueryAlias DEPT
: : +- LocalRelation [dept_id#x, dept_name#x, state#x]
: +- Project [outer(id#x)]
: +- Filter (dept_id#x = outer(dept_id#x))
: +- SubqueryAlias dept
: +- View (`DEPT`, [dept_id#x, dept_name#x, state#x])
: +- Project [cast(dept_id#x as int) AS dept_id#x, cast(dept_name#x as string) AS dept_name#x, cast(state#x as string) AS state#x]
: +- Project [dept_id#x, dept_name#x, state#x]
: +- SubqueryAlias DEPT
: +- LocalRelation [dept_id#x, dept_name#x, state#x]
+- SubqueryAlias emp
+- View (`EMP`, [id#x, emp_name#x, hiredate#x, salary#x, dept_id#x])
+- Project [cast(id#x as int) AS id#x, cast(emp_name#x as string) AS emp_name#x, cast(hiredate#x as date) AS hiredate#x, cast(salary#x as double) AS salary#x, cast(dept_id#x as int) AS dept_id#x]
+- Project [id#x, emp_name#x, hiredate#x, salary#x, dept_id#x]
+- SubqueryAlias EMP
+- LocalRelation [id#x, emp_name#x, hiredate#x, salary#x, dept_id#x]
Loading

0 comments on commit 2427073

Please sign in to comment.