forked from Kotlin/coroutines-examples
-
Notifications
You must be signed in to change notification settings - Fork 0
/
threadContext.kt
48 lines (39 loc) · 2 KB
/
threadContext.kt
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
package context
import java.util.concurrent.Executors
import java.util.concurrent.ScheduledExecutorService
import java.util.concurrent.atomic.AtomicInteger
import kotlin.concurrent.thread
import kotlin.coroutines.experimental.AbstractCoroutineContextElement
import kotlin.coroutines.experimental.Continuation
import kotlin.coroutines.experimental.ContinuationInterceptor
fun newFixedThreadPoolContext(nThreads: Int, name: String) = ThreadContext(nThreads, name)
fun newSingleThreadContext(name: String) = ThreadContext(1, name)
private val thisThreadContext = ThreadLocal<ThreadContext>()
class ThreadContext(
nThreads: Int,
name: String
) : AbstractCoroutineContextElement(ContinuationInterceptor), ContinuationInterceptor {
val threadNo = AtomicInteger()
val executor: ScheduledExecutorService = Executors.newScheduledThreadPool(nThreads) { target ->
thread(start = false, isDaemon = true, name = name + "-" + threadNo.incrementAndGet()) {
thisThreadContext.set(this@ThreadContext)
target.run()
}
}
override fun <T> interceptContinuation(continuation: Continuation<T>): Continuation<T> =
ThreadContinuation(continuation.context.fold(continuation, { cont, element ->
if (element != this@ThreadContext && element is ContinuationInterceptor)
element.interceptContinuation(cont) else cont
}))
private inner class ThreadContinuation<T>(val continuation: Continuation<T>) : Continuation<T> by continuation {
override fun resume(value: T) {
if (isContextThread()) continuation.resume(value)
else executor.execute { continuation.resume(value) }
}
override fun resumeWithException(exception: Throwable) {
if (isContextThread()) continuation.resumeWithException(exception)
else executor.execute { continuation.resumeWithException(exception) }
}
}
private fun isContextThread() = thisThreadContext.get() == this@ThreadContext
}