diff --git a/docs/sql-migration-guide.md b/docs/sql-migration-guide.md index 9b74b456b88e7..7e52e6924d9fc 100644 --- a/docs/sql-migration-guide.md +++ b/docs/sql-migration-guide.md @@ -63,7 +63,7 @@ license: | - Since Spark 3.0, JSON datasource and JSON function `schema_of_json` infer TimestampType from string values if they match to the pattern defined by the JSON option `timestampFormat`. Set JSON option `inferTimestamp` to `false` to disable such type inferring. - - In Spark version 2.4 and earlier, if `org.apache.spark.sql.functions.udf(Any, DataType)` gets a Scala closure with primitive-type argument, the returned UDF will return null if the input values is null. Since Spark 3.0, the UDF will return the default value of the Java type if the input value is null. For example, `val f = udf((x: Int) => x, IntegerType)`, `f($"x")` will return null in Spark 2.4 and earlier if column `x` is null, and return 0 in Spark 3.0. This behavior change is introduced because Spark 3.0 is built with Scala 2.12 by default. + - Since Spark 3.0, using `org.apache.spark.sql.functions.udf(AnyRef, DataType)` is not allowed by default. Set `spark.sql.legacy.allowUntypedScalaUDF` to true to keep using it. But please note that, in Spark version 2.4 and earlier, if `org.apache.spark.sql.functions.udf(AnyRef, DataType)` gets a Scala closure with primitive-type argument, the returned UDF will return null if the input values is null. However, since Spark 3.0, the UDF will return the default value of the Java type if the input value is null. For example, `val f = udf((x: Int) => x, IntegerType)`, `f($"x")` will return null in Spark 2.4 and earlier if column `x` is null, and return 0 in Spark 3.0. This behavior change is introduced because Spark 3.0 is built with Scala 2.12 by default. - Since Spark 3.0, Proleptic Gregorian calendar is used in parsing, formatting, and converting dates and timestamps as well as in extracting sub-components like years, days and etc. Spark 3.0 uses Java 8 API classes from the java.time packages that based on ISO chronology (https://docs.oracle.com/javase/8/docs/api/java/time/chrono/IsoChronology.html). In Spark version 2.4 and earlier, those operations are performed by using the hybrid calendar (Julian + Gregorian, see https://docs.oracle.com/javase/7/docs/api/java/util/GregorianCalendar.html). The changes impact on the results for dates before October 15, 1582 (Gregorian) and affect on the following Spark 3.0 API: diff --git a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala index 7874fc29db6c8..1652131a9003a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala @@ -18,6 +18,7 @@ package org.apache.spark.ml import scala.annotation.varargs +import scala.reflect.runtime.universe.TypeTag import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.internal.Logging @@ -79,7 +80,7 @@ abstract class Transformer extends PipelineStage { * result as a new column. */ @DeveloperApi -abstract class UnaryTransformer[IN, OUT, T <: UnaryTransformer[IN, OUT, T]] +abstract class UnaryTransformer[IN: TypeTag, OUT: TypeTag, T <: UnaryTransformer[IN, OUT, T]] extends Transformer with HasInputCol with HasOutputCol with Logging { /** @group setParam */ @@ -118,7 +119,7 @@ abstract class UnaryTransformer[IN, OUT, T <: UnaryTransformer[IN, OUT, T]] override def transform(dataset: Dataset[_]): DataFrame = { val outputSchema = transformSchema(dataset.schema, logging = true) - val transformUDF = udf(this.createTransformFunc, outputDataType) + val transformUDF = udf(this.createTransformFunc) dataset.withColumn($(outputCol), transformUDF(dataset($(inputCol))), outputSchema($(outputCol)).metadata) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/LSH.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/LSH.scala index 01741019fb546..6d5c7c50dbacc 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/LSH.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/LSH.scala @@ -98,7 +98,7 @@ private[ml] abstract class LSHModel[T <: LSHModel[T]] override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) - val transformUDF = udf(hashFunction(_: Vector), DataTypes.createArrayType(new VectorUDT)) + val transformUDF = udf(hashFunction(_: Vector)) dataset.withColumn($(outputCol), transformUDF(dataset($(inputCol)))) } @@ -128,14 +128,13 @@ private[ml] abstract class LSHModel[T <: LSHModel[T]] } // In the origin dataset, find the hash value that hash the same bucket with the key - val sameBucketWithKeyUDF = udf((x: Seq[Vector]) => - sameBucket(x, keyHash), DataTypes.BooleanType) + val sameBucketWithKeyUDF = udf((x: Seq[Vector]) => sameBucket(x, keyHash)) modelDataset.filter(sameBucketWithKeyUDF(col($(outputCol)))) } else { // In the origin dataset, find the hash value that is closest to the key // Limit the use of hashDist since it's controversial - val hashDistUDF = udf((x: Seq[Vector]) => hashDistance(x, keyHash), DataTypes.DoubleType) + val hashDistUDF = udf((x: Seq[Vector]) => hashDistance(x, keyHash)) val hashDistCol = hashDistUDF(col($(outputCol))) val modelDatasetWithDist = modelDataset.withColumn(distCol, hashDistCol) @@ -172,7 +171,7 @@ private[ml] abstract class LSHModel[T <: LSHModel[T]] } // Get the top k nearest neighbor by their distance to the key - val keyDistUDF = udf((x: Vector) => keyDistance(x, key), DataTypes.DoubleType) + val keyDistUDF = udf((x: Vector) => keyDistance(x, key)) val modelSubsetWithDistCol = modelSubset.withColumn(distCol, keyDistUDF(col($(inputCol)))) modelSubsetWithDistCol.sort(distCol).limit(numNearestNeighbors) } @@ -290,7 +289,7 @@ private[ml] abstract class LSHModel[T <: LSHModel[T]] .drop(explodeCols: _*).distinct() // Add a new column to store the distance of the two rows. - val distUDF = udf((x: Vector, y: Vector) => keyDistance(x, y), DataTypes.DoubleType) + val distUDF = udf((x: Vector, y: Vector) => keyDistance(x, y)) val joinedDatasetWithDist = joinedDataset.select(col("*"), distUDF(col(s"$leftColName.${$(inputCol)}"), col(s"$rightColName.${$(inputCol)}")).as(distCol) ) diff --git a/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala b/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala index 4d001c159eda0..e50d4255b1f37 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala @@ -29,10 +29,10 @@ import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared.HasPredictionCol import org.apache.spark.ml.util._ import org.apache.spark.ml.util.Instrumentation.instrumented -import org.apache.spark.mllib.fpm.{AssociationRules => MLlibAssociationRules, - FPGrowth => MLlibFPGrowth} +import org.apache.spark.mllib.fpm.{AssociationRules => MLlibAssociationRules, FPGrowth => MLlibFPGrowth} import org.apache.spark.mllib.fpm.FPGrowth.FreqItemset import org.apache.spark.sql._ +import org.apache.spark.sql.expressions.SparkUserDefinedFunction import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ import org.apache.spark.storage.StorageLevel @@ -286,14 +286,17 @@ class FPGrowthModel private[ml] ( val dt = dataset.schema($(itemsCol)).dataType // For each rule, examine the input items and summarize the consequents - val predictUDF = udf((items: Seq[Any]) => { + val predictUDF = SparkUserDefinedFunction((items: Seq[Any]) => { if (items != null) { val itemset = items.toSet brRules.value.filter(_._1.forall(itemset.contains)) .flatMap(_._2.filter(!itemset.contains(_))).distinct } else { Seq.empty - }}, dt) + }}, + dt, + Nil + ) dataset.withColumn($(predictionCol), predictUDF(col($(itemsCol)))) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/LSHTest.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/LSHTest.scala index 76a4acd798e34..1d052fbebd92d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/LSHTest.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/LSHTest.scala @@ -76,9 +76,8 @@ private[ml] object LSHTest { // Perform a cross join and label each pair of same_bucket and distance val pairs = transformedData.as("a").crossJoin(transformedData.as("b")) - val distUDF = udf((x: Vector, y: Vector) => model.keyDistance(x, y), DataTypes.DoubleType) - val sameBucket = udf((x: Seq[Vector], y: Seq[Vector]) => model.hashDistance(x, y) == 0.0, - DataTypes.BooleanType) + val distUDF = udf((x: Vector, y: Vector) => model.keyDistance(x, y)) + val sameBucket = udf((x: Seq[Vector], y: Seq[Vector]) => model.hashDistance(x, y) == 0.0) val result = pairs .withColumn("same_bucket", sameBucket(col(s"a.$outputCol"), col(s"b.$outputCol"))) .withColumn("distance", distUDF(col(s"a.$inputCol"), col(s"b.$inputCol"))) @@ -110,7 +109,7 @@ private[ml] object LSHTest { val model = lsh.fit(dataset) // Compute expected - val distUDF = udf((x: Vector) => model.keyDistance(x, key), DataTypes.DoubleType) + val distUDF = udf((x: Vector) => model.keyDistance(x, key)) val expected = dataset.sort(distUDF(col(model.getInputCol))).limit(k) // Compute actual @@ -148,7 +147,7 @@ private[ml] object LSHTest { val inputCol = model.getInputCol // Compute expected - val distUDF = udf((x: Vector, y: Vector) => model.keyDistance(x, y), DataTypes.DoubleType) + val distUDF = udf((x: Vector, y: Vector) => model.keyDistance(x, y)) val expected = datasetA.as("a").crossJoin(datasetB.as("b")) .filter(distUDF(col(s"a.$inputCol"), col(s"b.$inputCol")) < threshold) diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 65ffa228eddec..2a72284c011f5 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -74,6 +74,9 @@ object MimaExcludes { ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.clustering.KMeans.getRuns"), ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.clustering.KMeans.setRuns"), + // [SPARK-26580][SQL][ML][FOLLOW-UP] Throw exception when use untyped UDF by default + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.UnaryTransformer.this"), + // [SPARK-27090][CORE] Removing old LEGACY_DRIVER_IDENTIFIER ("") ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.SparkContext.LEGACY_DRIVER_IDENTIFIER"), // [SPARK-25838] Remove formatVersion from Saveable diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index d756014c393cd..7f3cbe31b5e02 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -2017,6 +2017,14 @@ object SQLConf { .booleanConf .createWithDefault(false) + val LEGACY_ALLOW_UNTYPED_SCALA_UDF = + buildConf("spark.sql.legacy.allowUntypedScalaUDF") + .internal() + .doc("When set to true, user is allowed to use org.apache.spark.sql.functions." + + "udf(f: AnyRef, dataType: DataType). Otherwise, exception will be throw.") + .booleanConf + .createWithDefault(false) + val TRUNCATE_TABLE_IGNORE_PERMISSION_ACL = buildConf("spark.sql.truncateTable.ignorePermissionAcl.enabled") .internal() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala index 85b2cd379ba24..c50168cf7ac13 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala @@ -90,7 +90,7 @@ sealed abstract class UserDefinedFunction { def asNondeterministic(): UserDefinedFunction } -private[sql] case class SparkUserDefinedFunction( +private[spark] case class SparkUserDefinedFunction( f: AnyRef, dataType: DataType, inputSchemas: Seq[Option[ScalaReflection.Schema]], diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 2d5504ac00ffa..c60df14f04817 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -4733,6 +4733,15 @@ object functions { * @since 2.0.0 */ def udf(f: AnyRef, dataType: DataType): UserDefinedFunction = { + if (!SQLConf.get.getConf(SQLConf.LEGACY_ALLOW_UNTYPED_SCALA_UDF)) { + val errorMsg = "You're using untyped Scala UDF, which does not have the input type " + + "information. Spark may blindly pass null to the Scala closure with primitive-type " + + "argument, and the closure will see the default value of the Java type for the null " + + "argument, e.g. `udf((x: Int) => x, IntegerType)`, the result is 0 for null input. " + + "You could use other typed Scala UDF APIs to avoid this problem, or set " + + s"${SQLConf.LEGACY_ALLOW_UNTYPED_SCALA_UDF.key} to true and use this API with caution." + throw new AnalysisException(errorMsg) + } SparkUserDefinedFunction(f, dataType, inputSchemas = Nil) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index cc3995516dcc2..cbe2e91a20d61 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -134,10 +134,12 @@ class UDFSuite extends QueryTest with SharedSparkSession { assert(df1.logicalPlan.asInstanceOf[Project].projectList.forall(!_.deterministic)) assert(df1.head().getDouble(0) >= 0.0) - val bar = udf(() => Math.random(), DataTypes.DoubleType).asNondeterministic() - val df2 = testData.select(bar()) - assert(df2.logicalPlan.asInstanceOf[Project].projectList.forall(!_.deterministic)) - assert(df2.head().getDouble(0) >= 0.0) + withSQLConf(SQLConf.LEGACY_ALLOW_UNTYPED_SCALA_UDF.key -> "true") { + val bar = udf(() => Math.random(), DataTypes.DoubleType).asNondeterministic() + val df2 = testData.select(bar()) + assert(df2.logicalPlan.asInstanceOf[Project].projectList.forall(!_.deterministic)) + assert(df2.head().getDouble(0) >= 0.0) + } val javaUdf = udf(new UDF0[Double] { override def call(): Double = Math.random() @@ -441,16 +443,23 @@ class UDFSuite extends QueryTest with SharedSparkSession { } test("SPARK-25044 Verify null input handling for primitive types - with udf(Any, DataType)") { - val f = udf((x: Int) => x, IntegerType) - checkAnswer( - Seq(Integer.valueOf(1), null).toDF("x").select(f($"x")), - Row(1) :: Row(0) :: Nil) + withSQLConf(SQLConf.LEGACY_ALLOW_UNTYPED_SCALA_UDF.key -> "true") { + val f = udf((x: Int) => x, IntegerType) + checkAnswer( + Seq(Integer.valueOf(1), null).toDF("x").select(f($"x")), + Row(1) :: Row(0) :: Nil) + + val f2 = udf((x: Double) => x, DoubleType) + checkAnswer( + Seq(java.lang.Double.valueOf(1.1), null).toDF("x").select(f2($"x")), + Row(1.1) :: Row(0.0) :: Nil) + } - val f2 = udf((x: Double) => x, DoubleType) - checkAnswer( - Seq(java.lang.Double.valueOf(1.1), null).toDF("x").select(f2($"x")), - Row(1.1) :: Row(0.0) :: Nil) + } + test("use untyped Scala UDF should fail by default") { + val e = intercept[AnalysisException](udf((x: Int) => x, IntegerType)) + assert(e.getMessage.contains("You're using untyped Scala UDF")) } test("SPARK-26308: udf with decimal") {