Skip to content

Commit

Permalink
[SPARK-40149][SQL][3.2] Propagate metadata columns through Project
Browse files Browse the repository at this point in the history
backport apache#37758 to 3.2

This PR fixes a regression caused by apache#32017 .

In apache#32017 , we tried to be more conservative and decided to not propagate metadata columns in certain operators, including `Project`. However, the decision was made only considering SQL API, not DataFrame API. In fact, it's very common to chain `Project` operators in DataFrame, e.g. `df.withColumn(...).withColumn(...)...`, and it's very inconvenient if metadata columns are not propagated through `Project`.

This PR makes 2 changes:
1. Project should propagate metadata columns
2. SubqueryAlias should only propagate metadata columns if the child is a leaf node or also a SubqueryAlias

The second change is needed to still forbid weird queries like `SELECT m from (SELECT a from t)`, which is the main motivation of apache#32017 .

After propagating metadata columns, a problem from apache#31666 is exposed: the natural join metadata columns may confuse the analyzer and lead to wrong analyzed plan. For example, `SELECT t1.value FROM t1 LEFT JOIN t2 USING (key) ORDER BY key`, how shall we resolve `ORDER BY key`? It should be resolved to `t1.key` via the rule `ResolveMissingReferences`, which is in the output of the left join. However, if `Project` can propagate metadata columns, `ORDER BY key` will be resolved to `t2.key`.

To solve this problem, this PR only allows qualified access for metadata columns of natural join. This has no breaking change, as people can only do qualified access for natural join metadata columns before, in the `Project` right after `Join`. This actually enables more use cases, as people can now access natural join metadata columns in ORDER BY. I've added a test for it.

fix a regression

For SQL API, there is no change, as a `SubqueryAlias` always comes with a `Project` or `Aggregate`, so we still don't propagate metadata columns through a SELECT group.

For DataFrame API, the behavior becomes more lenient. The only breaking case is an operator that can propagate metadata columns then follows a `SubqueryAlias`, e.g. `df.filter(...).as("t").select("t.metadata_col")`. But this is a weird use case and I don't think we should support it at the first place.

new tests

Closes apache#37818 from cloud-fan/backport.

Authored-by: Wenchen Fan <wenchen@databricks.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
(cherry picked from commit d566017)
  • Loading branch information
cloud-fan authored and huaxingao committed Oct 30, 2022
1 parent 28d196c commit a19762d
Show file tree
Hide file tree
Showing 9 changed files with 264 additions and 239 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1048,9 +1048,11 @@ class Analyzer(override val catalogManager: CatalogManager)
private def addMetadataCol(plan: LogicalPlan): LogicalPlan = plan match {
case r: DataSourceV2Relation => r.withMetadataColumns()
case p: Project =>
p.copy(
val newProj = p.copy(
projectList = p.metadataOutput ++ p.projectList,
child = addMetadataCol(p.child))
newProj.copyTagsFrom(p)
newProj
case _ => plan.withNewChildren(plan.children.map(addMetadataCol))
}
}
Expand Down Expand Up @@ -3532,8 +3534,8 @@ class Analyzer(override val catalogManager: CatalogManager)
val project = Project(projectList, Join(left, right, joinType, newCondition, hint))
project.setTagValue(
Project.hiddenOutputTag,
hiddenList.map(_.markAsSupportsQualifiedStar()) ++
project.child.metadataOutput.filter(_.supportsQualifiedStar))
hiddenList.map(_.markAsQualifiedAccessOnly()) ++
project.child.metadataOutput.filter(_.qualifiedAccessOnly))
project
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,7 @@ case class UnresolvedStar(target: Option[Seq[String]]) extends Star with Unevalu
if (target.isEmpty) return input.output

// If there is a table specified, use hidden input attributes as well
val hiddenOutput = input.metadataOutput.filter(_.supportsQualifiedStar)
val hiddenOutput = input.metadataOutput.filter(_.qualifiedAccessOnly)
val expandedAttributes = (hiddenOutput ++ input.output).filter(
matchedQualifier(_, target.get, resolver))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import com.google.common.collect.Maps

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.analysis.{Resolver, UnresolvedAttribute}
import org.apache.spark.sql.catalyst.util.MetadataColumnHelper
import org.apache.spark.sql.types.{StructField, StructType}

