Skip to content

Commit

Permalink
address the comments:
Browse files Browse the repository at this point in the history
	keep the whole MQTTTestUtils in test and then link to test jar from python
	fix issue under Maven build
	return JavaDStream[String] directly.
  • Loading branch information
Prabeesh K committed Jul 23, 2015
1 parent 97244ec commit 87fc677
Show file tree
Hide file tree
Showing 6 changed files with 55 additions and 24 deletions.
3 changes: 2 additions & 1 deletion dev/run-tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,8 @@ def build_spark_sbt(hadoop_version):
"assembly/assembly",
"streaming-kafka-assembly/assembly",
"streaming-flume-assembly/assembly",
"streaming-mqtt-assembly/assembly"]
"streaming-mqtt-assembly/assembly",
"streaming-mqtt/test:assembly"]
profiles_and_goals = build_profiles + sbt_goals

print("[info] Building Spark (w/Hive 0.13.1) using SBT with these arguments: ",
Expand Down
1 change: 1 addition & 0 deletions external/mqtt-assembly/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
<artifactId>maven-shade-plugin</artifactId>
<configuration>
<shadedArtifactAttached>false</shadedArtifactAttached>
<outputFile>${project.build.directory}/scala-${scala.binary.version}/spark-streaming-mqtt-assembly-${project.version}.jar</outputFile>
<artifactSet>
<includes>
<include>*:*</include>
Expand Down
1 change: 1 addition & 0 deletions external/mqtt/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@
<groupId>org.apache.activemq</groupId>
<artifactId>activemq-core</artifactId>
<version>5.7.0</version>
<scope>test</scope>
</dependency>
</dependencies>
<build>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,12 +87,7 @@ private class MQTTUtilsPythonHelper {
brokerUrl: String,
topic: String,
storageLevel: StorageLevel
): JavaDStream[Array[Byte]] = {
val dstream = MQTTUtils.createStream(jssc, brokerUrl, topic, storageLevel)
dstream.map(new Function[String, Array[Byte]] {
override def call(data: String): Array[Byte] = {
data.getBytes("UTF-8")
}
})
): JavaDStream[String] = {
MQTTUtils.createStream(jssc, brokerUrl, topic, storageLevel)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
package org.apache.spark.streaming.mqtt

import java.net.{ServerSocket, URI}
import java.util.concurrent.{TimeUnit, CountDownLatch}
import java.util.concurrent.{CountDownLatch, TimeUnit}

import scala.language.postfixOps

Expand All @@ -27,7 +27,7 @@ import org.apache.commons.lang3.RandomUtils
import org.eclipse.paho.client.mqttv3._
import org.eclipse.paho.client.mqttv3.persist.MqttDefaultFilePersistence

import org.apache.spark.streaming.{StreamingContext, Milliseconds}
import org.apache.spark.streaming.StreamingContext
import org.apache.spark.streaming.scheduler.StreamingListener
import org.apache.spark.streaming.scheduler.StreamingListenerReceiverStarted
import org.apache.spark.util.Utils
Expand All @@ -40,7 +40,7 @@ private class MQTTTestUtils extends Logging {

private val persistenceDir = Utils.createTempDir()
private val brokerHost = "localhost"
private var brokerPort = findFreePort()
private val brokerPort = findFreePort()

private var broker: BrokerService = _
private var connector: TransportConnector = _
Expand Down
59 changes: 46 additions & 13 deletions python/pyspark/streaming/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -850,28 +850,43 @@ def tearDown(self):
def _randomTopic(self):
return "topic-%d" % random.randint(0, 10000)

def _validateStreamResult(self, sendData, dstream):
def _startContext(self, topic):
# Start the StreamingContext and also collect the result
stream = MQTTUtils.createStream(self.ssc, "tcp://" + self._MQTTTestUtils.brokerUri(), topic)
result = []

def get_output(_, rdd):
def getOutput(_, rdd):
for data in rdd.collect():
result.append(data)

dstream.foreachRDD(get_output)
receiveData = ' '.join(result[0])
stream.foreachRDD(getOutput)
self.ssc.start()
return result

def _publishData(self, topic, data):
start_time = time.time()
while True:
try:
self._MQTTTestUtils.publishData(topic, data)
break
except:
if time.time() - start_time < self.timeout:
time.sleep(0.01)
else:
raise

def _validateStreamResult(self, sendData, result):
receiveData = ''.join(result[0])
self.assertEqual(sendData, receiveData)

def test_mqtt_stream(self):
"""Test the Python MQTT stream API."""
topic = self._randomTopic()
sendData = "MQTT demo for spark streaming"
ssc = self.ssc

self._MQTTTestUtils.waitForReceiverToStart(ssc)
self._MQTTTestUtils.publishData(topic, sendData)

stream = MQTTUtils.createStream(ssc, "tcp://" + self._MQTTTestUtils.brokerUri(), topic)
self._validateStreamResult(sendData, stream)
topic = self._randomTopic()
result = self._startContext(topic)
self._publishData(topic, sendData)
self.wait_for(result, len(sendData))
self._validateStreamResult(sendData, result)


def search_kafka_assembly_jar():
Expand Down Expand Up @@ -928,11 +943,29 @@ def search_mqtt_assembly_jar():
return jars[0]


def search_mqtt_test_jar():
SPARK_HOME = os.environ["SPARK_HOME"]
mqtt_test_dir = os.path.join(SPARK_HOME, "external/mqtt")
jars = glob.glob(
os.path.join(mqtt_test_dir, "target/scala-*/spark-streaming-mqtt-test-*.jar"))
if not jars:
raise Exception(
("Failed to find Spark Streaming MQTT test jar in %s. " % mqtt_test_dir) +
"You need to build Spark with "
"'build/sbt assembly/assembly streaming-mqtt/test:assembly'")
elif len(jars) > 1:
raise Exception(("Found multiple Spark Streaming MQTT test JARs in %s; please "
"remove all but one") % mqtt_test_dir)
else:
return jars[0]

if __name__ == "__main__":
kafka_assembly_jar = search_kafka_assembly_jar()
flume_assembly_jar = search_flume_assembly_jar()
mqtt_assembly_jar = search_mqtt_assembly_jar()
jars = "%s,%s,%s" % (kafka_assembly_jar, flume_assembly_jar, mqtt_assembly_jar)
mqtt_test_jar = search_mqtt_test_jar()
jars = "%s,%s,%s,%s" % (kafka_assembly_jar, flume_assembly_jar,
mqtt_assembly_jar, mqtt_test_jar)

os.environ["PYSPARK_SUBMIT_ARGS"] = "--jars %s pyspark-shell" % jars
unittest.main()

0 comments on commit 87fc677

Please sign in to comment.