Skip to content

Commit

Permalink
Merge pull request #1734 from vasilmkd/blocking
Browse files Browse the repository at this point in the history
The work stealing pool becomes a BlockContext
  • Loading branch information
djspiewak committed Feb 27, 2021
2 parents 8274788 + dd624df commit 3f18865
Show file tree
Hide file tree
Showing 10 changed files with 483 additions and 104 deletions.

This file was deleted.

43 changes: 43 additions & 0 deletions core/js/src/main/scala/cats/effect/unsafe/workstealing.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
/*
* Copyright 2020-2021 Typelevel
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package cats.effect
package unsafe

import scala.concurrent.ExecutionContext

// Can you imagine a thread pool on JS? Have fun trying to extend or instantiate
// this class. Unfortunately, due to the explicit branching, this type leaks
// into the shared source code of IOFiber.scala.
private[effect] sealed abstract class WorkStealingThreadPool private ()
extends ExecutionContext {
def execute(runnable: Runnable): Unit
def reportFailure(cause: Throwable): Unit
private[effect] def executeFiber(fiber: IOFiber[_]): Unit
}

// Unfortunately, due to the explicit branching for optimization purposes, this
// type leaks into the shared source code of IOFiber.scala.
private[effect] sealed abstract class WorkerThread private () extends Thread {
def reschedule(fiber: IOFiber[_]): Unit
def schedule(fiber: IOFiber[_]): Unit
}

// Unfortunately, due to the explicit branching for optimization purposes, this
// type leaks into the shared source code of IOFiber.scala.
private[effect] sealed abstract class HelperThread private () extends Thread {
def schedule(fiber: IOFiber[_]): Unit
}
6 changes: 3 additions & 3 deletions core/jvm/src/main/scala/cats/effect/IOApp.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

package cats.effect

import scala.concurrent.CancellationException
import scala.concurrent.{blocking, CancellationException}

import java.util.concurrent.CountDownLatch

Expand Down Expand Up @@ -55,7 +55,7 @@ trait IOApp {
if (latch.getCount() > 0) {
val cancelLatch = new CountDownLatch(1)
fiber.cancel.unsafeRunAsync(_ => cancelLatch.countDown())(runtime)
cancelLatch.await()
blocking(cancelLatch.await())
}

// Clean up after ourselves, relevant for running IOApps in sbt,
Expand All @@ -76,7 +76,7 @@ trait IOApp {
}

try {
latch.await()
blocking(latch.await())
if (error != null) {
// Runtime has already been shutdown in IOFiber.
throw error
Expand Down
3 changes: 2 additions & 1 deletion core/jvm/src/main/scala/cats/effect/IOPlatform.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package cats.effect

import scala.annotation.unchecked.uncheckedVariance
import scala.concurrent.blocking
import scala.concurrent.duration._

import java.util.concurrent.{CompletableFuture, CountDownLatch, TimeUnit}
Expand All @@ -37,7 +38,7 @@ abstract private[effect] class IOPlatform[+A] { self: IO[A] =>
latch.countDown()
}

if (latch.await(limit.toNanos, TimeUnit.NANOSECONDS)) {
if (blocking(latch.await(limit.toNanos, TimeUnit.NANOSECONDS))) {
results.fold(throw _, a => Some(a))
} else {
None
Expand Down
182 changes: 182 additions & 0 deletions core/jvm/src/main/scala/cats/effect/unsafe/HelperThread.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
/*
* Copyright 2020-2021 Typelevel
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package cats.effect
package unsafe

import scala.concurrent.{BlockContext, CanAwait}

import java.util.concurrent.ConcurrentLinkedQueue
import java.util.concurrent.atomic.{AtomicBoolean, AtomicInteger}

/**
* A helper thread which is spawned whenever a blocking action is being executed
* by a [[WorkerThread]]. The purpose of this thread is to continue executing
* the fibers of the blocked [[WorkerThread]], one of which might ultimately
* unblock the currently blocked thread. Since all [[WorkerThreads]] drain their
* local queues before entering a blocking region, the helper threads do not
* actually steal fibers from the [[WorkerThread]]s. Instead, they operate
* solely on the `overflow` queue, where all drained fibers end up, as well as
* incoming fibers scheduled from outside the runtime. The helper thread loops
* until the [[WorkerThread]] which spawned it has exited the blocking section
* (by setting the `signal` variable of this thread), or until the `overflow`
* queue has been exhausted, whichever comes first.
*
* The helper thread itself extends [[scala.concurrent.BlockContext]], which
* means that it also has the ability to anticipate blocking actions. If
* blocking does occur on a helper thread, another helper thread is started to
* take its place. Similarly, that thread sticks around until it has been
* signalled to go away, or the `overflow` queue has been exhausted.
*
* As for why we're not simply using other [[WorkerThread]]s to take the place
* of other blocked [[WorkerThreads]], it comes down to optimization and
* simplicity of implementation. Blocking is simply not expected to occur
* frequently on the compute pool of Cats Effect, and over time, the users of
* Cats Effect are expected to learn and use machinery such as `IO.blocking` to
* properly delineate blocking actions. If blocking were to be anticipated in
* the [[WorkerThread]]s, their implementation (especially in the trickiest
* cases of proper finalization of the threads) would be much more complex. This
* way, both [[WorkerThread]] and [[HelperThread]] get to enjoy a somewhat
* simpler, more maintainable implementation. The [[WorkStealingThreadPool]]
* itself is heavily optimized for operating with a fixed number of
* [[WorkerThread]]s, and having a dynamic number of [[WorkerThread]] instances
* introduces more logic on the hot path.
*/
private[effect] final class HelperThread(
private[this] val threadPrefix: String,
private[this] val blockingThreadCounter: AtomicInteger,
private[this] val overflow: ConcurrentLinkedQueue[IOFiber[_]],
private[this] val pool: WorkStealingThreadPool)
extends Thread
with BlockContext {

/**
* Signalling mechanism through which the [[WorkerThread]] which spawned this
* [[HelperThread]] signals that it has successfully exited the blocking code
* region and that this [[HelperThread]] should finalize.
*/
private[this] val signal: AtomicBoolean = new AtomicBoolean(false)

/**
* A flag which is set whenever a blocking code region is entered. This is
* useful for detecting nested blocking regions, in order to avoid
* unnecessarily spawning extra [[HelperThread]]s.
*/
private[this] var blocking: Boolean = false

/**
* Called by the [[WorkerThread]] which spawned this [[HelperThread]], to
* notify the [[HelperThread]] that the [[WorkerThread]] is finished blocking
* and is returning to normal operation. The [[HelperThread]] should finalize
* and die.
*/
private[unsafe] def setSignal(): Unit = {
signal.lazySet(true)
}

/**
* Schedules a fiber on the `overflow` queue. [[HelperThread]]s exclusively
* work with fibers from the `overflow` queue.
*
* @param fiber the fiber to be scheduled on the `overflow` queue
*/
def schedule(fiber: IOFiber[_]): Unit = {
overflow.offer(fiber)
()
}

/**
* The run loop of the [[HelperThread]]. A loop iteration consists of
* checking the `overflow` queue for available work. If it cannot secure a
* fiber from the `overflow` queue, the [[HelperThread]] exits its runloop
* and dies. If a fiber is secured, it is executed.
*
* Each iteration of the loop is preceded with a global check of the status
* of the pool, as well as a check of the `signal` variable. In the case that
* any of these two variables have been set by another thread, it is a signal
* for the [[HelperThread]] to exit its runloop and die.
*/
override def run(): Unit = {
// Check for exit condition. Do not continue if the `WorkStealingPool` has
// been shut down, or the `WorkerThread` which spawned this `HelperThread`
// has finished blocking.
while (!pool.done && !signal.get()) {
val fiber = overflow.poll()

if (fiber eq null) {
// There are no more fibers on the overflow queue. Since the overflow
// queue is not a blocking queue, there is no point in busy waiting,
// especially since there is no guarantee that the `WorkerThread` which
// spawned this `HelperThread` will ever exit the blocking region, and
// new external work may never arrive on the `overflow` queue. This
// pathological case is not handled as it is a case of uncontrolled
// blocking on a fixed thread pool, an inherently careless and unsafe
// situation.
return
} else {
fiber.run()
}
}
}

/**
* A mechanism for executing support code before executing a blocking action.
*/
override def blockOn[T](thunk: => T)(implicit permission: CanAwait): T = {
if (blocking) {
// This `HelperThread` is already inside an enclosing blocking region.
// There is no need to spawn another `HelperThread`. Instead, directly
// execute the blocking action.
thunk
} else {
// Spawn a new `HelperThread` to take the place of this thread, as the
// current thread prepares to execute a blocking action.

// Logically enter the blocking region.
blocking = true

// Spawn a new `HelperThread`.
val helper = new HelperThread(threadPrefix, blockingThreadCounter, overflow, pool)
helper.setName(
s"$threadPrefix-blocking-helper-${blockingThreadCounter.incrementAndGet()}")
helper.setDaemon(true)
helper.start()

// With another `HelperThread` started, it is time to execute the blocking
// action.
val result = thunk

// Blocking is finished. Time to signal the spawned helper thread.
helper.setSignal()

// Do not proceed until the helper thread has fully died. This is terrible
// for performance, but it is justified in this case as the stability of
// the `WorkStealingThreadPool` is of utmost importance in the face of
// blocking, which in itself is **not** what the pool is optimized for.
// In practice however, unless looking at a completely pathological case
// of propagating blocking actions on every spawned helper thread, this is
// not an issue, as the `HelperThread`s are all executing `IOFiber[_]`
// instances, which mostly consist of non-blocking code.
helper.join()

// Logically exit the blocking region.
blocking = false

// Return the computed result from the blocking operation
result
}
}
}
66 changes: 66 additions & 0 deletions core/jvm/src/main/scala/cats/effect/unsafe/LocalQueue.scala
Original file line number Diff line number Diff line change
Expand Up @@ -516,6 +516,72 @@ private final class LocalQueue {
null
}