/**
Expand Down Expand Up @@ -265,7 +266,7 @@ package object expressions {
case (Seq(), _) =>
val name = nameParts.head
val attributes = collectMatches(name, direct.get(name.toLowerCase(Locale.ROOT)))
(attributes, nameParts.tail)
(attributes.filterNot(_.qualifiedAccessOnly), nameParts.tail)
case _ => matches
}
}
Expand Down Expand Up @@ -314,10 +315,12 @@ package object expressions {
var i = nameParts.length - 1
while (i >= 0 && candidates.isEmpty) {
val name = nameParts(i)
candidates = collectMatches(
name,
nameParts.take(i),
direct.get(name.toLowerCase(Locale.ROOT)))
val attrsToLookup = if (i == 0) {
direct.get(name.toLowerCase(Locale.ROOT)).map(_.filterNot(_.qualifiedAccessOnly))
} else {
direct.get(name.toLowerCase(Locale.ROOT))
}
candidates = collectMatches(name, nameParts.take(i), attrsToLookup)
if (candidates.nonEmpty) {
nestedFields = nameParts.takeRight(nameParts.length - i - 1)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ case class Project(projectList: Seq[NamedExpression], child: LogicalPlan)
getAllValidConstraints(projectList)

override def metadataOutput: Seq[Attribute] =
getTagValue(Project.hiddenOutputTag).getOrElse(Nil)
getTagValue(Project.hiddenOutputTag).getOrElse(child.metadataOutput)

override protected def withNewChildInternal(newChild: LogicalPlan): Project =
copy(child = newChild)
Expand Down Expand Up @@ -1307,9 +1307,14 @@ case class SubqueryAlias(
}

override def metadataOutput: Seq[Attribute] = {
val qualifierList = identifier.qualifier :+ alias
val nonHiddenMetadataOutput = child.metadataOutput.filter(!_.supportsQualifiedStar)
nonHiddenMetadataOutput.map(_.withQualifier(qualifierList))
// Propagate metadata columns from leaf nodes through a chain of `SubqueryAlias`.
if (child.isInstanceOf[LeafNode] || child.isInstanceOf[SubqueryAlias]) {
val qualifierList = identifier.qualifier :+ alias
val nonHiddenMetadataOutput = child.metadataOutput.filter(!_.qualifiedAccessOnly)
nonHiddenMetadataOutput.map(_.withQualifier(qualifierList))
} else {
Nil
}
}

override def maxRows: Option[Long] = child.maxRows
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -206,22 +206,23 @@ package object util extends Logging {

implicit class MetadataColumnHelper(attr: Attribute) {
/**
* If set, this metadata column is a candidate during qualified star expansions.
* If set, this metadata column can only be accessed with qualifiers, e.g. `qualifiers.col` or
* `qualifiers.*`. If not set, metadata columns cannot be accessed via star.
*/
val SUPPORTS_QUALIFIED_STAR = "__supports_qualified_star"
val QUALIFIED_ACCESS_ONLY = "__qualified_access_only"

def isMetadataCol: Boolean = attr.metadata.contains(METADATA_COL_ATTR_KEY) &&
attr.metadata.getBoolean(METADATA_COL_ATTR_KEY)

def supportsQualifiedStar: Boolean = attr.isMetadataCol &&
attr.metadata.contains(SUPPORTS_QUALIFIED_STAR) &&
attr.metadata.getBoolean(SUPPORTS_QUALIFIED_STAR)
def qualifiedAccessOnly: Boolean = attr.isMetadataCol &&
attr.metadata.contains(QUALIFIED_ACCESS_ONLY) &&
attr.metadata.getBoolean(QUALIFIED_ACCESS_ONLY)

