Skip to content

Commit

Permalink
fix: Fix emitting of stack-growth guards and Await.result (#3804)
Browse files Browse the repository at this point in the history
* Fix emitting stack-growth guards, emit stack restore in loop entry, and stack-save in block before first entry
* Fix `Await.result(Future, Duration)` - these were waiting for Int-overflow instead of Byte-overflow
* Add runtime test for issue 3799
  • Loading branch information
WojciechMazur committed Mar 5, 2024
1 parent b49fc78 commit cd2c477
Show file tree
Hide file tree
Showing 7 changed files with 89 additions and 40 deletions.
16 changes: 8 additions & 8 deletions javalib/src/main/scala/java/util/concurrent/ForkJoinPool.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1510,16 +1510,16 @@ object ForkJoinPool {

final val UNCOMPENSATE = 1 << 16 // tryCompensate return
// Lower and upper word masks
private val SP_MASK: Long = 0xffffffffL
private val UC_MASK: Long = ~SP_MASK
private final val SP_MASK: Long = 0xffffffffL
private final val UC_MASK: Long = ~SP_MASK
// Release counts
private val RC_SHIFT: Int = 48
private val RC_UNIT: Long = 0x0001L << RC_SHIFT
private val RC_MASK: Long = 0xffffL << RC_SHIFT
private final val RC_SHIFT: Int = 48
private final val RC_UNIT: Long = 0x0001L << RC_SHIFT
private final val RC_MASK: Long = 0xffffL << RC_SHIFT
// Total counts
private val TC_SHIFT: Int = 32
private val TC_UNIT: Long = 0x0001L << TC_SHIFT
private val TC_MASK: Long = 0xffffL << TC_SHIFT
private final val TC_SHIFT: Int = 32
private final val TC_UNIT: Long = 0x0001L << TC_SHIFT
private final val TC_MASK: Long = 0xffffL << TC_SHIFT
// sp bits
private final val SS_SEQ = 1 << 16; // version count
private final val INACTIVE = 1 << 31; // phase bit when idle
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,8 @@ abstract class AbstractQueuedLongSynchronizer protected ()
val current = Thread.currentThread()

var node: Node = _node
var spins, postSpins = 0 // retries upon unpark of first thread
var spins: Byte = 0
var postSpins: Byte = 0 // retries upon unpark of first thread
var interrupted, first = false
var pred: Node = null // predecessor of node when enqueued

Expand Down Expand Up @@ -256,7 +257,7 @@ abstract class AbstractQueuedLongSynchronizer protected ()
else if (!casTail(t, node)) node.setPrevRelaxed(null) // back out
else t.next = node
} else if (first && spins != 0) {
spins -= 1 // reduce unfairness on rewaits
spins = (spins - 1).toByte // reduce unfairness on rewaits
Thread.onSpinWait()
} else if (node.status == 0)
node.status = WAITING // enable signal and recheck
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,8 @@ abstract class AbstractQueuedSynchronizer protected ()
val current = Thread.currentThread()

var node: Node = _node
var spins, postSpins = 0 // retries upon unpark of first thread
var spins: Byte = 0
var postSpins: Byte = 0 // retries upon unpark of first thread
var interrupted, first = false
var pred: Node = null // predecessor of node when enqueued

Expand Down Expand Up @@ -254,7 +255,7 @@ abstract class AbstractQueuedSynchronizer protected ()
else if (!casTail(t, node)) node.setPrevRelaxed(null) // back out
else t.next = node
} else if (first && spins != 0) {
spins -= 1 // reduce unfairness on rewaits
spins = (spins - 1).toByte // reduce unfairness on rewaits
Thread.onSpinWait()
} else if (node.status == 0)
node.status = WAITING // enable signal and recheck
Expand Down
40 changes: 22 additions & 18 deletions tools/src/main/scala/scala/scalanative/interflow/MergeBlock.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,27 +17,18 @@ final class MergeBlock(val label: nir.Inst.Label, val id: nir.Local) {
if (cf != null) cf.pos
else label.pos
}

private var stackSavePtr: nir.Val.Local = _
private var stackSavePtr: Option[nir.Val.Local] = None
private[interflow] var emitStackSaveOp = false
private[interflow] var emitStackRestoreFor: List[nir.Local] = Nil
private[interflow] var emitStackRestoreFromBlocks: List[MergeBlock] = Nil

