Permalink
Browse files

Fixed a bug with setting execution contexts.

Ported most of the future tests.
  • Loading branch information...
axel22 committed Apr 30, 2012
1 parent 9c4baa9 commit b2472a5f057ca5e18dc8b3205600c14bf31dc1d1
@@ -22,7 +22,7 @@ abstract class ConcurrentPackageObject {
*/
lazy val defaultExecutionContext = new impl.ExecutionContextImpl(null)
- private val currentExecutionContext = new ThreadLocal[ExecutionContext]
+ val currentExecutionContext = new ThreadLocal[ExecutionContext]
val handledFutureException: PartialFunction[Throwable, Throwable] = {
case t: Throwable if isFutureThrowable(t) => t
@@ -82,7 +82,7 @@ abstract class ConcurrentPackageObject {
* - TimeoutException - in the case that the blockable object timed out
*/
def blocking[T](body: =>T): T =
- blocking(impl.Future.body2awaitable(body), Duration.fromNanos(0))
+ blocking(impl.Future.body2awaitable(body), Duration.Inf)
/** Blocks on an awaitable object.
*
@@ -93,11 +93,12 @@ abstract class ConcurrentPackageObject {
* - InterruptedException - in the case that a wait within the blockable object was interrupted
* - TimeoutException - in the case that the blockable object timed out
*/
- def blocking[T](awaitable: Awaitable[T], atMost: Duration): T =
+ def blocking[T](awaitable: Awaitable[T], atMost: Duration): T = {
currentExecutionContext.get match {
- case null => Await.result(awaitable, atMost)
+ case null => awaitable.result(atMost)(Await.canAwaitEvidence)
case ec => ec.internalBlockingCall(awaitable, atMost)
}
+ }
@inline implicit final def int2durationops(x: Int): DurationOps = new DurationOps(x)
}
@@ -47,11 +47,18 @@ object ExecutionContext {
/** Creates an `ExecutionContext` from the given `ExecutorService`.
*/
- def fromExecutorService(e: ExecutorService): ExecutionContext with Executor = new impl.ExecutionContextImpl(e)
+ def fromExecutorService(e: ExecutorService, reporter: Throwable => Unit = defaultReporter): ExecutionContext with Executor = new impl.ExecutionContextImpl(e, reporter)
/** Creates an `ExecutionContext` from the given `Executor`.
*/
- def fromExecutor(e: Executor): ExecutionContext with Executor = new impl.ExecutionContextImpl(e)
+ def fromExecutor(e: Executor, reporter: Throwable => Unit = defaultReporter): ExecutionContext with Executor = new impl.ExecutionContextImpl(e, reporter)
+
+ def defaultReporter: Throwable => Unit = {
+ // `Error`s are currently wrapped by `resolver`.
+ // Also, re-throwing `Error`s here causes an exception handling test to fail.
+ //case e: Error => throw e
+ case t => t.printStackTrace()
+ }
}
@@ -17,7 +17,8 @@ import scala.concurrent.util.{ Duration }
-private[scala] class ExecutionContextImpl(es: AnyRef) extends ExecutionContext with Executor {
+private[scala] class ExecutionContextImpl(es: AnyRef, reporter: Throwable => Unit = ExecutionContext.defaultReporter)
+extends ExecutionContext with Executor {
import ExecutionContextImpl._
val executorService: AnyRef = if (es eq null) getExecutorService else es
@@ -26,7 +27,7 @@ private[scala] class ExecutionContextImpl(es: AnyRef) extends ExecutionContext w
def executorsThreadFactory = new ThreadFactory {
def newThread(r: Runnable) = new Thread(new Runnable {
override def run() {
- currentExecutionContext.set(ExecutionContextImpl.this)
+ scala.concurrent.currentExecutionContext.set(ExecutionContextImpl.this)
r.run()
}
})
@@ -36,7 +37,7 @@ private[scala] class ExecutionContextImpl(es: AnyRef) extends ExecutionContext w
def forkJoinPoolThreadFactory = new ForkJoinPool.ForkJoinWorkerThreadFactory {
def newThread(fjp: ForkJoinPool) = new ForkJoinWorkerThread(fjp) {
override def onStart() {
- currentExecutionContext.set(ExecutionContextImpl.this)
+ scala.concurrent.currentExecutionContext.set(ExecutionContextImpl.this)
}
}
}
@@ -92,22 +93,13 @@ private[scala] class ExecutionContextImpl(es: AnyRef) extends ExecutionContext w
}
}
- def reportFailure(t: Throwable) = t match {
- // `Error`s are currently wrapped by `resolver`.
- // Also, re-throwing `Error`s here causes an exception handling test to fail.
- //case e: Error => throw e
- case t => t.printStackTrace()
- }
+ def reportFailure(t: Throwable) = reporter(t)
}
private[concurrent] object ExecutionContextImpl {
-
- private[concurrent] def currentExecutionContext: ThreadLocal[ExecutionContext] = new ThreadLocal[ExecutionContext] {
- override protected def initialValue = null
- }
-
+
}
@@ -78,7 +78,7 @@ object Promise {
*/
class DefaultPromise[T](implicit val executor: ExecutionContext) extends AbstractPromise with Promise[T] { self =>
updater.set(this, Nil) // Start at "No callbacks" //FIXME switch to Unsafe instead of ARFU
-
+
protected final def tryAwait(atMost: Duration): Boolean = {
@tailrec
def awaitUnsafe(waitTimeNanos: Long): Boolean = {
@@ -88,7 +88,7 @@ object Promise {
val start = System.nanoTime()
try {
synchronized {
- while (!isCompleted) wait(ms, ns)
+ if (!isCompleted) wait(ms, ns) // previously - this was a `while`, ending up in an infinite loop
}
} catch {
case e: InterruptedException =>
@@ -99,7 +99,7 @@ object Promise {
isCompleted
}
//FIXME do not do this if there'll be no waiting
- blocking(Future.body2awaitable(awaitUnsafe(if (atMost.isFinite) atMost.toNanos else Long.MaxValue)), atMost)
+ awaitUnsafe(if (atMost.isFinite) atMost.toNanos else Long.MaxValue)
}
@throws(classOf[TimeoutException])
@@ -147,7 +147,9 @@ object Promise {
}
tryComplete(resolveEither(value))
} finally {
- synchronized { notifyAll() } //Notify any evil blockers
+ synchronized { //Notify any evil blockers
+ notifyAll()
+ }
}
}
@@ -26,9 +26,15 @@ package concurrent {
object Await {
private[concurrent] implicit val canAwaitEvidence = new CanAwait {}
- def ready[T](awaitable: Awaitable[T], atMost: Duration): Awaitable[T] = awaitable.ready(atMost)
+ def ready[T <: Awaitable[_]](awaitable: T, atMost: Duration): T = {
+ blocking(awaitable, atMost)
+ awaitable
+ }
+
+ def result[T](awaitable: Awaitable[T], atMost: Duration): T = {
+ blocking(awaitable, atMost)
+ }
- def result[T](awaitable: Awaitable[T], atMost: Duration): T = awaitable.result(atMost)
}
final class DurationOps private[concurrent] (x: Int) {
@@ -115,13 +115,13 @@ object Duration {
* Parse TimeUnit from string representation.
*/
protected[util] def timeUnit(unit: String): TimeUnit = unit.toLowerCase match {
- case "d" | "day" | "days" DAYS
- case "h" | "hour" | "hours" HOURS
- case "min" | "minute" | "minutes" MINUTES
- case "s" | "sec" | "second" | "seconds" SECONDS
- case "ms" | "milli" | "millis" | "millisecond" | "milliseconds" MILLISECONDS
- case "µs" | "micro" | "micros" | "microsecond" | "microseconds" MICROSECONDS
- case "ns" | "nano" | "nanos" | "nanosecond" | "nanoseconds" NANOSECONDS
+ case "d" | "day" | "days" => DAYS
+ case "h" | "hour" | "hours" => HOURS
+ case "min" | "minute" | "minutes" => MINUTES
+ case "s" | "sec" | "second" | "seconds" => SECONDS
+ case "ms" | "milli" | "millis" | "millisecond" | "milliseconds" => MILLISECONDS
+ case "µs" | "micro" | "micros" | "microsecond" | "microseconds" => MICROSECONDS
+ case "ns" | "nano" | "nanos" | "nanosecond" | "nanoseconds" => NANOSECONDS
}
val Zero: FiniteDuration = new FiniteDuration(0, NANOSECONDS)
@@ -138,26 +138,26 @@ object Duration {
}
trait Infinite {
- this: Duration
+ this: Duration =>
def +(other: Duration): Duration =
other match {
- case _: this.type this
- case _: Infinite throw new IllegalArgumentException("illegal addition of infinities")
- case _ this
+ case _: this.type => this
+ case _: Infinite => throw new IllegalArgumentException("illegal addition of infinities")
+ case _ => this
}
def -(other: Duration): Duration =
other match {
- case _: this.type throw new IllegalArgumentException("illegal subtraction of infinities")
- case _ this
+ case _: this.type => throw new IllegalArgumentException("illegal subtraction of infinities")
+ case _ => this
}
def *(factor: Double): Duration = this
def /(factor: Double): Duration = this
def /(other: Duration): Double =
other match {
- case _: Infinite throw new IllegalArgumentException("illegal division of infinities")
+ case _: Infinite => throw new IllegalArgumentException("illegal division of infinities")
// maybe questionable but pragmatic: Inf / 0 => Inf
- case x Double.PositiveInfinity * (if ((this > Zero) ^ (other >= Zero)) -1 else 1)
+ case x => Double.PositiveInfinity * (if ((this > Zero) ^ (other >= Zero)) -1 else 1)
}
def finite_? = false
@@ -300,20 +300,20 @@ class FiniteDuration(val length: Long, val unit: TimeUnit) extends Duration {
def toUnit(u: TimeUnit) = toNanos.toDouble / NANOSECONDS.convert(1, u)
override def toString = this match {
- case Duration(1, DAYS) "1 day"
- case Duration(x, DAYS) x + " days"
- case Duration(1, HOURS) "1 hour"
- case Duration(x, HOURS) x + " hours"
- case Duration(1, MINUTES) "1 minute"
- case Duration(x, MINUTES) x + " minutes"
- case Duration(1, SECONDS) "1 second"
- case Duration(x, SECONDS) x + " seconds"
- case Duration(1, MILLISECONDS) "1 millisecond"
- case Duration(x, MILLISECONDS) x + " milliseconds"
- case Duration(1, MICROSECONDS) "1 microsecond"
- case Duration(x, MICROSECONDS) x + " microseconds"
- case Duration(1, NANOSECONDS) "1 nanosecond"
- case Duration(x, NANOSECONDS) x + " nanoseconds"
+ case Duration(1, DAYS) => "1 day"
+ case Duration(x, DAYS) => x + " days"
+ case Duration(1, HOURS) => "1 hour"
+ case Duration(x, HOURS) => x + " hours"
+ case Duration(1, MINUTES) => "1 minute"
+ case Duration(x, MINUTES) => x + " minutes"
+ case Duration(1, SECONDS) => "1 second"
+ case Duration(x, SECONDS) => x + " seconds"
+ case Duration(1, MILLISECONDS) => "1 millisecond"
+ case Duration(x, MILLISECONDS) => x + " milliseconds"
+ case Duration(1, MICROSECONDS) => "1 microsecond"
+ case Duration(x, MICROSECONDS) => x + " microseconds"
+ case Duration(1, NANOSECONDS) => "1 nanosecond"
+ case Duration(x, NANOSECONDS) => x + " nanoseconds"
}
def compare(other: Duration) =
Oops, something went wrong.

0 comments on commit b2472a5

Please sign in to comment.