Skip to content

Commit

Permalink
[SPARK-27992][SPARK-28881][PYTHON][2.4] Allow Python to join with con…
Browse files Browse the repository at this point in the history
…nection thread to propagate errors

### What changes were proposed in this pull request?

This PR proposes to backport apache#24834 with minimised changes, and the tests added at apache#25594.

apache#24834 was not backported before because basically it targeted a better exception by propagating the exception from JVM.

However, actually this PR fixed another problem accidentally (see  apache#25594 and [SPARK-28881](https://issues.apache.org/jira/browse/SPARK-28881)). This regression seems introduced by apache#21546.

Root cause is that, seems

https://github.com/apache/spark/blob/23bed0d3c08e03085d3f0c3a7d457eedd30bd67f/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala#L3370-L3384

`runJob` with `resultHandler` seems able to write partial output.

JVM throws an exception but, since the JVM exception is not propagated into Python process, Python process doesn't know if the exception is thrown or not from JVM (it just closes the socket), which results as below:

```
./bin/pyspark --conf spark.driver.maxResultSize=1m
```
```python
spark.conf.set("spark.sql.execution.arrow.enabled",True)
spark.range(10000000).toPandas()
```
```
Empty DataFrame
Columns: [id]
Index: []
```

With this change, it lets Python process catches exceptions from JVM.

### Why are the changes needed?

It returns incorrect data. And potentially it returns partial results when an exception happens in JVM sides. This is a regression. The codes work fine in Spark 2.3.3.

### Does this PR introduce any user-facing change?

Yes.

```
./bin/pyspark --conf spark.driver.maxResultSize=1m
```
```python
spark.conf.set("spark.sql.execution.arrow.enabled",True)
spark.range(10000000).toPandas()
```

```
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/.../pyspark/sql/dataframe.py", line 2122, in toPandas
    batches = self._collectAsArrow()
  File "/.../pyspark/sql/dataframe.py", line 2184, in _collectAsArrow
    jsocket_auth_server.getResult()  # Join serving thread and raise any exceptions
  File "/.../lib/py4j-0.10.7-src.zip/py4j/java_gateway.py", line 1257, in __call__
  File "/.../pyspark/sql/utils.py", line 63, in deco
    return f(*a, **kw)
  File "/.../lib/py4j-0.10.7-src.zip/py4j/protocol.py", line 328, in get_return_value
py4j.protocol.Py4JJavaError: An error occurred while calling o42.getResult.
: org.apache.spark.SparkException: Exception thrown in awaitResult:
    ...
Caused by: org.apache.spark.SparkException: Job aborted due to stage failure: Total size of serialized results of 1 tasks (6.5 MB) is bigger than spark.driver.maxResultSize (1024.0 KB)
```

now throws an exception as expected.

### How was this patch tested?

Manually as described above. unittest added.

Closes apache#25593 from HyukjinKwon/SPARK-27992.

Lead-authored-by: HyukjinKwon <gurwls223@apache.org>
Co-authored-by: Bryan Cutler <cutlerb@gmail.com>
Signed-off-by: Dongjoon Hyun <dhyun@apple.com>
  • Loading branch information
2 people authored and Raphaël Luta committed Sep 17, 2019
1 parent 1834ed3 commit 4a95667
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 5 deletions.
37 changes: 37 additions & 0 deletions core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -440,6 +440,29 @@ private[spark] object PythonRDD extends Logging {
Array(port, secret)
}

/**
* Create a socket server object and background thread to execute the writeFunc
* with the given OutputStream.
*
* This is the same as serveToStream, only it returns a server object that
* can be used to sync in Python.
*/
private[spark] def serveToStreamWithSync(
threadName: String)(writeFunc: OutputStream => Unit): Array[Any] = {

val handleFunc = (sock: Socket) => {
val out = new BufferedOutputStream(sock.getOutputStream())
Utils.tryWithSafeFinally {
writeFunc(out)
} {
out.close()
}
}

val server = new SocketFuncServer(authHelper, threadName, handleFunc)
Array(server.port, server.secret, server)
}

private def getMergedConf(confAsMap: java.util.HashMap[String, String],
baseConf: Configuration): Configuration = {
val conf = PythonHadoopUtil.mapToConf(confAsMap)
Expand Down Expand Up @@ -957,3 +980,17 @@ private[spark] class PythonParallelizeServer(sc: SparkContext, parallelism: Int)
}
}

