diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index a577194a48006..726cff6703dcb 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -74,13 +74,25 @@ private[spark] class PythonRDD( * runner. */ private[spark] case class PythonFunction( - command: Array[Byte], + command: Seq[Byte], envVars: JMap[String, String], pythonIncludes: JList[String], pythonExec: String, pythonVer: String, broadcastVars: JList[Broadcast[PythonBroadcast]], - accumulator: PythonAccumulatorV2) + accumulator: PythonAccumulatorV2) { + + def this( + command: Array[Byte], + envVars: JMap[String, String], + pythonIncludes: JList[String], + pythonExec: String, + pythonVer: String, + broadcastVars: JList[Broadcast[PythonBroadcast]], + accumulator: PythonAccumulatorV2) = { + this(command.toSeq, envVars, pythonIncludes, pythonExec, pythonVer, broadcastVars, accumulator) + } +} /** * A wrapper for chained Python functions (from bottom to top). diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala index f34316424c4ca..d7a09b599794e 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala @@ -613,7 +613,7 @@ private[spark] class PythonRunner(funcs: Seq[ChainedPythonFunctions]) protected override def writeCommand(dataOut: DataOutputStream): Unit = { val command = funcs.head.funcs.head.command dataOut.writeInt(command.length) - dataOut.write(command) + dataOut.write(command.toArray) } protected override def writeIteratorToStream(dataOut: DataOutputStream): Unit = { diff --git a/python/pyspark/sql/tests/test_udf.py b/python/pyspark/sql/tests/test_udf.py index 061d3f5e1f7ac..2689b9c33d576 100644 --- a/python/pyspark/sql/tests/test_udf.py +++ b/python/pyspark/sql/tests/test_udf.py @@ -642,6 +642,15 @@ def f(*a): r = df.select(fUdf(*df.columns)) self.assertEqual(r.first()[0], "success") + def test_udf_cache(self): + func = lambda x: x + + df = self.spark.range(1) + df.select(udf(func)("id")).cache() + + self.assertEqual(df.select(udf(func)("id"))._jdf.queryExecution() + .withCachedData().getClass().getSimpleName(), 'InMemoryRelation') + class UDFInitializationTests(unittest.TestCase): def tearDown(self): diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala index 0a250b27ccb94..d341d7019f0ac 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala @@ -104,7 +104,7 @@ object PythonUDFRunner { dataOut.writeInt(chained.funcs.length) chained.funcs.foreach { f => dataOut.writeInt(f.command.length) - dataOut.write(f.command) + dataOut.write(f.command.toArray) } } }