Skip to content

Commit

Permalink
Apply Comet diff for Spark 3.4.2
Browse files Browse the repository at this point in the history
  • Loading branch information
sunchao committed Mar 7, 2024
1 parent 0c0e7d4 commit f7c15aa
Show file tree
Hide file tree
Showing 47 changed files with 352 additions and 70 deletions.
21 changes: 21 additions & 0 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,8 @@
<chill.version>0.10.0</chill.version>
<ivy.version>2.5.1</ivy.version>
<oro.version>2.0.8</oro.version>
<spark.version.short>3.4</spark.version.short>
<comet.version>0.1.0-SNAPSHOT</comet.version>
<!--
If you changes codahale.metrics.version, you also need to change
the link to metrics.dropwizard.io in docs/monitoring.md.
Expand Down Expand Up @@ -2766,6 +2768,25 @@
<artifactId>arpack</artifactId>
<version>${netlib.ludovic.dev.version}</version>
</dependency>
<dependency>
<groupId>org.apache.comet</groupId>
<artifactId>comet-spark-spark${spark.version.short}_${scala.binary.version}</artifactId>
<version>${comet.version}</version>
<exclusions>
<exclusion>
<groupId>org.apache.spark</groupId>
<artifactId>spark-sql_${scala.binary.version}</artifactId>
</exclusion>
<exclusion>
<groupId>org.apache.spark</groupId>
<artifactId>spark-core_${scala.binary.version}</artifactId>
</exclusion>
<exclusion>
<groupId>org.apache.spark</groupId>
<artifactId>spark-catalyst_${scala.binary.version}</artifactId>
</exclusion>
</exclusions>
</dependency>
</dependencies>
</dependencyManagement>

Expand Down
4 changes: 4 additions & 0 deletions sql/core/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,10 @@
<groupId>org.apache.spark</groupId>
<artifactId>spark-tags_${scala.binary.version}</artifactId>
</dependency>
<dependency>
<groupId>org.apache.comet</groupId>
<artifactId>comet-spark-spark${spark.version.short}_${scala.binary.version}</artifactId>
</dependency>

<!--
This spark-tags test-dep is needed even though it isn't used in this module, otherwise testing-cmds that exclude
Expand Down
18 changes: 15 additions & 3 deletions sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ import scala.collection.JavaConverters._
import scala.reflect.runtime.universe.TypeTag
import scala.util.control.NonFatal

import org.apache.comet.CometConf

