From f78719f1f42fbaa1b82c81690c75dfb252d19afe Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 7 Sep 2022 23:44:54 +0800 Subject: [PATCH] [SPARK-40149][SQL][3.2] Propagate metadata columns through Project backport https://github.com/apache/spark/pull/37758 to 3.2 This PR fixes a regression caused by https://github.com/apache/spark/pull/32017 . In https://github.com/apache/spark/pull/32017 , we tried to be more conservative and decided to not propagate metadata columns in certain operators, including `Project`. However, the decision was made only considering SQL API, not DataFrame API. In fact, it's very common to chain `Project` operators in DataFrame, e.g. `df.withColumn(...).withColumn(...)...`, and it's very inconvenient if metadata columns are not propagated through `Project`. This PR makes 2 changes: 1. Project should propagate metadata columns 2. SubqueryAlias should only propagate metadata columns if the child is a leaf node or also a SubqueryAlias The second change is needed to still forbid weird queries like `SELECT m from (SELECT a from t)`, which is the main motivation of https://github.com/apache/spark/pull/32017 . After propagating metadata columns, a problem from https://github.com/apache/spark/pull/31666 is exposed: the natural join metadata columns may confuse the analyzer and lead to wrong analyzed plan. For example, `SELECT t1.value FROM t1 LEFT JOIN t2 USING (key) ORDER BY key`, how shall we resolve `ORDER BY key`? It should be resolved to `t1.key` via the rule `ResolveMissingReferences`, which is in the output of the left join. However, if `Project` can propagate metadata columns, `ORDER BY key` will be resolved to `t2.key`. To solve this problem, this PR only allows qualified access for metadata columns of natural join. This has no breaking change, as people can only do qualified access for natural join metadata columns before, in the `Project` right after `Join`. This actually enables more use cases, as people can now access natural join metadata columns in ORDER BY. I've added a test for it. fix a regression For SQL API, there is no change, as a `SubqueryAlias` always comes with a `Project` or `Aggregate`, so we still don't propagate metadata columns through a SELECT group. For DataFrame API, the behavior becomes more lenient. The only breaking case is an operator that can propagate metadata columns then follows a `SubqueryAlias`, e.g. `df.filter(...).as("t").select("t.metadata_col")`. But this is a weird use case and I don't think we should support it at the first place. new tests Closes #37818 from cloud-fan/backport. Authored-by: Wenchen Fan Signed-off-by: Wenchen Fan (cherry picked from commit d566017de441beebfb62d9d9271defd4041ffdc4) --- .../sql/catalyst/analysis/Analyzer.scala | 8 +- .../sql/catalyst/analysis/unresolved.scala | 2 +- .../sql/catalyst/expressions/package.scala | 13 +- .../plans/logical/basicLogicalOperators.scala | 13 +- .../spark/sql/catalyst/util/package.scala | 15 +- .../resources/sql-tests/inputs/using-join.sql | 2 + .../sql-tests/results/using-join.sql.out | 11 + .../sql/connector/DataSourceV2SQLSuite.scala | 218 ----------------- .../sql/connector/MetadataColumnSuite.scala | 219 ++++++++++++++++++ 9 files changed, 263 insertions(+), 238 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/connector/MetadataColumnSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index ae66a6c7005ac..55e0fe307f52f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1048,9 +1048,11 @@ class Analyzer(override val catalogManager: CatalogManager) private def addMetadataCol(plan: LogicalPlan): LogicalPlan = plan match { case r: DataSourceV2Relation => r.withMetadataColumns() case p: Project => - p.copy( + val newProj = p.copy( projectList = p.metadataOutput ++ p.projectList, child = addMetadataCol(p.child)) + newProj.copyTagsFrom(p) + newProj case _ => plan.withNewChildren(plan.children.map(addMetadataCol)) } } @@ -3532,8 +3534,8 @@ class Analyzer(override val catalogManager: CatalogManager) val project = Project(projectList, Join(left, right, joinType, newCondition, hint)) project.setTagValue( Project.hiddenOutputTag, - hiddenList.map(_.markAsSupportsQualifiedStar()) ++ - project.child.metadataOutput.filter(_.supportsQualifiedStar)) + hiddenList.map(_.markAsQualifiedAccessOnly()) ++ + project.child.metadataOutput.filter(_.qualifiedAccessOnly)) project } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index 9db038dbf350b..cd02b03e2d00e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -386,7 +386,7 @@ case class UnresolvedStar(target: Option[Seq[String]]) extends Star with Unevalu if (target.isEmpty) return input.output // If there is a table specified, use hidden input attributes as well - val hiddenOutput = input.metadataOutput.filter(_.supportsQualifiedStar) + val hiddenOutput = input.metadataOutput.filter(_.qualifiedAccessOnly) val expandedAttributes = (hiddenOutput ++ input.output).filter( matchedQualifier(_, target.get, resolver)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala index 6a4fb099c8b78..7913f396120f4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala @@ -23,6 +23,7 @@ import com.google.common.collect.Maps import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis.{Resolver, UnresolvedAttribute} +import org.apache.spark.sql.catalyst.util.MetadataColumnHelper import org.apache.spark.sql.types.{StructField, StructType} /** @@ -265,7 +266,7 @@ package object expressions { case (Seq(), _) => val name = nameParts.head val attributes = collectMatches(name, direct.get(name.toLowerCase(Locale.ROOT))) - (attributes, nameParts.tail) + (attributes.filterNot(_.qualifiedAccessOnly), nameParts.tail) case _ => matches } } @@ -314,10 +315,12 @@ package object expressions { var i = nameParts.length - 1 while (i >= 0 && candidates.isEmpty) { val name = nameParts(i) - candidates = collectMatches( - name, - nameParts.take(i), - direct.get(name.toLowerCase(Locale.ROOT))) + val attrsToLookup = if (i == 0) { + direct.get(name.toLowerCase(Locale.ROOT)).map(_.filterNot(_.qualifiedAccessOnly)) + } else { + direct.get(name.toLowerCase(Locale.ROOT)) + } + candidates = collectMatches(name, nameParts.take(i), attrsToLookup) if (candidates.nonEmpty) { nestedFields = nameParts.takeRight(nameParts.length - i - 1) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 9b699a94d0be7..95d7578bea1aa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -88,7 +88,7 @@ case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) getAllValidConstraints(projectList) override def metadataOutput: Seq[Attribute] = - getTagValue(Project.hiddenOutputTag).getOrElse(Nil) + getTagValue(Project.hiddenOutputTag).getOrElse(child.metadataOutput) override protected def withNewChildInternal(newChild: LogicalPlan): Project = copy(child = newChild) @@ -1307,9 +1307,14 @@ case class SubqueryAlias( } override def metadataOutput: Seq[Attribute] = { - val qualifierList = identifier.qualifier :+ alias - val nonHiddenMetadataOutput = child.metadataOutput.filter(!_.supportsQualifiedStar) - nonHiddenMetadataOutput.map(_.withQualifier(qualifierList)) + // Propagate metadata columns from leaf nodes through a chain of `SubqueryAlias`. + if (child.isInstanceOf[LeafNode] || child.isInstanceOf[SubqueryAlias]) { + val qualifierList = identifier.qualifier :+ alias + val nonHiddenMetadataOutput = child.metadataOutput.filter(!_.qualifiedAccessOnly) + nonHiddenMetadataOutput.map(_.withQualifier(qualifierList)) + } else { + Nil + } } override def maxRows: Option[Long] = child.maxRows diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala index 33fe48d44dadb..d1a0aa52f6757 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala @@ -206,22 +206,23 @@ package object util extends Logging { implicit class MetadataColumnHelper(attr: Attribute) { /** - * If set, this metadata column is a candidate during qualified star expansions. + * If set, this metadata column can only be accessed with qualifiers, e.g. `qualifiers.col` or + * `qualifiers.*`. If not set, metadata columns cannot be accessed via star. */ - val SUPPORTS_QUALIFIED_STAR = "__supports_qualified_star" + val QUALIFIED_ACCESS_ONLY = "__qualified_access_only" def isMetadataCol: Boolean = attr.metadata.contains(METADATA_COL_ATTR_KEY) && attr.metadata.getBoolean(METADATA_COL_ATTR_KEY) - def supportsQualifiedStar: Boolean = attr.isMetadataCol && - attr.metadata.contains(SUPPORTS_QUALIFIED_STAR) && - attr.metadata.getBoolean(SUPPORTS_QUALIFIED_STAR) + def qualifiedAccessOnly: Boolean = attr.isMetadataCol && + attr.metadata.contains(QUALIFIED_ACCESS_ONLY) && + attr.metadata.getBoolean(QUALIFIED_ACCESS_ONLY) - def markAsSupportsQualifiedStar(): Attribute = attr.withMetadata( + def markAsQualifiedAccessOnly(): Attribute = attr.withMetadata( new MetadataBuilder() .withMetadata(attr.metadata) .putBoolean(METADATA_COL_ATTR_KEY, true) - .putBoolean(SUPPORTS_QUALIFIED_STAR, true) + .putBoolean(QUALIFIED_ACCESS_ONLY, true) .build() ) } diff --git a/sql/core/src/test/resources/sql-tests/inputs/using-join.sql b/sql/core/src/test/resources/sql-tests/inputs/using-join.sql index 336d19f0f2a3d..87390b388764f 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/using-join.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/using-join.sql @@ -19,6 +19,8 @@ SELECT nt1.*, nt2.* FROM nt1 left outer join nt2 using (k); SELECT nt1.k, nt2.k FROM nt1 left outer join nt2 using (k); +SELECT nt1.k, nt2.k FROM nt1 left outer join nt2 using (k) ORDER BY nt2.k; + SELECT k, nt1.k FROM nt1 left outer join nt2 using (k); SELECT k, nt2.k FROM nt1 left outer join nt2 using (k); diff --git a/sql/core/src/test/resources/sql-tests/results/using-join.sql.out b/sql/core/src/test/resources/sql-tests/results/using-join.sql.out index 1d2ae9d96ecad..db9ac1f10bb00 100644 --- a/sql/core/src/test/resources/sql-tests/results/using-join.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/using-join.sql.out @@ -71,6 +71,17 @@ three NULL two two +-- !query +SELECT nt1.k, nt2.k FROM nt1 left outer join nt2 using (k) ORDER BY nt2.k +-- !query schema +struct +-- !query output +three NULL +one one +one one +two two + + -- !query SELECT k, nt1.k FROM nt1 left outer join nt2 using (k) -- !query schema diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala index e6eb2105763e8..db8cc71bb6e66 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala @@ -2524,100 +2524,6 @@ class DataSourceV2SQLSuite } } - test("SPARK-31255: Project a metadata column") { - val t1 = s"${catalogAndNamespace}table" - withTable(t1) { - sql(s"CREATE TABLE $t1 (id bigint, data string) USING $v2Format " + - "PARTITIONED BY (bucket(4, id), id)") - sql(s"INSERT INTO $t1 VALUES (1, 'a'), (2, 'b'), (3, 'c')") - - val sqlQuery = spark.sql(s"SELECT id, data, index, _partition FROM $t1") - val dfQuery = spark.table(t1).select("id", "data", "index", "_partition") - - Seq(sqlQuery, dfQuery).foreach { query => - checkAnswer(query, Seq(Row(1, "a", 0, "3/1"), Row(2, "b", 0, "0/2"), Row(3, "c", 0, "1/3"))) - } - } - } - - test("SPARK-31255: Projects data column when metadata column has the same name") { - val t1 = s"${catalogAndNamespace}table" - withTable(t1) { - sql(s"CREATE TABLE $t1 (index bigint, data string) USING $v2Format " + - "PARTITIONED BY (bucket(4, index), index)") - sql(s"INSERT INTO $t1 VALUES (3, 'c'), (2, 'b'), (1, 'a')") - - val sqlQuery = spark.sql(s"SELECT index, data, _partition FROM $t1") - val dfQuery = spark.table(t1).select("index", "data", "_partition") - - Seq(sqlQuery, dfQuery).foreach { query => - checkAnswer(query, Seq(Row(3, "c", "1/3"), Row(2, "b", "0/2"), Row(1, "a", "3/1"))) - } - } - } - - test("SPARK-31255: * expansion does not include metadata columns") { - val t1 = s"${catalogAndNamespace}table" - withTable(t1) { - sql(s"CREATE TABLE $t1 (id bigint, data string) USING $v2Format " + - "PARTITIONED BY (bucket(4, id), id)") - sql(s"INSERT INTO $t1 VALUES (3, 'c'), (2, 'b'), (1, 'a')") - - val sqlQuery = spark.sql(s"SELECT * FROM $t1") - val dfQuery = spark.table(t1) - - Seq(sqlQuery, dfQuery).foreach { query => - checkAnswer(query, Seq(Row(3, "c"), Row(2, "b"), Row(1, "a"))) - } - } - } - - test("SPARK-31255: metadata column should only be produced when necessary") { - val t1 = s"${catalogAndNamespace}table" - withTable(t1) { - sql(s"CREATE TABLE $t1 (id bigint, data string) USING $v2Format " + - "PARTITIONED BY (bucket(4, id), id)") - - val sqlQuery = spark.sql(s"SELECT * FROM $t1 WHERE index = 0") - val dfQuery = spark.table(t1).filter("index = 0") - - Seq(sqlQuery, dfQuery).foreach { query => - assert(query.schema.fieldNames.toSeq == Seq("id", "data")) - } - } - } - - test("SPARK-34547: metadata columns are resolved last") { - val t1 = s"${catalogAndNamespace}tableOne" - val t2 = "t2" - withTable(t1) { - sql(s"CREATE TABLE $t1 (id bigint, data string) USING $v2Format " + - "PARTITIONED BY (bucket(4, id), id)") - sql(s"INSERT INTO $t1 VALUES (1, 'a'), (2, 'b'), (3, 'c')") - withTempView(t2) { - sql(s"CREATE TEMPORARY VIEW $t2 AS SELECT * FROM " + - s"VALUES (1, -1), (2, -2), (3, -3) AS $t2(id, index)") - - val sqlQuery = spark.sql(s"SELECT $t1.id, $t2.id, data, index, $t1.index, $t2.index FROM " + - s"$t1 JOIN $t2 WHERE $t1.id = $t2.id") - val t1Table = spark.table(t1) - val t2Table = spark.table(t2) - val dfQuery = t1Table.join(t2Table, t1Table.col("id") === t2Table.col("id")) - .select(s"$t1.id", s"$t2.id", "data", "index", s"$t1.index", s"$t2.index") - - Seq(sqlQuery, dfQuery).foreach { query => - checkAnswer(query, - Seq( - Row(1, 1, "a", -1, 0, -1), - Row(2, 2, "b", -2, 0, -2), - Row(3, 3, "c", -3, 0, -3) - ) - ) - } - } - } - } - test("SPARK-33505: insert into partitioned table") { val t = "testpart.ns1.ns2.tbl" withTable(t) { @@ -2702,27 +2608,6 @@ class DataSourceV2SQLSuite } } - test("SPARK-34555: Resolve DataFrame metadata column") { - val tbl = s"${catalogAndNamespace}table" - withTable(tbl) { - sql(s"CREATE TABLE $tbl (id bigint, data string) USING $v2Format " + - "PARTITIONED BY (bucket(4, id), id)") - sql(s"INSERT INTO $tbl VALUES (1, 'a'), (2, 'b'), (3, 'c')") - val table = spark.table(tbl) - val dfQuery = table.select( - table.col("id"), - table.col("data"), - table.col("index"), - table.col("_partition") - ) - - checkAnswer( - dfQuery, - Seq(Row(1, "a", 0, "3/1"), Row(2, "b", 0, "0/2"), Row(3, "c", 0, "1/3")) - ) - } - } - test("SPARK-34561: drop/add columns to a dataset of `DESCRIBE TABLE`") { val tbl = s"${catalogAndNamespace}tbl" withTable(tbl) { @@ -2785,109 +2670,6 @@ class DataSourceV2SQLSuite } } - test("SPARK-34923: do not propagate metadata columns through Project") { - val t1 = s"${catalogAndNamespace}table" - withTable(t1) { - sql(s"CREATE TABLE $t1 (id bigint, data string) USING $v2Format " + - "PARTITIONED BY (bucket(4, id), id)") - sql(s"INSERT INTO $t1 VALUES (1, 'a'), (2, 'b'), (3, 'c')") - - assertThrows[AnalysisException] { - sql(s"SELECT index, _partition from (SELECT id, data FROM $t1)") - } - assertThrows[AnalysisException] { - spark.table(t1).select("id", "data").select("index", "_partition") - } - } - } - - test("SPARK-34923: do not propagate metadata columns through View") { - val t1 = s"${catalogAndNamespace}table" - val view = "view" - - withTable(t1) { - withTempView(view) { - sql(s"CREATE TABLE $t1 (id bigint, data string) USING $v2Format " + - "PARTITIONED BY (bucket(4, id), id)") - sql(s"INSERT INTO $t1 VALUES (1, 'a'), (2, 'b'), (3, 'c')") - sql(s"CACHE TABLE $view AS SELECT * FROM $t1") - assertThrows[AnalysisException] { - sql(s"SELECT index, _partition FROM $view") - } - } - } - } - - test("SPARK-34923: propagate metadata columns through Filter") { - val t1 = s"${catalogAndNamespace}table" - withTable(t1) { - sql(s"CREATE TABLE $t1 (id bigint, data string) USING $v2Format " + - "PARTITIONED BY (bucket(4, id), id)") - sql(s"INSERT INTO $t1 VALUES (1, 'a'), (2, 'b'), (3, 'c')") - - val sqlQuery = spark.sql(s"SELECT id, data, index, _partition FROM $t1 WHERE id > 1") - val dfQuery = spark.table(t1).where("id > 1").select("id", "data", "index", "_partition") - - Seq(sqlQuery, dfQuery).foreach { query => - checkAnswer(query, Seq(Row(2, "b", 0, "0/2"), Row(3, "c", 0, "1/3"))) - } - } - } - - test("SPARK-34923: propagate metadata columns through Sort") { - val t1 = s"${catalogAndNamespace}table" - withTable(t1) { - sql(s"CREATE TABLE $t1 (id bigint, data string) USING $v2Format " + - "PARTITIONED BY (bucket(4, id), id)") - sql(s"INSERT INTO $t1 VALUES (1, 'a'), (2, 'b'), (3, 'c')") - - val sqlQuery = spark.sql(s"SELECT id, data, index, _partition FROM $t1 ORDER BY id") - val dfQuery = spark.table(t1).orderBy("id").select("id", "data", "index", "_partition") - - Seq(sqlQuery, dfQuery).foreach { query => - checkAnswer(query, Seq(Row(1, "a", 0, "3/1"), Row(2, "b", 0, "0/2"), Row(3, "c", 0, "1/3"))) - } - } - } - - test("SPARK-34923: propagate metadata columns through RepartitionBy") { - val t1 = s"${catalogAndNamespace}table" - withTable(t1) { - sql(s"CREATE TABLE $t1 (id bigint, data string) USING $v2Format " + - "PARTITIONED BY (bucket(4, id), id)") - sql(s"INSERT INTO $t1 VALUES (1, 'a'), (2, 'b'), (3, 'c')") - - val sqlQuery = spark.sql( - s"SELECT /*+ REPARTITION_BY_RANGE(3, id) */ id, data, index, _partition FROM $t1") - val tbl = spark.table(t1) - val dfQuery = tbl.repartitionByRange(3, tbl.col("id")) - .select("id", "data", "index", "_partition") - - Seq(sqlQuery, dfQuery).foreach { query => - checkAnswer(query, Seq(Row(1, "a", 0, "3/1"), Row(2, "b", 0, "0/2"), Row(3, "c", 0, "1/3"))) - } - } - } - - test("SPARK-34923: propagate metadata columns through SubqueryAlias") { - val t1 = s"${catalogAndNamespace}table" - val sbq = "sbq" - withTable(t1) { - sql(s"CREATE TABLE $t1 (id bigint, data string) USING $v2Format " + - "PARTITIONED BY (bucket(4, id), id)") - sql(s"INSERT INTO $t1 VALUES (1, 'a'), (2, 'b'), (3, 'c')") - - val sqlQuery = spark.sql( - s"SELECT $sbq.id, $sbq.data, $sbq.index, $sbq._partition FROM $t1 as $sbq") - val dfQuery = spark.table(t1).as(sbq).select( - s"$sbq.id", s"$sbq.data", s"$sbq.index", s"$sbq._partition") - - Seq(sqlQuery, dfQuery).foreach { query => - checkAnswer(query, Seq(Row(1, "a", 0, "3/1"), Row(2, "b", 0, "0/2"), Row(3, "c", 0, "1/3"))) - } - } - } - test("OPTIMIZE is not supported for regular tables") { val t = "testcat.ns1.ns2.tbl" withTable(t) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/MetadataColumnSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/MetadataColumnSuite.scala new file mode 100644 index 0000000000000..95b9c4f72356a --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/MetadataColumnSuite.scala @@ -0,0 +1,219 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector + +import org.apache.spark.sql.{AnalysisException, Row} +import org.apache.spark.sql.functions.struct + +class MetadataColumnSuite extends DatasourceV2SQLBase { + import testImplicits._ + + private val tbl = "testcat.t" + + private def prepareTable(): Unit = { + sql(s"CREATE TABLE $tbl (id bigint, data string) PARTITIONED BY (bucket(4, id), id)") + sql(s"INSERT INTO $tbl VALUES (1, 'a'), (2, 'b'), (3, 'c')") + } + + test("SPARK-31255: Project a metadata column") { + withTable(tbl) { + prepareTable() + val sqlQuery = sql(s"SELECT id, data, index, _partition FROM $tbl") + val dfQuery = spark.table(tbl).select("id", "data", "index", "_partition") + + Seq(sqlQuery, dfQuery).foreach { query => + checkAnswer(query, Seq(Row(1, "a", 0, "3/1"), Row(2, "b", 0, "0/2"), Row(3, "c", 0, "1/3"))) + } + } + } + + test("SPARK-31255: Projects data column when metadata column has the same name") { + withTable(tbl) { + sql(s"CREATE TABLE $tbl (index bigint, data string) PARTITIONED BY (bucket(4, index), index)") + sql(s"INSERT INTO $tbl VALUES (3, 'c'), (2, 'b'), (1, 'a')") + + val sqlQuery = sql(s"SELECT index, data, _partition FROM $tbl") + val dfQuery = spark.table(tbl).select("index", "data", "_partition") + + Seq(sqlQuery, dfQuery).foreach { query => + checkAnswer(query, Seq(Row(3, "c", "1/3"), Row(2, "b", "0/2"), Row(1, "a", "3/1"))) + } + } + } + + test("SPARK-31255: * expansion does not include metadata columns") { + withTable(tbl) { + prepareTable() + val sqlQuery = sql(s"SELECT * FROM $tbl") + val dfQuery = spark.table(tbl) + + Seq(sqlQuery, dfQuery).foreach { query => + checkAnswer(query, Seq(Row(1, "a"), Row(2, "b"), Row(3, "c"))) + } + } + } + + test("SPARK-31255: metadata column should only be produced when necessary") { + withTable(tbl) { + prepareTable() + val sqlQuery = sql(s"SELECT * FROM $tbl WHERE index = 0") + val dfQuery = spark.table(tbl).filter("index = 0") + + Seq(sqlQuery, dfQuery).foreach { query => + assert(query.schema.fieldNames.toSeq == Seq("id", "data")) + } + } + } + + test("SPARK-34547: metadata columns are resolved last") { + withTable(tbl) { + prepareTable() + withTempView("v") { + sql(s"CREATE TEMPORARY VIEW v AS SELECT * FROM " + + s"VALUES (1, -1), (2, -2), (3, -3) AS v(id, index)") + + val sqlQuery = sql(s"SELECT $tbl.id, v.id, data, index, $tbl.index, v.index " + + s"FROM $tbl JOIN v WHERE $tbl.id = v.id") + val tableDf = spark.table(tbl) + val viewDf = spark.table("v") + val dfQuery = tableDf.join(viewDf, tableDf.col("id") === viewDf.col("id")) + .select(s"$tbl.id", "v.id", "data", "index", s"$tbl.index", "v.index") + + Seq(sqlQuery, dfQuery).foreach { query => + checkAnswer(query, + Seq( + Row(1, 1, "a", -1, 0, -1), + Row(2, 2, "b", -2, 0, -2), + Row(3, 3, "c", -3, 0, -3) + ) + ) + } + } + } + } + + test("SPARK-34555: Resolve DataFrame metadata column") { + withTable(tbl) { + prepareTable() + val table = spark.table(tbl) + val dfQuery = table.select( + table.col("id"), + table.col("data"), + table.col("index"), + table.col("_partition") + ) + + checkAnswer( + dfQuery, + Seq(Row(1, "a", 0, "3/1"), Row(2, "b", 0, "0/2"), Row(3, "c", 0, "1/3")) + ) + } + } + + test("SPARK-34923: propagate metadata columns through Project") { + withTable(tbl) { + prepareTable() + checkAnswer( + spark.table(tbl).select("id", "data").select("index", "_partition"), + Seq(Row(0, "3/1"), Row(0, "0/2"), Row(0, "1/3")) + ) + } + } + + test("SPARK-34923: do not propagate metadata columns through View") { + val view = "view" + withTable(tbl) { + withTempView(view) { + prepareTable() + sql(s"CACHE TABLE $view AS SELECT * FROM $tbl") + assertThrows[AnalysisException] { + sql(s"SELECT index, _partition FROM $view") + } + } + } + } + + test("SPARK-34923: propagate metadata columns through Filter") { + withTable(tbl) { + prepareTable() + val sqlQuery = sql(s"SELECT id, data, index, _partition FROM $tbl WHERE id > 1") + val dfQuery = spark.table(tbl).where("id > 1").select("id", "data", "index", "_partition") + + Seq(sqlQuery, dfQuery).foreach { query => + checkAnswer(query, Seq(Row(2, "b", 0, "0/2"), Row(3, "c", 0, "1/3"))) + } + } + } + + test("SPARK-34923: propagate metadata columns through Sort") { + withTable(tbl) { + prepareTable() + val sqlQuery = sql(s"SELECT id, data, index, _partition FROM $tbl ORDER BY id") + val dfQuery = spark.table(tbl).orderBy("id").select("id", "data", "index", "_partition") + + Seq(sqlQuery, dfQuery).foreach { query => + checkAnswer(query, Seq(Row(1, "a", 0, "3/1"), Row(2, "b", 0, "0/2"), Row(3, "c", 0, "1/3"))) + } + } + } + + test("SPARK-34923: propagate metadata columns through RepartitionBy") { + withTable(tbl) { + prepareTable() + val sqlQuery = sql( + s"SELECT /*+ REPARTITION_BY_RANGE(3, id) */ id, data, index, _partition FROM $tbl") + val dfQuery = spark.table(tbl).repartitionByRange(3, $"id") + .select("id", "data", "index", "_partition") + + Seq(sqlQuery, dfQuery).foreach { query => + checkAnswer(query, Seq(Row(1, "a", 0, "3/1"), Row(2, "b", 0, "0/2"), Row(3, "c", 0, "1/3"))) + } + } + } + + test("SPARK-34923: propagate metadata columns through SubqueryAlias if child is leaf node") { + val sbq = "sbq" + withTable(tbl) { + prepareTable() + val sqlQuery = sql( + s"SELECT $sbq.id, $sbq.data, $sbq.index, $sbq._partition FROM $tbl $sbq") + val dfQuery = spark.table(tbl).as(sbq).select( + s"$sbq.id", s"$sbq.data", s"$sbq.index", s"$sbq._partition") + + Seq(sqlQuery, dfQuery).foreach { query => + checkAnswer(query, Seq(Row(1, "a", 0, "3/1"), Row(2, "b", 0, "0/2"), Row(3, "c", 0, "1/3"))) + } + + assertThrows[AnalysisException] { + sql(s"SELECT $sbq.index FROM (SELECT id FROM $tbl) $sbq") + } + assertThrows[AnalysisException] { + spark.table(tbl).select($"id").as(sbq).select(s"$sbq.index") + } + } + } + + test("SPARK-40149: select outer join metadata columns with DataFrame API") { + val df1 = Seq(1 -> "a").toDF("k", "v").as("left") + val df2 = Seq(1 -> "b").toDF("k", "v").as("right") + val dfQuery = df1.join(df2, Seq("k"), "outer") + .withColumn("left_all", struct($"left.*")) + .withColumn("right_all", struct($"right.*")) + checkAnswer(dfQuery, Row(1, "a", "b", Row(1, "a"), Row(1, "b"))) + } +}