diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/RateLimiter.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/RateLimiter.scala index 8df542b367d27..356ae340387df 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/RateLimiter.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/RateLimiter.scala @@ -17,6 +17,8 @@ package org.apache.spark.streaming.receiver +import java.util.concurrent.atomic.AtomicInteger + import com.google.common.util.concurrent.{RateLimiter => GuavaRateLimiter} import org.apache.spark.{Logging, SparkConf} @@ -34,12 +36,28 @@ import org.apache.spark.{Logging, SparkConf} */ private[receiver] abstract class RateLimiter(conf: SparkConf) extends Logging { - private val desiredRate = conf.getInt("spark.streaming.receiver.maxRate", 0) - private lazy val rateLimiter = GuavaRateLimiter.create(desiredRate) + // treated as an upper limit + private val maxRateLimit = conf.getInt("spark.streaming.receiver.maxRate", 0) + private[receiver] var currentRateLimit = new AtomicInteger(maxRateLimit) + private lazy val rateLimiter = GuavaRateLimiter.create(currentRateLimit.get()) def waitToPush() { - if (desiredRate > 0) { + if (currentRateLimit.get() > 0) { rateLimiter.acquire() } } + + private[receiver] def updateRate(newRate: Int): Unit = + if (newRate > 0) { + try { + if (maxRateLimit > 0) { + currentRateLimit.set(newRate.min(maxRateLimit)) + } + else { + currentRateLimit.set(newRate) + } + } finally { + rateLimiter.setRate(currentRateLimit.get()) + } + } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverMessage.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverMessage.scala index 7bf3c33319491..1eb55affaa9d0 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverMessage.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverMessage.scala @@ -23,4 +23,5 @@ import org.apache.spark.streaming.Time private[streaming] sealed trait ReceiverMessage extends Serializable private[streaming] object StopReceiver extends ReceiverMessage private[streaming] case class CleanupOldBlocks(threshTime: Time) extends ReceiverMessage - +private[streaming] case class UpdateRateLimit(elementsPerSecond: Long) + extends ReceiverMessage diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala index 6078cdf8f8790..6e819460b1b23 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala @@ -77,6 +77,8 @@ private[streaming] class ReceiverSupervisorImpl( case CleanupOldBlocks(threshTime) => logDebug("Received delete old batch signal") cleanupOldBlocks(threshTime) + case UpdateRateLimit(eps) => + blockGenerator.updateRate(eps.toInt) } }) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala index 644e581cd8279..604d1a0dae289 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala @@ -27,7 +27,7 @@ import org.apache.spark.{Logging, SparkEnv, SparkException} import org.apache.spark.rpc._ import org.apache.spark.streaming.{StreamingContext, Time} import org.apache.spark.streaming.receiver.{CleanupOldBlocks, Receiver, ReceiverSupervisorImpl, - StopReceiver} + StopReceiver, UpdateRateLimit} import org.apache.spark.util.SerializableConfiguration /** @@ -180,6 +180,12 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false logError(s"Deregistered receiver for stream $streamId: $messageWithError") } + /** Update a receiver's maximum rate from an estimator's update */ + def sendRateUpdate(streamUID: Int, newRate: Long): Unit = { + for (info <- receiverInfo.get(streamUID); eP <- Option(info.endpoint)) + eP.send(UpdateRateLimit(newRate)) + } + /** Add new blocks for the given stream */ private def addBlock(receivedBlockInfo: ReceivedBlockInfo): Boolean = { receivedBlockTracker.addBlock(receivedBlockInfo) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/receiver/RateLimiterSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/receiver/RateLimiterSuite.scala new file mode 100644 index 0000000000000..904c7773c5f2c --- /dev/null +++ b/streaming/src/test/scala/org/apache/spark/streaming/receiver/RateLimiterSuite.scala @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.receiver + +import org.apache.spark.SparkConf +import org.apache.spark.SparkFunSuite + +/** Testsuite for testing the network receiver behavior */ +class RateLimiterSuite extends SparkFunSuite { + + test("rate limiter initializes even without a maxRate set") { + val conf = new SparkConf() + val rateLimiter = new RateLimiter(conf){} + rateLimiter.updateRate(105) + assert(rateLimiter.currentRateLimit.get == 105) + } + + test("rate limiter updates when below maxRate") { + val conf = new SparkConf().set("spark.streaming.receiver.maxRate", "110") + val rateLimiter = new RateLimiter(conf){} + rateLimiter.updateRate(105) + assert(rateLimiter.currentRateLimit.get == 105) + } + + test("rate limiter stays below maxRate despite large updates") { + val conf = new SparkConf().set("spark.streaming.receiver.maxRate", "100") + val rateLimiter = new RateLimiter(conf){} + rateLimiter.updateRate(105) + assert(rateLimiter.currentRateLimit.get == 100) + } + +}