Permalink
Browse files

SI-7029 - Makes sure that uncaught exceptions are propagated to the U…

…EH for the global ExecutionContext
  • Loading branch information...
1 parent 42c4cc7 commit 3f78bee128bd6a478bef6a66c5574f77a2d6dd74 @viktorklang viktorklang committed with phaller Jan 30, 2013
@@ -25,11 +25,15 @@ private[scala] class ExecutionContextImpl private[impl] (es: Executor, reporter:
case some => some
}
+ private val uncaughtExceptionHandler: Thread.UncaughtExceptionHandler = new Thread.UncaughtExceptionHandler {
+ def uncaughtException(thread: Thread, cause: Throwable): Unit = reporter(cause)
+ }
+
// Implement BlockContext on FJP threads
class DefaultThreadFactory(daemonic: Boolean) extends ThreadFactory with ForkJoinPool.ForkJoinWorkerThreadFactory {
def wire[T <: Thread](thread: T): T = {
thread.setDaemon(daemonic)
- //Potentially set things like uncaught exception handler, name etc
+ thread.setUncaughtExceptionHandler(uncaughtExceptionHandler)
thread
}
@@ -73,7 +77,7 @@ private[scala] class ExecutionContextImpl private[impl] (es: Executor, reporter:
new ForkJoinPool(
desiredParallelism,
threadFactory,
- null, //FIXME we should have an UncaughtExceptionHandler, see what Akka does
+ uncaughtExceptionHandler,
true) // Async all the way baby
} catch {
case NonFatal(t) =>
@@ -94,13 +98,13 @@ private[scala] class ExecutionContextImpl private[impl] (es: Executor, reporter:
def execute(runnable: Runnable): Unit = executor match {
case fj: ForkJoinPool =>
+ val fjt = runnable match {
+ case t: ForkJoinTask[_] => t
+ case r => new ExecutionContextImpl.AdaptedForkJoinTask(r)
+ }
Thread.currentThread match {
- case fjw: ForkJoinWorkerThread if fjw.getPool eq fj =>
- (runnable match {
- case fjt: ForkJoinTask[_] => fjt
- case _ => ForkJoinTask.adapt(runnable)
- }).fork
- case _ => fj.execute(runnable)
+ case fjw: ForkJoinWorkerThread if fjw.getPool eq fj => fjt.fork()
+ case _ => fj execute fjt
}
case generic => generic execute runnable
}
@@ -111,6 +115,20 @@ private[scala] class ExecutionContextImpl private[impl] (es: Executor, reporter:
private[concurrent] object ExecutionContextImpl {
+ final class AdaptedForkJoinTask(runnable: Runnable) extends ForkJoinTask[Unit] {
+ final override def setRawResult(u: Unit): Unit = ()
+ final override def getRawResult(): Unit = ()
+ final override def exec(): Boolean = try { runnable.run(); true } catch {
+ case anything: Throwable
+ val t = Thread.currentThread
+ t.getUncaughtExceptionHandler match {
+ case null
+ case some some.uncaughtException(t, anything)
+ }
+ throw anything
+ }
+ }
+
def fromExecutor(e: Executor, reporter: Throwable => Unit = ExecutionContext.defaultReporter): ExecutionContextImpl = new ExecutionContextImpl(e, reporter)
def fromExecutorService(es: ExecutorService, reporter: Throwable => Unit = ExecutionContext.defaultReporter): ExecutionContextImpl with ExecutionContextExecutorService =
new ExecutionContextImpl(es, reporter) with ExecutionContextExecutorService {
@@ -34,7 +34,7 @@ private class CallbackRunnable[T](val executor: ExecutionContext, val onComplete
value = v
// Note that we cannot prepare the ExecutionContext at this point, since we might
// already be running on a different thread!
- executor.execute(this)
+ try executor.execute(this) catch { case NonFatal(t) => executor reportFailure t }
}
}
@@ -74,6 +74,30 @@ object FutureTests extends MinimalScalaTest {
"A future with global ExecutionContext" should {
import ExecutionContext.Implicits._
+ "output uncaught exceptions" in {
+ import java.io.{ ByteArrayOutputStream, PrintStream }
+
+ val baos = new ByteArrayOutputStream(1 << 16) { def isEmpty: Boolean = count == 0 }
+ val tmpErr = new PrintStream(baos)
+
+ def assertPrintedToErr(t: Throwable): Unit = {
+ t.printStackTrace(tmpErr)
+ tmpErr.flush()
+ val expected = baos.toByteArray.toIndexedSeq
+ baos.reset()
+ val f = Future(throw t)
+ val d = Deadline.now + 5.seconds
+ while(d.hasTimeLeft && baos.isEmpty) Thread.sleep(10)
+ baos.toByteArray.toIndexedSeq mustBe (expected)
+ }
+
+ val oldErr = System.err
+ System.setErr(tmpErr)
+ try {
+ assertPrintedToErr(new NotImplementedError("foo"))
+ } finally System.setErr(oldErr)
+ }
+
"compose with for-comprehensions" in {
def async(x: Int) = future { (x * 2).toString }
val future0 = future[Any] {

0 comments on commit 3f78bee

Please sign in to comment.