Skip to content

Commit

Permalink
[SPARK-13398][STREAMING] Move away from thread pool task support to f…
Browse files Browse the repository at this point in the history
…orkjoin

## What changes were proposed in this pull request?

Remove old deprecated ThreadPoolExecutor and replace with ExecutionContext using a ForkJoinPool. The downside of this is that scala's ForkJoinPool doesn't give us a way to specify the thread pool name (and is a wrapper of Java's in 2.12) except by providing a custom factory. Note that we can't use Java's ForkJoinPool directly in Scala 2.11 since it uses a ExecutionContext which reports system parallelism. One other implicit change that happens is the old ExecutionContext would have reported a different default parallelism since it used system parallelism rather than threadpool parallelism (this was likely not intended but also likely not a huge difference).

The previous version of this PR attempted to use an execution context constructed on the ThreadPool (but not the deprecated ThreadPoolExecutor class) so as to keep the ability to have human readable named threads but this reported system parallelism.

## How was this patch tested?

unit tests: streaming/testOnly org.apache.spark.streaming.util.*

Author: Holden Karau <holden@us.ibm.com>

Closes apache#11423 from holdenk/SPARK-13398-move-away-from-ThreadPoolTaskSupport-java-forkjoin.
  • Loading branch information
holdenk authored and roygao94 committed Mar 22, 2016
1 parent f7ddcb9 commit c8a76b7
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 14 deletions.
18 changes: 18 additions & 0 deletions core/src/main/scala/org/apache/spark/util/ThreadUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.util
import java.util.concurrent._

import scala.concurrent.{ExecutionContext, ExecutionContextExecutor}
import scala.concurrent.forkjoin.{ForkJoinPool => SForkJoinPool, ForkJoinWorkerThread => SForkJoinWorkerThread}
import scala.util.control.NonFatal

import com.google.common.util.concurrent.{MoreExecutors, ThreadFactoryBuilder}
Expand Down Expand Up @@ -156,4 +157,21 @@ private[spark] object ThreadUtils {
result
}
}

/**
* Construct a new Scala ForkJoinPool with a specified max parallelism and name prefix.
*/
def newForkJoinPool(prefix: String, maxThreadNumber: Int): SForkJoinPool = {
// Custom factory to set thread names
val factory = new SForkJoinPool.ForkJoinWorkerThreadFactory {
override def newThread(pool: SForkJoinPool) =
new SForkJoinWorkerThread(pool) {
setName(prefix + "-" + super.getName)
}
}
new SForkJoinPool(maxThreadNumber, factory,
null, // handler
false // asyncMode
)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@ package org.apache.spark.streaming.util

import java.nio.ByteBuffer
import java.util.{Iterator => JIterator}
import java.util.concurrent.{RejectedExecutionException, ThreadPoolExecutor}
import java.util.concurrent.RejectedExecutionException

import scala.collection.JavaConverters._
import scala.collection.mutable.ArrayBuffer
import scala.collection.parallel.ThreadPoolTaskSupport
import scala.collection.parallel.ExecutionContextTaskSupport
import scala.concurrent.{Await, ExecutionContext, Future}
import scala.language.postfixOps

Expand Down Expand Up @@ -62,8 +62,8 @@ private[streaming] class FileBasedWriteAheadLog(
private val threadpoolName = {
"WriteAheadLogManager" + callerName.map(c => s" for $c").getOrElse("")
}
private val threadpool = ThreadUtils.newDaemonCachedThreadPool(threadpoolName, 20)
private val executionContext = ExecutionContext.fromExecutorService(threadpool)
private val forkJoinPool = ThreadUtils.newForkJoinPool(threadpoolName, 20)
private val executionContext = ExecutionContext.fromExecutorService(forkJoinPool)

override protected def logName = {
getClass.getName.stripSuffix("$") +
Expand Down Expand Up @@ -144,7 +144,7 @@ private[streaming] class FileBasedWriteAheadLog(
} else {
// For performance gains, it makes sense to parallelize the recovery if
// closeFileAfterWrite = true
seqToParIterator(threadpool, logFilesToRead, readFile).asJava
seqToParIterator(executionContext, logFilesToRead, readFile).asJava
}
}

Expand Down Expand Up @@ -283,16 +283,17 @@ private[streaming] object FileBasedWriteAheadLog {

/**
* This creates an iterator from a parallel collection, by keeping at most `n` objects in memory
* at any given time, where `n` is the size of the thread pool. This is crucial for use cases
* where we create `FileBasedWriteAheadLogReader`s during parallel recovery. We don't want to
* open up `k` streams altogether where `k` is the size of the Seq that we want to parallelize.
* at any given time, where `n` is at most the max of the size of the thread pool or 8. This is
* crucial for use cases where we create `FileBasedWriteAheadLogReader`s during parallel recovery.
* We don't want to open up `k` streams altogether where `k` is the size of the Seq that we want
* to parallelize.
*/
def seqToParIterator[I, O](
tpool: ThreadPoolExecutor,
executionContext: ExecutionContext,
source: Seq[I],
handler: I => Iterator[O]): Iterator[O] = {
val taskSupport = new ThreadPoolTaskSupport(tpool)
val groupSize = tpool.getMaximumPoolSize.max(8)
val taskSupport = new ExecutionContextTaskSupport(executionContext)
val groupSize = taskSupport.parallelismLevel.max(8)
source.grouped(groupSize).flatMap { group =>
val parallelCollection = group.par
parallelCollection.tasksupport = taskSupport
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,9 @@ class FileBasedWriteAheadLogSuite
the list of files.
*/
val numThreads = 8
val tpool = ThreadUtils.newDaemonFixedThreadPool(numThreads, "wal-test-thread-pool")
val fpool = ThreadUtils.newForkJoinPool("wal-test-thread-pool", numThreads)
val executionContext = ExecutionContext.fromExecutorService(fpool)

class GetMaxCounter {
private val value = new AtomicInteger()
@volatile private var max: Int = 0
Expand Down Expand Up @@ -258,7 +260,8 @@ class FileBasedWriteAheadLogSuite
val t = new Thread() {
override def run() {
// run the calculation on a separate thread so that we can release the latch
val iterator = FileBasedWriteAheadLog.seqToParIterator[Int, Int](tpool, testSeq, handle)
val iterator = FileBasedWriteAheadLog.seqToParIterator[Int, Int](executionContext,
testSeq, handle)
collected = iterator.toSeq
}
}
Expand All @@ -273,7 +276,7 @@ class FileBasedWriteAheadLogSuite
// make sure we didn't open too many Iterators
assert(counter.getMax() <= numThreads)
} finally {
tpool.shutdownNow()
fpool.shutdownNow()
}
}

Expand Down

0 comments on commit c8a76b7

Please sign in to comment.