Skip to content

Commit

Permalink
[SPARK-26580][SQL][ML][FOLLOW-UP] Throw exception when use untyped UD…
Browse files Browse the repository at this point in the history
…F by default

### What changes were proposed in this pull request?

This PR proposes to throw exception by default when user use untyped UDF(a.k.a `org.apache.spark.sql.functions.udf(AnyRef, DataType)`).

And user could still use it by setting `spark.sql.legacy.useUnTypedUdf.enabled` to `true`.

### Why are the changes needed?

According to apache#23498, since Spark 3.0, the untyped UDF will return the default value of the Java type if the input value is null. For example, `val f = udf((x: Int) => x, IntegerType)`, `f($"x")` will  return 0 in Spark 3.0 but null in Spark 2.4. And the behavior change is introduced due to Spark3.0 is built with Scala 2.12 by default.

As a result, this might change data silently and may cause correctness issue if user still expect `null` in some cases. Thus, we'd better to encourage user to use typed UDF to avoid this problem.

### Does this PR introduce any user-facing change?

Yeah. User will hit exception now when use untyped UDF.

### How was this patch tested?

Added test and updated some tests.

Closes apache#27488 from Ngone51/spark_26580_followup.

Lead-authored-by: yi.wu <yi.wu@databricks.com>
Co-authored-by: wuyi <yi.wu@databricks.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
Ngone51 authored and Seongjin Cho committed Apr 14, 2020
1 parent 45f0f46 commit 419585a
Show file tree
Hide file tree
Showing 10 changed files with 62 additions and 31 deletions.
2 changes: 1 addition & 1 deletion docs/sql-migration-guide.md
Expand Up @@ -63,7 +63,7 @@ license: |

- Since Spark 3.0, JSON datasource and JSON function `schema_of_json` infer TimestampType from string values if they match to the pattern defined by the JSON option `timestampFormat`. Set JSON option `inferTimestamp` to `false` to disable such type inferring.

- In Spark version 2.4 and earlier, if `org.apache.spark.sql.functions.udf(Any, DataType)` gets a Scala closure with primitive-type argument, the returned UDF will return null if the input values is null. Since Spark 3.0, the UDF will return the default value of the Java type if the input value is null. For example, `val f = udf((x: Int) => x, IntegerType)`, `f($"x")` will return null in Spark 2.4 and earlier if column `x` is null, and return 0 in Spark 3.0. This behavior change is introduced because Spark 3.0 is built with Scala 2.12 by default.
- Since Spark 3.0, using `org.apache.spark.sql.functions.udf(AnyRef, DataType)` is not allowed by default. Set `spark.sql.legacy.allowUntypedScalaUDF` to true to keep using it. But please note that, in Spark version 2.4 and earlier, if `org.apache.spark.sql.functions.udf(AnyRef, DataType)` gets a Scala closure with primitive-type argument, the returned UDF will return null if the input values is null. However, since Spark 3.0, the UDF will return the default value of the Java type if the input value is null. For example, `val f = udf((x: Int) => x, IntegerType)`, `f($"x")` will return null in Spark 2.4 and earlier if column `x` is null, and return 0 in Spark 3.0. This behavior change is introduced because Spark 3.0 is built with Scala 2.12 by default.

