diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index fbe60412b2f39..2dbabfcffebca 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1012,7 +1012,10 @@ class Analyzer( } val key = catalog.name +: ident.namespace :+ ident.name AnalysisContext.get.relationCache.get(key).map(_.transform { - case multi: MultiInstanceRelation => multi.newInstance() + case multi: MultiInstanceRelation => + val newRelation = multi.newInstance() + newRelation.copyTagsFrom(multi) + newRelation }).orElse { loaded.foreach(AnalysisContext.get.relationCache.update(key, _)) loaded @@ -1164,6 +1167,7 @@ class Analyzer( case oldVersion: MultiInstanceRelation if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty => val newVersion = oldVersion.newInstance() + newVersion.copyTagsFrom(oldVersion) Seq((oldVersion, newVersion)) case oldVersion: SerializeFromObject diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index 4dc3627cd6a50..68bbd1b7dac5e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -91,7 +91,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { */ private val tags: mutable.Map[TreeNodeTag[_], Any] = mutable.Map.empty - protected def copyTagsFrom(other: BaseType): Unit = { + def copyTagsFrom(other: BaseType): Unit = { // SPARK-32753: it only makes sense to copy tags to a new node // but it's too expensive to detect other cases likes node removal // so we make a compromise here to copy tags to node with no tags diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala index a49f95f1ed134..0cf81b4867be6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{BROADCAST, Filter, HintInfo, import org.apache.spark.sql.connector.catalog.CatalogManager import org.apache.spark.sql.execution.FileSourceScanExec import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper +import org.apache.spark.sql.execution.analysis.DetectAmbiguousSelfJoin.LogicalPlanWithDatasetId import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.execution.exchange.BroadcastExchangeExec import org.apache.spark.sql.execution.joins.BroadcastHashJoinExec @@ -264,7 +265,16 @@ class DataFrameJoinSuite extends QueryTest withSQLConf(SQLConf.CROSS_JOINS_ENABLED.key -> "false") { val df = spark.range(2) // this throws an exception before the fix - df.join(df, df("id") <=> df("id")).queryExecution.optimizedPlan + val plan = df.join(df, df("id") <=> df("id")).queryExecution.optimizedPlan + + plan match { + // SPARK-34178: we can't match the plan before the fix due to + // the right side plan doesn't contains dataset id. + case Join( + LogicalPlanWithDatasetId(_, leftId), + LogicalPlanWithDatasetId(_, rightId), _, _, _) => + assert(leftId === rightId) + } } }