Skip to content

Commit b2472a5

Browse files
author
Aleksandar Prokopec
committed
Fixed a bug with setting execution contexts.
Ported most of the future tests.
1 parent 9c4baa9 commit b2472a5

File tree

7 files changed

+201
-57
lines changed

7 files changed

+201
-57
lines changed

src/library/scala/concurrent/ConcurrentPackageObject.scala

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ abstract class ConcurrentPackageObject {
2222
*/
2323
lazy val defaultExecutionContext = new impl.ExecutionContextImpl(null)
2424

25-
private val currentExecutionContext = new ThreadLocal[ExecutionContext]
25+
val currentExecutionContext = new ThreadLocal[ExecutionContext]
2626

2727
val handledFutureException: PartialFunction[Throwable, Throwable] = {
2828
case t: Throwable if isFutureThrowable(t) => t
@@ -82,7 +82,7 @@ abstract class ConcurrentPackageObject {
8282
* - TimeoutException - in the case that the blockable object timed out
8383
*/
8484
def blocking[T](body: =>T): T =
85-
blocking(impl.Future.body2awaitable(body), Duration.fromNanos(0))
85+
blocking(impl.Future.body2awaitable(body), Duration.Inf)
8686

8787
/** Blocks on an awaitable object.
8888
*
@@ -93,11 +93,12 @@ abstract class ConcurrentPackageObject {
9393
* - InterruptedException - in the case that a wait within the blockable object was interrupted
9494
* - TimeoutException - in the case that the blockable object timed out
9595
*/
96-
def blocking[T](awaitable: Awaitable[T], atMost: Duration): T =
96+
def blocking[T](awaitable: Awaitable[T], atMost: Duration): T = {
9797
currentExecutionContext.get match {
98-
case null => Await.result(awaitable, atMost)
98+
case null => awaitable.result(atMost)(Await.canAwaitEvidence)
9999
case ec => ec.internalBlockingCall(awaitable, atMost)
100100
}
101+
}
101102

102103
@inline implicit final def int2durationops(x: Int): DurationOps = new DurationOps(x)
103104
}

src/library/scala/concurrent/ExecutionContext.scala

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,11 +47,18 @@ object ExecutionContext {
4747

4848
/** Creates an `ExecutionContext` from the given `ExecutorService`.
4949
*/
50-
def fromExecutorService(e: ExecutorService): ExecutionContext with Executor = new impl.ExecutionContextImpl(e)
50+
def fromExecutorService(e: ExecutorService, reporter: Throwable => Unit = defaultReporter): ExecutionContext with Executor = new impl.ExecutionContextImpl(e, reporter)
5151

5252
/** Creates an `ExecutionContext` from the given `Executor`.
5353
*/
54-
def fromExecutor(e: Executor): ExecutionContext with Executor = new impl.ExecutionContextImpl(e)
54+
def fromExecutor(e: Executor, reporter: Throwable => Unit = defaultReporter): ExecutionContext with Executor = new impl.ExecutionContextImpl(e, reporter)
55+
56+
def defaultReporter: Throwable => Unit = {
57+
// `Error`s are currently wrapped by `resolver`.
58+
// Also, re-throwing `Error`s here causes an exception handling test to fail.
59+
//case e: Error => throw e
60+
case t => t.printStackTrace()
61+
}
5562

5663
}
5764

src/library/scala/concurrent/impl/ExecutionContextImpl.scala

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@ import scala.concurrent.util.{ Duration }
1717

1818

1919

20-
private[scala] class ExecutionContextImpl(es: AnyRef) extends ExecutionContext with Executor {
20+
private[scala] class ExecutionContextImpl(es: AnyRef, reporter: Throwable => Unit = ExecutionContext.defaultReporter)
21+
extends ExecutionContext with Executor {
2122
import ExecutionContextImpl._
2223

2324
val executorService: AnyRef = if (es eq null) getExecutorService else es
@@ -26,7 +27,7 @@ private[scala] class ExecutionContextImpl(es: AnyRef) extends ExecutionContext w
2627
def executorsThreadFactory = new ThreadFactory {
2728
def newThread(r: Runnable) = new Thread(new Runnable {
2829
override def run() {
29-
currentExecutionContext.set(ExecutionContextImpl.this)
30+
scala.concurrent.currentExecutionContext.set(ExecutionContextImpl.this)
3031
r.run()
3132
}
3233
})
@@ -36,7 +37,7 @@ private[scala] class ExecutionContextImpl(es: AnyRef) extends ExecutionContext w
3637
def forkJoinPoolThreadFactory = new ForkJoinPool.ForkJoinWorkerThreadFactory {
3738
def newThread(fjp: ForkJoinPool) = new ForkJoinWorkerThread(fjp) {
3839
override def onStart() {
39-
currentExecutionContext.set(ExecutionContextImpl.this)
40+
scala.concurrent.currentExecutionContext.set(ExecutionContextImpl.this)
4041
}
4142
}
4243
}
@@ -92,22 +93,13 @@ private[scala] class ExecutionContextImpl(es: AnyRef) extends ExecutionContext w
9293
}
9394
}
9495

95-
def reportFailure(t: Throwable) = t match {
96-
// `Error`s are currently wrapped by `resolver`.
97-
// Also, re-throwing `Error`s here causes an exception handling test to fail.
98-
//case e: Error => throw e
99-
case t => t.printStackTrace()
100-
}
96+
def reportFailure(t: Throwable) = reporter(t)
10197

10298
}
10399

104100

105101
private[concurrent] object ExecutionContextImpl {
106-
107-
private[concurrent] def currentExecutionContext: ThreadLocal[ExecutionContext] = new ThreadLocal[ExecutionContext] {
108-
override protected def initialValue = null
109-
}
110-
102+
111103
}
112104

113105

src/library/scala/concurrent/impl/Promise.scala

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ object Promise {
7878
*/
7979
class DefaultPromise[T](implicit val executor: ExecutionContext) extends AbstractPromise with Promise[T] { self =>
8080
updater.set(this, Nil) // Start at "No callbacks" //FIXME switch to Unsafe instead of ARFU
81-
81+
8282
protected final def tryAwait(atMost: Duration): Boolean = {
8383
@tailrec
8484
def awaitUnsafe(waitTimeNanos: Long): Boolean = {
@@ -88,7 +88,7 @@ object Promise {
8888
val start = System.nanoTime()
8989
try {
9090
synchronized {
91-
while (!isCompleted) wait(ms, ns)
91+
if (!isCompleted) wait(ms, ns) // previously - this was a `while`, ending up in an infinite loop
9292
}
9393
} catch {
9494
case e: InterruptedException =>
@@ -99,7 +99,7 @@ object Promise {
9999
isCompleted
100100
}
101101
//FIXME do not do this if there'll be no waiting
102-
blocking(Future.body2awaitable(awaitUnsafe(if (atMost.isFinite) atMost.toNanos else Long.MaxValue)), atMost)
102+
awaitUnsafe(if (atMost.isFinite) atMost.toNanos else Long.MaxValue)
103103
}
104104

105105
@throws(classOf[TimeoutException])
@@ -147,7 +147,9 @@ object Promise {
147147
}
148148
tryComplete(resolveEither(value))
149149
} finally {
150-
synchronized { notifyAll() } //Notify any evil blockers
150+
synchronized { //Notify any evil blockers
151+
notifyAll()
152+
}
151153
}
152154
}
153155

src/library/scala/concurrent/package.scala

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,15 @@ package concurrent {
2626
object Await {
2727
private[concurrent] implicit val canAwaitEvidence = new CanAwait {}
2828

29-
def ready[T](awaitable: Awaitable[T], atMost: Duration): Awaitable[T] = awaitable.ready(atMost)
29+
def ready[T <: Awaitable[_]](awaitable: T, atMost: Duration): T = {
30+
blocking(awaitable, atMost)
31+
awaitable
32+
}
33+
34+
def result[T](awaitable: Awaitable[T], atMost: Duration): T = {
35+
blocking(awaitable, atMost)
36+
}
3037

31-
def result[T](awaitable: Awaitable[T], atMost: Duration): T = awaitable.result(atMost)
3238
}
3339

3440
final class DurationOps private[concurrent] (x: Int) {

src/library/scala/concurrent/util/Duration.scala

Lines changed: 29 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -115,13 +115,13 @@ object Duration {
115115
* Parse TimeUnit from string representation.
116116
*/
117117
protected[util] def timeUnit(unit: String): TimeUnit = unit.toLowerCase match {
118-
case "d" | "day" | "days" DAYS
119-
case "h" | "hour" | "hours" HOURS
120-
case "min" | "minute" | "minutes" MINUTES
121-
case "s" | "sec" | "second" | "seconds" SECONDS
122-
case "ms" | "milli" | "millis" | "millisecond" | "milliseconds" MILLISECONDS
123-
case "µs" | "micro" | "micros" | "microsecond" | "microseconds" MICROSECONDS
124-
case "ns" | "nano" | "nanos" | "nanosecond" | "nanoseconds" NANOSECONDS
118+
case "d" | "day" | "days" => DAYS
119+
case "h" | "hour" | "hours" => HOURS
120+
case "min" | "minute" | "minutes" => MINUTES
121+
case "s" | "sec" | "second" | "seconds" => SECONDS
122+
case "ms" | "milli" | "millis" | "millisecond" | "milliseconds" => MILLISECONDS
123+
case "µs" | "micro" | "micros" | "microsecond" | "microseconds" => MICROSECONDS
124+
case "ns" | "nano" | "nanos" | "nanosecond" | "nanoseconds" => NANOSECONDS
125125
}
126126

127127
val Zero: FiniteDuration = new FiniteDuration(0, NANOSECONDS)
@@ -138,26 +138,26 @@ object Duration {
138138
}
139139

140140
trait Infinite {
141-
this: Duration
141+
this: Duration =>
142142

143143
def +(other: Duration): Duration =
144144
other match {
145-
case _: this.type this
146-
case _: Infinite throw new IllegalArgumentException("illegal addition of infinities")
147-
case _ this
145+
case _: this.type => this
146+
case _: Infinite => throw new IllegalArgumentException("illegal addition of infinities")
147+
case _ => this
148148
}
149149
def -(other: Duration): Duration =
150150
other match {
151-
case _: this.type throw new IllegalArgumentException("illegal subtraction of infinities")
152-
case _ this
151+
case _: this.type => throw new IllegalArgumentException("illegal subtraction of infinities")
152+
case _ => this
153153
}
154154
def *(factor: Double): Duration = this
155155
def /(factor: Double): Duration = this
156156
def /(other: Duration): Double =
157157
other match {
158-
case _: Infinite throw new IllegalArgumentException("illegal division of infinities")
158+
case _: Infinite => throw new IllegalArgumentException("illegal division of infinities")
159159
// maybe questionable but pragmatic: Inf / 0 => Inf
160-
case x Double.PositiveInfinity * (if ((this > Zero) ^ (other >= Zero)) -1 else 1)
160+
case x => Double.PositiveInfinity * (if ((this > Zero) ^ (other >= Zero)) -1 else 1)
161161
}
162162

163163
def finite_? = false
@@ -300,20 +300,20 @@ class FiniteDuration(val length: Long, val unit: TimeUnit) extends Duration {
300300
def toUnit(u: TimeUnit) = toNanos.toDouble / NANOSECONDS.convert(1, u)
301301

302302
override def toString = this match {
303-
case Duration(1, DAYS) "1 day"
304-
case Duration(x, DAYS) x + " days"
305-
case Duration(1, HOURS) "1 hour"
306-
case Duration(x, HOURS) x + " hours"
307-
case Duration(1, MINUTES) "1 minute"
308-
case Duration(x, MINUTES) x + " minutes"
309-
case Duration(1, SECONDS) "1 second"
310-
case Duration(x, SECONDS) x + " seconds"
311-
case Duration(1, MILLISECONDS) "1 millisecond"
312-
case Duration(x, MILLISECONDS) x + " milliseconds"
313-
case Duration(1, MICROSECONDS) "1 microsecond"
314-
case Duration(x, MICROSECONDS) x + " microseconds"
315-
case Duration(1, NANOSECONDS) "1 nanosecond"
316-
case Duration(x, NANOSECONDS) x + " nanoseconds"
303+
case Duration(1, DAYS) => "1 day"
304+
case Duration(x, DAYS) => x + " days"
305+
case Duration(1, HOURS) => "1 hour"
306+
case Duration(x, HOURS) => x + " hours"
307+
case Duration(1, MINUTES) => "1 minute"
308+
case Duration(x, MINUTES) => x + " minutes"
309+
case Duration(1, SECONDS) => "1 second"
310+
case Duration(x, SECONDS) => x + " seconds"
311+
case Duration(1, MILLISECONDS) => "1 millisecond"
312+
case Duration(x, MILLISECONDS) => x + " milliseconds"
313+
case Duration(1, MICROSECONDS) => "1 microsecond"
314+
case Duration(x, MICROSECONDS) => x + " microseconds"
315+
case Duration(1, NANOSECONDS) => "1 nanosecond"
316+
case Duration(x, NANOSECONDS) => x + " nanoseconds"
317317
}
318318

319319
def compare(other: Duration) =

0 commit comments

Comments
 (0)