Skip to content

Commit

Permalink
[SPARK-42003][SQL] Reduce duplicate code in ResolveGroupByAll
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

Reduce duplicate code in ResolveGroupByAll by moving the group by expression inference into a new method.

### Why are the changes needed?

Code clean up

### Does this PR introduce _any_ user-facing change?

No
### How was this patch tested?

Existing UT

Closes apache#39523 from gengliangwang/refactorAll.

Authored-by: Gengliang Wang <gengliang@apache.org>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
gengliangwang authored and vicennial committed Jan 17, 2023
1 parent aef2bb2 commit 0b6b504
Showing 1 changed file with 25 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -47,25 +47,40 @@ object ResolveGroupByAll extends Rule[LogicalPlan] {
}
}

/**
* Returns all the grouping expressions inferred from a GROUP BY ALL aggregate.
* The result is optional. If Spark fails to infer the grouping columns, it is None.
* Otherwise, it contains all the non-aggregate expressions from the project list of the input
* Aggregate.
*/
private def getGroupingExpressions(a: Aggregate): Option[Seq[Expression]] = {
val groupingExprs = a.aggregateExpressions.filter(!_.exists(AggregateExpression.isAggregate))
// If the grouping exprs are empty, this could either be (1) a valid global aggregate, or
// (2) we simply fail to infer the grouping columns. As an example, in "i + sum(j)", we will
// not automatically infer the grouping column to be "i".
if (groupingExprs.isEmpty && a.aggregateExpressions.exists(containsAttribute)) {
None
} else {
Some(groupingExprs)
}
}

override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUpWithPruning(
_.containsAllPatterns(UNRESOLVED_ATTRIBUTE, AGGREGATE), ruleId) {
case a: Aggregate
if a.child.resolved && a.aggregateExpressions.forall(_.resolved) && matchToken(a) =>
// Only makes sense to do the rewrite once all the aggregate expressions have been resolved.
// Otherwise, we might incorrectly pull an actual aggregate expression over to the grouping
// expression list (because we don't know they would be aggregate expressions until resolved).
val groupingExprs = a.aggregateExpressions.filter(!_.exists(AggregateExpression.isAggregate))
val groupingExprs = getGroupingExpressions(a)

// If the grouping exprs are empty, this could either be (1) a valid global aggregate, or
// (2) we simply fail to infer the grouping columns. As an example, in "i + sum(j)", we will
// not automatically infer the grouping column to be "i".
if (groupingExprs.isEmpty && a.aggregateExpressions.exists(containsAttribute)) {
// Case (2): don't replace the ALL. We will eventually tell the user in checkAnalysis
// that we cannot resolve the all in group by.
if (groupingExprs.isEmpty) {
// Don't replace the ALL when we fail to infer the grouping columns. We will eventually
// tell the user in checkAnalysis that we cannot resolve the all in group by.
a
} else {
// Case (1): this is a valid global aggregate.
a.copy(groupingExpressions = groupingExprs)
// This is a valid GROUP BY ALL aggregate.
a.copy(groupingExpressions = groupingExprs.get)
}
}

Expand Down Expand Up @@ -94,8 +109,7 @@ object ResolveGroupByAll extends Rule[LogicalPlan] {
*/
def checkAnalysis(operator: LogicalPlan): Unit = operator match {
case a: Aggregate if a.aggregateExpressions.forall(_.resolved) && matchToken(a) =>
val noAgg = a.aggregateExpressions.filter(!_.exists(AggregateExpression.isAggregate))
if (noAgg.isEmpty && a.aggregateExpressions.exists(containsAttribute)) {
if (getGroupingExpressions(a).isEmpty) {
operator.failAnalysis(
errorClass = "UNRESOLVED_ALL_IN_GROUP_BY",
messageParameters = Map.empty)
Expand Down

0 comments on commit 0b6b504

Please sign in to comment.