Skip to content

Commit

Permalink
[SPARK-31003][TESTS] Fix incorrect uses of assume() in tests
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

This patch fixes several incorrect uses of `assume()` in our tests.

If a call to `assume(condition)` fails then it will cause the test to be marked as skipped instead of failed: this feature allows test cases to be skipped if certain prerequisites are missing. For example, we use this to skip certain tests when running on Windows (or when Python dependencies are unavailable).

In contrast, `assert(condition)` will fail the test if the condition doesn't hold.

If `assume()` is accidentally substituted for `assert()`then the resulting test will be marked as skipped in cases where it should have failed, undermining the purpose of the test.

This patch fixes several such cases, replacing certain `assume()` calls with `assert()`.

Credit to ahirreddy for spotting this problem.

### Does this PR introduce any user-facing change?

No.

### How was this patch tested?

Existing tests.

Closes apache#27754 from JoshRosen/fix-assume-vs-assert.

Lead-authored-by: Josh Rosen <rosenville@gmail.com>
Co-authored-by: Josh Rosen <joshrosen@databricks.com>
Signed-off-by: Dongjoon Hyun <dhyun@apple.com>
  • Loading branch information
2 people authored and sjincho committed Apr 14, 2020
1 parent 2fc39d9 commit e72d7b5
Show file tree
Hide file tree
Showing 10 changed files with 15 additions and 15 deletions.
Expand Up @@ -106,7 +106,7 @@ class OrderingSuite extends SparkFunSuite with ExpressionEvalHelper {
StructField("a", dataType, nullable = true) ::
StructField("b", dataType, nullable = true) :: Nil)
val maybeDataGenerator = RandomDataGenerator.forType(rowType, nullable = false)
assume(maybeDataGenerator.isDefined)
assert(maybeDataGenerator.isDefined)
val randGenerator = maybeDataGenerator.get
val toCatalyst = CatalystTypeConverters.createToCatalystConverter(rowType)
for (_ <- 1 to 50) {
Expand Down
Expand Up @@ -195,7 +195,7 @@ class CachedTableSuite extends QueryTest with SQLTestUtils
}

