Skip to content

Commit

Permalink
added Python test
Browse files Browse the repository at this point in the history
  • Loading branch information
prabeesh committed Jul 11, 2015
1 parent 9767d82 commit a5a8f9f
Show file tree
Hide file tree
Showing 7 changed files with 216 additions and 114 deletions.
3 changes: 2 additions & 1 deletion dev/run-tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,8 @@ def build_spark_sbt(hadoop_version):
sbt_goals = ["package",
"assembly/assembly",
"streaming-kafka-assembly/assembly",
"streaming-flume-assembly/assembly"]
"streaming-flume-assembly/assembly",
"streaming-mqtt-assembly/assembly"]
profiles_and_goals = build_profiles + sbt_goals

print("[info] Building Spark (w/Hive 0.13.1) using SBT with these arguments: ",
Expand Down
3 changes: 2 additions & 1 deletion dev/sparktestsupport/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,7 @@ def contains_file(self, filename):
dependencies=[streaming],
source_file_regexes=[
"external/mqtt",
"external/mqtt-assembly",
],
sbt_test_goals=[
"streaming-mqtt/test",
Expand Down Expand Up @@ -290,7 +291,7 @@ def contains_file(self, filename):

pyspark_streaming = Module(
name="pyspark-streaming",
dependencies=[pyspark_core, streaming, streaming_kafka, streaming_flume_assembly],
dependencies=[pyspark_core, streaming, streaming_kafka, streaming_flume_assembly, streaming_mqtt],
source_file_regexes=[
"python/pyspark/streaming"
],
Expand Down
2 changes: 1 addition & 1 deletion docs/streaming-programming-guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -683,7 +683,7 @@ for Java, and [StreamingContext](api/python/pyspark.streaming.html#pyspark.strea
{:.no_toc}

<span class="badge" style="background-color: grey">Python API</span> As of Spark {{site.SPARK_VERSION_SHORT}},
out of these sources, *only* Kafka and Flume are available in the Python API. We will add more advanced sources in the Python API in future.
out of these sources, *only* Kafka, Flume and MQTT are available in the Python API. We will add more advanced sources in the Python API in future.

This category of sources require interfacing with external non-Spark libraries, some of them with
complex dependencies (e.g., Kafka and Flume). Hence, to minimize issues related to version conflicts
Expand Down
1 change: 0 additions & 1 deletion external/mqtt/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,6 @@
<groupId>org.apache.activemq</groupId>
<artifactId>activemq-core</artifactId>
<version>5.7.0</version>
<scope>test</scope>
</dependency>
</dependencies>
<build>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.streaming.mqtt

import java.net.{ServerSocket, URI}
import java.util.concurrent.{TimeUnit, CountDownLatch}

import scala.language.postfixOps

import org.apache.activemq.broker.{BrokerService, TransportConnector}
import org.apache.commons.lang3.RandomUtils
import org.eclipse.paho.client.mqttv3._
import org.eclipse.paho.client.mqttv3.persist.MqttDefaultFilePersistence

import org.apache.spark.streaming.{StreamingContext, Milliseconds}
import org.apache.spark.streaming.scheduler.StreamingListener
import org.apache.spark.streaming.scheduler.StreamingListenerReceiverStarted
import org.apache.spark.util.Utils
import org.apache.spark.{Logging, SparkConf}

/**
* Share codes for Scala and Python unit tests
*/
private class MQTTTestUtils extends Logging {

private val persistenceDir = Utils.createTempDir()
private val brokerHost = "localhost"
private var brokerPort = findFreePort()

private var broker: BrokerService = _
private var connector: TransportConnector = _

def brokerUri: String = {
s"$brokerHost:$brokerPort"
}

def setup(): Unit = {
broker = new BrokerService()
broker.setDataDirectoryFile(Utils.createTempDir())
connector = new TransportConnector()
connector.setName("mqtt")
connector.setUri(new URI("mqtt://" + brokerUri))
broker.addConnector(connector)
broker.start()
}

def teardown(): Unit = {
if (broker != null) {
broker.stop()
broker = null
}
if (connector != null) {
connector.stop()
connector = null
}
}

private def findFreePort(): Int = {
val candidatePort = RandomUtils.nextInt(1024, 65536)
Utils.startServiceOnPort(candidatePort, (trialPort: Int) => {
val socket = new ServerSocket(trialPort)
socket.close()
(null, trialPort)
}, new SparkConf())._2
}

def publishData(topic: String, data: String): Unit = {
var client: MqttClient = null
try {
val persistence = new MqttDefaultFilePersistence(persistenceDir.getAbsolutePath)
client = new MqttClient("tcp://" + brokerUri, MqttClient.generateClientId(), persistence)
client.connect()
if (client.isConnected) {
val msgTopic = client.getTopic(topic)
val message = new MqttMessage(data.getBytes("utf-8"))
message.setQos(1)
message.setRetained(true)

for (i <- 0 to 10) {
try {
msgTopic.publish(message)
} catch {
case e: MqttException if e.getReasonCode == MqttException.REASON_CODE_MAX_INFLIGHT =>
// wait for Spark streaming to consume something from the message queue
Thread.sleep(50)
}
}
}
} finally {
client.disconnect()
client.close()
client = null
}
}

/**
* Block until at least one receiver has started or timeout occurs.
*/
def waitForReceiverToStart(ssc: StreamingContext) = {
val latch = new CountDownLatch(1)
ssc.addStreamingListener(new StreamingListener {
override def onReceiverStarted(receiverStarted: StreamingListenerReceiverStarted) {
latch.countDown()
}
})

assert(latch.await(10, TimeUnit.SECONDS), "Timeout waiting for receiver to start.")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,61 +17,45 @@

package org.apache.spark.streaming.mqtt

import java.net.{URI, ServerSocket}
import java.util.concurrent.CountDownLatch
import java.util.concurrent.TimeUnit

import scala.concurrent.duration._
import scala.language.postfixOps

import org.apache.activemq.broker.{TransportConnector, BrokerService}
import org.apache.commons.lang3.RandomUtils
import org.eclipse.paho.client.mqttv3._
import org.eclipse.paho.client.mqttv3.persist.MqttDefaultFilePersistence

import org.scalatest.BeforeAndAfter
import org.scalatest.BeforeAndAfterAll
import org.scalatest.concurrent.Eventually

import org.apache.spark.streaming.{Milliseconds, StreamingContext}
import org.apache.spark.storage.StorageLevel
import org.apache.spark.streaming.dstream.ReceiverInputDStream
import org.apache.spark.streaming.scheduler.StreamingListener
import org.apache.spark.streaming.scheduler.StreamingListenerReceiverStarted
import org.apache.spark.{SparkConf, SparkFunSuite}
import org.apache.spark.util.Utils

class MQTTStreamSuite extends SparkFunSuite with Eventually with BeforeAndAfter {
import org.apache.spark.storage.StorageLevel
import org.apache.spark.streaming.{Milliseconds, StreamingContext}

private val batchDuration = Milliseconds(500)
private val master = "local[2]"
private val framework = this.getClass.getSimpleName
private val freePort = findFreePort()
private val brokerUri = "//localhost:" + freePort
private val topic = "def"
private val persistenceDir = Utils.createTempDir()
class MQTTStreamSuite extends SparkFunSuite with Eventually with BeforeAndAfterAll {

private val topic = "topic"
private var ssc: StreamingContext = _
private var broker: BrokerService = _
private var connector: TransportConnector = _
private var MQTTTestUtils: MQTTTestUtils = _

before {
ssc = new StreamingContext(master, framework, batchDuration)
setupMQTT()
override def beforeAll(): Unit = {
MQTTTestUtils = new MQTTTestUtils
MQTTTestUtils.setup()
}

after {
override def afterAll(): Unit = {
if (ssc != null) {
ssc.stop()
ssc = null
}
Utils.deleteRecursively(persistenceDir)
tearDownMQTT()

if (MQTTTestUtils != null) {
MQTTTestUtils.teardown()
MQTTTestUtils = null
}
}

test("mqtt input stream") {
val sparkConf = new SparkConf().setMaster("local[4]").setAppName(this.getClass.getSimpleName)
ssc = new StreamingContext(sparkConf, Milliseconds(500))
val sendMessage = "MQTT demo for spark streaming"
val receiveStream =
MQTTUtils.createStream(ssc, "tcp:" + brokerUri, topic, StorageLevel.MEMORY_ONLY)
MQTTUtils.createStream(ssc, "tcp://" + MQTTTestUtils.brokerUri, topic, StorageLevel.MEMORY_ONLY)
@volatile var receiveMessage: List[String] = List()
receiveStream.foreachRDD { rdd =>
if (rdd.collect.length > 0) {
Expand All @@ -83,85 +67,13 @@ class MQTTStreamSuite extends SparkFunSuite with Eventually with BeforeAndAfter

// wait for the receiver to start before publishing data, or we risk failing
// the test nondeterministically. See SPARK-4631
waitForReceiverToStart()
MQTTTestUtils.waitForReceiverToStart(ssc)

MQTTTestUtils.publishData(topic, sendMessage)

publishData(sendMessage)
eventually(timeout(10000 milliseconds), interval(100 milliseconds)) {
assert(sendMessage.equals(receiveMessage(0)))
}
ssc.stop()
}

private def setupMQTT() {
broker = new BrokerService()
broker.setDataDirectoryFile(Utils.createTempDir())
connector = new TransportConnector()
connector.setName("mqtt")
connector.setUri(new URI("mqtt:" + brokerUri))
broker.addConnector(connector)
broker.start()
}

private def tearDownMQTT() {
if (broker != null) {
broker.stop()
broker = null
}
if (connector != null) {
connector.stop()
connector = null
}
}

private def findFreePort(): Int = {
val candidatePort = RandomUtils.nextInt(1024, 65536)
Utils.startServiceOnPort(candidatePort, (trialPort: Int) => {
val socket = new ServerSocket(trialPort)
socket.close()
(null, trialPort)
}, new SparkConf())._2
}

def publishData(data: String): Unit = {
var client: MqttClient = null
try {
val persistence = new MqttDefaultFilePersistence(persistenceDir.getAbsolutePath)
client = new MqttClient("tcp:" + brokerUri, MqttClient.generateClientId(), persistence)
client.connect()
if (client.isConnected) {
val msgTopic = client.getTopic(topic)
val message = new MqttMessage(data.getBytes("utf-8"))
message.setQos(1)
message.setRetained(true)

for (i <- 0 to 10) {
try {
msgTopic.publish(message)
} catch {
case e: MqttException if e.getReasonCode == MqttException.REASON_CODE_MAX_INFLIGHT =>
// wait for Spark streaming to consume something from the message queue
Thread.sleep(50)
}
}
}
} finally {
client.disconnect()
client.close()
client = null
}
}

/**
* Block until at least one receiver has started or timeout occurs.
*/
private def waitForReceiverToStart() = {
val latch = new CountDownLatch(1)
ssc.addStreamingListener(new StreamingListener {
override def onReceiverStarted(receiverStarted: StreamingListenerReceiverStarted) {
latch.countDown()
}
})

assert(latch.await(10, TimeUnit.SECONDS), "Timeout waiting for receiver to start.")
}
}
Loading

0 comments on commit a5a8f9f

Please sign in to comment.