Skip to content

Commit

Permalink
[SPARK-13807] De-duplicate Python*Helper instantiation code in PySp…
Browse files Browse the repository at this point in the history
…ark streaming

This patch de-duplicates code in PySpark streaming which loads the `Python*Helper` classes. I also changed a few `raise e` statements to simply `raise` in order to preserve the full exception stacktrace when re-throwing.

Here's a link to the whitespace-change-free diff: https://github.com/apache/spark/compare/master...JoshRosen:pyspark-reflection-deduplication?w=0

Author: Josh Rosen <joshrosen@databricks.com>

Closes apache#11641 from JoshRosen/pyspark-reflection-deduplication.
  • Loading branch information
JoshRosen authored and roygao94 committed Mar 22, 2016
1 parent bbd429e commit 3050c55
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 84 deletions.
40 changes: 17 additions & 23 deletions python/pyspark/streaming/flume.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,17 +55,8 @@ def createStream(ssc, hostname, port,
:return: A DStream object
"""
jlevel = ssc._sc._getJavaStorageLevel(storageLevel)

try:
helperClass = ssc._jvm.java.lang.Thread.currentThread().getContextClassLoader()\
.loadClass("org.apache.spark.streaming.flume.FlumeUtilsPythonHelper")
helper = helperClass.newInstance()
jstream = helper.createStream(ssc._jssc, hostname, port, jlevel, enableDecompression)
except Py4JJavaError as e:
if 'ClassNotFoundException' in str(e.java_exception):
FlumeUtils._printErrorMsg(ssc.sparkContext)
raise e

helper = FlumeUtils._get_helper(ssc._sc)
jstream = helper.createStream(ssc._jssc, hostname, port, jlevel, enableDecompression)
return FlumeUtils._toPythonDStream(ssc, jstream, bodyDecoder)

@staticmethod
Expand Down Expand Up @@ -95,18 +86,9 @@ def createPollingStream(ssc, addresses,
for (host, port) in addresses:
hosts.append(host)
ports.append(port)

try:
helperClass = ssc._jvm.java.lang.Thread.currentThread().getContextClassLoader() \
.loadClass("org.apache.spark.streaming.flume.FlumeUtilsPythonHelper")
helper = helperClass.newInstance()
jstream = helper.createPollingStream(
ssc._jssc, hosts, ports, jlevel, maxBatchSize, parallelism)
except Py4JJavaError as e:
if 'ClassNotFoundException' in str(e.java_exception):
FlumeUtils._printErrorMsg(ssc.sparkContext)
raise e

helper = FlumeUtils._get_helper(ssc._sc)
jstream = helper.createPollingStream(
ssc._jssc, hosts, ports, jlevel, maxBatchSize, parallelism)
return FlumeUtils._toPythonDStream(ssc, jstream, bodyDecoder)

@staticmethod
Expand All @@ -126,6 +108,18 @@ def func(event):
return (headers, body)
return stream.map(func)

@staticmethod
def _get_helper(sc):
try:
helperClass = sc._jvm.java.lang.Thread.currentThread().getContextClassLoader() \
.loadClass("org.apache.spark.streaming.flume.FlumeUtilsPythonHelper")
return helperClass.newInstance()
except Py4JJavaError as e:
# TODO: use --jar once it also work on driver
if 'ClassNotFoundException' in str(e.java_exception):
FlumeUtils._printErrorMsg(sc)
raise

@staticmethod
def _printErrorMsg(sc):
print("""
Expand Down
100 changes: 41 additions & 59 deletions python/pyspark/streaming/kafka.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,18 +66,8 @@ def createStream(ssc, zkQuorum, groupId, topics, kafkaParams=None,
if not isinstance(topics, dict):
raise TypeError("topics should be dict")
jlevel = ssc._sc._getJavaStorageLevel(storageLevel)

try:
# Use KafkaUtilsPythonHelper to access Scala's KafkaUtils (see SPARK-6027)
helperClass = ssc._jvm.java.lang.Thread.currentThread().getContextClassLoader()\
.loadClass("org.apache.spark.streaming.kafka.KafkaUtilsPythonHelper")
helper = helperClass.newInstance()
jstream = helper.createStream(ssc._jssc, kafkaParams, topics, jlevel)
except Py4JJavaError as e:
# TODO: use --jar once it also work on driver
if 'ClassNotFoundException' in str(e.java_exception):
KafkaUtils._printErrorMsg(ssc.sparkContext)
raise e
helper = KafkaUtils._get_helper(ssc._sc)
jstream = helper.createStream(ssc._jssc, kafkaParams, topics, jlevel)
ser = PairDeserializer(NoOpSerializer(), NoOpSerializer())
stream = DStream(jstream, ssc, ser)
return stream.map(lambda k_v: (keyDecoder(k_v[0]), valueDecoder(k_v[1])))
Expand Down Expand Up @@ -129,27 +119,20 @@ def funcWithMessageHandler(m):
m._set_value_decoder(valueDecoder)
return messageHandler(m)

try:
helperClass = ssc._jvm.java.lang.Thread.currentThread().getContextClassLoader() \
.loadClass("org.apache.spark.streaming.kafka.KafkaUtilsPythonHelper")
helper = helperClass.newInstance()

jfromOffsets = dict([(k._jTopicAndPartition(helper),
v) for (k, v) in fromOffsets.items()])
if messageHandler is None:
ser = PairDeserializer(NoOpSerializer(), NoOpSerializer())
func = funcWithoutMessageHandler
jstream = helper.createDirectStreamWithoutMessageHandler(
ssc._jssc, kafkaParams, set(topics), jfromOffsets)
else:
ser = AutoBatchedSerializer(PickleSerializer())
func = funcWithMessageHandler
jstream = helper.createDirectStreamWithMessageHandler(
ssc._jssc, kafkaParams, set(topics), jfromOffsets)
except Py4JJavaError as e:
if 'ClassNotFoundException' in str(e.java_exception):
KafkaUtils._printErrorMsg(ssc.sparkContext)
raise e
helper = KafkaUtils._get_helper(ssc._sc)

jfromOffsets = dict([(k._jTopicAndPartition(helper),
v) for (k, v) in fromOffsets.items()])
if messageHandler is None:
ser = PairDeserializer(NoOpSerializer(), NoOpSerializer())
func = funcWithoutMessageHandler
jstream = helper.createDirectStreamWithoutMessageHandler(
ssc._jssc, kafkaParams, set(topics), jfromOffsets)
else:
ser = AutoBatchedSerializer(PickleSerializer())
func = funcWithMessageHandler
jstream = helper.createDirectStreamWithMessageHandler(
ssc._jssc, kafkaParams, set(topics), jfromOffsets)

stream = DStream(jstream, ssc, ser).map(func)
return KafkaDStream(stream._jdstream, ssc, stream._jrdd_deserializer)
Expand Down Expand Up @@ -189,28 +172,35 @@ def funcWithMessageHandler(m):
m._set_value_decoder(valueDecoder)
return messageHandler(m)

helper = KafkaUtils._get_helper(sc)

joffsetRanges = [o._jOffsetRange(helper) for o in offsetRanges]
jleaders = dict([(k._jTopicAndPartition(helper),
v._jBroker(helper)) for (k, v) in leaders.items()])
if messageHandler is None:
jrdd = helper.createRDDWithoutMessageHandler(
sc._jsc, kafkaParams, joffsetRanges, jleaders)
ser = PairDeserializer(NoOpSerializer(), NoOpSerializer())
rdd = RDD(jrdd, sc, ser).map(funcWithoutMessageHandler)
else:
jrdd = helper.createRDDWithMessageHandler(
sc._jsc, kafkaParams, joffsetRanges, jleaders)
rdd = RDD(jrdd, sc).map(funcWithMessageHandler)

return KafkaRDD(rdd._jrdd, sc, rdd._jrdd_deserializer)

@staticmethod
def _get_helper(sc):
try:
# Use KafkaUtilsPythonHelper to access Scala's KafkaUtils (see SPARK-6027)
helperClass = sc._jvm.java.lang.Thread.currentThread().getContextClassLoader() \
.loadClass("org.apache.spark.streaming.kafka.KafkaUtilsPythonHelper")
helper = helperClass.newInstance()
joffsetRanges = [o._jOffsetRange(helper) for o in offsetRanges]
jleaders = dict([(k._jTopicAndPartition(helper),
v._jBroker(helper)) for (k, v) in leaders.items()])
if messageHandler is None:
jrdd = helper.createRDDWithoutMessageHandler(
sc._jsc, kafkaParams, joffsetRanges, jleaders)
ser = PairDeserializer(NoOpSerializer(), NoOpSerializer())
rdd = RDD(jrdd, sc, ser).map(funcWithoutMessageHandler)
else:
jrdd = helper.createRDDWithMessageHandler(
sc._jsc, kafkaParams, joffsetRanges, jleaders)
rdd = RDD(jrdd, sc).map(funcWithMessageHandler)
return helperClass.newInstance()
except Py4JJavaError as e:
# TODO: use --jar once it also work on driver
if 'ClassNotFoundException' in str(e.java_exception):
KafkaUtils._printErrorMsg(sc)
raise e

return KafkaRDD(rdd._jrdd, sc, rdd._jrdd_deserializer)
raise

@staticmethod
def _printErrorMsg(sc):
Expand Down Expand Up @@ -333,16 +323,8 @@ def offsetRanges(self):
Get the OffsetRange of specific KafkaRDD.
:return: A list of OffsetRange
"""
try:
helperClass = self.ctx._jvm.java.lang.Thread.currentThread().getContextClassLoader() \
.loadClass("org.apache.spark.streaming.kafka.KafkaUtilsPythonHelper")
helper = helperClass.newInstance()
joffsetRanges = helper.offsetRangesOfKafkaRDD(self._jrdd.rdd())
except Py4JJavaError as e:
if 'ClassNotFoundException' in str(e.java_exception):
KafkaUtils._printErrorMsg(self.ctx)
raise e

helper = KafkaUtils._get_helper(self.ctx)
joffsetRanges = helper.offsetRangesOfKafkaRDD(self._jrdd.rdd())
ranges = [OffsetRange(o.topic(), o.partition(), o.fromOffset(), o.untilOffset())
for o in joffsetRanges]
return ranges
Expand Down
2 changes: 1 addition & 1 deletion python/pyspark/streaming/kinesis.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def createStream(ssc, kinesisAppName, streamName, endpointUrl, regionName,
except Py4JJavaError as e:
if 'ClassNotFoundException' in str(e.java_exception):
KinesisUtils._printErrorMsg(ssc.sparkContext)
raise e
raise
stream = DStream(jstream, ssc, NoOpSerializer())
return stream.map(lambda v: decoder(v))

Expand Down
2 changes: 1 addition & 1 deletion python/pyspark/streaming/mqtt.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def createStream(ssc, brokerUrl, topic,
except Py4JJavaError as e:
if 'ClassNotFoundException' in str(e.java_exception):
MQTTUtils._printErrorMsg(ssc.sparkContext)
raise e
raise

return DStream(jstream, ssc, UTF8Deserializer())

Expand Down

0 comments on commit 3050c55

Please sign in to comment.