Skip to content

Commit

Permalink
[SPARK-47398][SQL] Extract a trait for InMemoryTableScanExec to allow…
Browse files Browse the repository at this point in the history
… for extending functionality

### What changes were proposed in this pull request?
We are proposing to allow the users to have a custom `InMemoryTableScanExec`. To accomplish this we can follow the same convention we followed for `ShuffleExchangeLike` and `BroadcastExchangeLike`

### Why are the changes needed?
In the PR added by ulysses-you, we are wrapping `InMemoryTableScanExec` inside `TableCacheQueryStageExec`. This could potentially cause problems, especially in the RAPIDS Accelerator for Apache Spark, where we replace `InMemoryTableScanExec` with a customized version that has optimizations needed by us. This situation could lead to the loss of benefits from [SPARK-42101](https://issues.apache.org/jira/browse/SPARK-42101) or even result in Spark throwing an Exception.

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
Ran the existing tests

### Was this patch authored or co-authored using generative AI tooling?
No

Closes apache#45525 from razajafri/extract-inmem-trait.

Authored-by: Raza Jafri <rjafri@nvidia.com>
Signed-off-by: Thomas Graves <tgraves@apache.org>
  • Loading branch information
razajafri authored and tgravescs committed Mar 21, 2024
1 parent 25ecde9 commit 6a27789
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec._
import org.apache.spark.sql.execution.bucketing.{CoalesceBucketsInJoin, DisableUnnecessaryBucketedScan}
import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec
import org.apache.spark.sql.execution.columnar.InMemoryTableScanLike
import org.apache.spark.sql.execution.exchange._
import org.apache.spark.sql.execution.ui.{SparkListenerSQLAdaptiveExecutionUpdate, SparkListenerSQLAdaptiveSQLMetricUpdates, SQLPlanMetric}
import org.apache.spark.sql.internal.SQLConf
Expand Down Expand Up @@ -255,7 +255,7 @@ case class AdaptiveSparkPlanExec(
// and display SQL metrics correctly.
// 2. If the `QueryExecution` does not match the current execution ID, it means the execution
// ID belongs to another (parent) query, and we should not call update UI in this query.
// e.g., a nested `AdaptiveSparkPlanExec` in `InMemoryTableScanExec`.
// e.g., a nested `AdaptiveSparkPlanExec` in `InMemoryTableScanLike`.
//
// That means only the root `AdaptiveSparkPlanExec` of the main query that triggers this
// query execution need to do a plan update for the UI.
Expand Down Expand Up @@ -558,9 +558,9 @@ case class AdaptiveSparkPlanExec(
}
}

case i: InMemoryTableScanExec =>
// There is no reuse for `InMemoryTableScanExec`, which is different from `Exchange`. If we
// hit it the first time, we should always create a new query stage.
case i: InMemoryTableScanLike =>
// There is no reuse for `InMemoryTableScanLike`, which is different from `Exchange`.
// If we hit it the first time, we should always create a new query stage.
val newStage = newQueryStage(i)
CreateStageResult(
newPlan = newStage,
Expand Down Expand Up @@ -605,12 +605,12 @@ case class AdaptiveSparkPlanExec(
}
BroadcastQueryStageExec(currentStageId, newPlan, e.canonicalized)
}
case i: InMemoryTableScanExec =>
case i: InMemoryTableScanLike =>
// Apply `queryStageOptimizerRules` so that we can reuse subquery.
// No need to apply `postStageCreationRules` for `InMemoryTableScanExec`
// No need to apply `postStageCreationRules` for `InMemoryTableScanLike`
// as it's a leaf node.
val newPlan = optimizeQueryStage(i, isFinalStage = false)
if (!newPlan.isInstanceOf[InMemoryTableScanExec]) {
if (!newPlan.isInstanceOf[InMemoryTableScanLike]) {
throw SparkException.internalError(
"Custom AQE rules cannot transform table scan node to something else.")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.plans.logical.Statistics
import org.apache.spark.sql.catalyst.plans.physical.Partitioning
import org.apache.spark.sql.columnar.CachedBatch
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec
import org.apache.spark.sql.execution.columnar.InMemoryTableScanLike
import org.apache.spark.sql.execution.exchange._
import org.apache.spark.sql.vectorized.ColumnarBatch

Expand Down Expand Up @@ -261,7 +261,7 @@ case class BroadcastQueryStageExec(
}

/**
* A table cache query stage whose child is a [[InMemoryTableScanExec]].
* A table cache query stage whose child is a [[InMemoryTableScanLike]].
*
* @param id the query stage id.
* @param plan the underlying plan.
Expand All @@ -271,7 +271,7 @@ case class TableCacheQueryStageExec(
override val plan: SparkPlan) extends QueryStageExec {

@transient val inMemoryTableScan = plan match {
case i: InMemoryTableScanExec => i
case i: InMemoryTableScanLike => i
case _ =>
throw SparkException.internalError(s"wrong plan for table cache stage:\n ${plan.treeString}")
}
Expand All @@ -294,5 +294,5 @@ case class TableCacheQueryStageExec(

override protected def doMaterialize(): Future[Any] = future

override def getRuntimeStatistics: Statistics = inMemoryTableScan.relation.computeStats()
override def getRuntimeStatistics: Statistics = inMemoryTableScan.runtimeStatistics
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,40 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.plans.logical.Statistics
import org.apache.spark.sql.catalyst.plans.physical.Partitioning
import org.apache.spark.sql.columnar.CachedBatch
import org.apache.spark.sql.execution.{LeafExecNode, SparkPlan, WholeStageCodegenExec}
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec
import org.apache.spark.sql.execution.metric.SQLMetrics
import org.apache.spark.sql.vectorized.ColumnarBatch

/**
* Common trait for all InMemoryTableScans implementations to facilitate pattern matching.
*/
trait InMemoryTableScanLike extends LeafExecNode {

/**
* Returns whether the cache buffer is loaded
*/
def isMaterialized: Boolean

/**
* Returns the actual cached RDD without filters and serialization of row/columnar.
*/
def baseCacheRDD(): RDD[CachedBatch]

/**
* Returns the runtime statistics after materialization.
*/
def runtimeStatistics: Statistics
}

case class InMemoryTableScanExec(
attributes: Seq[Attribute],
predicates: Seq[Expression],
@transient relation: InMemoryRelation)
extends LeafExecNode {
extends InMemoryTableScanLike {

override lazy val metrics = Map(
"numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
Expand Down Expand Up @@ -176,13 +198,18 @@ case class InMemoryTableScanExec(
columnarInputRDD
}

def isMaterialized: Boolean = relation.cacheBuilder.isCachedColumnBuffersLoaded
override def isMaterialized: Boolean = relation.cacheBuilder.isCachedColumnBuffersLoaded

/**
* This method is only used by AQE which executes the actually cached RDD that without filter and
* serialization of row/columnar.
*/
def baseCacheRDD(): RDD[CachedBatch] = {
override def baseCacheRDD(): RDD[CachedBatch] = {
relation.cacheBuilder.cachedColumnBuffers
}

/**
* Returns the runtime statistics after shuffle materialization.
*/
override def runtimeStatistics: Statistics = relation.computeStats()
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight}
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan}
import org.apache.spark.sql.execution.{CollectLimitExec, ColumnarToRowExec, LocalTableScanExec, PartialReducerPartitionSpec, QueryExecution, ReusedSubqueryExec, ShuffledRowRDD, SortExec, SparkPlan, SparkPlanInfo, UnionExec}
import org.apache.spark.sql.execution.aggregate.BaseAggregateExec
import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec
import org.apache.spark.sql.execution.columnar.{InMemoryTableScanExec, InMemoryTableScanLike}
import org.apache.spark.sql.execution.command.DataWritingCommandExec
import org.apache.spark.sql.execution.datasources.noop.NoopDataSource
import org.apache.spark.sql.execution.datasources.v2.V2TableWriteExec
Expand Down Expand Up @@ -2763,7 +2763,7 @@ class AdaptiveQueryExecSuite
case s: SortExec => s
}.size == (if (firstAccess) 2 else 0))
assert(collect(initialExecutedPlan) {
case i: InMemoryTableScanExec => i
case i: InMemoryTableScanLike => i
}.head.isMaterialized != firstAccess)

df.collect()
Expand All @@ -2775,7 +2775,7 @@ class AdaptiveQueryExecSuite
case s: SortExec => s
}.isEmpty)
assert(collect(initialExecutedPlan) {
case i: InMemoryTableScanExec => i
case i: InMemoryTableScanLike => i
}.head.isMaterialized)
}

Expand Down

0 comments on commit 6a27789

Please sign in to comment.