-
Notifications
You must be signed in to change notification settings - Fork 506
/
TestContext.scala
117 lines (106 loc) · 3.1 KB
/
TestContext.scala
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
/*
* Copyright 2017 Typelevel
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package cats.effect.laws.util
import cats.effect.internals.NonFatal
import cats.effect.laws.util.TestContext.State
import scala.annotation.tailrec
import scala.collection.immutable.Queue
import scala.concurrent.ExecutionContext
import scala.util.Random
/**
* A `scala.concurrent.ExecutionContext` implementation that can be
* used for testing purposes.
*
* Usage:
*
* {{{
* implicit val ec = TestContext()
*
* ec.execute(new Runnable { def run() = println("task1") })
*
* ex.execute(new Runnable {
* def run() = {
* println("outer")
*
* ec.execute(new Runnable {
* def run() = println("inner")
* })
* }
* })
*
* // Nothing executes until `tick` gets called
* ec.tick()
*
* // Testing the resulting state
* assert(ec.state.tasks.isEmpty)
* assert(ec.state.lastReportedFailure == None)
* }}}
*/
final class TestContext private () extends ExecutionContext {
private[this] var stateRef = State(Queue.empty, None)
def execute(r: Runnable): Unit =
synchronized {
stateRef = stateRef.copy(tasks = stateRef.tasks.enqueue(r))
}
def reportFailure(cause: Throwable): Unit =
synchronized {
stateRef = stateRef.copy(lastReportedFailure = Some(cause))
}
/**
* Returns the internal state of the `TestContext`, useful for testing
* that certain execution conditions have been met.
*/
def state: State =
synchronized(stateRef)
/**
* Triggers execution by going through the queue of scheduled tasks and
* executing them all, until no tasks remain in the queue to execute.
*
* Order of execution isn't guaranteed, the queued `Runnable`s are
* being shuffled in order to simulate the needed non-determinism
* that happens with multi-threading.
*/
@tailrec def tick(): Unit = {
val queue = synchronized {
val ref = stateRef.tasks
stateRef = stateRef.copy(tasks = Queue.empty)
ref
}
if (queue.nonEmpty) {
// Simulating non-deterministic execution
val batch = Random.shuffle(queue)
for (r <- batch) try r.run() catch {
case NonFatal(ex) =>
synchronized {
stateRef = stateRef.copy(lastReportedFailure = Some(ex))
}
}
tick() // Next cycle please
}
}
}
object TestContext {
/** Builder for [[TestContext]] instances. */
def apply(): TestContext =
new TestContext
/**
* The internal state of [[TestContext]].
*/
final case class State(
tasks: Queue[Runnable],
lastReportedFailure: Option[Throwable]
)
}