From 89443ab1118b0e07acd639609094961f783b01e1 Mon Sep 17 00:00:00 2001 From: "yi.wu" Date: Wed, 20 Jan 2021 13:36:14 +0000 Subject: [PATCH] [SPARK-34178][SQL] Copy tags for the new node created by MultiInstanceRelation.newInstance ### What changes were proposed in this pull request? Call `copyTagsFrom` for the new node created by `MultiInstanceRelation.newInstance()`. ### Why are the changes needed? ```scala val df = spark.range(2) df.join(df, df("id") <=> df("id")).show() ``` For this query, it's supposed to be non-ambiguous join by the rule `DetectAmbiguousSelfJoin` because of the same attribute reference in the condition: https://github.com/apache/spark/blob/537a49fc0966b0b289b67ac9c6ea20093165b0da/sql/core/src/main/scala/org/apache/spark/sql/execution/analysis/DetectAmbiguousSelfJoin.scala#L125 However, `DetectAmbiguousSelfJoin` can not apply this prediction due to the right side plan doesn't contain the dataset_id TreeNodeTag, which is missing after `MultiInstanceRelation.newInstance`. That's why we should preserve the tags info for the copied node. Fortunately, the query is still considered as non-ambiguous join because `DetectAmbiguousSelfJoin` only checks the left side plan and the reference is the same as the left side plan. However, this's not the expected behavior but only a coincidence. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Updated a unit test Closes #31260 from Ngone51/fix-missing-tags. Authored-by: yi.wu Signed-off-by: Wenchen Fan (cherry picked from commit f4989772229e2ba35f1d005727b7d4d9f1369895) Signed-off-by: Wenchen Fan --- .../spark/sql/catalyst/analysis/Analyzer.scala | 6 +++++- .../apache/spark/sql/catalyst/trees/TreeNode.scala | 2 +- .../org/apache/spark/sql/DataFrameJoinSuite.scala | 12 +++++++++++- 3 files changed, 17 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index fbe60412b2f39..2dbabfcffebca 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1012,7 +1012,10 @@ class Analyzer( } val key = catalog.name +: ident.namespace :+ ident.name AnalysisContext.get.relationCache.get(key).map(_.transform { - case multi: MultiInstanceRelation => multi.newInstance() + case multi: MultiInstanceRelation => + val newRelation = multi.newInstance() + newRelation.copyTagsFrom(multi) + newRelation }).orElse { loaded.foreach(AnalysisContext.get.relationCache.update(key, _)) loaded @@ -1164,6 +1167,7 @@ class Analyzer( case oldVersion: MultiInstanceRelation if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty => val newVersion = oldVersion.newInstance() + newVersion.copyTagsFrom(oldVersion) Seq((oldVersion, newVersion)) case oldVersion: SerializeFromObject diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index 4dc3627cd6a50..68bbd1b7dac5e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -91,7 +91,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { */ private val tags: mutable.Map[TreeNodeTag[_], Any] = mutable.Map.empty - protected def copyTagsFrom(other: BaseType): Unit = { + def copyTagsFrom(other: BaseType): Unit = { // SPARK-32753: it only makes sense to copy tags to a new node // but it's too expensive to detect other cases likes node removal // so we make a compromise here to copy tags to node with no tags diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala index a49f95f1ed134..0cf81b4867be6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{BROADCAST, Filter, HintInfo, import org.apache.spark.sql.connector.catalog.CatalogManager import org.apache.spark.sql.execution.FileSourceScanExec import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper +import org.apache.spark.sql.execution.analysis.DetectAmbiguousSelfJoin.LogicalPlanWithDatasetId import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.execution.exchange.BroadcastExchangeExec import org.apache.spark.sql.execution.joins.BroadcastHashJoinExec @@ -264,7 +265,16 @@ class DataFrameJoinSuite extends QueryTest withSQLConf(SQLConf.CROSS_JOINS_ENABLED.key -> "false") { val df = spark.range(2) // this throws an exception before the fix - df.join(df, df("id") <=> df("id")).queryExecution.optimizedPlan + val plan = df.join(df, df("id") <=> df("id")).queryExecution.optimizedPlan + + plan match { + // SPARK-34178: we can't match the plan before the fix due to + // the right side plan doesn't contains dataset id. + case Join( + LogicalPlanWithDatasetId(_, leftId), + LogicalPlanWithDatasetId(_, rightId), _, _, _) => + assert(leftId === rightId) + } } }