diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala index 2879aaca72151..a5e681535cb82 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala @@ -41,7 +41,7 @@ import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec._ import org.apache.spark.sql.execution.bucketing.{CoalesceBucketsInJoin, DisableUnnecessaryBucketedScan} -import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec +import org.apache.spark.sql.execution.columnar.InMemoryTableScanLike import org.apache.spark.sql.execution.exchange._ import org.apache.spark.sql.execution.ui.{SparkListenerSQLAdaptiveExecutionUpdate, SparkListenerSQLAdaptiveSQLMetricUpdates, SQLPlanMetric} import org.apache.spark.sql.internal.SQLConf @@ -255,7 +255,7 @@ case class AdaptiveSparkPlanExec( // and display SQL metrics correctly. // 2. If the `QueryExecution` does not match the current execution ID, it means the execution // ID belongs to another (parent) query, and we should not call update UI in this query. - // e.g., a nested `AdaptiveSparkPlanExec` in `InMemoryTableScanExec`. + // e.g., a nested `AdaptiveSparkPlanExec` in `InMemoryTableScanLike`. // // That means only the root `AdaptiveSparkPlanExec` of the main query that triggers this // query execution need to do a plan update for the UI. @@ -558,9 +558,9 @@ case class AdaptiveSparkPlanExec( } } - case i: InMemoryTableScanExec => - // There is no reuse for `InMemoryTableScanExec`, which is different from `Exchange`. If we - // hit it the first time, we should always create a new query stage. + case i: InMemoryTableScanLike => + // There is no reuse for `InMemoryTableScanLike`, which is different from `Exchange`. + // If we hit it the first time, we should always create a new query stage. val newStage = newQueryStage(i) CreateStageResult( newPlan = newStage, @@ -605,12 +605,12 @@ case class AdaptiveSparkPlanExec( } BroadcastQueryStageExec(currentStageId, newPlan, e.canonicalized) } - case i: InMemoryTableScanExec => + case i: InMemoryTableScanLike => // Apply `queryStageOptimizerRules` so that we can reuse subquery. - // No need to apply `postStageCreationRules` for `InMemoryTableScanExec` + // No need to apply `postStageCreationRules` for `InMemoryTableScanLike` // as it's a leaf node. val newPlan = optimizeQueryStage(i, isFinalStage = false) - if (!newPlan.isInstanceOf[InMemoryTableScanExec]) { + if (!newPlan.isInstanceOf[InMemoryTableScanLike]) { throw SparkException.internalError( "Custom AQE rules cannot transform table scan node to something else.") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageExec.scala index 88954d6f822d0..7db9271aee0c4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageExec.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.plans.logical.Statistics import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.columnar.CachedBatch import org.apache.spark.sql.execution._ -import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec +import org.apache.spark.sql.execution.columnar.InMemoryTableScanLike import org.apache.spark.sql.execution.exchange._ import org.apache.spark.sql.vectorized.ColumnarBatch @@ -261,7 +261,7 @@ case class BroadcastQueryStageExec( } /** - * A table cache query stage whose child is a [[InMemoryTableScanExec]]. + * A table cache query stage whose child is a [[InMemoryTableScanLike]]. * * @param id the query stage id. * @param plan the underlying plan. @@ -271,7 +271,7 @@ case class TableCacheQueryStageExec( override val plan: SparkPlan) extends QueryStageExec { @transient val inMemoryTableScan = plan match { - case i: InMemoryTableScanExec => i + case i: InMemoryTableScanLike => i case _ => throw SparkException.internalError(s"wrong plan for table cache stage:\n ${plan.treeString}") } @@ -294,5 +294,5 @@ case class TableCacheQueryStageExec( override protected def doMaterialize(): Future[Any] = future - override def getRuntimeStatistics: Statistics = inMemoryTableScan.relation.computeStats() + override def getRuntimeStatistics: Statistics = inMemoryTableScan.runtimeStatistics } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala index 064a46369055f..cfcfd282e5480 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala @@ -21,6 +21,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.QueryPlan +import org.apache.spark.sql.catalyst.plans.logical.Statistics import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.columnar.CachedBatch import org.apache.spark.sql.execution.{LeafExecNode, SparkPlan, WholeStageCodegenExec} @@ -28,11 +29,32 @@ import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.vectorized.ColumnarBatch +/** + * Common trait for all InMemoryTableScans implementations to facilitate pattern matching. + */ +trait InMemoryTableScanLike extends LeafExecNode { + + /** + * Returns whether the cache buffer is loaded + */ + def isMaterialized: Boolean + + /** + * Returns the actual cached RDD without filters and serialization of row/columnar. + */ + def baseCacheRDD(): RDD[CachedBatch] + + /** + * Returns the runtime statistics after materialization. + */ + def runtimeStatistics: Statistics +} + case class InMemoryTableScanExec( attributes: Seq[Attribute], predicates: Seq[Expression], @transient relation: InMemoryRelation) - extends LeafExecNode { + extends InMemoryTableScanLike { override lazy val metrics = Map( "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) @@ -176,13 +198,18 @@ case class InMemoryTableScanExec( columnarInputRDD } - def isMaterialized: Boolean = relation.cacheBuilder.isCachedColumnBuffersLoaded + override def isMaterialized: Boolean = relation.cacheBuilder.isCachedColumnBuffersLoaded /** * This method is only used by AQE which executes the actually cached RDD that without filter and * serialization of row/columnar. */ - def baseCacheRDD(): RDD[CachedBatch] = { + override def baseCacheRDD(): RDD[CachedBatch] = { relation.cacheBuilder.cachedColumnBuffers } + + /** + * Returns the runtime statistics after shuffle materialization. + */ + override def runtimeStatistics: Statistics = relation.computeStats() } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala index f528c5584fee9..39f6aa8505b32 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight} import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan} import org.apache.spark.sql.execution.{CollectLimitExec, ColumnarToRowExec, LocalTableScanExec, PartialReducerPartitionSpec, QueryExecution, ReusedSubqueryExec, ShuffledRowRDD, SortExec, SparkPlan, SparkPlanInfo, UnionExec} import org.apache.spark.sql.execution.aggregate.BaseAggregateExec -import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec +import org.apache.spark.sql.execution.columnar.{InMemoryTableScanExec, InMemoryTableScanLike} import org.apache.spark.sql.execution.command.DataWritingCommandExec import org.apache.spark.sql.execution.datasources.noop.NoopDataSource import org.apache.spark.sql.execution.datasources.v2.V2TableWriteExec @@ -2763,7 +2763,7 @@ class AdaptiveQueryExecSuite case s: SortExec => s }.size == (if (firstAccess) 2 else 0)) assert(collect(initialExecutedPlan) { - case i: InMemoryTableScanExec => i + case i: InMemoryTableScanLike => i }.head.isMaterialized != firstAccess) df.collect() @@ -2775,7 +2775,7 @@ class AdaptiveQueryExecSuite case s: SortExec => s }.isEmpty) assert(collect(initialExecutedPlan) { - case i: InMemoryTableScanExec => i + case i: InMemoryTableScanLike => i }.head.isMaterialized) }