import org.apache.spark.{SPARK_VERSION, SparkConf, SparkContext, TaskContext}
import org.apache.spark.annotation.{DeveloperApi, Experimental, Stable, Unstable}
import org.apache.spark.api.java.JavaRDD
Expand Down Expand Up @@ -102,7 +104,7 @@ class SparkSession private(
sc: SparkContext,
initialSessionOptions: java.util.HashMap[String, String]) = {
this(sc, None, None,
SparkSession.applyExtensions(
SparkSession.applyExtensions(sc,
sc.getConf.get(StaticSQLConf.SPARK_SESSION_EXTENSIONS).getOrElse(Seq.empty),
new SparkSessionExtensions), initialSessionOptions.asScala.toMap)
}
Expand Down Expand Up @@ -1028,7 +1030,7 @@ object SparkSession extends Logging {
}

loadExtensions(extensions)
applyExtensions(
applyExtensions(sparkContext,
sparkContext.getConf.get(StaticSQLConf.SPARK_SESSION_EXTENSIONS).getOrElse(Seq.empty),
extensions)

Expand Down Expand Up @@ -1282,14 +1284,24 @@ object SparkSession extends Logging {
}
}

private def loadCometExtension(sparkContext: SparkContext): Seq[String] = {
if (sparkContext.getConf.getBoolean(CometConf.COMET_ENABLED.key, false)) {
Seq("org.apache.comet.CometSparkSessionExtensions")
} else {
Seq.empty
}
}

/**
* Initialize extensions for given extension classnames. The classes will be applied to the
* extensions passed into this function.
*/
private def applyExtensions(
sparkContext: SparkContext,
extensionConfClassNames: Seq[String],
extensions: SparkSessionExtensions): SparkSessionExtensions = {
extensionConfClassNames.foreach { extensionConfClassName =>
val extensionClassNames = extensionConfClassNames ++ loadCometExtension(sparkContext)
extensionClassNames.foreach { extensionConfClassName =>
try {
val extensionConfClass = Utils.classForName(extensionConfClassName)
val extensionConf = extensionConfClass.getConstructor().newInstance()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.sql.execution

import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.sql.comet.CometScanExec
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, QueryStageExec}
import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec
import org.apache.spark.sql.execution.exchange.ReusedExchangeExec
Expand Down Expand Up @@ -67,6 +68,7 @@ private[execution] object SparkPlanInfo {
// dump the file scan metadata (e.g file path) to event log
val metadata = plan match {
case fileScan: FileSourceScanExec => fileScan.metadata
case cometScan: CometScanExec => cometScan.metadata
case _ => Map[String, String]()
}
new SparkPlanInfo(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@

--SET spark.sql.adaptive.enabled=true
--SET spark.sql.maxMetadataStringLength = 500
--SET spark.comet.enabled = false
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
--SET spark.sql.cbo.enabled=true
--SET spark.sql.maxMetadataStringLength = 500
--SET spark.comet.enabled = false

CREATE TABLE explain_temp1(a INT, b INT) USING PARQUET;
CREATE TABLE explain_temp2(c INT, d INT) USING PARQUET;
Expand Down
1 change: 1 addition & 0 deletions sql/core/src/test/resources/sql-tests/inputs/explain.sql
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
--SET spark.sql.codegen.wholeStage = true
--SET spark.sql.adaptive.enabled = false
--SET spark.sql.maxMetadataStringLength = 500
--SET spark.comet.enabled = false

-- Test tables
CREATE table explain_temp1 (key int, val int) USING PARQUET;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@

-- avoid bit-exact output here because operations may not be bit-exact.
-- SET extra_float_digits = 0;
-- Disable Comet exec due to floating point precision difference
--SET spark.comet.exec.enabled = false


-- Test aggregate operator with codegen on and off.
--CONFIG_DIM1 spark.sql.codegen.wholeStage=true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
-- AGGREGATES [Part 3]
-- https://github.com/postgres/postgres/blob/REL_12_BETA2/src/test/regress/sql/aggregates.sql#L352-L605

-- Disable Comet exec due to floating point precision difference
--SET spark.comet.exec.enabled = false

-- Test aggregate operator with codegen on and off.
--CONFIG_DIM1 spark.sql.codegen.wholeStage=true
--CONFIG_DIM1 spark.sql.codegen.wholeStage=false,spark.sql.codegen.factoryMode=CODEGEN_ONLY
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
--
-- Portions Copyright (c) 1996-2019, PostgreSQL Global Development Group
--

-- Disable Comet exec due to floating point precision difference
--SET spark.comet.exec.enabled = false

--
-- INT8
-- Test int8 64-bit integers.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
--
-- Portions Copyright (c) 1996-2019, PostgreSQL Global Development Group
--

-- Disable Comet exec due to floating point precision difference
--SET spark.comet.exec.enabled = false

--
-- SELECT_HAVING
-- https://github.com/postgres/postgres/blob/REL_12_BETA2/src/test/regress/sql/select_having.sql
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,9 @@ class DataFrameJoinSuite extends QueryTest

withTempDatabase { dbName =>
withTable(table1Name, table2Name) {
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
withSQLConf(
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
"spark.comet.enabled" -> "false") {
spark.range(50).write.saveAsTable(s"$dbName.$table1Name")
spark.range(100).write.saveAsTable(s"$dbName.$table2Name")

Expand Down
39 changes: 39 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/DisableComet.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql

import org.scalactic.source.Position
import org.scalatest.Tag

import org.apache.spark.sql.test.SQLTestUtils

case class DisableComet(reason: String) extends Tag("DisableComet")

/**
* Helper trait that disables Comet for all tests regardless of default config values.
*/
trait DisableCometSuite extends SQLTestUtils {

This comment has been minimized.

Copy link
@advancedxy

advancedxy Mar 8, 2024

I think this trait and the case class name doesn't represent what they do.

A more appropriate name may be: IgnoreWhenCometOnSuite or something like that.

This comment has been minimized.

Copy link
@sunchao

sunchao Mar 8, 2024

Author Owner

Hmm how about IgnoreCometSuite? I feel IgnoreWhenCometOnSuite is a bit too long :)

I think it is very easy to get what the case class and trait does.

This comment has been minimized.

Copy link
@advancedxy

advancedxy Mar 8, 2024

Hmm how about IgnoreCometSuite?

SGTM

override protected def test(testName: String, testTags: Tag*)(testFun: => Any)
(implicit pos: Position): Unit = {
if (isCometEnabled) {
ignore(testName + " (disabled when Comet is on)", testTags: _*)(testFun)
} else {
super.test(testName, testTags: _*)(testFun)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import org.scalatest.GivenWhenThen
import org.apache.spark.sql.catalyst.expressions.{DynamicPruningExpression, Expression}
import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode._
import org.apache.spark.sql.catalyst.plans.ExistenceJoin
import org.apache.spark.sql.comet.CometScanExec
import org.apache.spark.sql.connector.catalog.{InMemoryTableCatalog, InMemoryTableWithV2FilterCatalog}
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.adaptive._
Expand Down Expand Up @@ -262,6 +263,9 @@ abstract class DynamicPartitionPruningSuiteBase
case s: BatchScanExec => s.runtimeFilters.collect {
case d: DynamicPruningExpression => d.child
}
case s: CometScanExec => s.partitionFilters.collect {
case d: DynamicPruningExpression => d.child
}
case _ => Nil
}
}
Expand Down Expand Up @@ -1729,6 +1733,8 @@ abstract class DynamicPartitionPruningV1Suite extends DynamicPartitionPruningDat
case s: BatchScanExec =>
// we use f1 col for v2 tables due to schema pruning
s.output.exists(_.exists(_.argString(maxFields = 100).contains("f1")))
case s: CometScanExec =>
s.output.exists(_.exists(_.argString(maxFields = 100).contains("fid")))
case _ => false
}
assert(scanOption.isDefined)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,8 @@ class ExplainSuite extends ExplainSuiteHelper with DisableAdaptiveExecutionSuite
}
}

test("Explain formatted output for scan operator for datasource V2") {
test("Explain formatted output for scan operator for datasource V2",
DisableComet("Comet explain output is different")) {
withTempDir { dir =>
Seq("parquet", "orc", "csv", "json").foreach { fmt =>
val basePath = dir.getCanonicalPath + "/" + fmt
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import org.apache.spark.sql.TestingUDT.{IntervalUDT, NullData, NullUDT}
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, GreaterThan, Literal}
import org.apache.spark.sql.catalyst.expressions.IntegralLiteralTestUtils.{negativeInt, positiveInt}
import org.apache.spark.sql.catalyst.plans.logical.Filter
import org.apache.spark.sql.comet.{CometBatchScanExec, CometScanExec}
import org.apache.spark.sql.execution.{FileSourceScanLike, SimpleMode}
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.execution.datasources.FilePartition
Expand Down Expand Up @@ -875,6 +876,7 @@ class FileBasedDataSourceSuite extends QueryTest

val fileScan = df.queryExecution.executedPlan collectFirst {
case BatchScanExec(_, f: FileScan, _, _, _, _, _, _, _) => f
case CometBatchScanExec(BatchScanExec(_, f: FileScan, _, _, _, _, _, _, _), _) => f
}
assert(fileScan.nonEmpty)
assert(fileScan.get.partitionFilters.nonEmpty)
Expand Down Expand Up @@ -916,6 +918,7 @@ class FileBasedDataSourceSuite extends QueryTest

val fileScan = df.queryExecution.executedPlan collectFirst {
case BatchScanExec(_, f: FileScan, _, _, _, _, _, _, _) => f
case CometBatchScanExec(BatchScanExec(_, f: FileScan, _, _, _, _, _, _, _), _) => f
}
assert(fileScan.nonEmpty)
assert(fileScan.get.partitionFilters.isEmpty)
Expand Down Expand Up @@ -1100,6 +1103,8 @@ class FileBasedDataSourceSuite extends QueryTest
val filters = df.queryExecution.executedPlan.collect {
case f: FileSourceScanLike => f.dataFilters
case b: BatchScanExec => b.scan.asInstanceOf[FileScan].dataFilters
case b: CometScanExec => b.dataFilters
case b: CometBatchScanExec => b.scan.asInstanceOf[FileScan].dataFilters
}.flatten
assert(filters.contains(GreaterThan(scan.logicalPlan.output.head, Literal(5L))))
}
Expand Down
5 changes: 3 additions & 2 deletions sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
import org.apache.spark.sql.catalyst.expressions.{Ascending, GenericRow, SortOrder}
import org.apache.spark.sql.catalyst.plans.logical.Filter
import org.apache.spark.sql.comet._
import org.apache.spark.sql.execution.{BinaryExecNode, FilterExec, ProjectExec, SortExec, SparkPlan, WholeStageCodegenExec}
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.execution.exchange.{ShuffleExchangeExec, ShuffleExchangeLike}
Expand Down Expand Up @@ -1371,7 +1372,7 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan
assert(collect(plan) { case _: BroadcastNestedLoopJoinExec => true }.size === 1)
assert(collect(plan) { case _: SortMergeJoinExec => true }.size === 3)
// No extra sort on left side before last sort merge join
assert(collect(plan) { case _: SortExec => true }.size === 5)
assert(collect(plan) { case _: SortExec | _: CometSortExec => true }.size === 5)
}

// Test output ordering is not preserved
Expand All @@ -1382,7 +1383,7 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan
assert(collect(plan) { case _: BroadcastNestedLoopJoinExec => true }.size === 1)
assert(collect(plan) { case _: SortMergeJoinExec => true }.size === 3)
// Have sort on left side before last sort merge join
assert(collect(plan) { case _: SortExec => true }.size === 6)
assert(collect(plan) { case _: SortExec | _: CometSortExec => true }.size === 6)
}

// Test singe partition
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ import org.apache.spark.tags.ExtendedSQLTest
* }}}
*/
// scalastyle:on line.size.limit
trait PlanStabilitySuite extends DisableAdaptiveExecutionSuite {
trait PlanStabilitySuite extends DisableAdaptiveExecutionSuite with DisableCometSuite {

protected val baseResourcePath = {
// use the same way as `SQLQueryTestSuite` to get the resource path
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import scala.collection.mutable.ArrayBuffer

import org.apache.spark.sql.catalyst.expressions.SubqueryExpression
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Join, LogicalPlan, Project, Sort, Union}
import org.apache.spark.sql.comet.CometScanExec
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanHelper, DisableAdaptiveExecution}
import org.apache.spark.sql.execution.datasources.FileScanRDD
Expand Down Expand Up @@ -1543,6 +1544,12 @@ class SubquerySuite extends QueryTest
fs.inputRDDs().forall(
_.asInstanceOf[FileScanRDD].filePartitions.forall(
_.files.forall(_.urlEncodedPath.contains("p=0"))))
case WholeStageCodegenExec(ColumnarToRowExec(InputAdapter(
fs @ CometScanExec(_, _, _, partitionFilters, _, _, _, _, _, _)))) =>
partitionFilters.exists(ExecSubqueryExpression.hasSubquery) &&
fs.inputRDDs().forall(
_.asInstanceOf[FileScanRDD].filePartitions.forall(
_.files.forall(_.urlEncodedPath.contains("p=0"))))
case _ => false
})
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import scala.collection.mutable.ArrayBuffer
import org.apache.spark.SparkConf
import org.apache.spark.sql.{AnalysisException, QueryTest}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.comet.CometScanExec
import org.apache.spark.sql.connector.catalog.{SupportsRead, SupportsWrite, Table, TableCapability}
import org.apache.spark.sql.connector.read.ScanBuilder
import org.apache.spark.sql.connector.write.{LogicalWriteInfo, WriteBuilder}
Expand Down Expand Up @@ -184,7 +185,11 @@ class FileDataSourceV2FallBackSuite extends QueryTest with SharedSparkSession {
val df = spark.read.format(format).load(path.getCanonicalPath)
checkAnswer(df, inputData.toDF())
assert(
df.queryExecution.executedPlan.exists(_.isInstanceOf[FileSourceScanExec]))
df.queryExecution.executedPlan.exists {
case _: FileSourceScanExec | _: CometScanExec => true
case _ => false
}
)
}
} finally {
spark.listenerManager.unregister(listener)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import org.apache.hadoop.fs.permission.FsPermission
import org.mockito.Mockito.{mock, spy, when}

import org.apache.spark._
import org.apache.spark.sql.{AnalysisException, DataFrame, Dataset, QueryTest, Row, SaveMode}
import org.apache.spark.sql.{AnalysisException, DataFrame, Dataset, DisableComet, QueryTest, Row, SaveMode}
import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode._
import org.apache.spark.sql.catalyst.util.BadRecordException
import org.apache.spark.sql.execution.datasources.jdbc.{DriverRegistry, JDBCOptions}
Expand Down Expand Up @@ -248,7 +248,8 @@ class QueryExecutionErrorsSuite
}

test("INCONSISTENT_BEHAVIOR_CROSS_VERSION: " +
"compatibility with Spark 2.4/3.2 in reading/writing dates") {
"compatibility with Spark 2.4/3.2 in reading/writing dates",
DisableComet("Comet doesn't completely support datetime rebase mode yet")) {

// Fail to read ancient datetime values.
withSQLConf(SQLConf.PARQUET_REBASE_MODE_IN_READ.key -> EXCEPTION.toString) {
Expand Down

0 comments on commit f7c15aa

Please sign in to comment.