Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Prevent deadlocks when mixing threads with queue based execution context #3852

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 1 addition & 3 deletions javalib/src/main/scala/java/lang/Runtime.scala
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,7 @@ class Runtime private () {
}

private def handleSignal(sig: CInt): Unit = {
if (isMultithreadingEnabled) {
scalanative.runtime.Proxy.skipWaitingForNonDeamonThreads()
}
scalanative.runtime.Proxy.disableGracefullShutdown()
Runtime.getRuntime().runHooks()
exit(128 + sig)
}
Expand Down
2 changes: 0 additions & 2 deletions javalib/src/main/scala/java/lang/Thread.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@ import scala.scalanative.libc.stdatomic.{AtomicLongLong, atomic_thread_fence}
import scala.scalanative.libc.stdatomic.memory_order._
import scala.scalanative.runtime.UnsupportedFeature

import scala.scalanative.runtime.JoinNonDaemonThreads

class Thread private[lang] (
@volatile private var name: String,
private[java] val platformCtx: PlatformThreadContext /* | Null */
Expand Down
2 changes: 1 addition & 1 deletion javalib/src/main/scala/java/lang/ThreadGroup.scala
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ class ThreadGroup(
if (out == null) throw new NullPointerException()
if (out.length == 0) 0
else {
val aliveThreads = NativeThread.Registry.aliveThreads
val aliveThreads = NativeThread.Registry.aliveThreads.toArray
@tailrec def loop(idx: Int, included: Int): Int =
if (idx == aliveThreads.length || included == out.length) included
else {
Expand Down
3 changes: 2 additions & 1 deletion javalib/src/main/scala/scala/scalanative/runtime/Proxy.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,5 +20,6 @@ object Proxy {
callback: GCWeakReferencesCollectedCallback
): Unit = GC.setWeakReferencesCollectedCallback(callback)

def skipWaitingForNonDeamonThreads(): Unit = JoinNonDaemonThreads.skip()
def disableGracefullShutdown(): Unit =
MainThreadShutdownContext.gracefully = false
}

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -14,25 +14,57 @@ object NativeExecutionContext {

private[runtime] object QueueExecutionContext
extends ExecutionContextExecutor {
private val queue: ListBuffer[Runnable] = new ListBuffer
override def execute(runnable: Runnable): Unit = queue += runnable
private val queue: Queue =
if (isMultithreadingEnabled) new Queue.Concurrent()
else new Queue.Singlethreaded()

override def execute(runnable: Runnable): Unit = {
queue.enqueue(runnable)
if (isMultithreadingEnabled) {
MainThreadShutdownContext.onTaskEnqueued()
}
}

override def reportFailure(t: Throwable): Unit = t.printStackTrace()

def hasNextTask: Boolean = queue.nonEmpty
def hasNextTask: Boolean = !queue.isEmpty
def availableTasks: Int = queue.size

def executeNextTask(): Unit = if (hasNextTask) {
val runnable = queue.remove(0)
try runnable.run()
catch {
case t: Throwable =>
QueueExecutionContext.reportFailure(t)
queue.dequeue() match {
case null => ()
case runnable =>
try runnable.run()
catch { case t: Throwable => QueueExecutionContext.reportFailure(t) }
}
}

/** Execute all the available tasks. Returns the number of executed tasks */
def executeAvailableTasks(): Unit = while (hasNextTask) {
executeNextTask()
}

private trait Queue {
def enqueue(runnable: Runnable): Unit
def dequeue(): Runnable
def size: Int
def isEmpty: Boolean
}
private object Queue {
class Concurrent() extends Queue {
val backend = new java.util.concurrent.ConcurrentLinkedQueue[Runnable]()
override def enqueue(runnable: Runnable): Unit = backend.add(runnable)
override def dequeue(): Runnable = backend.poll()
override def size: Int = backend.size()
override def isEmpty: Boolean = backend.isEmpty()
}
class Singlethreaded() extends Queue {
val backend = ListBuffer.empty[Runnable]
override def enqueue(runnable: Runnable) = backend += runnable
override def dequeue(): Runnable = backend.remove(0)
override def size: Int = backend.size
override def isEmpty: Boolean = backend.isEmpty
}
}
}
}
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
package scala.scalanative.runtime
package scala.scalanative
package runtime

import scala.scalanative.runtime.Intrinsics._
import scala.scalanative.runtime.GC.{ThreadRoutineArg, ThreadStartRoutine}
Expand All @@ -10,6 +11,7 @@ import scala.scalanative.runtime.ffi.stdatomic.memory_order._
import scala.annotation.nowarn

import java.util.concurrent.ConcurrentHashMap
import java.{util => ju}

trait NativeThread {
import NativeThread._
Expand Down Expand Up @@ -58,6 +60,7 @@ trait NativeThread {
protected def onTermination(): Unit = if (isMultithreadingEnabled) {
state = NativeThread.State.Terminated
Registry.remove(this)
MainThreadShutdownContext.onThreadFinished(this.thread)
}
}

Expand Down Expand Up @@ -104,14 +107,14 @@ object NativeThread {
private[NativeThread] def add(thread: NativeThread): Unit =
_aliveThreads.put(thread.thread.getId(): @nowarn, thread)

private[NativeThread] def remove(thread: NativeThread): Unit =
private[NativeThread] def remove(thread: NativeThread): Unit = {
_aliveThreads.remove(thread.thread.getId(): @nowarn)
}

def aliveThreads: scala.Array[NativeThread] =
_aliveThreads.values.toArray().asInstanceOf[scala.Array[NativeThread]]

def onMainThreadTermination() = {
_aliveThreads.remove(MainThreadId)
@nowarn
def aliveThreads: Iterable[NativeThread] = {
import scala.collection.JavaConverters._
_aliveThreads.values.asScala
}
}

Expand Down
44 changes: 33 additions & 11 deletions nativelib/src/main/scala/scala/scalanative/runtime/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,12 @@ import scalanative.unsafe._
import scalanative.unsigned.USize
import scalanative.runtime.Intrinsics._
import scalanative.runtime.monitor._
import scalanative.runtime.ffi.stdatomic.{atomic_thread_fence, memory_order}
import scala.scalanative.meta.LinktimeInfo.isMultithreadingEnabled

// Extract any fields from runtime package to ensure it does not require initialization
private object runtimeState {
var _filename: String = null
}
import java.util.concurrent.locks.LockSupport

package object runtime {
import runtimeState._

def filename = _filename
def filename = ExecInfo.filename

/** Used as a stub right hand of intrinsified methods. */
private[scalanative] def intrinsic: Nothing = throwUndefined()
Expand Down Expand Up @@ -77,16 +72,43 @@ package object runtime {
c += 1
}

_filename = fromCString(argv(0))
ExecInfo.filename = fromCString(argv(0))
args
}

/* Internal shutdown method called after successfully running the main method.
* Ensures that all scheduled tasks / non-deamon threads would finish before exit.
*/
@noinline private[runtime] def onShutdown(): Unit = {
NativeExecutionContext.QueueExecutionContext.executeAvailableTasks()
if (isMultithreadingEnabled) JoinNonDaemonThreads.run()
import MainThreadShutdownContext._
if (isMultithreadingEnabled) {
shutdownThread = Thread.currentThread()
atomic_thread_fence(memory_order.memory_order_seq_cst)
}
def pollNonDaemonThreads = NativeThread.Registry.aliveThreads.iterator
.map(_.thread)
.filter { thread =>
(thread ne shutdownThread) && !thread.isDaemon() &&
thread.isAlive()
}

def queue = NativeExecutionContext.QueueExecutionContext
def shouldWaitForThreads =
if (isMultithreadingEnabled) gracefully && pollNonDaemonThreads.hasNext
else false
def shouldRunQueuedTasks = gracefully && queue.hasNextTask

// Both runnable from the NativeExecutionContext.queue and the running threads can spawn new runnables
while ({
// drain the queue
queue.executeAvailableTasks()
// queue is empty, threads might be still running
if (isMultithreadingEnabled) {
if (shouldWaitForThreads) LockSupport.park()
// When unparked thread has either finished execution or there are new tasks enqueued
}
shouldWaitForThreads || shouldRunQueuedTasks
}) ()
}

private[scalanative] final def executeUncaughtExceptionHandler(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package scala.scalanative.runtime

import java.util.concurrent.locks.LockSupport
import scala.scalanative.meta.LinktimeInfo.isMultithreadingEnabled

// Extracted fields from runtime package to ensure it does not require initialization
private[runtime] object MainThreadShutdownContext {
@volatile var shutdownThread: Thread = _
var gracefully: Boolean = true

def inShutdown: Boolean = shutdownThread != null

/* Notify that thread has */
def onThreadFinished(thread: Thread): Unit = if (!thread.isDaemon()) signal()
def onTaskEnqueued(): Unit = signal()

private def signal() =
if (isMultithreadingEnabled)
if (inShutdown)
LockSupport.unpark(shutdownThread)
}

private object ExecInfo {
var filename: String = null
}
10 changes: 10 additions & 0 deletions scripted-tests/run/shutdown/build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -62,3 +62,13 @@ runTestThreadsJoin := {
checkThreadsJoin(cmd, joinInMain = false)
}
}

val runTestQueueWithThreads = taskKey[Unit](
"test multithreaded shutdown in mixed environement using Queue and Threads scheduling"
)
runTestQueueWithThreads := {
val cmd = (Compile / nativeLink).value.toString
val proc = new ProcessBuilder(cmd).start()
assert(proc.waitFor(5, TimeUnit.SECONDS))
assert(proc.exitValue == 0)
}
3 changes: 3 additions & 0 deletions scripted-tests/run/shutdown/test
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,6 @@ $ copy-file variants/SetDeleteOnExit.scala src/main/scala/Main.scala

$ copy-file variants/ThreadsJoin.scala src/main/scala/Main.scala
> runTestThreadsJoin

$ copy-file variants/QueueWithThreads.scala src/main/scala/Main.scala
> runTestQueueWithThreads
37 changes: 37 additions & 0 deletions scripted-tests/run/shutdown/variants/QueueWithThreads.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import java.util.concurrent.locks.ReentrantLock
import java.util.concurrent.CountDownLatch
import scala.scalanative.runtime.NativeExecutionContext

object Test {
def main(args: Array[String]): Unit = {
println("Hello, World!")
def spawnRunnable(name: String)(fn: => Unit) =
NativeExecutionContext.queue
.execute(() => { fn; println(s"task $name done") })

def spawnThread(name: String)(fn: => Unit) = {
val t = new Thread(() => { fn; println(s"thread $name done") })
t.setName(name)
t.start()
}

spawnThread("T1") {
val latch1 = new CountDownLatch(1)
spawnRunnable("R1") { latch1.countDown() }
spawnThread("T2") {
latch1.await() // blocks until T1, R1 are done
val latch2 = new CountDownLatch(1)
val latch3 = new CountDownLatch(3)
spawnThread("T3") {
spawnRunnable("R2") { latch2.await(); latch3.countDown() }
}
spawnThread("T4") {
spawnRunnable("R3") { latch2.await(); latch3.countDown() }
}
spawnRunnable("R4") { latch2.await(); latch3.countDown() }
latch2.countDown()
latch3.await()
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -661,12 +661,6 @@ private[codegen] object Generate {
nir.Sig.Method("executeUncaughtExceptionHandler", Seq(JavaThreadUEHRef, JavaThreadRef, Throwable, nir.Type.Unit))
)

val JoinNonDaemonThreadsModule = nir.Type.Ref(nir.Global.Top("scala.scalanative.runtime.JoinNonDaemonThreads"))
val JoinNonDaemonThreadsRun =
JoinNonDaemonThreadsModule.name
.member(nir.Sig.Method("run", Seq(nir.Type.Unit), nir.Sig.Scope.PublicStatic))
val JoinNonDaemonThreadsRunSig = nir.Type.Function(Seq(), nir.Type.Unit)

val InitSig = nir.Type.Function(Seq.empty, nir.Type.Unit)
val InitDecl = nir.Defn.Declare(nir.Attrs.None, extern("scalanative_GC_init"), InitSig)
val Init = nir.Val.Global(InitDecl.name, nir.Type.Ptr)
Expand Down Expand Up @@ -699,6 +693,6 @@ private[codegen] object Generate {
JavaThreadCurrentThread,
JavaThreadGetUEH,
JavaThreadUEH
) ++ { if (platform.isMultithreadingEnabled) Seq(JoinNonDaemonThreadsRun) else Nil }
)
}
}