def markAsSupportsQualifiedStar(): Attribute = attr.withMetadata(
def markAsQualifiedAccessOnly(): Attribute = attr.withMetadata(
new MetadataBuilder()
.withMetadata(attr.metadata)
.putBoolean(METADATA_COL_ATTR_KEY, true)
.putBoolean(SUPPORTS_QUALIFIED_STAR, true)
.putBoolean(QUALIFIED_ACCESS_ONLY, true)
.build()
)
}
Expand Down
2 changes: 2 additions & 0 deletions sql/core/src/test/resources/sql-tests/inputs/using-join.sql
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ SELECT nt1.*, nt2.* FROM nt1 left outer join nt2 using (k);

SELECT nt1.k, nt2.k FROM nt1 left outer join nt2 using (k);

SELECT nt1.k, nt2.k FROM nt1 left outer join nt2 using (k) ORDER BY nt2.k;

SELECT k, nt1.k FROM nt1 left outer join nt2 using (k);

SELECT k, nt2.k FROM nt1 left outer join nt2 using (k);
Expand Down
11 changes: 11 additions & 0 deletions sql/core/src/test/resources/sql-tests/results/using-join.sql.out
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,17 @@ three NULL
two two


-- !query
SELECT nt1.k, nt2.k FROM nt1 left outer join nt2 using (k) ORDER BY nt2.k
-- !query schema
struct<k:string,k:string>
-- !query output
three NULL
one one
one one
two two


-- !query
SELECT k, nt1.k FROM nt1 left outer join nt2 using (k)
-- !query schema
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2524,100 +2524,6 @@ class DataSourceV2SQLSuite
}
}

test("SPARK-31255: Project a metadata column") {
val t1 = s"${catalogAndNamespace}table"
withTable(t1) {
sql(s"CREATE TABLE $t1 (id bigint, data string) USING $v2Format " +
"PARTITIONED BY (bucket(4, id), id)")
sql(s"INSERT INTO $t1 VALUES (1, 'a'), (2, 'b'), (3, 'c')")

val sqlQuery = spark.sql(s"SELECT id, data, index, _partition FROM $t1")
val dfQuery = spark.table(t1).select("id", "data", "index", "_partition")

Seq(sqlQuery, dfQuery).foreach { query =>
checkAnswer(query, Seq(Row(1, "a", 0, "3/1"), Row(2, "b", 0, "0/2"), Row(3, "c", 0, "1/3")))
}
}
}

test("SPARK-31255: Projects data column when metadata column has the same name") {
val t1 = s"${catalogAndNamespace}table"
withTable(t1) {
sql(s"CREATE TABLE $t1 (index bigint, data string) USING $v2Format " +
"PARTITIONED BY (bucket(4, index), index)")
sql(s"INSERT INTO $t1 VALUES (3, 'c'), (2, 'b'), (1, 'a')")

val sqlQuery = spark.sql(s"SELECT index, data, _partition FROM $t1")
val dfQuery = spark.table(t1).select("index", "data", "_partition")

Seq(sqlQuery, dfQuery).foreach { query =>
checkAnswer(query, Seq(Row(3, "c", "1/3"), Row(2, "b", "0/2"), Row(1, "a", "3/1")))
}
}
}

test("SPARK-31255: * expansion does not include metadata columns") {
val t1 = s"${catalogAndNamespace}table"
withTable(t1) {
sql(s"CREATE TABLE $t1 (id bigint, data string) USING $v2Format " +
"PARTITIONED BY (bucket(4, id), id)")
sql(s"INSERT INTO $t1 VALUES (3, 'c'), (2, 'b'), (1, 'a')")

val sqlQuery = spark.sql(s"SELECT * FROM $t1")
val dfQuery = spark.table(t1)

Seq(sqlQuery, dfQuery).foreach { query =>
checkAnswer(query, Seq(Row(3, "c"), Row(2, "b"), Row(1, "a")))
}
}
}

test("SPARK-31255: metadata column should only be produced when necessary") {
val t1 = s"${catalogAndNamespace}table"
withTable(t1) {
sql(s"CREATE TABLE $t1 (id bigint, data string) USING $v2Format " +
"PARTITIONED BY (bucket(4, id), id)")

val sqlQuery = spark.sql(s"SELECT * FROM $t1 WHERE index = 0")
val dfQuery = spark.table(t1).filter("index = 0")

Seq(sqlQuery, dfQuery).foreach { query =>
assert(query.schema.fieldNames.toSeq == Seq("id", "data"))
}
}
}

