Skip to content

Commit

Permalink
[SPARK-34178][SQL] Copy tags for the new node created by MultiInstanc…
Browse files Browse the repository at this point in the history
…eRelation.newInstance

### What changes were proposed in this pull request?

Call `copyTagsFrom` for the new node created by `MultiInstanceRelation.newInstance()`.

### Why are the changes needed?

```scala
val df = spark.range(2)
df.join(df, df("id") <=> df("id")).show()
```

For this query, it's supposed to be non-ambiguous join by the rule `DetectAmbiguousSelfJoin` because of the same attribute reference in the condition:

https://github.com/apache/spark/blob/537a49fc0966b0b289b67ac9c6ea20093165b0da/sql/core/src/main/scala/org/apache/spark/sql/execution/analysis/DetectAmbiguousSelfJoin.scala#L125

However, `DetectAmbiguousSelfJoin` can not apply this prediction due to the right side plan doesn't contain the dataset_id TreeNodeTag, which is missing after `MultiInstanceRelation.newInstance`. That's why we should preserve the tags info for the copied node.

Fortunately, the query is still considered as non-ambiguous join because `DetectAmbiguousSelfJoin` only checks the left side plan and the reference is the same as the left side plan. However, this's not the expected behavior but only a coincidence.

### Does this PR introduce _any_ user-facing change?

No.

### How was this patch tested?

Updated a unit test

Closes apache#31260 from Ngone51/fix-missing-tags.

Authored-by: yi.wu <yi.wu@databricks.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
(cherry picked from commit f498977)
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
Ngone51 authored and cloud-fan committed Jan 20, 2021
1 parent b5b1da9 commit 89443ab
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1012,7 +1012,10 @@ class Analyzer(
}
val key = catalog.name +: ident.namespace :+ ident.name
AnalysisContext.get.relationCache.get(key).map(_.transform {
case multi: MultiInstanceRelation => multi.newInstance()
case multi: MultiInstanceRelation =>
val newRelation = multi.newInstance()
newRelation.copyTagsFrom(multi)
newRelation
}).orElse {
loaded.foreach(AnalysisContext.get.relationCache.update(key, _))
loaded
Expand Down Expand Up @@ -1164,6 +1167,7 @@ class Analyzer(
case oldVersion: MultiInstanceRelation
if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty =>
val newVersion = oldVersion.newInstance()
newVersion.copyTagsFrom(oldVersion)
Seq((oldVersion, newVersion))

case oldVersion: SerializeFromObject
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
*/
private val tags: mutable.Map[TreeNodeTag[_], Any] = mutable.Map.empty

protected def copyTagsFrom(other: BaseType): Unit = {
def copyTagsFrom(other: BaseType): Unit = {
// SPARK-32753: it only makes sense to copy tags to a new node
// but it's too expensive to detect other cases likes node removal
// so we make a compromise here to copy tags to node with no tags
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{BROADCAST, Filter, HintInfo,
import org.apache.spark.sql.connector.catalog.CatalogManager
import org.apache.spark.sql.execution.FileSourceScanExec
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.execution.analysis.DetectAmbiguousSelfJoin.LogicalPlanWithDatasetId
import org.apache.spark.sql.execution.datasources.LogicalRelation
import org.apache.spark.sql.execution.exchange.BroadcastExchangeExec
import org.apache.spark.sql.execution.joins.BroadcastHashJoinExec
Expand Down Expand Up @@ -264,7 +265,16 @@ class DataFrameJoinSuite extends QueryTest
withSQLConf(SQLConf.CROSS_JOINS_ENABLED.key -> "false") {
val df = spark.range(2)
// this throws an exception before the fix
df.join(df, df("id") <=> df("id")).queryExecution.optimizedPlan
val plan = df.join(df, df("id") <=> df("id")).queryExecution.optimizedPlan

plan match {
// SPARK-34178: we can't match the plan before the fix due to
// the right side plan doesn't contains dataset id.
case Join(
LogicalPlanWithDatasetId(_, leftId),
LogicalPlanWithDatasetId(_, rightId), _, _, _) =>
assert(leftId === rightId)
}
}
}

Expand Down

0 comments on commit 89443ab

Please sign in to comment.