/**
* Create a socket server class and run user function on the socket in a background thread.
* This is the same as calling SocketAuthServer.setupOneConnectionServer except it creates
* a server object that can then be synced from Python.
*/
private [spark] class SocketFuncServer(
authHelper: SocketAuthHelper,
threadName: String,
func: Socket => Unit) extends PythonServer[Unit](authHelper, threadName) {

override def handleConnection(sock: Socket): Unit = {
func(sock)
}
}
10 changes: 7 additions & 3 deletions python/pyspark/sql/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -2175,9 +2175,13 @@ def _collectAsArrow(self):
.. note:: Experimental.
"""
with SCCallSiteSync(self._sc) as css:
sock_info = self._jdf.collectAsArrowToPython()
return list(_load_from_socket(sock_info, ArrowStreamSerializer()))
with SCCallSiteSync(self._sc):
from pyspark.rdd import _load_from_socket
port, auth_secret, jsocket_auth_server = self._jdf.collectAsArrowToPython()
try:
return list(_load_from_socket((port, auth_secret), ArrowStreamSerializer()))
finally:
jsocket_auth_server.getResult() # Join serving thread and raise any exceptions

##########################################################################################
# Pandas compatibility
Expand Down
28 changes: 27 additions & 1 deletion python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@
_have_pyarrow = _pyarrow_requirement_message is None
_test_compiled = _test_not_compiled_message is None

from pyspark import SparkContext
from pyspark import SparkContext, SparkConf
from pyspark.sql import SparkSession, SQLContext, HiveContext, Column, Row
from pyspark.sql.types import *
from pyspark.sql.types import UserDefinedType, _infer_type, _make_type_verifier
Expand Down Expand Up @@ -4550,6 +4550,32 @@ def test_timestamp_dst(self):
self.assertPandasEqual(pdf, df_from_pandas.toPandas())


@unittest.skipIf(
not _have_pandas or not _have_pyarrow,
_pandas_requirement_message or _pyarrow_requirement_message)
class MaxResultArrowTests(unittest.TestCase):
# These tests are separate as 'spark.driver.maxResultSize' configuration
# is a static configuration to Spark context.

@classmethod
def setUpClass(cls):
cls.spark = SparkSession(SparkContext(
'local[4]', cls.__name__, conf=SparkConf().set("spark.driver.maxResultSize", "10k")))

# Explicitly enable Arrow and disable fallback.
cls.spark.conf.set("spark.sql.execution.arrow.enabled", "true")
cls.spark.conf.set("spark.sql.execution.arrow.fallback.enabled", "false")

@classmethod
def tearDownClass(cls):
if hasattr(cls, "spark"):
cls.spark.stop()

def test_exception_by_max_results(self):
with self.assertRaisesRegexp(Exception, "is bigger than"):
self.spark.range(0, 10000, 1, 100).toPandas()


class EncryptionArrowTests(ArrowTests):

@classmethod
Expand Down
2 changes: 1 addition & 1 deletion sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3284,7 +3284,7 @@ class Dataset[T] private[sql](
val timeZoneId = sparkSession.sessionState.conf.sessionLocalTimeZone

withAction("collectAsArrowToPython", queryExecution) { plan =>
PythonRDD.serveToStream("serve-Arrow") { out =>
PythonRDD.serveToStreamWithSync("serve-Arrow") { out =>
val batchWriter = new ArrowBatchStreamWriter(schema, out, timeZoneId)
val arrowBatchRdd = toArrowBatchRdd(plan)
val numPartitions = arrowBatchRdd.partitions.length
Expand Down

0 comments on commit 4a95667

Please sign in to comment.