- Since Spark 3.0, Proleptic Gregorian calendar is used in parsing, formatting, and converting dates and timestamps as well as in extracting sub-components like years, days and etc. Spark 3.0 uses Java 8 API classes from the java.time packages that based on ISO chronology (https://docs.oracle.com/javase/8/docs/api/java/time/chrono/IsoChronology.html). In Spark version 2.4 and earlier, those operations are performed by using the hybrid calendar (Julian + Gregorian, see https://docs.oracle.com/javase/7/docs/api/java/util/GregorianCalendar.html). The changes impact on the results for dates before October 15, 1582 (Gregorian) and affect on the following Spark 3.0 API:

Expand Down
5 changes: 3 additions & 2 deletions mllib/src/main/scala/org/apache/spark/ml/Transformer.scala
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.ml

import scala.annotation.varargs
import scala.reflect.runtime.universe.TypeTag

import org.apache.spark.annotation.{DeveloperApi, Since}
import org.apache.spark.internal.Logging
Expand Down Expand Up @@ -79,7 +80,7 @@ abstract class Transformer extends PipelineStage {
* result as a new column.
*/
@DeveloperApi
abstract class UnaryTransformer[IN, OUT, T <: UnaryTransformer[IN, OUT, T]]
abstract class UnaryTransformer[IN: TypeTag, OUT: TypeTag, T <: UnaryTransformer[IN, OUT, T]]
extends Transformer with HasInputCol with HasOutputCol with Logging {

/** @group setParam */
Expand Down Expand Up @@ -118,7 +119,7 @@ abstract class UnaryTransformer[IN, OUT, T <: UnaryTransformer[IN, OUT, T]]

override def transform(dataset: Dataset[_]): DataFrame = {
val outputSchema = transformSchema(dataset.schema, logging = true)
val transformUDF = udf(this.createTransformFunc, outputDataType)
val transformUDF = udf(this.createTransformFunc)
dataset.withColumn($(outputCol), transformUDF(dataset($(inputCol))),
outputSchema($(outputCol)).metadata)
}
Expand Down
11 changes: 5 additions & 6 deletions mllib/src/main/scala/org/apache/spark/ml/feature/LSH.scala
Expand Up @@ -98,7 +98,7 @@ private[ml] abstract class LSHModel[T <: LSHModel[T]]

override def transform(dataset: Dataset[_]): DataFrame = {
transformSchema(dataset.schema, logging = true)
val transformUDF = udf(hashFunction(_: Vector), DataTypes.createArrayType(new VectorUDT))
val transformUDF = udf(hashFunction(_: Vector))
dataset.withColumn($(outputCol), transformUDF(dataset($(inputCol))))
}

Expand Down Expand Up @@ -128,14 +128,13 @@ private[ml] abstract class LSHModel[T <: LSHModel[T]]
}

// In the origin dataset, find the hash value that hash the same bucket with the key
val sameBucketWithKeyUDF = udf((x: Seq[Vector]) =>
sameBucket(x, keyHash), DataTypes.BooleanType)
val sameBucketWithKeyUDF = udf((x: Seq[Vector]) => sameBucket(x, keyHash))

modelDataset.filter(sameBucketWithKeyUDF(col($(outputCol))))
} else {
// In the origin dataset, find the hash value that is closest to the key
// Limit the use of hashDist since it's controversial
val hashDistUDF = udf((x: Seq[Vector]) => hashDistance(x, keyHash), DataTypes.DoubleType)
val hashDistUDF = udf((x: Seq[Vector]) => hashDistance(x, keyHash))
val hashDistCol = hashDistUDF(col($(outputCol)))
val modelDatasetWithDist = modelDataset.withColumn(distCol, hashDistCol)

Expand Down Expand Up @@ -172,7 +171,7 @@ private[ml] abstract class LSHModel[T <: LSHModel[T]]
}

// Get the top k nearest neighbor by their distance to the key
val keyDistUDF = udf((x: Vector) => keyDistance(x, key), DataTypes.DoubleType)
val keyDistUDF = udf((x: Vector) => keyDistance(x, key))
val modelSubsetWithDistCol = modelSubset.withColumn(distCol, keyDistUDF(col($(inputCol))))
modelSubsetWithDistCol.sort(distCol).limit(numNearestNeighbors)
}
Expand Down Expand Up @@ -290,7 +289,7 @@ private[ml] abstract class LSHModel[T <: LSHModel[T]]
.drop(explodeCols: _*).distinct()

