Skip to content

Commit

Permalink
[SPARK-28321][SQL] 0-args Java UDF should not be called only once
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?

0-args Java UDF alone calls the function even before making it as an expression.
It causes that the function always returns the same value and the function is called at driver side.
Seems like a mistake.

## How was this patch tested?

Unit test was added

Closes apache#25108 from HyukjinKwon/SPARK-28321.

Authored-by: HyukjinKwon <gurwls223@apache.org>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
HyukjinKwon authored and vinodkc committed Jul 18, 2019
1 parent 50add0f commit 6280475
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -142,16 +142,16 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
val anyCast = s".asInstanceOf[UDF$i[$anyTypeArgs]]"
val anyParams = (1 to i).map(_ => "_: Any").mkString(", ")
val version = if (i == 0) "2.3.0" else "1.3.0"
val funcCall = if (i == 0) "() => func" else "func"
val funcCall = if (i == 0) s"() => f$anyCast.call($anyParams)" else s"f$anyCast.call($anyParams)"
println(s"""
|/**
| * Register a deterministic Java UDF$i instance as user-defined function (UDF).
| * @since $version
| */
|def register(name: String, f: UDF$i[$extTypeArgs], returnType: DataType): Unit = {
| val func = f$anyCast.call($anyParams)
| val func = $funcCall
| def builder(e: Seq[Expression]) = if (e.length == $i) {
| ScalaUDF($funcCall, returnType, e, e.map(_ => false), udfName = Some(name))
| ScalaUDF(func, returnType, e, e.map(_ => false), udfName = Some(name))
| } else {
| throw new AnalysisException("Invalid number of arguments for function " + name +
| ". Expected: $i; Found: " + e.length)
Expand Down Expand Up @@ -717,9 +717,9 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
* @since 2.3.0
*/
def register(name: String, f: UDF0[_], returnType: DataType): Unit = {
val func = f.asInstanceOf[UDF0[Any]].call()
val func = () => f.asInstanceOf[UDF0[Any]].call()
def builder(e: Seq[Expression]) = if (e.length == 0) {
ScalaUDF(() => func, returnType, e, e.map(_ => false), udfName = Some(name))
ScalaUDF(func, returnType, e, e.map(_ => false), udfName = Some(name))
} else {
throw new AnalysisException("Invalid number of arguments for function " + name +
". Expected: 0; Found: " + e.length)
Expand Down
10 changes: 5 additions & 5 deletions sql/core/src/main/scala/org/apache/spark/sql/functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3932,7 +3932,7 @@ object functions {
val anyTypeArgs = (0 to i).map(_ => "Any").mkString(", ")
val anyCast = s".asInstanceOf[UDF$i[$anyTypeArgs]]"
val anyParams = (1 to i).map(_ => "_: Any").mkString(", ")
val funcCall = if (i == 0) "() => func" else "func"
val funcCall = if (i == 0) s"() => f$anyCast.call($anyParams)" else s"f$anyCast.call($anyParams)"
println(s"""
|/**
| * Defines a Java UDF$i instance as user-defined function (UDF).
Expand All @@ -3944,8 +3944,8 @@ object functions {
| * @since 2.3.0
| */
|def udf(f: UDF$i[$extTypeArgs], returnType: DataType): UserDefinedFunction = {
| val func = f$anyCast.call($anyParams)
| SparkUserDefinedFunction($funcCall, returnType, inputSchemas = Seq.fill($i)(None))
| val func = $funcCall
| SparkUserDefinedFunction(func, returnType, inputSchemas = Seq.fill($i)(None))
|}""".stripMargin)
}
Expand Down Expand Up @@ -4145,8 +4145,8 @@ object functions {
* @since 2.3.0
*/
def udf(f: UDF0[_], returnType: DataType): UserDefinedFunction = {
val func = f.asInstanceOf[UDF0[Any]].call()
SparkUserDefinedFunction(() => func, returnType, inputSchemas = Seq.fill(0)(None))
val func = () => f.asInstanceOf[UDF0[Any]].call()
SparkUserDefinedFunction(func, returnType, inputSchemas = Seq.fill(0)(None))
}

/**
Expand Down
9 changes: 9 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -514,4 +514,13 @@ class UDFSuite extends QueryTest with SharedSQLContext {
assert(df.collect().toSeq === Seq(Row(expected)))
}
}

test("SPARK-28321 0-args Java UDF should not be called only once") {
val nonDeterministicJavaUDF = udf(
new UDF0[Int] {
override def call(): Int = scala.util.Random.nextInt()
}, IntegerType).asNondeterministic()

assert(spark.range(2).select(nonDeterministicJavaUDF()).distinct().count() == 2)
}
}

0 comments on commit 6280475

Please sign in to comment.