test("SPARK-34547: metadata columns are resolved last") {
val t1 = s"${catalogAndNamespace}tableOne"
val t2 = "t2"
withTable(t1) {
sql(s"CREATE TABLE $t1 (id bigint, data string) USING $v2Format " +
"PARTITIONED BY (bucket(4, id), id)")
sql(s"INSERT INTO $t1 VALUES (1, 'a'), (2, 'b'), (3, 'c')")
withTempView(t2) {
sql(s"CREATE TEMPORARY VIEW $t2 AS SELECT * FROM " +
s"VALUES (1, -1), (2, -2), (3, -3) AS $t2(id, index)")

val sqlQuery = spark.sql(s"SELECT $t1.id, $t2.id, data, index, $t1.index, $t2.index FROM " +
s"$t1 JOIN $t2 WHERE $t1.id = $t2.id")
val t1Table = spark.table(t1)
val t2Table = spark.table(t2)
val dfQuery = t1Table.join(t2Table, t1Table.col("id") === t2Table.col("id"))
.select(s"$t1.id", s"$t2.id", "data", "index", s"$t1.index", s"$t2.index")

Seq(sqlQuery, dfQuery).foreach { query =>
checkAnswer(query,
Seq(
Row(1, 1, "a", -1, 0, -1),
Row(2, 2, "b", -2, 0, -2),
Row(3, 3, "c", -3, 0, -3)
)
)
}
}
}
}

test("SPARK-33505: insert into partitioned table") {
val t = "testpart.ns1.ns2.tbl"
withTable(t) {
Expand Down Expand Up @@ -2702,27 +2608,6 @@ class DataSourceV2SQLSuite
}
}

test("SPARK-34555: Resolve DataFrame metadata column") {
val tbl = s"${catalogAndNamespace}table"
withTable(tbl) {
sql(s"CREATE TABLE $tbl (id bigint, data string) USING $v2Format " +
"PARTITIONED BY (bucket(4, id), id)")
sql(s"INSERT INTO $tbl VALUES (1, 'a'), (2, 'b'), (3, 'c')")
val table = spark.table(tbl)
val dfQuery = table.select(
table.col("id"),
table.col("data"),
table.col("index"),
table.col("_partition")
)

checkAnswer(
dfQuery,
Seq(Row(1, "a", 0, "3/1"), Row(2, "b", 0, "0/2"), Row(3, "c", 0, "1/3"))
)
}
}

test("SPARK-34561: drop/add columns to a dataset of `DESCRIBE TABLE`") {
val tbl = s"${catalogAndNamespace}tbl"
withTable(tbl) {
Expand Down Expand Up @@ -2785,110 +2670,7 @@ class DataSourceV2SQLSuite
}
}

test("SPARK-34923: do not propagate metadata columns through Project") {
val t1 = s"${catalogAndNamespace}table"
withTable(t1) {
sql(s"CREATE TABLE $t1 (id bigint, data string) USING $v2Format " +
"PARTITIONED BY (bucket(4, id), id)")
sql(s"INSERT INTO $t1 VALUES (1, 'a'), (2, 'b'), (3, 'c')")

assertThrows[AnalysisException] {
sql(s"SELECT index, _partition from (SELECT id, data FROM $t1)")
}
assertThrows[AnalysisException] {
spark.table(t1).select("id", "data").select("index", "_partition")
}
}
}

test("SPARK-34923: do not propagate metadata columns through View") {
val t1 = s"${catalogAndNamespace}table"
val view = "view"

withTable(t1) {
withTempView(view) {
sql(s"CREATE TABLE $t1 (id bigint, data string) USING $v2Format " +
"PARTITIONED BY (bucket(4, id), id)")
sql(s"INSERT INTO $t1 VALUES (1, 'a'), (2, 'b'), (3, 'c')")
sql(s"CACHE TABLE $view AS SELECT * FROM $t1")
assertThrows[AnalysisException] {
sql(s"SELECT index, _partition FROM $view")
}
}
}
}

