From 032d17933b4009ed8a9d70585434ccdbf4d1d7df Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Wed, 10 Jun 2020 16:38:59 +0900 Subject: [PATCH] [SPARK-31945][SQL][PYSPARK] Enable cache for the same Python function ### What changes were proposed in this pull request? This PR proposes to make `PythonFunction` holds `Seq[Byte]` instead of `Array[Byte]` to be able to compare if the byte array has the same values for the cache manager. ### Why are the changes needed? Currently the cache manager doesn't use the cache for `udf` if the `udf` is created again even if the functions is the same. ```py >>> func = lambda x: x >>> df = spark.range(1) >>> df.select(udf(func)("id")).cache() ``` ```py >>> df.select(udf(func)("id")).explain() == Physical Plan == *(2) Project [pythonUDF0#14 AS (id)#12] +- BatchEvalPython [(id#0L)], [pythonUDF0#14] +- *(1) Range (0, 1, step=1, splits=12) ``` This is because `PythonFunction` holds `Array[Byte]`, and `equals` method of array equals only when the both array is the same instance. ### Does this PR introduce _any_ user-facing change? Yes, if the user reuse the Python function for the UDF, the cache manager will detect the same function and use the cache for it. ### How was this patch tested? I added a test case and manually. ```py >>> df.select(udf(func)("id")).explain() == Physical Plan == InMemoryTableScan [(id)#12] +- InMemoryRelation [(id)#12], StorageLevel(disk, memory, deserialized, 1 replicas) +- *(2) Project [pythonUDF0#5 AS (id)#3] +- BatchEvalPython [(id#0L)], [pythonUDF0#5] +- *(1) Range (0, 1, step=1, splits=12) ``` Closes #28774 from ueshin/issues/SPARK-31945/udf_cache. Authored-by: Takuya UESHIN Signed-off-by: HyukjinKwon --- .../org/apache/spark/api/python/PythonRDD.scala | 16 ++++++++++++++-- .../apache/spark/api/python/PythonRunner.scala | 2 +- python/pyspark/sql/tests/test_udf.py | 9 +++++++++ .../sql/execution/python/PythonUDFRunner.scala | 2 +- 4 files changed, 25 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index a577194a48006..726cff6703dcb 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -74,13 +74,25 @@ private[spark] class PythonRDD( * runner. */ private[spark] case class PythonFunction( - command: Array[Byte], + command: Seq[Byte], envVars: JMap[String, String], pythonIncludes: JList[String], pythonExec: String, pythonVer: String, broadcastVars: JList[Broadcast[PythonBroadcast]], - accumulator: PythonAccumulatorV2) + accumulator: PythonAccumulatorV2) { + + def this( + command: Array[Byte], + envVars: JMap[String, String], + pythonIncludes: JList[String], + pythonExec: String, + pythonVer: String, + broadcastVars: JList[Broadcast[PythonBroadcast]], + accumulator: PythonAccumulatorV2) = { + this(command.toSeq, envVars, pythonIncludes, pythonExec, pythonVer, broadcastVars, accumulator) + } +} /** * A wrapper for chained Python functions (from bottom to top). diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala index f34316424c4ca..d7a09b599794e 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala @@ -613,7 +613,7 @@ private[spark] class PythonRunner(funcs: Seq[ChainedPythonFunctions]) protected override def writeCommand(dataOut: DataOutputStream): Unit = { val command = funcs.head.funcs.head.command dataOut.writeInt(command.length) - dataOut.write(command) + dataOut.write(command.toArray) } protected override def writeIteratorToStream(dataOut: DataOutputStream): Unit = { diff --git a/python/pyspark/sql/tests/test_udf.py b/python/pyspark/sql/tests/test_udf.py index 061d3f5e1f7ac..2689b9c33d576 100644 --- a/python/pyspark/sql/tests/test_udf.py +++ b/python/pyspark/sql/tests/test_udf.py @@ -642,6 +642,15 @@ def f(*a): r = df.select(fUdf(*df.columns)) self.assertEqual(r.first()[0], "success") + def test_udf_cache(self): + func = lambda x: x + + df = self.spark.range(1) + df.select(udf(func)("id")).cache() + + self.assertEqual(df.select(udf(func)("id"))._jdf.queryExecution() + .withCachedData().getClass().getSimpleName(), 'InMemoryRelation') + class UDFInitializationTests(unittest.TestCase): def tearDown(self): diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala index 0a250b27ccb94..d341d7019f0ac 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala @@ -104,7 +104,7 @@ object PythonUDFRunner { dataOut.writeInt(chained.funcs.length) chained.funcs.foreach { f => dataOut.writeInt(f.command.length) - dataOut.write(f.command) + dataOut.write(f.command.toArray) } } }