Skip to content

Commit

Permalink
fix: Prevent deadlocks when mixing threads with queue based execution…
Browse files Browse the repository at this point in the history
… context (#3852)

Prevent deadlocks and prevent pre-early exit when mixing threads with queue based execution context
NativeThread.registry.aliveThreads now returns iterator instead of array
  • Loading branch information
WojciechMazur committed Mar 28, 2024
1 parent 6123483 commit 677b1bc
Show file tree
Hide file tree
Showing 13 changed files with 163 additions and 65 deletions.
4 changes: 1 addition & 3 deletions javalib/src/main/scala/java/lang/Runtime.scala
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,7 @@ class Runtime private () {
}

private def handleSignal(sig: CInt): Unit = {
if (isMultithreadingEnabled) {
scalanative.runtime.Proxy.skipWaitingForNonDeamonThreads()
}
scalanative.runtime.Proxy.disableGracefullShutdown()
Runtime.getRuntime().runHooks()
exit(128 + sig)
}
Expand Down
2 changes: 0 additions & 2 deletions javalib/src/main/scala/java/lang/Thread.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@ import scala.scalanative.libc.stdatomic.{AtomicLongLong, atomic_thread_fence}
import scala.scalanative.libc.stdatomic.memory_order._
import scala.scalanative.runtime.UnsupportedFeature

import scala.scalanative.runtime.JoinNonDaemonThreads

class Thread private[lang] (
@volatile private var name: String,
private[java] val platformCtx: PlatformThreadContext /* | Null */
Expand Down
2 changes: 1 addition & 1 deletion javalib/src/main/scala/java/lang/ThreadGroup.scala
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ class ThreadGroup(
if (out == null) throw new NullPointerException()
if (out.length == 0) 0
else {
val aliveThreads = NativeThread.Registry.aliveThreads
val aliveThreads = NativeThread.Registry.aliveThreads.toArray
@tailrec def loop(idx: Int, included: Int): Int =
if (idx == aliveThreads.length || included == out.length) included
else {
Expand Down
3 changes: 2 additions & 1 deletion javalib/src/main/scala/scala/scalanative/runtime/Proxy.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,5 +20,6 @@ object Proxy {
callback: GCWeakReferencesCollectedCallback
): Unit = GC.setWeakReferencesCollectedCallback(callback)

def skipWaitingForNonDeamonThreads(): Unit = JoinNonDaemonThreads.skip()
def disableGracefullShutdown(): Unit =
MainThreadShutdownContext.gracefully = false
}

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -14,25 +14,57 @@ object NativeExecutionContext {

private[runtime] object QueueExecutionContext
extends ExecutionContextExecutor {
private val queue: ListBuffer[Runnable] = new ListBuffer
override def execute(runnable: Runnable): Unit = queue += runnable
private val queue: Queue =
if (isMultithreadingEnabled) new Queue.Concurrent()
else new Queue.Singlethreaded()

override def execute(runnable: Runnable): Unit = {
queue.enqueue(runnable)
if (isMultithreadingEnabled) {
MainThreadShutdownContext.onTaskEnqueued()
}
}

override def reportFailure(t: Throwable): Unit = t.printStackTrace()

def hasNextTask: Boolean = queue.nonEmpty
def hasNextTask: Boolean = !queue.isEmpty
def availableTasks: Int = queue.size

def executeNextTask(): Unit = if (hasNextTask) {
val runnable = queue.remove(0)
try runnable.run()
catch {
case t: Throwable =>
QueueExecutionContext.reportFailure(t)
queue.dequeue() match {
case null => ()
case runnable =>
try runnable.run()
catch { case t: Throwable => QueueExecutionContext.reportFailure(t) }
}
}

/** Execute all the available tasks. Returns the number of executed tasks */
def executeAvailableTasks(): Unit = while (hasNextTask) {
executeNextTask()
}

private trait Queue {
def enqueue(runnable: Runnable): Unit
def dequeue(): Runnable
def size: Int
def isEmpty: Boolean
}
private object Queue {
class Concurrent() extends Queue {
val backend = new java.util.concurrent.ConcurrentLinkedQueue[Runnable]()
override def enqueue(runnable: Runnable): Unit = backend.add(runnable)
override def dequeue(): Runnable = backend.poll()
override def size: Int = backend.size()
override def isEmpty: Boolean = backend.isEmpty()
}
class Singlethreaded() extends Queue {
val backend = ListBuffer.empty[Runnable]
override def enqueue(runnable: Runnable) = backend += runnable
override def dequeue(): Runnable = backend.remove(0)
override def size: Int = backend.size
override def isEmpty: Boolean = backend.isEmpty
}
}
}
}
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
package scala.scalanative.runtime
package scala.scalanative
package runtime

import scala.scalanative.runtime.Intrinsics._
import scala.scalanative.runtime.GC.{ThreadRoutineArg, ThreadStartRoutine}
Expand All @@ -10,6 +11,7 @@ import scala.scalanative.runtime.ffi.stdatomic.memory_order._
import scala.annotation.nowarn

import java.util.concurrent.ConcurrentHashMap
import java.{util => ju}

trait NativeThread {
import NativeThread._
Expand Down Expand Up @@ -58,6 +60,7 @@ trait NativeThread {
protected def onTermination(): Unit = if (isMultithreadingEnabled) {
state = NativeThread.State.Terminated
Registry.remove(this)
MainThreadShutdownContext.onThreadFinished(this.thread)
}
}

Expand Down Expand Up @@ -104,14 +107,14 @@ object NativeThread {
private[NativeThread] def add(thread: NativeThread): Unit =
_aliveThreads.put(thread.thread.getId(): @nowarn, thread)

private[NativeThread] def remove(thread: NativeThread): Unit =
private[NativeThread] def remove(thread: NativeThread): Unit = {
_aliveThreads.remove(thread.thread.getId(): @nowarn)
}

def aliveThreads: scala.Array[NativeThread] =
_aliveThreads.values.toArray().asInstanceOf[scala.Array[NativeThread]]

def onMainThreadTermination() = {
_aliveThreads.remove(MainThreadId)
@nowarn
def aliveThreads: Iterable[NativeThread] = {
import scala.collection.JavaConverters._
_aliveThreads.values.asScala
}
}

Expand Down
44 changes: 33 additions & 11 deletions nativelib/src/main/scala/scala/scalanative/runtime/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,12 @@ import scalanative.unsafe._
import scalanative.unsigned.USize
import scalanative.runtime.Intrinsics._
import scalanative.runtime.monitor._
import scalanative.runtime.ffi.stdatomic.{atomic_thread_fence, memory_order}
import scala.scalanative.meta.LinktimeInfo.isMultithreadingEnabled

// Extract any fields from runtime package to ensure it does not require initialization
private object runtimeState {
var _filename: String = null
}
import java.util.concurrent.locks.LockSupport

package object runtime {
import runtimeState._

def filename = _filename
def filename = ExecInfo.filename

/** Used as a stub right hand of intrinsified methods. */
private[scalanative] def intrinsic: Nothing = throwUndefined()
Expand Down Expand Up @@ -77,16 +72,43 @@ package object runtime {
c += 1
}

_filename = fromCString(argv(0))
ExecInfo.filename = fromCString(argv(0))
args
}

/* Internal shutdown method called after successfully running the main method.
* Ensures that all scheduled tasks / non-deamon threads would finish before exit.
*/
@noinline private[runtime] def onShutdown(): Unit = {
NativeExecutionContext.QueueExecutionContext.executeAvailableTasks()
if (isMultithreadingEnabled) JoinNonDaemonThreads.run()
import MainThreadShutdownContext._
if (isMultithreadingEnabled) {
shutdownThread = Thread.currentThread()
atomic_thread_fence(memory_order.memory_order_seq_cst)
}
def pollNonDaemonThreads = NativeThread.Registry.aliveThreads.iterator
.map(_.thread)
.filter { thread =>
(thread ne shutdownThread) && !thread.isDaemon() &&
thread.isAlive()
}

def queue = NativeExecutionContext.QueueExecutionContext
def shouldWaitForThreads =
if (isMultithreadingEnabled) gracefully && pollNonDaemonThreads.hasNext
else false
def shouldRunQueuedTasks = gracefully && queue.hasNextTask

// Both runnable from the NativeExecutionContext.queue and the running threads can spawn new runnables
while ({
// drain the queue
queue.executeAvailableTasks()
// queue is empty, threads might be still running
if (isMultithreadingEnabled) {
if (shouldWaitForThreads) LockSupport.park()
// When unparked thread has either finished execution or there are new tasks enqueued
}
shouldWaitForThreads || shouldRunQueuedTasks
}) ()
}

private[scalanative] final def executeUncaughtExceptionHandler(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package scala.scalanative.runtime

import java.util.concurrent.locks.LockSupport
import scala.scalanative.meta.LinktimeInfo.isMultithreadingEnabled

// Extracted fields from runtime package to ensure it does not require initialization
private[runtime] object MainThreadShutdownContext {
@volatile var shutdownThread: Thread = _
var gracefully: Boolean = true

def inShutdown: Boolean = shutdownThread != null

/* Notify that thread has */
def onThreadFinished(thread: Thread): Unit = if (!thread.isDaemon()) signal()
def onTaskEnqueued(): Unit = signal()

private def signal() =
if (isMultithreadingEnabled)
if (inShutdown)
LockSupport.unpark(shutdownThread)
}

private object ExecInfo {
var filename: String = null
}
10 changes: 10 additions & 0 deletions scripted-tests/run/shutdown/build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -62,3 +62,13 @@ runTestThreadsJoin := {
checkThreadsJoin(cmd, joinInMain = false)
}
}

val runTestQueueWithThreads = taskKey[Unit](
"test multithreaded shutdown in mixed environement using Queue and Threads scheduling"
)
runTestQueueWithThreads := {
val cmd = (Compile / nativeLink).value.toString
val proc = new ProcessBuilder(cmd).start()
assert(proc.waitFor(5, TimeUnit.SECONDS))
assert(proc.exitValue == 0)
}
3 changes: 3 additions & 0 deletions scripted-tests/run/shutdown/test
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,6 @@ $ copy-file variants/SetDeleteOnExit.scala src/main/scala/Main.scala

$ copy-file variants/ThreadsJoin.scala src/main/scala/Main.scala
> runTestThreadsJoin

$ copy-file variants/QueueWithThreads.scala src/main/scala/Main.scala
> runTestQueueWithThreads
37 changes: 37 additions & 0 deletions scripted-tests/run/shutdown/variants/QueueWithThreads.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import java.util.concurrent.locks.ReentrantLock
import java.util.concurrent.CountDownLatch
import scala.scalanative.runtime.NativeExecutionContext

object Test {
def main(args: Array[String]): Unit = {
println("Hello, World!")
def spawnRunnable(name: String)(fn: => Unit) =
NativeExecutionContext.queue
.execute(() => { fn; println(s"task $name done") })

def spawnThread(name: String)(fn: => Unit) = {
val t = new Thread(() => { fn; println(s"thread $name done") })
t.setName(name)
t.start()
}

spawnThread("T1") {
val latch1 = new CountDownLatch(1)
spawnRunnable("R1") { latch1.countDown() }
spawnThread("T2") {
latch1.await() // blocks until T1, R1 are done
val latch2 = new CountDownLatch(1)
val latch3 = new CountDownLatch(3)
spawnThread("T3") {
spawnRunnable("R2") { latch2.await(); latch3.countDown() }
}
spawnThread("T4") {
spawnRunnable("R3") { latch2.await(); latch3.countDown() }
}
spawnRunnable("R4") { latch2.await(); latch3.countDown() }
latch2.countDown()
latch3.await()
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -661,12 +661,6 @@ private[codegen] object Generate {
nir.Sig.Method("executeUncaughtExceptionHandler", Seq(JavaThreadUEHRef, JavaThreadRef, Throwable, nir.Type.Unit))
)

val JoinNonDaemonThreadsModule = nir.Type.Ref(nir.Global.Top("scala.scalanative.runtime.JoinNonDaemonThreads"))
val JoinNonDaemonThreadsRun =
JoinNonDaemonThreadsModule.name
.member(nir.Sig.Method("run", Seq(nir.Type.Unit), nir.Sig.Scope.PublicStatic))
val JoinNonDaemonThreadsRunSig = nir.Type.Function(Seq(), nir.Type.Unit)

val InitSig = nir.Type.Function(Seq.empty, nir.Type.Unit)
val InitDecl = nir.Defn.Declare(nir.Attrs.None, extern("scalanative_GC_init"), InitSig)
val Init = nir.Val.Global(InitDecl.name, nir.Type.Ptr)
Expand Down Expand Up @@ -699,6 +693,6 @@ private[codegen] object Generate {
JavaThreadCurrentThread,
JavaThreadGetUEH,
JavaThreadUEH
) ++ { if (platform.isMultithreadingEnabled) Seq(JoinNonDaemonThreadsRun) else Nil }
)
}
}

0 comments on commit 677b1bc

Please sign in to comment.