Skip to content

Commit

Permalink
improvement: Support interrupted shutdown hooks in multithreaded appl…
Browse files Browse the repository at this point in the history
…ication (#3850)

* Implement `JoinNonDeamonThreads` as method called by main thread injected by codegen instead of shutdown hook
* Prevent deadlocks when executing shutdown hooks on SIGTERM/SIGINT
* Add scripted tests for interuption of threads executiuon
* Supress showing InterruptedExceptions when building
* Don't run shutdown test on Windows, due to deadlocks in the GC - cannot be fixed currently
  • Loading branch information
WojciechMazur committed Mar 24, 2024
1 parent 8ac5d27 commit 8c49efa
Show file tree
Hide file tree
Showing 13 changed files with 183 additions and 54 deletions.
9 changes: 7 additions & 2 deletions javalib/src/main/scala/java/lang/Runtime.scala
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import scala.scalanative.posix.unistd._
import scala.scalanative.windows.SysInfoApi._
import scala.scalanative.windows.SysInfoApiOps._
import scala.scalanative.unsafe._
import scala.scalanative.meta.LinktimeInfo.isWindows
import scala.scalanative.meta.LinktimeInfo._

class Runtime private () {
import Runtime._
Expand All @@ -22,13 +22,18 @@ class Runtime private () {

// https://docs.oracle.com/en/java/javase/21/docs/specs/man/java.html
// Currently, we use C lib signals so SIGHUP is not covered for POSIX platforms.

lazy val setupSignalHandler = {
// Executing handler during GC might lead to deadlock
// Make sure include any additional signals in `Synchronizer_init` and `sigset_t signalsBlockedDuringGC` in both Immix/Commix GC
// Warning: We cannot safetly adapt Boehm GC - it can deadlock for the same reasons as above
signal.signal(signal.SIGINT, handleSignal(_))
signal.signal(signal.SIGTERM, handleSignal(_))
}

private def handleSignal(sig: CInt): Unit = {
if (isMultithreadingEnabled) {
scalanative.runtime.Proxy.skipWaitingForNonDeamonThreads()
}
Runtime.getRuntime().runHooks()
exit(128 + sig)
}
Expand Down
1 change: 0 additions & 1 deletion javalib/src/main/scala/java/lang/Thread.scala
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,6 @@ class Thread private[lang] (

def start(): Unit = synchronized {
if (!isMultithreadingEnabled) UnsupportedFeature.threads()
if (!isDaemon()) JoinNonDaemonThreads.registerExitHook
if (isVirtual())
throw new UnsupportedOperationException(
"VirtualThreads are not yet supported"
Expand Down

This file was deleted.

2 changes: 2 additions & 0 deletions javalib/src/main/scala/scala/scalanative/runtime/Proxy.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,6 @@ object Proxy {
def GC_setWeakReferencesCollectedCallback(
callback: GCWeakReferencesCollectedCallback
): Unit = GC.setWeakReferencesCollectedCallback(callback)

def skipWaitingForNonDeamonThreads(): Unit = JoinNonDaemonThreads.skip()
}
20 changes: 18 additions & 2 deletions nativelib/src/main/resources/scala-native/gc/commix/Synchronizer.c
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,19 @@
#include "State.h"
#include "shared/ThreadUtil.h"
#include "MutatorThread.h"
#include <signal.h>

atomic_bool Synchronizer_stopThreads = false;
static mutex_t synchronizerLock;

#ifndef _WIN32
/* Receiving and handling SIGINT/SIGTERM during GC would lead to deadlocks
It can happen when thread executing GC would be suspended by signal handler.
Function executing handler might allocate new objects using GC, but when
doing so it would be stopped in Synchronizer_yield */
static sigset_t signalsBlockedDuringGC;
#endif

// Internal API used to implement threads execution yielding
static void Synchronizer_SuspendThreads(void);
static void Synchronizer_ResumeThreads(void);
Expand All @@ -32,7 +41,6 @@ static void Synchronizer_WaitForResumption(MutatorThread *selfThread);
#ifdef _WIN32
#include <errhandlingapi.h>
#else
#include <signal.h>
#include <pthread.h>
#include <sys/mman.h>
#include <sys/time.h>
Expand Down Expand Up @@ -234,6 +242,9 @@ void Synchronizer_init() {
exit(1);
}
#else
sigemptyset(&signalsBlockedDuringGC);
sigaddset(&signalsBlockedDuringGC, SIGINT);
sigaddset(&signalsBlockedDuringGC, SIGTERM);
if (pthread_mutex_init(&threadSuspension.lock, NULL) != 0 ||
pthread_cond_init(&threadSuspension.resume, NULL) != 0) {
perror("Failed to setup synchronizer lock");
Expand Down Expand Up @@ -268,7 +279,9 @@ bool Synchronizer_acquire() {
scalanative_GC_yield();
return false;
}

#ifndef _WIN32
sigprocmask(SIG_BLOCK, &signalsBlockedDuringGC, NULL);
#endif
// Don't allow for registration of any new threads;
MutatorThreads_lockRead();
Synchronizer_SuspendThreads();
Expand Down Expand Up @@ -298,6 +311,9 @@ void Synchronizer_release() {
mutex_unlock(&synchronizerLock);
MutatorThread_switchState(currentMutatorThread,
GC_MutatorThreadState_Managed);
#ifndef _WIN32
sigprocmask(SIG_UNBLOCK, &signalsBlockedDuringGC, NULL);
#endif
}

#endif
20 changes: 18 additions & 2 deletions nativelib/src/main/resources/scala-native/gc/immix/Synchronizer.c
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,19 @@
#include "State.h"
#include "shared/ThreadUtil.h"
#include "MutatorThread.h"
#include <signal.h>

atomic_bool Synchronizer_stopThreads = false;
static mutex_t synchronizerLock;

#ifndef _WIN32
/* Receiving and handling SIGINT/SIGTERM during GC would lead to deadlocks
It can happen when thread executing GC would be suspended by signal handler.
Function executing handler might allocate new objects using GC, but when
doing so it would be stopped in Synchronizer_yield */
static sigset_t signalsBlockedDuringGC;
#endif

// Internal API used to implement threads execution yielding
static void Synchronizer_SuspendThreads(void);
static void Synchronizer_ResumeThreads(void);
Expand All @@ -31,7 +40,6 @@ static void Synchronizer_WaitForResumption(MutatorThread *selfThread);
#ifdef _WIN32
#include <errhandlingapi.h>
#else
#include <signal.h>
#include <pthread.h>
#include <sys/mman.h>
#include <sys/time.h>
Expand Down Expand Up @@ -233,6 +241,9 @@ void Synchronizer_init() {
exit(1);
}
#else
sigemptyset(&signalsBlockedDuringGC);
sigaddset(&signalsBlockedDuringGC, SIGINT);
sigaddset(&signalsBlockedDuringGC, SIGTERM);
if (pthread_mutex_init(&threadSuspension.lock, NULL) != 0 ||
pthread_cond_init(&threadSuspension.resume, NULL) != 0) {
perror("Failed to setup synchronizer lock");
Expand Down Expand Up @@ -267,7 +278,9 @@ bool Synchronizer_acquire() {
scalanative_GC_yield();
return false;
}

#ifndef _WIN32
sigprocmask(SIG_BLOCK, &signalsBlockedDuringGC, NULL);
#endif
// Don't allow for registration of any new threads;
MutatorThreads_lock();
Synchronizer_SuspendThreads();
Expand Down Expand Up @@ -297,6 +310,9 @@ void Synchronizer_release() {
mutex_unlock(&synchronizerLock);
MutatorThread_switchState(currentMutatorThread,
GC_MutatorThreadState_Managed);
#ifndef _WIN32
sigprocmask(SIG_UNBLOCK, &signalsBlockedDuringGC, NULL);
#endif
}

#endif
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package scala.scalanative.runtime

import NativeThread.Registry
import scala.scalanative.meta.LinktimeInfo.isMultithreadingEnabled

private[runtime] object JoinNonDaemonThreads {
private var shouldWait = true
def skip(): Unit = shouldWait = false

def run(): Unit = if (isMultithreadingEnabled) if (shouldWait) {
def pollNonDaemonThreads = Registry.aliveThreads.iterator
.map(_.thread)
.filter { thread =>
thread != Thread.currentThread() && !thread.isDaemon() &&
thread.isAlive()
}

Registry.onMainThreadTermination()
Iterator
.continually(pollNonDaemonThreads)
.takeWhile(_.hasNext && shouldWait)
.flatten
.foreach(_.join())
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,21 @@ object ScalaNativePluginInternal {
log: sbt.Logger
)(body: ExecutionContext => Future[T]): T = {
val executor =
Executors.newFixedThreadPool(Runtime.getRuntime().availableProcessors())
Executors.newFixedThreadPool(
Runtime.getRuntime().availableProcessors(),
(task: Runnable) => {
val thread = Executors.defaultThreadFactory().newThread(task)
val defaultExceptionHandler = thread.getUncaughtExceptionHandler()
thread.setUncaughtExceptionHandler {
(thread: Thread, ex: Throwable) =>
ex match {
case _: InterruptedException => log.trace(ex)
case _ => defaultExceptionHandler.uncaughtException(thread, ex)
}
}
thread
}
)
val ec = ExecutionContext.fromExecutor(executor, log.trace(_))
try Await.result(body(ec), Duration.Inf)
catch { case ex: Exception => executor.shutdownNow(); throw ex }
Expand Down
45 changes: 42 additions & 3 deletions scripted-tests/run/shutdown/build.sbt
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
import java.util.concurrent.TimeUnit
import java.nio.file.Files
import java.io.File
import java.util.Locale

val osName = System
.getProperty("os.name", "unknown")
.toLowerCase(Locale.ROOT)
val isWindows = osName.startsWith("windows")

scalaVersion := {
val scalaVersion = System.getProperty("scala.version")
Expand All @@ -11,10 +18,11 @@ scalaVersion := {
else scalaVersion
}

val runTest = taskKey[Unit]("run test")

enablePlugins(ScalaNativePlugin)
runTest := {

val runTestDeleteOnExit =
taskKey[Unit]("run test checking if shutdown hook is exucuted")
runTestDeleteOnExit := {
val cmd = (Compile / nativeLink).value.toString
val file = Files.createTempFile("foo", "")
assert(Files.exists(file))
Expand All @@ -23,3 +31,34 @@ runTest := {
assert(proc.exitValue == 0)
assert(!Files.exists(file))
}

def checkThreadsJoin(cmd: String, joinInMain: Boolean): Unit = {
val joinArg = if (joinInMain) "--join" else ""
val outFile = Files.createTempFile("proc-out", ".log").toFile()
val proc =
new ProcessBuilder(cmd, joinArg)
.redirectOutput(outFile)
.start()
Thread.sleep(3000)
assert(proc.isAlive())
proc.destroy()
assert(proc.waitFor(1, TimeUnit.SECONDS))
assert(proc.exitValue != 0)
val stdout = scala.io.Source.fromFile(outFile).mkString
println(stdout)
val matched = raw"On shutdown:(\d)".r.findAllMatchIn(stdout).toSeq
assert(matched.size == 8)
assert(matched.map(_.group(1)).distinct.size == 8)
}
val runTestThreadsJoin = taskKey[Unit]("test multithreaded shutdown")
runTestThreadsJoin := {
if (isWindows)
System.err.println(
"Not testing multithreaded shutdown on Windows - it can deadlock during the GC, due to the lack of signals blocking"
)
else {
val cmd = (Compile / nativeLink).value.toString
checkThreadsJoin(cmd, joinInMain = true)
checkThreadsJoin(cmd, joinInMain = false)
}
}
6 changes: 5 additions & 1 deletion scripted-tests/run/shutdown/test
Original file line number Diff line number Diff line change
@@ -1 +1,5 @@
> runTest
$ copy-file variants/SetDeleteOnExit.scala src/main/scala/Main.scala
> runTestDeleteOnExit

$ copy-file variants/ThreadsJoin.scala src/main/scala/Main.scala
> runTestThreadsJoin
18 changes: 18 additions & 0 deletions scripted-tests/run/shutdown/variants/ThreadsJoin.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import scala.util.Random
object Test {
def main(args: Array[String]): Unit = {
val joinThreads = args.contains("--join")
val threads = List
.tabulate(8) { id =>
new Thread(() => {
sys.addShutdownHook(println(s"On shutdown:$id"))
while (true) {
Thread.sleep(100 + Random.nextInt(1000))
print(s"$id;")
}
})
}
threads.foreach(_.start())
if (joinThreads) threads.foreach(_.join())
}
}

0 comments on commit 8c49efa

Please sign in to comment.