test("SPARK-1669: cacheTable should be idempotent") {
assume(!spark.table("testData").logicalPlan.isInstanceOf[InMemoryRelation])
assert(!spark.table("testData").logicalPlan.isInstanceOf[InMemoryRelation])

spark.catalog.cacheTable("testData")
assertCached(spark.table("testData"))
Expand Down
Expand Up @@ -1033,7 +1033,7 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils {
df.write.insertInto("students")
spark.catalog.cacheTable("students")
checkAnswer(spark.table("students"), df)
assume(spark.catalog.isCached("students"), "bad test: table was not cached in the first place")
assert(spark.catalog.isCached("students"), "bad test: table was not cached in the first place")
sql("ALTER TABLE students RENAME TO teachers")
sql("CREATE TABLE students (age INT, name STRING) USING parquet")
// Now we have both students and teachers.
Expand Down Expand Up @@ -1959,7 +1959,7 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils {
Seq("json", "parquet").foreach { format =>
withTable("rectangles") {
data.write.format(format).saveAsTable("rectangles")
assume(spark.table("rectangles").collect().nonEmpty,
assert(spark.table("rectangles").collect().nonEmpty,
"bad test; table was empty to begin with")

sql("TRUNCATE TABLE rectangles")
Expand Down
Expand Up @@ -89,9 +89,9 @@ class CatalogSuite extends SharedSparkSession {
val columns = dbName
.map { db => spark.catalog.listColumns(db, tableName) }
.getOrElse { spark.catalog.listColumns(tableName) }
assume(tableMetadata.schema.nonEmpty, "bad test")
assume(tableMetadata.partitionColumnNames.nonEmpty, "bad test")
assume(tableMetadata.bucketSpec.isDefined, "bad test")
assert(tableMetadata.schema.nonEmpty, "bad test")
assert(tableMetadata.partitionColumnNames.nonEmpty, "bad test")
assert(tableMetadata.bucketSpec.isDefined, "bad test")
assert(columns.collect().map(_.name).toSet == tableMetadata.schema.map(_.name).toSet)
val bucketColumnNames = tableMetadata.bucketSpec.map(_.bucketColumnNames).getOrElse(Nil).toSet
columns.collect().foreach { col =>
Expand Down
Expand Up @@ -42,7 +42,7 @@ import org.apache.spark.util.collection.BitSet
class BucketedReadWithoutHiveSupportSuite extends BucketedReadSuite with SharedSparkSession {
protected override def beforeAll(): Unit = {
super.beforeAll()
assume(spark.sparkContext.conf.get(CATALOG_IMPLEMENTATION) == "in-memory")
assert(spark.sparkContext.conf.get(CATALOG_IMPLEMENTATION) == "in-memory")
}
}

Expand Down
Expand Up @@ -33,7 +33,7 @@ import org.apache.spark.sql.test.{SharedSparkSession, SQLTestUtils}
class BucketedWriteWithoutHiveSupportSuite extends BucketedWriteSuite with SharedSparkSession {
protected override def beforeAll(): Unit = {
super.beforeAll()
assume(spark.sparkContext.conf.get(CATALOG_IMPLEMENTATION) == "in-memory")
assert(spark.sparkContext.conf.get(CATALOG_IMPLEMENTATION) == "in-memory")
}

override protected def fileFormatsToTest: Seq[String] = Seq("parquet", "json")
Expand Down
Expand Up @@ -981,8 +981,8 @@ class HiveDDLSuite
val expectedSerdePropsString =
expectedSerdeProps.map { case (k, v) => s"'$k'='$v'" }.mkString(", ")
val oldPart = catalog.getPartition(TableIdentifier("boxes"), Map("width" -> "4"))
assume(oldPart.storage.serde != Some(expectedSerde), "bad test: serde was already set")
assume(oldPart.storage.properties.filterKeys(expectedSerdeProps.contains) !=
assert(oldPart.storage.serde != Some(expectedSerde), "bad test: serde was already set")
assert(oldPart.storage.properties.filterKeys(expectedSerdeProps.contains) !=
expectedSerdeProps, "bad test: serde properties were already set")
sql(s"""ALTER TABLE boxes PARTITION (width=4)
| SET SERDE '$expectedSerde'
Expand Down Expand Up @@ -1735,7 +1735,7 @@ class HiveDDLSuite
Seq("json", "parquet").foreach { format =>
withTable("rectangles") {
data.write.format(format).saveAsTable("rectangles")
assume(spark.table("rectangles").collect().nonEmpty,
assert(spark.table("rectangles").collect().nonEmpty,
"bad test; table was empty to begin with")

sql("TRUNCATE TABLE rectangles")
Expand Down
Expand Up @@ -212,7 +212,7 @@ private[hive] class TestHiveSparkSession(
}
}

assume(sc.conf.get(CATALOG_IMPLEMENTATION) == "hive")
assert(sc.conf.get(CATALOG_IMPLEMENTATION) == "hive")

@transient
override lazy val sharedState: TestHiveSharedState = {
Expand Down
Expand Up @@ -23,6 +23,6 @@ import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION
class BucketedReadWithHiveSupportSuite extends BucketedReadSuite with TestHiveSingleton {
protected override def beforeAll(): Unit = {
super.beforeAll()
assume(spark.sparkContext.conf.get(CATALOG_IMPLEMENTATION) == "hive")
assert(spark.sparkContext.conf.get(CATALOG_IMPLEMENTATION) == "hive")
}
}
Expand Up @@ -23,7 +23,7 @@ import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION
class BucketedWriteWithHiveSupportSuite extends BucketedWriteSuite with TestHiveSingleton {
protected override def beforeAll(): Unit = {
super.beforeAll()
assume(spark.sparkContext.conf.get(CATALOG_IMPLEMENTATION) == "hive")
assert(spark.sparkContext.conf.get(CATALOG_IMPLEMENTATION) == "hive")
}

override protected def fileFormatsToTest: Seq[String] = Seq("parquet", "orc")
Expand Down

0 comments on commit e72d7b5

Please sign in to comment.