diff --git a/core/src/main/resources/error/error-classes.json b/core/src/main/resources/error/error-classes.json index 230b616800fb9..84af7b5d64f5e 100644 --- a/core/src/main/resources/error/error-classes.json +++ b/core/src/main/resources/error/error-classes.json @@ -1532,6 +1532,11 @@ "Referencing a lateral column alias in the aggregate function ." ] }, + "LATERAL_COLUMN_ALIAS_IN_GROUP_BY" : { + "message" : [ + "Referencing a lateral column alias via GROUP BY alias/ALL is not supported yet." + ] + }, "LATERAL_JOIN_USING" : { "message" : [ "JOIN USING with LATERAL correlation." diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index ce273f01c7aa2..28ae09e123cd0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -185,8 +185,8 @@ object AnalysisContext { * Provides a logical query plan analyzer, which translates [[UnresolvedAttribute]]s and * [[UnresolvedRelation]]s into fully typed objects using information in a [[SessionCatalog]]. */ -class Analyzer(override val catalogManager: CatalogManager) - extends RuleExecutor[LogicalPlan] with CheckAnalysis with SQLConfHelper { +class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor[LogicalPlan] + with CheckAnalysis with SQLConfHelper with ColumnResolutionHelper { private val v1SessionCatalog: SessionCatalog = catalogManager.v1SessionCatalog @@ -295,10 +295,7 @@ class Analyzer(override val catalogManager: CatalogManager) ResolveGroupingAnalytics :: ResolvePivot :: ResolveUnpivot :: - ResolveOrderByAll :: - ResolveGroupByAll :: ResolveOrdinalInOrderByAndGroupBy :: - ResolveAggAliasInGroupBy :: ExtractGenerator :: ResolveGenerate :: ResolveFunctions :: @@ -1489,25 +1486,40 @@ class Analyzer(override val catalogManager: CatalogManager) } /** - * Resolves [[UnresolvedAttribute]]s with the following precedence: - * 1. Resolves it to [[AttributeReference]] with the output of the children plans. This includes - * metadata columns as well. - * 2. If the plan is Project/Aggregate, resolves it to lateral column alias, which is the alias - * defined previously in the SELECT list. - * 3. If the plan is UnresolvedHaving/Filter/Sort + Aggregate, resolves it to - * [[TempResolvedColumn]] with the output of Aggregate's child plan. This is to allow - * UnresolvedHaving/Filter/Sort to host grouping expressions and aggregate functions, which - * can be pushed down to the Aggregate later. - * 4. If the plan is Sort/Filter/RepartitionByExpression, resolves it to [[AttributeReference]] - * with the output of a descendant plan node. Spark will propagate the missing attributes from - * the descendant plan node to the Sort/Filter/RepartitionByExpression node. This is to allow - * users to filter/order/repartition by columns that are not in the SELECT clause, which is - * widely supported in other SQL dialects. - * 5. Resolves it to [[OuterReference]] with the outer plan if this is a subquery plan. + * Resolves column references in the query plan. Basically it transform the query plan tree bottom + * up, and only try to resolve references for a plan node if all its children nodes are resolved, + * and there is no conflicting attributes between the children nodes (see `hasConflictingAttrs` + * for details). + * + * The general workflow to resolve references: + * 1. Expands the star in Project/Aggregate/Generate. + * 2. Resolves the columns to [[AttributeReference]] with the output of the children plans. This + * includes metadata columns as well. + * 3. Resolves the columns to literal function which is allowed to be invoked without braces, + * e.g. `SELECT col, current_date FROM t`. + * 4. Resolves the columns to outer references with the outer plan if we are resolving subquery + * expressions. + * + * Some plan nodes have special column reference resolution logic, please read these sub-rules for + * details: + * - [[ResolveReferencesInAggregate]] + * - [[ResolveReferencesInSort]] + * + * Note: even if we use a single rule to resolve columns, it's still non-trivial to have a + * reliable column resolution order, as the rule will be executed multiple times, with other + * rules in the same batch. We should resolve columns with the next option only if all the + * previous options are permanently not applicable. If the current option can be applicable + * in the next iteration (other rules update the plan), we should not try the next option. */ - object ResolveReferences extends Rule[LogicalPlan] { + object ResolveReferences extends Rule[LogicalPlan] with ColumnResolutionHelper { - /** Return true if there're conflicting attributes among children's outputs of a plan */ + /** + * Return true if there're conflicting attributes among children's outputs of a plan + * + * The children logical plans may output columns with conflicting attribute IDs. This may happen + * in cases such as self-join. We should wait for the rule [[DeduplicateRelations]] to eliminate + * conflicting attribute IDs, otherwise we can't resolve columns correctly due to ambiguity. + */ def hasConflictingAttrs(p: LogicalPlan): Boolean = { p.children.length > 1 && { // Note that duplicated attributes are allowed within a single node, @@ -1628,31 +1640,7 @@ class Analyzer(override val catalogManager: CatalogManager) // rule: ResolveDeserializer. case plan if containsDeserializer(plan.expressions) => plan - // SPARK-31670: Resolve Struct field in groupByExpressions and aggregateExpressions - // with CUBE/ROLLUP will be wrapped with alias like Alias(GetStructField, name) with - // different ExprId. This cause aggregateExpressions can't be replaced by expanded - // groupByExpressions in `ResolveGroupingAnalytics.constructAggregateExprs()`, we trim - // unnecessary alias of GetStructField here. - case a: Aggregate => - val planForResolve = a.child match { - // SPARK-25942: Resolves aggregate expressions with `AppendColumns`'s children, instead of - // `AppendColumns`, because `AppendColumns`'s serializer might produce conflict attribute - // names leading to ambiguous references exception. - case appendColumns: AppendColumns => appendColumns - case _ => a - } - - val resolvedGroupingExprs = a.groupingExpressions - .map(resolveExpressionByPlanChildren(_, planForResolve, allowOuter = true)) - .map(trimTopLevelGetStructFieldAlias) - - val resolvedAggExprsNoOuter = a.aggregateExpressions - .map(resolveExpressionByPlanChildren(_, planForResolve, allowOuter = false)) - // Aggregate supports Lateral column alias, which has higher priority than outer reference. - val resolvedAggExprsWithLCA = resolveLateralColumnAlias(resolvedAggExprsNoOuter) - val resolvedAggExprsWithOuter = resolvedAggExprsWithLCA.map(resolveOuterRef) - .map(_.asInstanceOf[NamedExpression]) - a.copy(resolvedGroupingExprs, resolvedAggExprsWithOuter, a.child) + case a: Aggregate => ResolveReferencesInAggregate(a) // Special case for Project as it supports lateral column alias. case p: Project => @@ -1790,82 +1778,13 @@ class Analyzer(override val catalogManager: CatalogManager) Project(child.output, newFilter) } - // Same as Filter, Sort can host both grouping expressions/aggregate functions and missing - // attributes as well. - case s @ Sort(orders, _, child) if !s.resolved || s.missingInput.nonEmpty => - val resolvedNoOuter = orders.map(resolveExpressionByPlanOutput(_, child)) - val resolvedWithAgg = resolvedNoOuter.map(resolveColWithAgg(_, child)) - val (newOrder, newChild) = resolveExprsAndAddMissingAttrs(resolvedWithAgg, child) - // Outer reference has lowermost priority. See the doc of `ResolveReferences`. - val ordering = newOrder.map(e => resolveOuterRef(e).asInstanceOf[SortOrder]) - if (child.output == newChild.output) { - s.copy(order = ordering) - } else { - // Add missing attributes and then project them away. - val newSort = s.copy(order = ordering, child = newChild) - Project(child.output, newSort) - } + case s: Sort if !s.resolved || s.missingInput.nonEmpty => ResolveReferencesInSort(s) case q: LogicalPlan => logTrace(s"Attempting to resolve ${q.simpleString(conf.maxToStringFields)}") q.mapExpressions(resolveExpressionByPlanChildren(_, q, allowOuter = true)) } - /** - * This method tries to resolve expressions and find missing attributes recursively. - * Specifically, when the expressions used in `Sort` or `Filter` contain unresolved attributes - * or resolved attributes which are missing from child output. This method tries to find the - * missing attributes and add them into the projection. - */ - private def resolveExprsAndAddMissingAttrs( - exprs: Seq[Expression], plan: LogicalPlan): (Seq[Expression], LogicalPlan) = { - // Missing attributes can be unresolved attributes or resolved attributes which are not in - // the output attributes of the plan. - if (exprs.forall(e => e.resolved && e.references.subsetOf(plan.outputSet))) { - (exprs, plan) - } else { - plan match { - case p: Project => - // Resolving expressions against current plan. - val maybeResolvedExprs = exprs.map(resolveExpressionByPlanOutput(_, p)) - // Recursively resolving expressions on the child of current plan. - val (newExprs, newChild) = resolveExprsAndAddMissingAttrs(maybeResolvedExprs, p.child) - // If some attributes used by expressions are resolvable only on the rewritten child - // plan, we need to add them into original projection. - val missingAttrs = (AttributeSet(newExprs) -- p.outputSet).intersect(newChild.outputSet) - (newExprs, Project(p.projectList ++ missingAttrs, newChild)) - - case a @ Aggregate(groupExprs, aggExprs, child) => - val maybeResolvedExprs = exprs.map(resolveExpressionByPlanOutput(_, a)) - val (newExprs, newChild) = resolveExprsAndAddMissingAttrs(maybeResolvedExprs, child) - val missingAttrs = (AttributeSet(newExprs) -- a.outputSet).intersect(newChild.outputSet) - if (missingAttrs.forall(attr => groupExprs.exists(_.semanticEquals(attr)))) { - // All the missing attributes are grouping expressions, valid case. - (newExprs, a.copy(aggregateExpressions = aggExprs ++ missingAttrs, child = newChild)) - } else { - // Need to add non-grouping attributes, invalid case. - (exprs, a) - } - - case g: Generate => - val maybeResolvedExprs = exprs.map(resolveExpressionByPlanOutput(_, g)) - val (newExprs, newChild) = resolveExprsAndAddMissingAttrs(maybeResolvedExprs, g.child) - (newExprs, g.copy(unrequiredChildIndex = Nil, child = newChild)) - - // For `Distinct` and `SubqueryAlias`, we can't recursively resolve and add attributes - // via its children. - case u: UnaryNode if !u.isInstanceOf[Distinct] && !u.isInstanceOf[SubqueryAlias] => - val maybeResolvedExprs = exprs.map(resolveExpressionByPlanOutput(_, u)) - val (newExprs, newChild) = resolveExprsAndAddMissingAttrs(maybeResolvedExprs, u.child) - (newExprs, u.withNewChildren(Seq(newChild))) - - // For other operators, we can't recursively resolve and add attributes via its children. - case other => - (exprs.map(resolveExpressionByPlanOutput(_, other)), other) - } - } - } - private object MergeResolvePolicy extends Enumeration { val BOTH, SOURCE, TARGET = Value } @@ -1916,16 +1835,6 @@ class Analyzer(override val catalogManager: CatalogManager) resolved } - // This method is used to trim groupByExpressions/selectedGroupByExpressions's top-level - // GetStructField Alias. Since these expression are not NamedExpression originally, - // we are safe to trim top-level GetStructField Alias. - def trimTopLevelGetStructFieldAlias(e: Expression): Expression = { - e match { - case Alias(s: GetStructField, _) => s - case other => other - } - } - // Expand the star expression using the input plan first. If failed, try resolve // the star expression using the outer query plan and wrap the resolved attributes // in outer references. Otherwise throw the original exception. @@ -2037,277 +1946,6 @@ class Analyzer(override val catalogManager: CatalogManager) exprs.exists(_.exists(_.isInstanceOf[UnresolvedDeserializer])) } - // support CURRENT_DATE, CURRENT_TIMESTAMP, and grouping__id - private val literalFunctions: Seq[(String, () => Expression, Expression => String)] = Seq( - (CurrentDate().prettyName, () => CurrentDate(), toPrettySQL(_)), - (CurrentTimestamp().prettyName, () => CurrentTimestamp(), toPrettySQL(_)), - (CurrentUser().prettyName, () => CurrentUser(), toPrettySQL), - ("user", () => CurrentUser(), toPrettySQL), - (VirtualColumn.hiveGroupingIdName, () => GroupingID(Nil), _ => VirtualColumn.hiveGroupingIdName) - ) - - /** - * Literal functions do not require the user to specify braces when calling them - * When an attributes is not resolvable, we try to resolve it as a literal function. - */ - private def resolveLiteralFunction(nameParts: Seq[String]): Option[NamedExpression] = { - if (nameParts.length != 1) return None - val name = nameParts.head - literalFunctions.find(func => caseInsensitiveResolution(func._1, name)).map { - case (_, getFuncExpr, getAliasName) => - val funcExpr = getFuncExpr() - Alias(funcExpr, getAliasName(funcExpr))() - } - } - - /** - * Resolves `UnresolvedAttribute`, `GetColumnByOrdinal` and extract value expressions(s) by - * traversing the input expression in top-down manner. It must be top-down because we need to - * skip over unbound lambda function expression. The lambda expressions are resolved in a - * different place [[ResolveLambdaVariables]]. - * - * Example : - * SELECT transform(array(1, 2, 3), (x, i) -> x + i)" - * - * In the case above, x and i are resolved as lambda variables in [[ResolveLambdaVariables]]. - */ - private def resolveExpression( - expr: Expression, - resolveColumnByName: Seq[String] => Option[Expression], - getAttrCandidates: () => Seq[Attribute], - throws: Boolean, - allowOuter: Boolean): Expression = { - def innerResolve(e: Expression, isTopLevel: Boolean): Expression = withOrigin(e.origin) { - if (e.resolved) return e - val resolved = e match { - case f: LambdaFunction if !f.bound => f - - case GetColumnByOrdinal(ordinal, _) => - val attrCandidates = getAttrCandidates() - assert(ordinal >= 0 && ordinal < attrCandidates.length) - attrCandidates(ordinal) - - case GetViewColumnByNameAndOrdinal( - viewName, colName, ordinal, expectedNumCandidates, viewDDL) => - val attrCandidates = getAttrCandidates() - val matched = attrCandidates.filter(a => resolver(a.name, colName)) - if (matched.length != expectedNumCandidates) { - throw QueryCompilationErrors.incompatibleViewSchemaChangeError( - viewName, colName, expectedNumCandidates, matched, viewDDL) - } - matched(ordinal) - - case u @ UnresolvedAttribute(nameParts) => - val result = withPosition(u) { - resolveColumnByName(nameParts).orElse(resolveLiteralFunction(nameParts)).map { - // We trim unnecessary alias here. Note that, we cannot trim the alias at top-level, - // as we should resolve `UnresolvedAttribute` to a named expression. The caller side - // can trim the top-level alias if it's safe to do so. Since we will call - // CleanupAliases later in Analyzer, trim non top-level unnecessary alias is safe. - case Alias(child, _) if !isTopLevel => child - case other => other - }.getOrElse(u) - } - logDebug(s"Resolving $u to $result") - result - - // Re-resolves `TempResolvedColumn` if it has tried to be resolved with Aggregate - // but failed. If we still can't resolve it, we should keep it as `TempResolvedColumn`, - // so that it won't become a fresh `TempResolvedColumn` again. - case t: TempResolvedColumn if t.hasTried => withPosition(t) { - innerResolve(UnresolvedAttribute(t.nameParts), isTopLevel) match { - case _: UnresolvedAttribute => t - case other => other - } - } - - case u @ UnresolvedExtractValue(child, fieldName) => - val newChild = innerResolve(child, isTopLevel = false) - if (newChild.resolved) { - ExtractValue(newChild, fieldName, resolver) - } else { - u.copy(child = newChild) - } - - case _ => e.mapChildren(innerResolve(_, isTopLevel = false)) - } - resolved.copyTagsFrom(e) - resolved - } - - try { - val resolved = innerResolve(expr, isTopLevel = true) - if (allowOuter) resolveOuterRef(resolved) else resolved - } catch { - case ae: AnalysisException if !throws => - logDebug(ae.getMessage) - expr - } - } - - // Resolves `UnresolvedAttribute` to `OuterReference`. - private def resolveOuterRef(e: Expression): Expression = { - val outerPlan = AnalysisContext.get.outerPlan - if (outerPlan.isEmpty) return e - - def resolve(nameParts: Seq[String]): Option[Expression] = try { - outerPlan.get match { - // Subqueries in UnresolvedHaving can host grouping expressions and aggregate functions. - // We should resolve columns with `agg.output` and the rule `ResolveAggregateFunctions` will - // push them down to Aggregate later. This is similar to what we do in `resolveColumns`. - case u @ UnresolvedHaving(_, agg: Aggregate) => - agg.resolveChildren(nameParts, resolver).orElse(u.resolveChildren(nameParts, resolver)) - .map(wrapOuterReference) - case other => - other.resolveChildren(nameParts, resolver).map(wrapOuterReference) - } - } catch { - case ae: AnalysisException => - logDebug(ae.getMessage) - None - } - - e.transformWithPruning(_.containsAnyPattern(UNRESOLVED_ATTRIBUTE, TEMP_RESOLVED_COLUMN)) { - case u: UnresolvedAttribute => - resolve(u.nameParts).getOrElse(u) - // Re-resolves `TempResolvedColumn` as outer references if it has tried to be resolved with - // Aggregate but failed. - case t: TempResolvedColumn if t.hasTried => - resolve(t.nameParts).getOrElse(t) - } - } - - // Resolves `UnresolvedAttribute` to `TempResolvedColumn` via `plan.child.output` if plan is an - // `Aggregate`. If `TempResolvedColumn` doesn't end up as aggregate function input or grouping - // column, we will undo the column resolution later to avoid confusing error message. E,g,, if - // a table `t` has columns `c1` and `c2`, for query `SELECT ... FROM t GROUP BY c1 HAVING c2 = 0`, - // even though we can resolve column `c2` here, we should undo it and fail with - // "Column c2 not found". - private def resolveColWithAgg(e: Expression, plan: LogicalPlan): Expression = plan match { - case agg: Aggregate => - e.transformWithPruning(_.containsAnyPattern(UNRESOLVED_ATTRIBUTE)) { - case u: UnresolvedAttribute => - try { - agg.child.resolve(u.nameParts, resolver).map({ - case a: Alias => TempResolvedColumn(a.child, u.nameParts) - case o => TempResolvedColumn(o, u.nameParts) - }).getOrElse(u) - } catch { - case ae: AnalysisException => - logDebug(ae.getMessage) - u - } - } - case _ => e - } - - private def resolveLateralColumnAlias(selectList: Seq[Expression]): Seq[Expression] = { - if (!conf.getConf(SQLConf.LATERAL_COLUMN_ALIAS_IMPLICIT_ENABLED)) return selectList - - // A mapping from lower-cased alias name to either the Alias itself, or the count of aliases - // that have the same lower-cased name. If the count is larger than 1, we won't use it to - // resolve lateral column aliases. - val aliasMap = mutable.HashMap.empty[String, Either[Alias, Int]] - - def resolve(e: Expression): Expression = { - e.transformWithPruning( - _.containsAnyPattern(UNRESOLVED_ATTRIBUTE, LATERAL_COLUMN_ALIAS_REFERENCE)) { - case u: UnresolvedAttribute => - // Lateral column alias does not have qualifiers. We always use the first name part to - // look up lateral column aliases. - val lowerCasedName = u.nameParts.head.toLowerCase(Locale.ROOT) - aliasMap.get(lowerCasedName).map { - case scala.util.Left(alias) => - if (alias.resolved) { - val resolvedAttr = resolveExpressionByPlanOutput( - u, LocalRelation(Seq(alias.toAttribute)), throws = true - ).asInstanceOf[NamedExpression] - assert(resolvedAttr.resolved) - LateralColumnAliasReference(resolvedAttr, u.nameParts, alias.toAttribute) - } else { - // Still returns a `LateralColumnAliasReference` even if the lateral column alias - // is not resolved yet. This is to make sure we won't mistakenly resolve it to - // outer references. - LateralColumnAliasReference(u, u.nameParts, alias.toAttribute) - } - case scala.util.Right(count) => - throw QueryCompilationErrors.ambiguousLateralColumnAliasError(u.name, count) - }.getOrElse(u) - - case LateralColumnAliasReference(u: UnresolvedAttribute, _, _) => - resolve(u) - } - } - - selectList.map { - case a: Alias => - val result = resolve(a) - val lowerCasedName = a.name.toLowerCase(Locale.ROOT) - aliasMap.get(lowerCasedName) match { - case Some(scala.util.Left(_)) => - aliasMap(lowerCasedName) = scala.util.Right(2) - case Some(scala.util.Right(count)) => - aliasMap(lowerCasedName) = scala.util.Right(count + 1) - case None => - aliasMap += lowerCasedName -> scala.util.Left(a) - } - result - case other => resolve(other) - } - } - - /** - * Resolves `UnresolvedAttribute`, `GetColumnByOrdinal` and extract value expressions(s) by the - * input plan's output attributes. In order to resolve the nested fields correctly, this function - * makes use of `throws` parameter to control when to raise an AnalysisException. - * - * Example : - * SELECT * FROM t ORDER BY a.b - * - * In the above example, after `a` is resolved to a struct-type column, we may fail to resolve `b` - * if there is no such nested field named "b". We should not fail and wait for other rules to - * resolve it if possible. - */ - def resolveExpressionByPlanOutput( - expr: Expression, - plan: LogicalPlan, - throws: Boolean = false, - allowOuter: Boolean = false): Expression = { - resolveExpression( - expr, - resolveColumnByName = nameParts => { - plan.resolve(nameParts, resolver) - }, - getAttrCandidates = () => plan.output, - throws = throws, - allowOuter = allowOuter) - } - - /** - * Resolves `UnresolvedAttribute`, `GetColumnByOrdinal` and extract value expressions(s) by the - * input plan's children output attributes. - * - * @param e The expression need to be resolved. - * @param q The LogicalPlan whose children are used to resolve expression's attribute. - * @return resolved Expression. - */ - def resolveExpressionByPlanChildren( - e: Expression, - q: LogicalPlan, - allowOuter: Boolean = false): Expression = { - resolveExpression( - e, - resolveColumnByName = nameParts => { - q.resolveChildren(nameParts, resolver) - }, - getAttrCandidates = () => { - assert(q.children.length == 1) - q.children.head.output - }, - throws = true, - allowOuter = allowOuter) - } - /** * In many dialects of SQL it is valid to use ordinal positions in order/sort by and group by * clauses. This rule is to convert ordinal positions to the corresponding expressions in the @@ -2377,36 +2015,6 @@ class Analyzer(override val catalogManager: CatalogManager) } } - /** - * Replace unresolved expressions in grouping keys with resolved ones in SELECT clauses. - * This rule is expected to run after [[ResolveReferences]] applied. - */ - object ResolveAggAliasInGroupBy extends Rule[LogicalPlan] { - - // This is a strict check though, we put this to apply the rule only if the expression is not - // resolvable by child. - private def notResolvableByChild(attrName: String, child: LogicalPlan): Boolean = { - !child.output.exists(a => resolver(a.name, attrName)) - } - - private def mayResolveAttrByAggregateExprs( - exprs: Seq[Expression], aggs: Seq[NamedExpression], child: LogicalPlan): Seq[Expression] = { - exprs.map { _.transformWithPruning(_.containsPattern(UNRESOLVED_ATTRIBUTE)) { - case u: UnresolvedAttribute if notResolvableByChild(u.name, child) => - aggs.find(ne => resolver(ne.name, u.name)).getOrElse(u) - }} - } - - override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUpWithPruning( - // mayResolveAttrByAggregateExprs requires the TreePattern UNRESOLVED_ATTRIBUTE. - _.containsAllPatterns(AGGREGATE, UNRESOLVED_ATTRIBUTE), ruleId) { - case agg @ Aggregate(groups, aggs, child) - if conf.groupByAliases && child.resolved && aggs.forall(_.resolved) && - groups.exists(!_.resolved) => - agg.copy(groupingExpressions = mayResolveAttrByAggregateExprs(groups, aggs, child)) - } - } - /** * Checks whether a function identifier referenced by an [[UnresolvedFunction]] is defined in the diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index bc7b031a73820..c66105d1715d5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -238,7 +238,7 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB // Fail if we still have an unresolved all in group by. This needs to run before the // general unresolved check below to throw a more tailored error message. - ResolveGroupByAll.checkAnalysis(operator) + ResolveReferencesInAggregate.checkUnresolvedGroupByAll(operator) getAllExpressions(operator).foreach(_.foreachUp { case a: Attribute if !a.resolved => @@ -762,8 +762,8 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB private def getAllExpressions(plan: LogicalPlan): Seq[Expression] = { plan match { - // `groupingExpressions` may rely on `aggregateExpressions`, due to the GROUP BY alias - // feature. We should check errors in `aggregateExpressions` first. + // We only resolve `groupingExpressions` if `aggregateExpressions` is resolved first (See + // `ResolveReferencesInAggregate`). We should check errors in `aggregateExpressions` first. case a: Aggregate => a.aggregateExpressions ++ a.groupingExpressions case _ => plan.expressions } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala new file mode 100644 index 0000000000000..9ac64cf4658d0 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala @@ -0,0 +1,365 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import java.util.Locale + +import scala.collection.mutable + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.SubExprUtils.wrapOuterReference +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.trees.CurrentOrigin.withOrigin +import org.apache.spark.sql.catalyst.trees.TreePattern._ +import org.apache.spark.sql.catalyst.util.toPrettySQL +import org.apache.spark.sql.errors.QueryCompilationErrors +import org.apache.spark.sql.internal.SQLConf + +trait ColumnResolutionHelper extends Logging { + + def conf: SQLConf + + /** + * This method tries to resolve expressions and find missing attributes recursively. + * Specifically, when the expressions used in `Sort` or `Filter` contain unresolved attributes + * or resolved attributes which are missing from child output. This method tries to find the + * missing attributes and add them into the projection. + */ + protected def resolveExprsAndAddMissingAttrs( + exprs: Seq[Expression], plan: LogicalPlan): (Seq[Expression], LogicalPlan) = { + // Missing attributes can be unresolved attributes or resolved attributes which are not in + // the output attributes of the plan. + if (exprs.forall(e => e.resolved && e.references.subsetOf(plan.outputSet))) { + (exprs, plan) + } else { + plan match { + case p: Project => + // Resolving expressions against current plan. + val maybeResolvedExprs = exprs.map(resolveExpressionByPlanOutput(_, p)) + // Recursively resolving expressions on the child of current plan. + val (newExprs, newChild) = resolveExprsAndAddMissingAttrs(maybeResolvedExprs, p.child) + // If some attributes used by expressions are resolvable only on the rewritten child + // plan, we need to add them into original projection. + val missingAttrs = (AttributeSet(newExprs) -- p.outputSet).intersect(newChild.outputSet) + (newExprs, Project(p.projectList ++ missingAttrs, newChild)) + + case a @ Aggregate(groupExprs, aggExprs, child) => + val maybeResolvedExprs = exprs.map(resolveExpressionByPlanOutput(_, a)) + val (newExprs, newChild) = resolveExprsAndAddMissingAttrs(maybeResolvedExprs, child) + val missingAttrs = (AttributeSet(newExprs) -- a.outputSet).intersect(newChild.outputSet) + if (missingAttrs.forall(attr => groupExprs.exists(_.semanticEquals(attr)))) { + // All the missing attributes are grouping expressions, valid case. + (newExprs, a.copy(aggregateExpressions = aggExprs ++ missingAttrs, child = newChild)) + } else { + // Need to add non-grouping attributes, invalid case. + (exprs, a) + } + + case g: Generate => + val maybeResolvedExprs = exprs.map(resolveExpressionByPlanOutput(_, g)) + val (newExprs, newChild) = resolveExprsAndAddMissingAttrs(maybeResolvedExprs, g.child) + (newExprs, g.copy(unrequiredChildIndex = Nil, child = newChild)) + + // For `Distinct` and `SubqueryAlias`, we can't recursively resolve and add attributes + // via its children. + case u: UnaryNode if !u.isInstanceOf[Distinct] && !u.isInstanceOf[SubqueryAlias] => + val maybeResolvedExprs = exprs.map(resolveExpressionByPlanOutput(_, u)) + val (newExprs, newChild) = resolveExprsAndAddMissingAttrs(maybeResolvedExprs, u.child) + (newExprs, u.withNewChildren(Seq(newChild))) + + // For other operators, we can't recursively resolve and add attributes via its children. + case other => + (exprs.map(resolveExpressionByPlanOutput(_, other)), other) + } + } + } + + // support CURRENT_DATE, CURRENT_TIMESTAMP, and grouping__id + private val literalFunctions: Seq[(String, () => Expression, Expression => String)] = Seq( + (CurrentDate().prettyName, () => CurrentDate(), toPrettySQL(_)), + (CurrentTimestamp().prettyName, () => CurrentTimestamp(), toPrettySQL(_)), + (CurrentUser().prettyName, () => CurrentUser(), toPrettySQL), + ("user", () => CurrentUser(), toPrettySQL), + (VirtualColumn.hiveGroupingIdName, () => GroupingID(Nil), _ => VirtualColumn.hiveGroupingIdName) + ) + + /** + * Literal functions do not require the user to specify braces when calling them + * When an attributes is not resolvable, we try to resolve it as a literal function. + */ + private def resolveLiteralFunction(nameParts: Seq[String]): Option[NamedExpression] = { + if (nameParts.length != 1) return None + val name = nameParts.head + literalFunctions.find(func => caseInsensitiveResolution(func._1, name)).map { + case (_, getFuncExpr, getAliasName) => + val funcExpr = getFuncExpr() + Alias(funcExpr, getAliasName(funcExpr))() + } + } + + /** + * Resolves `UnresolvedAttribute`, `GetColumnByOrdinal` and extract value expressions(s) by + * traversing the input expression in top-down manner. It must be top-down because we need to + * skip over unbound lambda function expression. The lambda expressions are resolved in a + * different place [[ResolveLambdaVariables]]. + * + * Example : + * SELECT transform(array(1, 2, 3), (x, i) -> x + i)" + * + * In the case above, x and i are resolved as lambda variables in [[ResolveLambdaVariables]]. + */ + private def resolveExpression( + expr: Expression, + resolveColumnByName: Seq[String] => Option[Expression], + getAttrCandidates: () => Seq[Attribute], + throws: Boolean, + allowOuter: Boolean): Expression = { + def innerResolve(e: Expression, isTopLevel: Boolean): Expression = withOrigin(e.origin) { + if (e.resolved) return e + val resolved = e match { + case f: LambdaFunction if !f.bound => f + + case GetColumnByOrdinal(ordinal, _) => + val attrCandidates = getAttrCandidates() + assert(ordinal >= 0 && ordinal < attrCandidates.length) + attrCandidates(ordinal) + + case GetViewColumnByNameAndOrdinal( + viewName, colName, ordinal, expectedNumCandidates, viewDDL) => + val attrCandidates = getAttrCandidates() + val matched = attrCandidates.filter(a => conf.resolver(a.name, colName)) + if (matched.length != expectedNumCandidates) { + throw QueryCompilationErrors.incompatibleViewSchemaChangeError( + viewName, colName, expectedNumCandidates, matched, viewDDL) + } + matched(ordinal) + + case u @ UnresolvedAttribute(nameParts) => + val result = withPosition(u) { + resolveColumnByName(nameParts).orElse(resolveLiteralFunction(nameParts)).map { + // We trim unnecessary alias here. Note that, we cannot trim the alias at top-level, + // as we should resolve `UnresolvedAttribute` to a named expression. The caller side + // can trim the top-level alias if it's safe to do so. Since we will call + // CleanupAliases later in Analyzer, trim non top-level unnecessary alias is safe. + case Alias(child, _) if !isTopLevel => child + case other => other + }.getOrElse(u) + } + logDebug(s"Resolving $u to $result") + result + + // Re-resolves `TempResolvedColumn` if it has tried to be resolved with Aggregate + // but failed. If we still can't resolve it, we should keep it as `TempResolvedColumn`, + // so that it won't become a fresh `TempResolvedColumn` again. + case t: TempResolvedColumn if t.hasTried => withPosition(t) { + innerResolve(UnresolvedAttribute(t.nameParts), isTopLevel) match { + case _: UnresolvedAttribute => t + case other => other + } + } + + case u @ UnresolvedExtractValue(child, fieldName) => + val newChild = innerResolve(child, isTopLevel = false) + if (newChild.resolved) { + ExtractValue(newChild, fieldName, conf.resolver) + } else { + u.copy(child = newChild) + } + + case _ => e.mapChildren(innerResolve(_, isTopLevel = false)) + } + resolved.copyTagsFrom(e) + resolved + } + + try { + val resolved = innerResolve(expr, isTopLevel = true) + if (allowOuter) resolveOuterRef(resolved) else resolved + } catch { + case ae: AnalysisException if !throws => + logDebug(ae.getMessage) + expr + } + } + + // Resolves `UnresolvedAttribute` to `OuterReference`. + protected def resolveOuterRef(e: Expression): Expression = { + val outerPlan = AnalysisContext.get.outerPlan + if (outerPlan.isEmpty) return e + + def resolve(nameParts: Seq[String]): Option[Expression] = try { + outerPlan.get match { + // Subqueries in UnresolvedHaving can host grouping expressions and aggregate functions. + // We should resolve columns with `agg.output` and the rule `ResolveAggregateFunctions` will + // push them down to Aggregate later. This is similar to what we do in `resolveColumns`. + case u @ UnresolvedHaving(_, agg: Aggregate) => + agg.resolveChildren(nameParts, conf.resolver) + .orElse(u.resolveChildren(nameParts, conf.resolver)) + .map(wrapOuterReference) + case other => + other.resolveChildren(nameParts, conf.resolver).map(wrapOuterReference) + } + } catch { + case ae: AnalysisException => + logDebug(ae.getMessage) + None + } + + e.transformWithPruning(_.containsAnyPattern(UNRESOLVED_ATTRIBUTE, TEMP_RESOLVED_COLUMN)) { + case u: UnresolvedAttribute => + resolve(u.nameParts).getOrElse(u) + // Re-resolves `TempResolvedColumn` as outer references if it has tried to be resolved with + // Aggregate but failed. + case t: TempResolvedColumn if t.hasTried => + resolve(t.nameParts).getOrElse(t) + } + } + + // Resolves `UnresolvedAttribute` to `TempResolvedColumn` via `plan.child.output` if plan is an + // `Aggregate`. If `TempResolvedColumn` doesn't end up as aggregate function input or grouping + // column, we will undo the column resolution later to avoid confusing error message. E,g,, if + // a table `t` has columns `c1` and `c2`, for query `SELECT ... FROM t GROUP BY c1 HAVING c2 = 0`, + // even though we can resolve column `c2` here, we should undo it and fail with + // "Column c2 not found". + protected def resolveColWithAgg(e: Expression, plan: LogicalPlan): Expression = plan match { + case agg: Aggregate => + e.transformWithPruning(_.containsAnyPattern(UNRESOLVED_ATTRIBUTE)) { + case u: UnresolvedAttribute => + try { + agg.child.resolve(u.nameParts, conf.resolver).map({ + case a: Alias => TempResolvedColumn(a.child, u.nameParts) + case o => TempResolvedColumn(o, u.nameParts) + }).getOrElse(u) + } catch { + case ae: AnalysisException => + logDebug(ae.getMessage) + u + } + } + case _ => e + } + + protected def resolveLateralColumnAlias(selectList: Seq[Expression]): Seq[Expression] = { + if (!conf.getConf(SQLConf.LATERAL_COLUMN_ALIAS_IMPLICIT_ENABLED)) return selectList + + // A mapping from lower-cased alias name to either the Alias itself, or the count of aliases + // that have the same lower-cased name. If the count is larger than 1, we won't use it to + // resolve lateral column aliases. + val aliasMap = mutable.HashMap.empty[String, Either[Alias, Int]] + + def resolve(e: Expression): Expression = { + e.transformWithPruning( + _.containsAnyPattern(UNRESOLVED_ATTRIBUTE, LATERAL_COLUMN_ALIAS_REFERENCE)) { + case u: UnresolvedAttribute => + // Lateral column alias does not have qualifiers. We always use the first name part to + // look up lateral column aliases. + val lowerCasedName = u.nameParts.head.toLowerCase(Locale.ROOT) + aliasMap.get(lowerCasedName).map { + case scala.util.Left(alias) => + if (alias.resolved) { + val resolvedAttr = resolveExpressionByPlanOutput( + u, LocalRelation(Seq(alias.toAttribute)), throws = true + ).asInstanceOf[NamedExpression] + assert(resolvedAttr.resolved) + LateralColumnAliasReference(resolvedAttr, u.nameParts, alias.toAttribute) + } else { + // Still returns a `LateralColumnAliasReference` even if the lateral column alias + // is not resolved yet. This is to make sure we won't mistakenly resolve it to + // outer references. + LateralColumnAliasReference(u, u.nameParts, alias.toAttribute) + } + case scala.util.Right(count) => + throw QueryCompilationErrors.ambiguousLateralColumnAliasError(u.name, count) + }.getOrElse(u) + + case LateralColumnAliasReference(u: UnresolvedAttribute, _, _) => + resolve(u) + } + } + + selectList.map { + case a: Alias => + val result = resolve(a) + val lowerCasedName = a.name.toLowerCase(Locale.ROOT) + aliasMap.get(lowerCasedName) match { + case Some(scala.util.Left(_)) => + aliasMap(lowerCasedName) = scala.util.Right(2) + case Some(scala.util.Right(count)) => + aliasMap(lowerCasedName) = scala.util.Right(count + 1) + case None => + aliasMap += lowerCasedName -> scala.util.Left(a) + } + result + case other => resolve(other) + } + } + + /** + * Resolves `UnresolvedAttribute`, `GetColumnByOrdinal` and extract value expressions(s) by the + * input plan's output attributes. In order to resolve the nested fields correctly, this function + * makes use of `throws` parameter to control when to raise an AnalysisException. + * + * Example : + * SELECT * FROM t ORDER BY a.b + * + * In the above example, after `a` is resolved to a struct-type column, we may fail to resolve `b` + * if there is no such nested field named "b". We should not fail and wait for other rules to + * resolve it if possible. + */ + def resolveExpressionByPlanOutput( + expr: Expression, + plan: LogicalPlan, + throws: Boolean = false, + allowOuter: Boolean = false): Expression = { + resolveExpression( + expr, + resolveColumnByName = nameParts => { + plan.resolve(nameParts, conf.resolver) + }, + getAttrCandidates = () => plan.output, + throws = throws, + allowOuter = allowOuter) + } + + /** + * Resolves `UnresolvedAttribute`, `GetColumnByOrdinal` and extract value expressions(s) by the + * input plan's children output attributes. + * + * @param e The expression need to be resolved. + * @param q The LogicalPlan whose children are used to resolve expression's attribute. + * @return resolved Expression. + */ + def resolveExpressionByPlanChildren( + e: Expression, + q: LogicalPlan, + allowOuter: Boolean = false): Expression = { + resolveExpression( + e, + resolveColumnByName = nameParts => { + q.resolveChildren(nameParts, conf.resolver) + }, + getAttrCandidates = () => { + assert(q.children.length == 1) + q.children.head.output + }, + throws = true, + allowOuter = allowOuter) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveGroupByAll.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveGroupByAll.scala deleted file mode 100644 index 8c6ba20cd1af9..0000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveGroupByAll.scala +++ /dev/null @@ -1,119 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.analysis - -import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression} -import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression -import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan} -import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.catalyst.trees.TreePattern.{AGGREGATE, UNRESOLVED_ATTRIBUTE} - -/** - * Resolve "group by all" in the following SQL pattern: - * `select col1, col2, agg_expr(...) from table group by all`. - * - * The all is expanded to include all non-aggregate columns in the select clause. - */ -object ResolveGroupByAll extends Rule[LogicalPlan] { - - val ALL = "ALL" - - /** - * Returns true iff this is a GROUP BY ALL aggregate. i.e. an Aggregate expression that has - * a single unresolved all grouping expression. - */ - private def matchToken(a: Aggregate): Boolean = { - if (a.groupingExpressions.size != 1) { - return false - } - a.groupingExpressions.head match { - case a: UnresolvedAttribute => a.equalsIgnoreCase(ALL) - case _ => false - } - } - - /** - * Returns all the grouping expressions inferred from a GROUP BY ALL aggregate. - * The result is optional. If Spark fails to infer the grouping columns, it is None. - * Otherwise, it contains all the non-aggregate expressions from the project list of the input - * Aggregate. - */ - private def getGroupingExpressions(a: Aggregate): Option[Seq[Expression]] = { - val groupingExprs = a.aggregateExpressions.filter(!_.exists(AggregateExpression.isAggregate)) - // If the grouping exprs are empty, this could either be (1) a valid global aggregate, or - // (2) we simply fail to infer the grouping columns. As an example, in "i + sum(j)", we will - // not automatically infer the grouping column to be "i". - if (groupingExprs.isEmpty && a.aggregateExpressions.exists(containsAttribute)) { - None - } else { - Some(groupingExprs) - } - } - - override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUpWithPruning( - _.containsAllPatterns(UNRESOLVED_ATTRIBUTE, AGGREGATE), ruleId) { - case a: Aggregate - if a.child.resolved && a.aggregateExpressions.forall(_.resolved) && matchToken(a) => - // Only makes sense to do the rewrite once all the aggregate expressions have been resolved. - // Otherwise, we might incorrectly pull an actual aggregate expression over to the grouping - // expression list (because we don't know they would be aggregate expressions until resolved). - val groupingExprs = getGroupingExpressions(a) - - if (groupingExprs.isEmpty) { - // Don't replace the ALL when we fail to infer the grouping columns. We will eventually - // tell the user in checkAnalysis that we cannot resolve the all in group by. - a - } else { - // This is a valid GROUP BY ALL aggregate. - a.copy(groupingExpressions = groupingExprs.get) - } - } - - /** - * Returns true if the expression includes an Attribute outside the aggregate expression part. - * For example: - * "i" -> true - * "i + 2" -> true - * "i + sum(j)" -> true - * "sum(j)" -> false - * "sum(j) / 2" -> false - */ - private def containsAttribute(expr: Expression): Boolean = expr match { - case _ if AggregateExpression.isAggregate(expr) => - // Don't recurse into AggregateExpressions - false - case _: Attribute => - true - case e => - e.children.exists(containsAttribute) - } - - /** - * A check to be used in [[CheckAnalysis]] to see if we have any unresolved group by at the - * end of analysis, so we can tell users that we fail to infer the grouping columns. - */ - def checkAnalysis(operator: LogicalPlan): Unit = operator match { - case a: Aggregate if a.aggregateExpressions.forall(_.resolved) && matchToken(a) => - if (getGroupingExpressions(a).isEmpty) { - operator.failAnalysis( - errorClass = "UNRESOLVED_ALL_IN_GROUP_BY", - messageParameters = Map.empty) - } - case _ => - } -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveOrderByAll.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveOrderByAll.scala deleted file mode 100644 index 7cf584dadcf34..0000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveOrderByAll.scala +++ /dev/null @@ -1,81 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.analysis - -import org.apache.spark.sql.catalyst.expressions.SortOrder -import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Sort} -import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.catalyst.trees.TreePattern.{SORT, UNRESOLVED_ATTRIBUTE} - -/** - * Resolve "order by all" in the following SQL pattern: - * `select col1, col2 from table order by all`. - * - * It orders the query result by all columns, from left to right. The query above becomes: - * - * `select col1, col2 from table order by col1, col2` - * - * This should also support specifying asc/desc, and nulls first/last. - */ -object ResolveOrderByAll extends Rule[LogicalPlan] { - - val ALL = "ALL" - - /** - * An extractor to pull out the SortOrder field in the ORDER BY ALL clause. We pull out that - * SortOrder object so we can pass its direction and null ordering. - */ - object OrderByAll { - def unapply(s: Sort): Option[SortOrder] = { - // This only applies to global ordering. - if (!s.global) { - return None - } - // Don't do this if we have more than one order field. That means it's not order by all. - if (s.order.size != 1) { - return None - } - // Don't do this if there's a child field called ALL. That should take precedence. - if (s.child.output.exists(_.name.toUpperCase() == ALL)) { - return None - } - - s.order.find { so => - so.child match { - case a: UnresolvedAttribute => a.name.toUpperCase() == ALL - case _ => false - } - } - } - } - - override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUpWithPruning( - _.containsAllPatterns(UNRESOLVED_ATTRIBUTE, SORT), ruleId) { - // This only makes sense if the child is resolved. - case s: Sort if s.child.resolved => - s match { - case OrderByAll(sortOrder) => - // Replace a single order by all with N fields, where N = child's output, while - // retaining the same asc/desc and nulls ordering. - val order = s.child.output.map(a => sortOrder.copy(child = a)) - s.copy(order = order) - case _ => - s - } - } -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveReferencesInAggregate.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveReferencesInAggregate.scala new file mode 100644 index 0000000000000..4af2ecc91ab55 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveReferencesInAggregate.scala @@ -0,0 +1,201 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.SQLConfHelper +import org.apache.spark.sql.catalyst.expressions.{AliasHelper, Attribute, Expression, NamedExpression} +import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, AppendColumns, LogicalPlan} +import org.apache.spark.sql.catalyst.trees.TreePattern.{LATERAL_COLUMN_ALIAS_REFERENCE, UNRESOLVED_ATTRIBUTE} + +/** + * A virtual rule to resolve [[UnresolvedAttribute]] in [[Aggregate]]. It's only used by the real + * rule `ResolveReferences`. The column resolution order for [[Aggregate]] is: + * 1. Resolves the columns to [[AttributeReference]] with the output of the child plan. This + * includes metadata columns as well. + * 2. Resolves the columns to a literal function which is allowed to be invoked without braces, e.g. + * `SELECT col, current_date FROM t`. + * 3. If aggregate expressions are all resolved, resolve GROUP BY alias and GROUP BY ALL. + * 3.1. If the grouping expressions contain an unresolved column whose name matches an alias in the + * SELECT list, resolves that unresolved column to the alias. This is to support SQL pattern + * like `SELECT a + b AS c, max(col) FROM t GROUP BY c`. + * 3.2. If the grouping expressions only have one single unresolved column named 'ALL', expanded it + * to include all non-aggregate columns in the SELECT list. This is to support SQL pattern like + * `SELECT col1, col2, agg_expr(...) FROM t GROUP BY ALL`. + * 4. Resolves the columns in aggregate expressions to [[LateralColumnAliasReference]] if + * it references the alias defined previously in the SELECT list. The rule + * `ResolveLateralColumnAliasReference` will further resolve [[LateralColumnAliasReference]] and + * rewrite the plan. This is to support SQL pattern like + * `SELECT col1 + 1 AS x, x + 1 AS y, y + 1 AS z FROM t`. + * 5. Resolves the columns to outer references with the outer plan if we are resolving subquery + * expressions. + */ +object ResolveReferencesInAggregate extends SQLConfHelper + with ColumnResolutionHelper with AliasHelper { + def apply(a: Aggregate): Aggregate = { + val planForResolve = a.child match { + // SPARK-25942: Resolves aggregate expressions with `AppendColumns`'s children, instead of + // `AppendColumns`, because `AppendColumns`'s serializer might produce conflict attribute + // names leading to ambiguous references exception. + case appendColumns: AppendColumns => appendColumns + case _ => a + } + + val resolvedGroupExprsNoOuter = a.groupingExpressions + .map(resolveExpressionByPlanChildren(_, planForResolve, allowOuter = false)) + val resolvedAggExprsNoOuter = a.aggregateExpressions.map( + resolveExpressionByPlanChildren(_, planForResolve, allowOuter = false)) + val resolvedAggExprsWithLCA = resolveLateralColumnAlias(resolvedAggExprsNoOuter) + val resolvedAggExprsWithOuter = resolvedAggExprsWithLCA.map(resolveOuterRef) + .map(_.asInstanceOf[NamedExpression]) + // `groupingExpressions` may rely on `aggregateExpressions`, due to features like GROUP BY alias + // and GROUP BY ALL. We only do basic resolution for `groupingExpressions`, and will further + // resolve it after `aggregateExpressions` are all resolved. Note: the basic resolution is + // needed as `aggregateExpressions` may rely on `groupingExpressions` as well, for the session + // window feature. See the rule `SessionWindowing` for more details. + val resolvedGroupExprs = if (resolvedAggExprsWithOuter.forall(_.resolved)) { + val resolved = resolveGroupByAll( + resolvedAggExprsWithOuter, + resolveGroupByAlias(resolvedAggExprsWithOuter, resolvedGroupExprsNoOuter) + ).map(resolveOuterRef) + // TODO: currently we don't support LCA in `groupingExpressions` yet. + if (resolved.exists(_.containsPattern(LATERAL_COLUMN_ALIAS_REFERENCE))) { + throw new AnalysisException( + errorClass = "UNSUPPORTED_FEATURE.LATERAL_COLUMN_ALIAS_IN_GROUP_BY", + messageParameters = Map.empty) + } + resolved + } else { + // Do not resolve columns in grouping expressions to outer references here, as the aggregate + // expressions are not fully resolved yet and we still have chances to resolve GROUP BY + // alias/ALL in the next iteration. If aggregate expressions end up as unresolved, we don't + // need to resolve grouping expressions at all, as `CheckAnalysis` will report error for + // aggregate expressions first. + resolvedGroupExprsNoOuter + } + a.copy( + // The aliases in grouping expressions are useless and will be removed at the end of analysis + // by the rule `CleanupAliases`. However, some rules need to find the grouping expressions + // from aggregate expressions during analysis. If we don't remove alias here, then these rules + // can't find the grouping expressions via `semanticEquals` and the analysis will fail. + // Example rules: ResolveGroupingAnalytics (See SPARK-31670 for more details) and + // ResolveLateralColumnAliasReference. + groupingExpressions = resolvedGroupExprs.map(trimAliases), + aggregateExpressions = resolvedAggExprsWithOuter) + } + + private def resolveGroupByAlias( + selectList: Seq[NamedExpression], + groupExprs: Seq[Expression]): Seq[Expression] = { + assert(selectList.forall(_.resolved)) + if (conf.groupByAliases) { + groupExprs.map { g => + g.transformWithPruning(_.containsPattern(UNRESOLVED_ATTRIBUTE)) { + case u: UnresolvedAttribute => + selectList.find(ne => conf.resolver(ne.name, u.name)).getOrElse(u) + } + } + } else { + groupExprs + } + } + + private def resolveGroupByAll( + selectList: Seq[NamedExpression], + groupExprs: Seq[Expression]): Seq[Expression] = { + assert(selectList.forall(_.resolved)) + if (isGroupByAll(groupExprs)) { + val expandedGroupExprs = expandGroupByAll(selectList) + if (expandedGroupExprs.isEmpty) { + // Don't replace the ALL when we fail to infer the grouping columns. We will eventually + // tell the user in checkAnalysis that we cannot resolve the all in group by. + groupExprs + } else { + // This is a valid GROUP BY ALL aggregate. + expandedGroupExprs.get + } + } else { + groupExprs + } + } + + /** + * Returns all the grouping expressions inferred from a GROUP BY ALL aggregate. + * The result is optional. If Spark fails to infer the grouping columns, it is None. + * Otherwise, it contains all the non-aggregate expressions from the project list of the input + * Aggregate. + */ + private def expandGroupByAll(selectList: Seq[NamedExpression]): Option[Seq[Expression]] = { + val groupingExprs = selectList.filter(!_.exists(AggregateExpression.isAggregate)) + // If the grouping exprs are empty, this could either be (1) a valid global aggregate, or + // (2) we simply fail to infer the grouping columns. As an example, in "i + sum(j)", we will + // not automatically infer the grouping column to be "i". + if (groupingExprs.isEmpty && selectList.exists(containsAttribute)) { + None + } else { + Some(groupingExprs) + } + } + + /** + * Returns true iff this is a GROUP BY ALL: the grouping expressions only have a single column, + * which is an unresolved column named ALL. + */ + private def isGroupByAll(exprs: Seq[Expression]): Boolean = { + if (exprs.length != 1) return false + exprs.head match { + case a: UnresolvedAttribute => a.equalsIgnoreCase("ALL") + case _ => false + } + } + + /** + * Returns true if the expression includes an Attribute outside the aggregate expression part. + * For example: + * "i" -> true + * "i + 2" -> true + * "i + sum(j)" -> true + * "sum(j)" -> false + * "sum(j) / 2" -> false + */ + private def containsAttribute(expr: Expression): Boolean = expr match { + case _ if AggregateExpression.isAggregate(expr) => + // Don't recurse into AggregateExpressions + false + case _: Attribute => + true + case e => + e.children.exists(containsAttribute) + } + + /** + * A check to be used in [[CheckAnalysis]] to see if we have any unresolved group by at the + * end of analysis, so we can tell users that we fail to infer the grouping columns. + */ + def checkUnresolvedGroupByAll(operator: LogicalPlan): Unit = operator match { + case a: Aggregate if a.aggregateExpressions.forall(_.resolved) && + isGroupByAll(a.groupingExpressions) => + if (expandGroupByAll(a.aggregateExpressions).isEmpty) { + operator.failAnalysis( + errorClass = "UNRESOLVED_ALL_IN_GROUP_BY", + messageParameters = Map.empty) + } + case _ => + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveReferencesInSort.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveReferencesInSort.scala new file mode 100644 index 0000000000000..54044932d9e3b --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveReferencesInSort.scala @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.catalyst.SQLConfHelper +import org.apache.spark.sql.catalyst.expressions.SortOrder +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, Sort} + +/** + * A virtual rule to resolve [[UnresolvedAttribute]] in [[Sort]]. It's only used by the real + * rule `ResolveReferences`. The column resolution order for [[Sort]] is: + * 1. Resolves the column to [[AttributeReference]] with the output of the child plan. This + * includes metadata columns as well. + * 2. Resolves the column to a literal function which is allowed to be invoked without braces, e.g. + * `SELECT col, current_date FROM t`. + * 3. If the child plan is Aggregate, resolves the column to [[TempResolvedColumn]] with the output + * of Aggregate's child plan. This is to allow Sort to host grouping expressions and aggregate + * functions, which can be pushed down to the Aggregate later. For example, + * `SELECT max(a) FROM t GROUP BY b ORDER BY min(a)`. + * 4. Resolves the column to [[AttributeReference]] with the output of a descendant plan node. + * Spark will propagate the missing attributes from the descendant plan node to the Sort node. + * This is to allow users to ORDER BY columns that are not in the SELECT clause, which is + * widely supported in other SQL dialects. For example, `SELECT a FROM t ORDER BY b`. + * 5. If the order by expressions only have one single unresolved column named ALL, expanded it to + * include all columns in the SELECT list. This is to support SQL pattern like + * `SELECT col1, col2 FROM t ORDER BY ALL`. This should also support specifying asc/desc, and + * nulls first/last. + * 6. Resolves the column to outer references with the outer plan if we are resolving subquery + * expressions. + * + * Note, 3 and 4 are actually orthogonal. If the child plan is Aggregate, 4 can only resolve columns + * as the grouping columns, which is completely covered by 3. + */ +object ResolveReferencesInSort extends SQLConfHelper with ColumnResolutionHelper { + + def apply(s: Sort): LogicalPlan = { + val resolvedNoOuter = s.order.map(resolveExpressionByPlanOutput(_, s.child)) + val resolvedWithAgg = resolvedNoOuter.map(resolveColWithAgg(_, s.child)) + val (missingAttrResolved, newChild) = resolveExprsAndAddMissingAttrs(resolvedWithAgg, s.child) + val orderByAllResolved = resolveOrderByAll( + s.global, newChild, missingAttrResolved.map(_.asInstanceOf[SortOrder])) + val finalOrdering = orderByAllResolved.map(e => resolveOuterRef(e).asInstanceOf[SortOrder]) + if (s.child.output == newChild.output) { + s.copy(order = finalOrdering) + } else { + // Add missing attributes and then project them away. + val newSort = s.copy(order = finalOrdering, child = newChild) + Project(s.child.output, newSort) + } + } + + private def resolveOrderByAll( + globalSort: Boolean, + child: LogicalPlan, + orders: Seq[SortOrder]): Seq[SortOrder] = { + // This only applies to global ordering. + if (!globalSort) return orders + // Don't do this if we have more than one order field. That means it's not order by all. + if (orders.length != 1) return orders + + val order = orders.head + order.child match { + case a: UnresolvedAttribute if a.equalsIgnoreCase("ALL") => + // Replace a single order by all with N fields, where N = child's output, while + // retaining the same asc/desc and nulls ordering. + child.output.map(a => order.copy(child = a)) + case _ => orders + } + } +} diff --git a/sql/core/src/test/resources/sql-tests/inputs/column-resolution-aggregate.sql b/sql/core/src/test/resources/sql-tests/inputs/column-resolution-aggregate.sql new file mode 100644 index 0000000000000..4f879fc809d9f --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/column-resolution-aggregate.sql @@ -0,0 +1,33 @@ +-- Tests covering column resolution priority in Aggregate. + +CREATE TEMPORARY VIEW v1 AS VALUES (1, 1, 1), (2, 2, 1) AS t(a, b, k); +CREATE TEMPORARY VIEW v2 AS VALUES (1, 1, 1), (2, 2, 1) AS t(x, y, all); + +-- Relation output columns have higher priority than lateral column alias. This query +-- should fail as `b` is not in GROUP BY. +SELECT max(a) AS b, b FROM v1 GROUP BY k; + +-- Lateral column alias has higher priority than outer reference. +SELECT a FROM v1 WHERE (12, 13) IN (SELECT max(x + 10) AS a, a + 1 FROM v2); + +-- Relation output columns have higher priority than GROUP BY alias. This query should +-- fail as `a` is not in GROUP BY. +SELECT a AS k FROM v1 GROUP BY k; + +-- Relation output columns have higher priority than GROUP BY ALL. This query should +-- fail as `x` is not in GROUP BY. +SELECT x FROM v2 GROUP BY all; + +-- GROUP BY alias has higher priority than GROUP BY ALL, this query fails as `b` is not in GROUP BY. +SELECT a AS all, b FROM v1 GROUP BY all; + +-- GROUP BY alias/ALL does not support lateral column alias. +SELECT k AS lca, lca + 1 AS col FROM v1 GROUP BY k, col; +SELECT k AS lca, lca + 1 AS col FROM v1 GROUP BY all; + +-- GROUP BY alias still works if it does not directly reference lateral column alias. +SELECT k AS lca, lca + 1 AS col FROM v1 GROUP BY lca; + +-- GROUP BY ALL has higher priority than outer reference. This query should run as `a` and `b` are +-- in GROUP BY due to the GROUP BY ALL resolution. +SELECT * FROM v2 WHERE EXISTS (SELECT a, b FROM v1 GROUP BY all); diff --git a/sql/core/src/test/resources/sql-tests/inputs/column-resolution-sort.sql b/sql/core/src/test/resources/sql-tests/inputs/column-resolution-sort.sql new file mode 100644 index 0000000000000..2c5b9f9e9dfc7 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/column-resolution-sort.sql @@ -0,0 +1,20 @@ +--SET spark.sql.leafNodeDefaultParallelism=1 +-- Tests covering column resolution priority in Sort. + +CREATE TEMPORARY VIEW v1 AS VALUES (1, 2, 2), (2, 1, 1) AS t(a, b, k); +CREATE TEMPORARY VIEW v2 AS VALUES (1, 2, 2), (2, 1, 1) AS t(a, b, all); + +-- Relation output columns have higher priority than missing reference. +-- Query will fail if we order by the column `v1.b`, as it's not in GROUP BY. +-- Actually results are [1, 2] as we order by `max(a) AS b`. +SELECT max(a) AS b FROM v1 GROUP BY k ORDER BY b; + +-- Missing reference has higher priority than ORDER BY ALL. +-- Results will be [1, 2] if we order by `max(a)`. +-- Actually results are [2, 1] as we order by the grouping column `v2.all`. +SELECT max(a) FROM v2 GROUP BY all ORDER BY all; + +-- ORDER BY ALL has higher priority than outer reference. +-- Results will be [1, 1] if we order by outer reference 'v2.all'. +-- Actually results are [2, 2] as we order by column `v1.b` +SELECT (SELECT b FROM v1 ORDER BY all LIMIT 1) FROM v2; diff --git a/sql/core/src/test/resources/sql-tests/results/column-resolution-aggregate.sql.out b/sql/core/src/test/resources/sql-tests/results/column-resolution-aggregate.sql.out new file mode 100644 index 0000000000000..e8ab766751c43 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/column-resolution-aggregate.sql.out @@ -0,0 +1,129 @@ +-- Automatically generated by SQLQueryTestSuite +-- !query +CREATE TEMPORARY VIEW v1 AS VALUES (1, 1, 1), (2, 2, 1) AS t(a, b, k) +-- !query schema +struct<> +-- !query output + + + +-- !query +CREATE TEMPORARY VIEW v2 AS VALUES (1, 1, 1), (2, 2, 1) AS t(x, y, all) +-- !query schema +struct<> +-- !query output + + + +-- !query +SELECT max(a) AS b, b FROM v1 GROUP BY k +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "MISSING_AGGREGATION", + "sqlState" : "42803", + "messageParameters" : { + "expression" : "\"b\"", + "expressionAnyValue" : "\"any_value(b)\"" + } +} + + +-- !query +SELECT a FROM v1 WHERE (12, 13) IN (SELECT max(x + 10) AS a, a + 1 FROM v2) +-- !query schema +struct +-- !query output +1 +2 + + +-- !query +SELECT a AS k FROM v1 GROUP BY k +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "MISSING_AGGREGATION", + "sqlState" : "42803", + "messageParameters" : { + "expression" : "\"a\"", + "expressionAnyValue" : "\"any_value(a)\"" + } +} + + +-- !query +SELECT x FROM v2 GROUP BY all +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "MISSING_AGGREGATION", + "sqlState" : "42803", + "messageParameters" : { + "expression" : "\"x\"", + "expressionAnyValue" : "\"any_value(x)\"" + } +} + + +-- !query +SELECT a AS all, b FROM v1 GROUP BY all +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "MISSING_AGGREGATION", + "sqlState" : "42803", + "messageParameters" : { + "expression" : "\"b\"", + "expressionAnyValue" : "\"any_value(b)\"" + } +} + + +-- !query +SELECT k AS lca, lca + 1 AS col FROM v1 GROUP BY k, col +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "UNSUPPORTED_FEATURE.LATERAL_COLUMN_ALIAS_IN_GROUP_BY", + "sqlState" : "0A000" +} + + +-- !query +SELECT k AS lca, lca + 1 AS col FROM v1 GROUP BY all +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "UNSUPPORTED_FEATURE.LATERAL_COLUMN_ALIAS_IN_GROUP_BY", + "sqlState" : "0A000" +} + + +-- !query +SELECT k AS lca, lca + 1 AS col FROM v1 GROUP BY lca +-- !query schema +struct +-- !query output +1 2 + + +-- !query +SELECT * FROM v2 WHERE EXISTS (SELECT a, b FROM v1 GROUP BY all) +-- !query schema +struct +-- !query output +1 1 1 +2 2 1 diff --git a/sql/core/src/test/resources/sql-tests/results/column-resolution-sort.sql.out b/sql/core/src/test/resources/sql-tests/results/column-resolution-sort.sql.out new file mode 100644 index 0000000000000..67323d734c909 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/column-resolution-sort.sql.out @@ -0,0 +1,42 @@ +-- Automatically generated by SQLQueryTestSuite +-- !query +CREATE TEMPORARY VIEW v1 AS VALUES (1, 2, 2), (2, 1, 1) AS t(a, b, k) +-- !query schema +struct<> +-- !query output + + + +-- !query +CREATE TEMPORARY VIEW v2 AS VALUES (1, 2, 2), (2, 1, 1) AS t(a, b, all) +-- !query schema +struct<> +-- !query output + + + +-- !query +SELECT max(a) AS b FROM v1 GROUP BY k ORDER BY b +-- !query schema +struct +-- !query output +1 +2 + + +-- !query +SELECT max(a) FROM v2 GROUP BY all ORDER BY all +-- !query schema +struct +-- !query output +2 +1 + + +-- !query +SELECT (SELECT b FROM v1 ORDER BY all LIMIT 1) FROM v2 +-- !query schema +struct +-- !query output +1 +1