test("SPARK-34923: propagate metadata columns through Filter") {
val t1 = s"${catalogAndNamespace}table"
withTable(t1) {
sql(s"CREATE TABLE $t1 (id bigint, data string) USING $v2Format " +
"PARTITIONED BY (bucket(4, id), id)")
sql(s"INSERT INTO $t1 VALUES (1, 'a'), (2, 'b'), (3, 'c')")

val sqlQuery = spark.sql(s"SELECT id, data, index, _partition FROM $t1 WHERE id > 1")
val dfQuery = spark.table(t1).where("id > 1").select("id", "data", "index", "_partition")

Seq(sqlQuery, dfQuery).foreach { query =>
checkAnswer(query, Seq(Row(2, "b", 0, "0/2"), Row(3, "c", 0, "1/3")))
}
}
}

test("SPARK-34923: propagate metadata columns through Sort") {
val t1 = s"${catalogAndNamespace}table"
withTable(t1) {
sql(s"CREATE TABLE $t1 (id bigint, data string) USING $v2Format " +
"PARTITIONED BY (bucket(4, id), id)")
sql(s"INSERT INTO $t1 VALUES (1, 'a'), (2, 'b'), (3, 'c')")

val sqlQuery = spark.sql(s"SELECT id, data, index, _partition FROM $t1 ORDER BY id")
val dfQuery = spark.table(t1).orderBy("id").select("id", "data", "index", "_partition")

Seq(sqlQuery, dfQuery).foreach { query =>
checkAnswer(query, Seq(Row(1, "a", 0, "3/1"), Row(2, "b", 0, "0/2"), Row(3, "c", 0, "1/3")))
}
}
}

test("SPARK-34923: propagate metadata columns through RepartitionBy") {
val t1 = s"${catalogAndNamespace}table"
withTable(t1) {
sql(s"CREATE TABLE $t1 (id bigint, data string) USING $v2Format " +
"PARTITIONED BY (bucket(4, id), id)")
sql(s"INSERT INTO $t1 VALUES (1, 'a'), (2, 'b'), (3, 'c')")

val sqlQuery = spark.sql(
s"SELECT /*+ REPARTITION_BY_RANGE(3, id) */ id, data, index, _partition FROM $t1")
val tbl = spark.table(t1)
val dfQuery = tbl.repartitionByRange(3, tbl.col("id"))
.select("id", "data", "index", "_partition")

Seq(sqlQuery, dfQuery).foreach { query =>
checkAnswer(query, Seq(Row(1, "a", 0, "3/1"), Row(2, "b", 0, "0/2"), Row(3, "c", 0, "1/3")))
}
}
}

test("SPARK-34923: propagate metadata columns through SubqueryAlias") {
val t1 = s"${catalogAndNamespace}table"
val sbq = "sbq"
withTable(t1) {
sql(s"CREATE TABLE $t1 (id bigint, data string) USING $v2Format " +
"PARTITIONED BY (bucket(4, id), id)")
sql(s"INSERT INTO $t1 VALUES (1, 'a'), (2, 'b'), (3, 'c')")

val sqlQuery = spark.sql(
s"SELECT $sbq.id, $sbq.data, $sbq.index, $sbq._partition FROM $t1 as $sbq")
val dfQuery = spark.table(t1).as(sbq).select(
s"$sbq.id", s"$sbq.data", s"$sbq.index", s"$sbq._partition")

Seq(sqlQuery, dfQuery).foreach { query =>
checkAnswer(query, Seq(Row(1, "a", 0, "3/1"), Row(2, "b", 0, "0/2"), Row(3, "c", 0, "1/3")))
}
}
}

test("OPTIMIZE is not supported for regular tables") {
test("OPTIMIZE is not supported for regular tables") {
val t = "testcat.ns1.ns2.tbl"
withTable(t) {
sql(s"CREATE TABLE $t (id bigint, data string, p int) USING foo PARTITIONED BY (id, p)")
Expand Down
Loading

0 comments on commit a19762d

Please sign in to comment.