// Add a new column to store the distance of the two rows.
val distUDF = udf((x: Vector, y: Vector) => keyDistance(x, y), DataTypes.DoubleType)
val distUDF = udf((x: Vector, y: Vector) => keyDistance(x, y))
val joinedDatasetWithDist = joinedDataset.select(col("*"),
distUDF(col(s"$leftColName.${$(inputCol)}"), col(s"$rightColName.${$(inputCol)}")).as(distCol)
)
Expand Down
11 changes: 7 additions & 4 deletions mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala
Expand Up @@ -29,10 +29,10 @@ import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared.HasPredictionCol
import org.apache.spark.ml.util._
import org.apache.spark.ml.util.Instrumentation.instrumented
import org.apache.spark.mllib.fpm.{AssociationRules => MLlibAssociationRules,
FPGrowth => MLlibFPGrowth}
import org.apache.spark.mllib.fpm.{AssociationRules => MLlibAssociationRules, FPGrowth => MLlibFPGrowth}
import org.apache.spark.mllib.fpm.FPGrowth.FreqItemset
import org.apache.spark.sql._
import org.apache.spark.sql.expressions.SparkUserDefinedFunction
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
import org.apache.spark.storage.StorageLevel
Expand Down Expand Up @@ -286,14 +286,17 @@ class FPGrowthModel private[ml] (

val dt = dataset.schema($(itemsCol)).dataType
// For each rule, examine the input items and summarize the consequents
val predictUDF = udf((items: Seq[Any]) => {
val predictUDF = SparkUserDefinedFunction((items: Seq[Any]) => {
if (items != null) {
val itemset = items.toSet
brRules.value.filter(_._1.forall(itemset.contains))
.flatMap(_._2.filter(!itemset.contains(_))).distinct
} else {
Seq.empty
}}, dt)
}},
dt,
Nil
)
dataset.withColumn($(predictionCol), predictUDF(col($(itemsCol))))
}

