diff --git a/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala b/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala index f9fbe2ff858ce..9abbf4a7a3971 100644 --- a/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala @@ -20,6 +20,7 @@ package org.apache.spark.util import java.util.concurrent._ import scala.concurrent.{ExecutionContext, ExecutionContextExecutor} +import scala.concurrent.forkjoin.{ForkJoinPool => SForkJoinPool, ForkJoinWorkerThread => SForkJoinWorkerThread} import scala.util.control.NonFatal import com.google.common.util.concurrent.{MoreExecutors, ThreadFactoryBuilder} @@ -156,4 +157,21 @@ private[spark] object ThreadUtils { result } } + + /** + * Construct a new Scala ForkJoinPool with a specified max parallelism and name prefix. + */ + def newForkJoinPool(prefix: String, maxThreadNumber: Int): SForkJoinPool = { + // Custom factory to set thread names + val factory = new SForkJoinPool.ForkJoinWorkerThreadFactory { + override def newThread(pool: SForkJoinPool) = + new SForkJoinWorkerThread(pool) { + setName(prefix + "-" + super.getName) + } + } + new SForkJoinPool(maxThreadNumber, factory, + null, // handler + false // asyncMode + ) + } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala index 314263f26ee60..a3b7e783acd8d 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala @@ -18,11 +18,11 @@ package org.apache.spark.streaming.util import java.nio.ByteBuffer import java.util.{Iterator => JIterator} -import java.util.concurrent.{RejectedExecutionException, ThreadPoolExecutor} +import java.util.concurrent.RejectedExecutionException import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer -import scala.collection.parallel.ThreadPoolTaskSupport +import scala.collection.parallel.ExecutionContextTaskSupport import scala.concurrent.{Await, ExecutionContext, Future} import scala.language.postfixOps @@ -62,8 +62,8 @@ private[streaming] class FileBasedWriteAheadLog( private val threadpoolName = { "WriteAheadLogManager" + callerName.map(c => s" for $c").getOrElse("") } - private val threadpool = ThreadUtils.newDaemonCachedThreadPool(threadpoolName, 20) - private val executionContext = ExecutionContext.fromExecutorService(threadpool) + private val forkJoinPool = ThreadUtils.newForkJoinPool(threadpoolName, 20) + private val executionContext = ExecutionContext.fromExecutorService(forkJoinPool) override protected def logName = { getClass.getName.stripSuffix("$") + @@ -144,7 +144,7 @@ private[streaming] class FileBasedWriteAheadLog( } else { // For performance gains, it makes sense to parallelize the recovery if // closeFileAfterWrite = true - seqToParIterator(threadpool, logFilesToRead, readFile).asJava + seqToParIterator(executionContext, logFilesToRead, readFile).asJava } } @@ -283,16 +283,17 @@ private[streaming] object FileBasedWriteAheadLog { /** * This creates an iterator from a parallel collection, by keeping at most `n` objects in memory - * at any given time, where `n` is the size of the thread pool. This is crucial for use cases - * where we create `FileBasedWriteAheadLogReader`s during parallel recovery. We don't want to - * open up `k` streams altogether where `k` is the size of the Seq that we want to parallelize. + * at any given time, where `n` is at most the max of the size of the thread pool or 8. This is + * crucial for use cases where we create `FileBasedWriteAheadLogReader`s during parallel recovery. + * We don't want to open up `k` streams altogether where `k` is the size of the Seq that we want + * to parallelize. */ def seqToParIterator[I, O]( - tpool: ThreadPoolExecutor, + executionContext: ExecutionContext, source: Seq[I], handler: I => Iterator[O]): Iterator[O] = { - val taskSupport = new ThreadPoolTaskSupport(tpool) - val groupSize = tpool.getMaximumPoolSize.max(8) + val taskSupport = new ExecutionContextTaskSupport(executionContext) + val groupSize = taskSupport.parallelismLevel.max(8) source.grouped(groupSize).flatMap { group => val parallelCollection = group.par parallelCollection.tasksupport = taskSupport diff --git a/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala index 7460e8629b696..8c980dee2cc06 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala @@ -228,7 +228,9 @@ class FileBasedWriteAheadLogSuite the list of files. */ val numThreads = 8 - val tpool = ThreadUtils.newDaemonFixedThreadPool(numThreads, "wal-test-thread-pool") + val fpool = ThreadUtils.newForkJoinPool("wal-test-thread-pool", numThreads) + val executionContext = ExecutionContext.fromExecutorService(fpool) + class GetMaxCounter { private val value = new AtomicInteger() @volatile private var max: Int = 0 @@ -258,7 +260,8 @@ class FileBasedWriteAheadLogSuite val t = new Thread() { override def run() { // run the calculation on a separate thread so that we can release the latch - val iterator = FileBasedWriteAheadLog.seqToParIterator[Int, Int](tpool, testSeq, handle) + val iterator = FileBasedWriteAheadLog.seqToParIterator[Int, Int](executionContext, + testSeq, handle) collected = iterator.toSeq } } @@ -273,7 +276,7 @@ class FileBasedWriteAheadLogSuite // make sure we didn't open too many Iterators assert(counter.getMax() <= numThreads) } finally { - tpool.shutdownNow() + fpool.shutdownNow() } }