/**
* Steals all enqueued fibers and transfers them to the provided list.
*
* This method is called by the runtime when blocking is detected in order to
* give a chance to the fibers enqueued behind the `head` of the queue to run
* on another thread. More often than not, these fibers are the ones that will
* ultimately unblock the blocking fiber currently executing. Ideally, other
* [[WorkerThread]]s would steal the contents of this [[LocalQueue]].
* Unfortunately, in practice, careless blocking has a tendency to quickly
* spread around the [[WorkerThread]]s (and there's a fixed number of them) in
* the runtime and halt any and all progress. For that reason, this method is
* used to completely drain any remaining fibers and transfer them to other
* helper threads which will continue executing fibers until the blocked thread
* has been unblocked.
*
* Conceptually, this method is identical to [[LocalQueue#dequeue]], with the
* main difference being that the `head` of the queue is moved forward to
* match the `tail` of the queue, thus securing ''all'' remaining fibers.
*
* @param dst the destination list in which all remaining fibers are
* transferred
*/
def drain(dst: ArrayList[IOFiber[_]]): Unit = {
// A plain, unsynchronized load of the tail of the local queue.
val tl = tail

// A CAS loop on the head of the queue. The loop can break out of the whole
// method only when the "real" value of head has been successfully moved to
// match the tail of the queue.
while (true) {
// A load of the head of the queue using `acquire` semantics.
val hd = head.get()

val real = lsb(hd)

if (tl == real) {
// The tail and the "real" value of the head are equal. The queue is
// empty. There is nothing more to be done.
return
}

// Make sure to preserve the "steal" tag in the presence of a concurrent
// stealer. Otherwise, move the "steal" tag along with the "real" value.
val steal = msb(hd)
val newHd = if (steal == real) pack(tl, tl) else pack(steal, tl)

if (head.compareAndSet(hd, newHd)) {
// The head has been successfully moved forward and all remaining fibers
// secured. Proceed to null out the references to the fibers and
// transfer them to the destination list.
val n = unsignedShortSubtraction(tl, real)
var i = 0
while (i < n) {
val idx = index(real + i)
val fiber = buffer(idx)
buffer(idx) = null
dst.add(fiber)
i += 1
}

// The fibers have been transferred. Break out of the loop.
return
}
}
}

/**
* Checks whether the local queue is empty.
*
Expand Down
Loading

0 comments on commit 3f18865

Please sign in to comment.