Skip to content

Commit

Permalink
[SPARK-24063][SS] Add maximum epoch queue threshold for ContinuousExe…
Browse files Browse the repository at this point in the history
…cution

## What changes were proposed in this pull request?

Continuous processing is waiting on epochs which are not yet complete (for example one partition is not making progress) and stores pending items in queues. These queues are unbounded and can consume up all the memory easily. In this PR I've added `spark.sql.streaming.continuous.epochBacklogQueueSize` configuration possibility to make them bounded. If the related threshold reached then the query will stop with `IllegalStateException`.

## How was this patch tested?

Existing + additional unit tests.

Closes apache#23156 from gaborgsomogyi/SPARK-24063.

Authored-by: Gabor Somogyi <gabor.g.somogyi@gmail.com>
Signed-off-by: Marcelo Vanzin <vanzin@cloudera.com>
  • Loading branch information
gaborgsomogyi authored and mccheah committed May 15, 2019
1 parent 982df04 commit 38556e7
Show file tree
Hide file tree
Showing 5 changed files with 177 additions and 3 deletions.
Expand Up @@ -1441,6 +1441,13 @@ object SQLConf {
.booleanConf
.createWithDefault(true)

val CONTINUOUS_STREAMING_EPOCH_BACKLOG_QUEUE_SIZE =
buildConf("spark.sql.streaming.continuous.epochBacklogQueueSize")
.doc("The max number of entries to be stored in queue to wait for late epochs. " +
"If this parameter is exceeded by the size of the queue, stream will stop with an error.")
.intConf
.createWithDefault(10000)

val CONTINUOUS_STREAMING_EXECUTOR_QUEUE_SIZE =
buildConf("spark.sql.streaming.continuous.executorQueueSize")
.internal()
Expand Down Expand Up @@ -2073,6 +2080,9 @@ class SQLConf extends Serializable with Logging {

def literalPickMinimumPrecision: Boolean = getConf(LITERAL_PICK_MINIMUM_PRECISION)

def continuousStreamingEpochBacklogQueueSize: Int =
getConf(CONTINUOUS_STREAMING_EPOCH_BACKLOG_QUEUE_SIZE)

def continuousStreamingExecutorQueueSize: Int = getConf(CONTINUOUS_STREAMING_EXECUTOR_QUEUE_SIZE)

def continuousStreamingExecutorPollIntervalMs: Long =
Expand Down
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.streaming.continuous

import java.util.UUID
import java.util.concurrent.TimeUnit
import java.util.concurrent.atomic.AtomicReference
import java.util.function.UnaryOperator

import scala.collection.JavaConverters._
Expand Down Expand Up @@ -57,6 +58,9 @@ class ContinuousExecution(
// For use only in test harnesses.
private[sql] var currentEpochCoordinatorId: String = _

// Throwable that caused the execution to fail
private val failure: AtomicReference[Throwable] = new AtomicReference[Throwable](null)

override val logicalPlan: LogicalPlan = {
val v2ToRelationMap = MutableMap[StreamingRelationV2, StreamingDataSourceV2Relation]()
var nextSourceId = 0
Expand Down Expand Up @@ -253,6 +257,11 @@ class ContinuousExecution(
lastExecution.toRdd
}
}

val f = failure.get()
if (f != null) {
throw f
}
} catch {
case t: Throwable if StreamExecution.isInterruptionException(t, sparkSession.sparkContext) &&
state.get() == RECONFIGURING =>
Expand Down Expand Up @@ -381,6 +390,35 @@ class ContinuousExecution(
}
}

/**
* Stores error and stops the query execution thread to terminate the query in new thread.
*/
def stopInNewThread(error: Throwable): Unit = {
if (failure.compareAndSet(null, error)) {
logError(s"Query $prettyIdString received exception $error")
stopInNewThread()
}
}

/**
* Stops the query execution thread to terminate the query in new thread.
*/
private def stopInNewThread(): Unit = {
new Thread("stop-continuous-execution") {
setDaemon(true)

override def run(): Unit = {
try {
ContinuousExecution.this.stop()
} catch {
case e: Throwable =>
logError(e.getMessage, e)
throw e
}
}
}.start()
}

/**
* Stops the query execution thread to terminate the query.
*/
Expand Down
Expand Up @@ -123,6 +123,9 @@ private[continuous] class EpochCoordinator(
override val rpcEnv: RpcEnv)
extends ThreadSafeRpcEndpoint with Logging {

private val epochBacklogQueueSize =
session.sqlContext.conf.continuousStreamingEpochBacklogQueueSize

private var queryWritesStopped: Boolean = false

private var numReaderPartitions: Int = _
Expand Down Expand Up @@ -212,6 +215,7 @@ private[continuous] class EpochCoordinator(
if (!partitionCommits.isDefinedAt((epoch, partitionId))) {
partitionCommits.put((epoch, partitionId), message)
resolveCommitsAtEpoch(epoch)
checkProcessingQueueBoundaries()
}

case ReportPartitionOffset(partitionId, epoch, offset) =>
Expand All @@ -223,6 +227,22 @@ private[continuous] class EpochCoordinator(
query.addOffset(epoch, stream, thisEpochOffsets.toSeq)
resolveCommitsAtEpoch(epoch)
}
checkProcessingQueueBoundaries()
}

private def checkProcessingQueueBoundaries() = {
if (partitionOffsets.size > epochBacklogQueueSize) {
query.stopInNewThread(new IllegalStateException("Size of the partition offset queue has " +
"exceeded its maximum"))
}
if (partitionCommits.size > epochBacklogQueueSize) {
query.stopInNewThread(new IllegalStateException("Size of the partition commit queue has " +
"exceeded its maximum"))
}
if (epochsWaitingToBeCommitted.size > epochBacklogQueueSize) {
query.stopInNewThread(new IllegalStateException("Size of the epoch queue has " +
"exceeded its maximum"))
}
}

override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
Expand Down
Expand Up @@ -25,6 +25,7 @@ import org.apache.spark.sql.execution.streaming._
import org.apache.spark.sql.execution.streaming.continuous._
import org.apache.spark.sql.execution.streaming.sources.ContinuousMemoryStream
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf.CONTINUOUS_STREAMING_EPOCH_BACKLOG_QUEUE_SIZE
import org.apache.spark.sql.streaming.{StreamTest, Trigger}
import org.apache.spark.sql.test.TestSparkSession

Expand Down Expand Up @@ -343,3 +344,33 @@ class ContinuousMetaSuite extends ContinuousSuiteBase {
}
}
}

class ContinuousEpochBacklogSuite extends ContinuousSuiteBase {
import testImplicits._

override protected def createSparkSession = new TestSparkSession(
new SparkContext(
"local[1]",
"continuous-stream-test-sql-context",
sparkConf.set("spark.sql.testkey", "true")))

// This test forces the backlog to overflow by not standing up enough executors for the query
// to make progress.
test("epoch backlog overflow") {
withSQLConf((CONTINUOUS_STREAMING_EPOCH_BACKLOG_QUEUE_SIZE.key, "10")) {
val df = spark.readStream
.format("rate")
.option("numPartitions", "2")
.option("rowsPerSecond", "500")
.load()
.select('value)

testStream(df, useV2Sink = true)(
StartStream(Trigger.Continuous(1)),
ExpectFailure[IllegalStateException] { e =>
e.getMessage.contains("queue has exceeded its maximum")
}
)
}
}
}
Expand Up @@ -17,16 +17,17 @@

package org.apache.spark.sql.streaming.continuous

import org.mockito.{ArgumentCaptor, InOrder}
import org.mockito.ArgumentMatchers.{any, eq => eqTo}
import org.mockito.InOrder
import org.mockito.Mockito.{inOrder, never, verify}
import org.mockito.Mockito._
import org.scalatest.BeforeAndAfterEach
import org.scalatest.mockito.MockitoSugar

import org.apache.spark._
import org.apache.spark.rpc.RpcEndpointRef
import org.apache.spark.sql.LocalSparkSession
import org.apache.spark.sql.execution.streaming.continuous._
import org.apache.spark.sql.internal.SQLConf.CONTINUOUS_STREAMING_EPOCH_BACKLOG_QUEUE_SIZE
import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousStream, PartitionOffset}
import org.apache.spark.sql.sources.v2.writer.WriterCommitMessage
import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWrite
Expand All @@ -43,14 +44,19 @@ class EpochCoordinatorSuite
private var writeSupport: StreamingWrite = _
private var query: ContinuousExecution = _
private var orderVerifier: InOrder = _
private val epochBacklogQueueSize = 10

override def beforeEach(): Unit = {
val stream = mock[ContinuousStream]
writeSupport = mock[StreamingWrite]
query = mock[ContinuousExecution]
orderVerifier = inOrder(writeSupport, query)

spark = new TestSparkSession()
spark = new TestSparkSession(
new SparkContext(
"local[2]", "test-sql-context",
new SparkConf().set("spark.sql.testkey", "true")
.set(CONTINUOUS_STREAMING_EPOCH_BACKLOG_QUEUE_SIZE, epochBacklogQueueSize)))

epochCoordinator
= EpochCoordinatorRef.create(writeSupport, stream, query, "test", 1, spark, SparkEnv.get)
Expand Down Expand Up @@ -186,6 +192,66 @@ class EpochCoordinatorSuite
verifyCommitsInOrderOf(List(1, 2, 3, 4, 5))
}

test("several epochs, max epoch backlog reached by partitionOffsets") {
setWriterPartitions(1)
setReaderPartitions(1)

reportPartitionOffset(0, 1)
// Commit messages not arriving
for (i <- 2 to epochBacklogQueueSize + 1) {
reportPartitionOffset(0, i)
}

makeSynchronousCall()

for (i <- 1 to epochBacklogQueueSize + 1) {
verifyNoCommitFor(i)
}
verifyStoppedWithException("Size of the partition offset queue has exceeded its maximum")
}

test("several epochs, max epoch backlog reached by partitionCommits") {
setWriterPartitions(1)
setReaderPartitions(1)

commitPartitionEpoch(0, 1)
// Offset messages not arriving
for (i <- 2 to epochBacklogQueueSize + 1) {
commitPartitionEpoch(0, i)
}

makeSynchronousCall()

for (i <- 1 to epochBacklogQueueSize + 1) {
verifyNoCommitFor(i)
}
verifyStoppedWithException("Size of the partition commit queue has exceeded its maximum")
}

test("several epochs, max epoch backlog reached by epochsWaitingToBeCommitted") {
setWriterPartitions(2)
setReaderPartitions(2)

commitPartitionEpoch(0, 1)
reportPartitionOffset(0, 1)

// For partition 2 epoch 1 messages never arriving
// +2 because the first epoch not yet arrived
for (i <- 2 to epochBacklogQueueSize + 2) {
commitPartitionEpoch(0, i)
reportPartitionOffset(0, i)
commitPartitionEpoch(1, i)
reportPartitionOffset(1, i)
}

makeSynchronousCall()

for (i <- 1 to epochBacklogQueueSize + 2) {
verifyNoCommitFor(i)
}
verifyStoppedWithException("Size of the epoch queue has exceeded its maximum")
}

private def setWriterPartitions(numPartitions: Int): Unit = {
epochCoordinator.askSync[Unit](SetWriterPartitions(numPartitions))
}
Expand Down Expand Up @@ -221,4 +287,13 @@ class EpochCoordinatorSuite
private def verifyCommitsInOrderOf(epochs: Seq[Long]): Unit = {
epochs.foreach(verifyCommit)
}

private def verifyStoppedWithException(msg: String): Unit = {
val exceptionCaptor = ArgumentCaptor.forClass(classOf[Throwable]);
verify(query, atLeastOnce()).stopInNewThread(exceptionCaptor.capture())

import scala.collection.JavaConverters._
val throwable = exceptionCaptor.getAllValues.asScala.find(_.getMessage === msg)
assert(throwable != null, "Stream stopped with an exception but expected message is missing")
}
}

0 comments on commit 38556e7

Please sign in to comment.