Expand Down
Expand Up @@ -76,9 +76,8 @@ private[ml] object LSHTest {

// Perform a cross join and label each pair of same_bucket and distance
val pairs = transformedData.as("a").crossJoin(transformedData.as("b"))
val distUDF = udf((x: Vector, y: Vector) => model.keyDistance(x, y), DataTypes.DoubleType)
val sameBucket = udf((x: Seq[Vector], y: Seq[Vector]) => model.hashDistance(x, y) == 0.0,
DataTypes.BooleanType)
val distUDF = udf((x: Vector, y: Vector) => model.keyDistance(x, y))
val sameBucket = udf((x: Seq[Vector], y: Seq[Vector]) => model.hashDistance(x, y) == 0.0)
val result = pairs
.withColumn("same_bucket", sameBucket(col(s"a.$outputCol"), col(s"b.$outputCol")))
.withColumn("distance", distUDF(col(s"a.$inputCol"), col(s"b.$inputCol")))
Expand Down Expand Up @@ -110,7 +109,7 @@ private[ml] object LSHTest {
val model = lsh.fit(dataset)

// Compute expected
val distUDF = udf((x: Vector) => model.keyDistance(x, key), DataTypes.DoubleType)
val distUDF = udf((x: Vector) => model.keyDistance(x, key))
val expected = dataset.sort(distUDF(col(model.getInputCol))).limit(k)

// Compute actual
Expand Down Expand Up @@ -148,7 +147,7 @@ private[ml] object LSHTest {
val inputCol = model.getInputCol

// Compute expected
val distUDF = udf((x: Vector, y: Vector) => model.keyDistance(x, y), DataTypes.DoubleType)
val distUDF = udf((x: Vector, y: Vector) => model.keyDistance(x, y))
val expected = datasetA.as("a").crossJoin(datasetB.as("b"))
.filter(distUDF(col(s"a.$inputCol"), col(s"b.$inputCol")) < threshold)

Expand Down
3 changes: 3 additions & 0 deletions project/MimaExcludes.scala
Expand Up @@ -74,6 +74,9 @@ object MimaExcludes {
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.clustering.KMeans.getRuns"),
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.clustering.KMeans.setRuns"),

// [SPARK-26580][SQL][ML][FOLLOW-UP] Throw exception when use untyped UDF by default
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.UnaryTransformer.this"),

// [SPARK-27090][CORE] Removing old LEGACY_DRIVER_IDENTIFIER ("<driver>")
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.SparkContext.LEGACY_DRIVER_IDENTIFIER"),
// [SPARK-25838] Remove formatVersion from Saveable
Expand Down
Expand Up @@ -2017,6 +2017,14 @@ object SQLConf {
.booleanConf
.createWithDefault(false)

val LEGACY_ALLOW_UNTYPED_SCALA_UDF =
buildConf("spark.sql.legacy.allowUntypedScalaUDF")
.internal()
.doc("When set to true, user is allowed to use org.apache.spark.sql.functions." +
"udf(f: AnyRef, dataType: DataType). Otherwise, exception will be throw.")
.booleanConf
.createWithDefault(false)

val TRUNCATE_TABLE_IGNORE_PERMISSION_ACL =
buildConf("spark.sql.truncateTable.ignorePermissionAcl.enabled")
.internal()
Expand Down
Expand Up @@ -90,7 +90,7 @@ sealed abstract class UserDefinedFunction {
def asNondeterministic(): UserDefinedFunction
}

private[sql] case class SparkUserDefinedFunction(
private[spark] case class SparkUserDefinedFunction(
f: AnyRef,
dataType: DataType,
inputSchemas: Seq[Option[ScalaReflection.Schema]],
Expand Down
9 changes: 9 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/functions.scala
Expand Up @@ -4733,6 +4733,15 @@ object functions {
* @since 2.0.0
*/
def udf(f: AnyRef, dataType: DataType): UserDefinedFunction = {
if (!SQLConf.get.getConf(SQLConf.LEGACY_ALLOW_UNTYPED_SCALA_UDF)) {
val errorMsg = "You're using untyped Scala UDF, which does not have the input type " +
"information. Spark may blindly pass null to the Scala closure with primitive-type " +
"argument, and the closure will see the default value of the Java type for the null " +
"argument, e.g. `udf((x: Int) => x, IntegerType)`, the result is 0 for null input. " +
"You could use other typed Scala UDF APIs to avoid this problem, or set " +
s"${SQLConf.LEGACY_ALLOW_UNTYPED_SCALA_UDF.key} to true and use this API with caution."
throw new AnalysisException(errorMsg)
}
SparkUserDefinedFunction(f, dataType, inputSchemas = Nil)
}

Expand Down
33 changes: 21 additions & 12 deletions sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
Expand Up @@ -134,10 +134,12 @@ class UDFSuite extends QueryTest with SharedSparkSession {
assert(df1.logicalPlan.asInstanceOf[Project].projectList.forall(!_.deterministic))
assert(df1.head().getDouble(0) >= 0.0)

val bar = udf(() => Math.random(), DataTypes.DoubleType).asNondeterministic()
val df2 = testData.select(bar())
assert(df2.logicalPlan.asInstanceOf[Project].projectList.forall(!_.deterministic))
assert(df2.head().getDouble(0) >= 0.0)
withSQLConf(SQLConf.LEGACY_ALLOW_UNTYPED_SCALA_UDF.key -> "true") {
val bar = udf(() => Math.random(), DataTypes.DoubleType).asNondeterministic()
val df2 = testData.select(bar())
assert(df2.logicalPlan.asInstanceOf[Project].projectList.forall(!_.deterministic))
assert(df2.head().getDouble(0) >= 0.0)
}

val javaUdf = udf(new UDF0[Double] {
override def call(): Double = Math.random()
Expand Down Expand Up @@ -441,16 +443,23 @@ class UDFSuite extends QueryTest with SharedSparkSession {
}

test("SPARK-25044 Verify null input handling for primitive types - with udf(Any, DataType)") {
val f = udf((x: Int) => x, IntegerType)
checkAnswer(
Seq(Integer.valueOf(1), null).toDF("x").select(f($"x")),
Row(1) :: Row(0) :: Nil)
withSQLConf(SQLConf.LEGACY_ALLOW_UNTYPED_SCALA_UDF.key -> "true") {
val f = udf((x: Int) => x, IntegerType)
checkAnswer(
Seq(Integer.valueOf(1), null).toDF("x").select(f($"x")),
Row(1) :: Row(0) :: Nil)

val f2 = udf((x: Double) => x, DoubleType)
checkAnswer(
Seq(java.lang.Double.valueOf(1.1), null).toDF("x").select(f2($"x")),
Row(1.1) :: Row(0.0) :: Nil)
}

val f2 = udf((x: Double) => x, DoubleType)
checkAnswer(
Seq(java.lang.Double.valueOf(1.1), null).toDF("x").select(f2($"x")),
Row(1.1) :: Row(0.0) :: Nil)
}

test("use untyped Scala UDF should fail by default") {
val e = intercept[AnalysisException](udf((x: Int) => x, IntegerType))
assert(e.getMessage.contains("You're using untyped Scala UDF"))
}

test("SPARK-26308: udf with decimal") {
Expand Down

0 comments on commit 419585a

Please sign in to comment.