Skip to content

Commit

Permalink
[SPARK-18667][PYSPARK][SQL] Change the way to group row in BatchEvalP…
Browse files Browse the repository at this point in the history
…ythonExec so input_file_name function can work with UDF in pyspark

## What changes were proposed in this pull request?

`input_file_name` doesn't return filename when working with UDF in PySpark. An example shows the problem:

    from pyspark.sql.functions import *
    from pyspark.sql.types import *

    def filename(path):
        return path

    sourceFile = udf(filename, StringType())
    spark.read.json("tmp.json").select(sourceFile(input_file_name())).show()

    +---------------------------+
    |filename(input_file_name())|
    +---------------------------+
    |                           |
    +---------------------------+

The cause of this issue is, we group rows in `BatchEvalPythonExec` for batching processing of PythonUDF. Currently we group rows first and then evaluate expressions on the rows. If the data is less than the required number of rows for a group, the iterator will be consumed to the end before the evaluation. However, once the iterator reaches the end, we will unset input filename. So the input_file_name expression can't return correct filename.

This patch fixes the approach to group the batch of rows. We evaluate the expression first and then group evaluated results to batch.

## How was this patch tested?

Added unit test to PySpark.

Please review http://spark.apache.org/contributing.html before opening a pull request.

Author: Liang-Chi Hsieh <viirya@gmail.com>

Closes apache#16115 from viirya/fix-py-udf-input-filename.
  • Loading branch information
viirya authored and uzadude committed Jan 27, 2017
1 parent 6a64e6d commit 3c1f530
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 19 deletions.
8 changes: 8 additions & 0 deletions python/pyspark/sql/tests.py
Expand Up @@ -412,6 +412,14 @@ def test_udf_with_order_by_and_limit(self):
res.explain(True)
self.assertEqual(res.collect(), [Row(id=0, copy=0)])

def test_udf_with_input_file_name(self):
from pyspark.sql.functions import udf, input_file_name
from pyspark.sql.types import StringType
sourceFile = udf(lambda path: path, StringType())
filePath = "python/test_support/sql/people1.json"
row = self.spark.read.json(filePath).select(sourceFile(input_file_name())).first()
self.assertTrue(row[0].find("people1.json") != -1)

def test_basic_functions(self):
rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}'])
df = self.spark.read.json(rdd)
Expand Down
Expand Up @@ -119,26 +119,23 @@ case class BatchEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], chi
val pickle = new Pickler(needConversion)
// Input iterator to Python: input rows are grouped so we send them in batches to Python.
// For each row, add it to the queue.
val inputIterator = iter.grouped(100).map { inputRows =>
val toBePickled = inputRows.map { inputRow =>
queue.add(inputRow.asInstanceOf[UnsafeRow])
val row = projection(inputRow)
if (needConversion) {
EvaluatePython.toJava(row, schema)
} else {
// fast path for these types that does not need conversion in Python
val fields = new Array[Any](row.numFields)
var i = 0
while (i < row.numFields) {
val dt = dataTypes(i)
fields(i) = EvaluatePython.toJava(row.get(i, dt), dt)
i += 1
}
fields
val inputIterator = iter.map { inputRow =>
queue.add(inputRow.asInstanceOf[UnsafeRow])
val row = projection(inputRow)
if (needConversion) {
EvaluatePython.toJava(row, schema)
} else {
// fast path for these types that does not need conversion in Python
val fields = new Array[Any](row.numFields)
var i = 0
while (i < row.numFields) {
val dt = dataTypes(i)
fields(i) = EvaluatePython.toJava(row.get(i, dt), dt)
i += 1
}
}.toArray
pickle.dumps(toBePickled)
}
fields
}
}.grouped(100).map(x => pickle.dumps(x.toArray))

val context = TaskContext.get()

Expand Down

0 comments on commit 3c1f530

Please sign in to comment.