def toInsts(): Seq[nir.Inst] = {
def toInsts(): Seq[nir.Inst] = toInstsCached
private lazy val toInstsCached: Seq[nir.Inst] = {
import Interflow.LLVMIntrinsics._
val block = this
val result = new nir.InstructionBuilder()(nir.Fresh(0))

def mergeNext(next: nir.Next.Label): nir.Next.Label = {
val nextBlock = outgoing(next.id)

if (nextBlock.stackSavePtr != null &&
emitStackRestoreFor.contains(next.id)) {
emitIfMissing(
end.fresh(),
nir.Op
.Call(StackRestoreSig, StackRestore, Seq(nextBlock.stackSavePtr))
)(result, block)
}
val mergeValues = nextBlock.phis.flatMap {
case MergePhi(_, incoming) =>
incoming.collect {
Expand All @@ -57,14 +48,27 @@ final class MergeBlock(val label: nir.Inst.Label, val id: nir.Local) {

val params = block.phis.map(_.param)
result.label(block.id, params)

if (emitStackSaveOp) {
val id = block.end.fresh()
val emmited = emitIfMissing(
id = id,
op = nir.Op.Call(StackSaveSig, StackSave, Nil)
)(result, block)
if (emmited) block.stackSavePtr = nir.Val.Local(id, nir.Type.Ptr)
if (emitIfMissing(
id = id,
op = nir.Op.Call(StackSaveSig, StackSave, Nil)
)(result, block)) {
block.stackSavePtr = Some(nir.Val.Local(id, nir.Type.Ptr))
}
}
block.emitStackRestoreFromBlocks
.filterNot(block == _)
.flatMap(_.stackSavePtr)
.distinct
.foreach { stackSavePtr =>
emitIfMissing(
end.fresh(),
nir.Op.Call(StackRestoreSig, StackRestore, Seq(stackSavePtr))
)(result, block)
}

result ++= block.end.emit
block.cf match {
case ret: nir.Inst.Ret =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,6 @@ private[interflow] object MergePostProcessor {
.foreach { cycle =>
val startIdx = cycle.map(blockIndices(_)).min
val start = blocks(startIdx)
val startName = start.label.id
val end = cycle((cycle.indexOf(start) + 1) % cycle.size)
assert(
end.outgoing.contains(start.label.id),
"Invalid cycle, last block does not point to cycle start"
)

def canEscapeAlloc = allocationEscapeCheck(
allocatingBlock = block,
Expand All @@ -77,8 +71,19 @@ private[interflow] object MergePostProcessor {
// If memory escapes current loop we cannot create stack stage guards
// Instead try to insert guard in outer loop
if (!canEscapeAlloc || innerCycleStart.exists(cycle.contains)) {
start.emitStackSaveOp = true
end.emitStackRestoreFor ::= startName
val loopEntries = start.incoming
.flatMap {
case (_, (_, state)) =>
val block = blocks.find(_.id == state.blockId).get
if (cycle.contains(block)) None
else if (blockIndices(block) >= startIdx) None
else Some(block)
}
// assert(entries.size == 1)
loopEntries.foreach { loopEnteringBlock =>
loopEnteringBlock.emitStackSaveOp = true
}
start.emitStackRestoreFromBlocks :::= loopEntries.toList
} else if (innerCycleStart.isEmpty) {
// If allocation escapes direct loop try to create state restore in outer loop
// Outer loop is a while loop which does not perform stack allocation, but is a cycle
Expand Down
4 changes: 2 additions & 2 deletions tools/src/main/scala/scala/scalanative/interflow/State.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@ import scalanative.util.unreachable
import scalanative.linker._
import scalanative.codegen.Lower

final class State(block: nir.Local)(preserveDebugInfo: Boolean) {
final class State(val blockId: nir.Local)(preserveDebugInfo: Boolean) {

var fresh = nir.Fresh(block.id)
var fresh = nir.Fresh(blockId.id)
/* Performance Note: nir.OpenHashMap/LongMap/AnyRefMap have a faster clone()
* operation. This really makes a difference on fullClone() */
var heap = mutable.LongMap.empty[Instance]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package scala.scalanative

import org.junit.Test
import org.junit.Assert._
import org.junit.Assume._
import org.scalanative.testsuite.utils.AssertThrows.assertThrows

import scalanative.unsigned._
Expand All @@ -10,6 +11,10 @@ import scala.annotation.nowarn
import scala.scalanative.annotation.alwaysinline

import scala.language.higherKinds
import scala.scalanative.meta.LinktimeInfo.isMultithreadingEnabled
import java.util.concurrent.ThreadFactory
import java.util.concurrent.Executors
import java.util.concurrent.TimeUnit

class IssuesTest {

Expand Down Expand Up @@ -632,6 +637,39 @@ class IssuesTest {
assertNotNull(xs.sortBy(i => -i))
}

@Test def issue3799(): Unit = if (isMultithreadingEnabled) {
import scala.concurrent._
import scala.concurrent.duration._
// Use a dedicated thread pool with threads of limited stack size for easier stack overflow detection
val executor = Executors.newFixedThreadPool(
2,
new Thread(
Thread.currentThread().getThreadGroup(),
_,
"test-issue3799:",
128 * 1024L
)
)
implicit val ec: ExecutionContext = ExecutionContext.fromExecutor(executor)
def loop(nextSchedule: Long): Future[Unit] = Future {
if (System.currentTimeMillis() > nextSchedule) {
System.currentTimeMillis() + 100
} else nextSchedule
}.flatMap { next => loop(next) }

try
assertThrows(
classOf[java.util.concurrent.TimeoutException],
Await.result(loop(0), 2.seconds)
)
finally {
executor.shutdown()
if (!executor.awaitTermination(5, TimeUnit.SECONDS)) {
executor.shutdownNow()
}
}
}

@Test def dottyIssue15402(): Unit = {
trait Named {
def name: String
Expand Down

0 comments on commit cd2c477

Please sign in to comment.