Skip to content

Commit

Permalink
add in cooperative cancellation, test that it works
Browse files Browse the repository at this point in the history
  • Loading branch information
bishabosha committed Oct 23, 2023
1 parent 4ccba38 commit 517c3fe
Show file tree
Hide file tree
Showing 9 changed files with 247 additions and 98 deletions.
66 changes: 49 additions & 17 deletions compiler/src/dotty/tools/dotc/Run.scala
Original file line number Diff line number Diff line change
Expand Up @@ -177,9 +177,18 @@ class Run(comp: Compiler, ictx: Context) extends ImplicitRunInfo with Constraint
if local != null then
op(using ctx)(local)

def doBeginUnit()(using Context): Unit =
trackProgress: progress =>
progress.informUnitStarting(ctx.compilationUnit)
private inline def foldProgress[T](using Context)(inline default: T)(inline op: Context ?=> Progress => T): T =
val local = _progress
if local != null then
op(using ctx)(local)
else
default

def didEnterUnit()(using Context): Boolean =
foldProgress(true /* should progress by default */)(_.tryEnterUnit(ctx.compilationUnit))

def didEnterFinal()(using Context): Boolean =
foldProgress(true /* should progress by default */)(p => !p.checkCancellation())

def doAdvanceUnit()(using Context): Unit =
trackProgress: progress =>
Expand All @@ -195,6 +204,13 @@ class Run(comp: Compiler, ictx: Context) extends ImplicitRunInfo with Constraint
trackProgress: progress =>
progress.enterPhase(currentPhase)

/** interrupt the thread and set cancellation state */
private def cancelInterrupted(): Unit =
try
trackProgress(_.cancel())
finally
Thread.currentThread().nn.interrupt()

private def doAdvancePhase(currentPhase: Phase, wasRan: Boolean)(using Context): Unit =
trackProgress: progress =>
progress.unitc = 0 // reset unit count in current (sub)phase
Expand All @@ -213,7 +229,8 @@ class Run(comp: Compiler, ictx: Context) extends ImplicitRunInfo with Constraint
progress.seen += 1 // trace that we've seen a (sub)phase
progress.traversalc += 1 // add an extra traversal now that we completed a (sub)phase
progress.subtraversalc += 1 // record that we've seen a subphase
progress.tickSubphase()
if !progress.isCancelled() then
progress.tickSubphase()

/** Will be set to true if any of the compiled compilation units contains
* a pureFunctions language import.
Expand Down Expand Up @@ -297,7 +314,8 @@ class Run(comp: Compiler, ictx: Context) extends ImplicitRunInfo with Constraint
Stats.trackTime(s"phase time ms/$phase") {
val start = System.currentTimeMillis
val profileBefore = profiler.beforePhase(phase)
units = phase.runOn(units)
try units = phase.runOn(units)
catch case _: InterruptedException => cancelInterrupted()
profiler.afterPhase(phase, profileBefore)
if (ctx.settings.Xprint.value.containsPhase(phase))
for (unit <- units)
Expand Down Expand Up @@ -333,7 +351,7 @@ class Run(comp: Compiler, ictx: Context) extends ImplicitRunInfo with Constraint
if (!ctx.reporter.hasErrors)
Rewrites.writeBack()
suppressions.runFinished(hasErrors = ctx.reporter.hasErrors)
while (finalizeActions.nonEmpty) {
while (finalizeActions.nonEmpty && didEnterFinal()) {
val action = finalizeActions.remove(0)
action()
}
Expand Down Expand Up @@ -481,6 +499,8 @@ object Run {


private class Progress(cb: ProgressCallback, private val run: Run, val initialTraversals: Int):
export cb.{cancel, isCancelled}

private[Run] var totalTraversals: Int = initialTraversals // track how many phases we expect to run
private[Run] var unitc: Int = 0 // current unit count in the current (sub)phase
private[Run] var latec: Int = 0 // current late unit count
Expand Down Expand Up @@ -515,34 +535,46 @@ object Run {


/** Counts the number of completed full traversals over files, plus the number of units in the current phase */
private def currentProgress()(using Context): Int =
traversalc * run.files.size + unitc + latec
private def currentProgress(): Int =
traversalc * work() + unitc + latec

/**Total progress is computed as the sum of
* - the number of traversals we expect to make over all files
* - the number of late compilations
*/
private def totalProgress()(using Context): Int =
totalTraversals * run.files.size + run.lateFiles.size
private def totalProgress(): Int =
totalTraversals * work() + run.lateFiles.size

private def work(): Int = run.files.size

private def requireInitialized(): Unit =
require((currPhase: Phase | Null) != null, "enterPhase was not called")

/** trace that we are beginning a unit in the current (sub)phase */
private[Run] def informUnitStarting(unit: CompilationUnit)(using Context): Unit =
requireInitialized()
cb.informUnitStarting(currPhaseName, unit)
private[Run] def checkCancellation(): Boolean =
if Thread.interrupted() then cancel()
isCancelled()

/** trace that we are beginning a unit in the current (sub)phase, unless cancelled */
private[Run] def tryEnterUnit(unit: CompilationUnit): Boolean =
if checkCancellation() then false
else
requireInitialized()
cb.informUnitStarting(currPhaseName, unit)
true

/** trace the current progress out of the total, in the current (sub)phase, reporting the next (sub)phase */
private[Run] def refreshProgress()(using Context): Unit =
requireInitialized()
cb.progress(currentProgress(), totalProgress(), currPhaseName, nextPhaseName)
val total = totalProgress()
if total > 0 && !cb.progress(currentProgress(), total, currPhaseName, nextPhaseName) then
cancel()

extension (run: Run | Null)

/** record that the current phase has begun for the compilation unit of the current Context */
def beginUnit()(using Context): Unit =
if run != null then run.doBeginUnit()
def enterUnit()(using Context): Boolean =
if run != null then run.didEnterUnit()
else true // don't check cancellation if we're not tracking progress

/** advance the unit count and record progress in the current phase */
def advanceUnit()(using Context): Unit =
Expand Down
52 changes: 35 additions & 17 deletions compiler/src/dotty/tools/dotc/core/Phases.scala
Original file line number Diff line number Diff line change
Expand Up @@ -326,16 +326,20 @@ object Phases {

/** @pre `isRunnable` returns true */
def runOn(units: List[CompilationUnit])(using runCtx: Context): List[CompilationUnit] =
units.map { unit =>
val buf = List.newBuilder[CompilationUnit]
for unit <- units do
given unitCtx: Context = runCtx.fresh.setPhase(this.start).setCompilationUnit(unit).withRootImports
ctx.run.beginUnit()
try run
catch case ex: Throwable if !ctx.run.enrichedErrorMessage =>
println(ctx.run.enrichErrorMessage(s"unhandled exception while running $phaseName on $unit"))
throw ex
finally ctx.run.advanceUnit()
unitCtx.compilationUnit
}
if ctx.run.enterUnit() then
try run
catch case ex: Throwable if !ctx.run.enrichedErrorMessage =>
println(ctx.run.enrichErrorMessage(s"unhandled exception while running $phaseName on $unit"))
throw ex
finally ctx.run.advanceUnit()
buf += unitCtx.compilationUnit
end if
end for
buf.result()
end runOn

/** Convert a compilation unit's tree to a string; can be overridden */
def show(tree: untpd.Tree)(using Context): String =
Expand Down Expand Up @@ -448,14 +452,28 @@ object Phases {
Iterator.iterate(this)(_.next) takeWhile (_.hasNext)

/** run the body as one iteration of a (sub)phase (see Run.Progress), Enrich crash messages */
final def monitor(doing: String)(body: Context ?=> Unit)(using Context): Unit =
ctx.run.beginUnit()
try body
catch
case NonFatal(ex) if !ctx.run.enrichedErrorMessage =>
report.echo(ctx.run.enrichErrorMessage(s"exception occurred while $doing ${ctx.compilationUnit}"))
throw ex
finally ctx.run.advanceUnit()
final def monitor(doing: String)(body: Context ?=> Unit)(using Context): Boolean =
if ctx.run.enterUnit() then
try {body; true}
catch
case NonFatal(ex) if !ctx.run.enrichedErrorMessage =>
report.echo(ctx.run.enrichErrorMessage(s"exception occurred while $doing ${ctx.compilationUnit}"))
throw ex
finally ctx.run.advanceUnit()
else
false

/** run the body as one iteration of a (sub)phase (see Run.Progress), Enrich crash messages */
final def monitorOpt[T](doing: String)(body: Context ?=> Option[T])(using Context): Option[T] =
if ctx.run.enterUnit() then
try body
catch
case NonFatal(ex) if !ctx.run.enrichedErrorMessage =>
report.echo(ctx.run.enrichErrorMessage(s"exception occurred while $doing ${ctx.compilationUnit}"))
throw ex
finally ctx.run.advanceUnit()
else
None

override def toString: String = phaseName
}
Expand Down
14 changes: 8 additions & 6 deletions compiler/src/dotty/tools/dotc/fromtasty/ReadTasty.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,14 @@ class ReadTasty extends Phase {
ctx.settings.fromTasty.value

override def runOn(units: List[CompilationUnit])(using Context): List[CompilationUnit] =
withMode(Mode.ReadPositions)(units.flatMap(applyPhase(_)))
withMode(Mode.ReadPositions) {
val unitContexts = units.map(unit => ctx.fresh.setCompilationUnit(unit))
unitContexts.flatMap(applyPhase()(using _))
}

private def applyPhase(unit: CompilationUnit)(using Context): Option[CompilationUnit] =
ctx.run.beginUnit()
try readTASTY(unit)
finally ctx.run.advanceUnit()
private def applyPhase()(using Context): Option[CompilationUnit] = monitorOpt(phaseName):
val unit = ctx.compilationUnit
readTASTY(unit)

def readTASTY(unit: CompilationUnit)(using Context): Option[CompilationUnit] = unit match {
case unit: TASTYCompilationUnit =>
Expand Down Expand Up @@ -82,7 +84,7 @@ class ReadTasty extends Phase {
}
}
case unit =>
Some(unit)
Some(unit)
}

def run(using Context): Unit = unsupported("run")
Expand Down
11 changes: 7 additions & 4 deletions compiler/src/dotty/tools/dotc/parsing/ParserPhase.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class Parser extends Phase {
*/
private[dotc] var firstXmlPos: SourcePosition = NoSourcePosition

def parse(using Context) = monitor("parser") {
def parse(using Context): Boolean = monitor("parser") {
val unit = ctx.compilationUnit
unit.untpdTree =
if (unit.isJava) new JavaParsers.JavaParser(unit.source).parse()
Expand All @@ -46,12 +46,15 @@ class Parser extends Phase {
report.inform(s"parsing ${unit.source}")
ctx.fresh.setCompilationUnit(unit).withRootImports

for given Context <- unitContexts do
parse
val unitContexts0 =
for
given Context <- unitContexts
if parse
yield ctx

record("parsedTrees", ast.Trees.ntrees)

unitContexts.map(_.compilationUnit)
unitContexts0.map(_.compilationUnit)
}

def run(using Context): Unit = unsupported("run")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ default void informUnitStarting(String phase, CompilationUnit unit) {}
/** Record the current compilation progress.
* @param current `completedPhaseCount * totalUnits + completedUnitsInCurrPhase + completedLate`
* @param total `totalPhases * totalUnits + totalLate`
* @return true if the compilation should continue (if false, then subsequent calls to `isCancelled()` will return true)
* @return true if the compilation should continue (callers are expected to cancel if this returns false)
*/
default boolean progress(int current, int total, String currPhase, String nextPhase) { return true; }
}
18 changes: 13 additions & 5 deletions compiler/src/dotty/tools/dotc/transform/init/Checker.scala
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,21 @@ class Checker extends Phase:
override def isEnabled(using Context): Boolean =
super.isEnabled && (ctx.settings.YcheckInit.value || ctx.settings.YcheckInitGlobal.value)

def traverse(traverser: InitTreeTraverser)(using Context): Boolean = monitor(phaseName):
val unit = ctx.compilationUnit
traverser.traverse(unit.tpdTree)

override def runOn(units: List[CompilationUnit])(using Context): List[CompilationUnit] =
val checkCtx = ctx.fresh.setPhase(this.start)
val traverser = new InitTreeTraverser()
for unit <- units do
checkCtx.run.beginUnit()
try traverser.traverse(unit.tpdTree)
finally ctx.run.advanceUnit()
val unitContexts = units.map(unit => checkCtx.fresh.setCompilationUnit(unit))

val unitContexts0 =
for
given Context <- unitContexts
if traverse(traverser)
yield ctx

val classes = traverser.getClasses()

if ctx.settings.YcheckInit.value then
Expand All @@ -46,7 +54,7 @@ class Checker extends Phase:
if ctx.settings.YcheckInitGlobal.value then
Objects.checkClasses(classes)(using checkCtx)

units
unitContexts0.map(_.compilationUnit)

def run(using Context): Unit = unsupported("run")

Expand Down
41 changes: 25 additions & 16 deletions compiler/src/dotty/tools/dotc/typer/TyperPhase.scala
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,13 @@ class TyperPhase(addRootImports: Boolean = true) extends Phase {
// Run regardless of parsing errors
override def isRunnable(implicit ctx: Context): Boolean = true

def enterSyms(using Context): Unit = monitor("indexing") {
def enterSyms(using Context): Boolean = monitor("indexing") {
val unit = ctx.compilationUnit
ctx.typer.index(unit.untpdTree)
typr.println("entered: " + unit.source)
}

def typeCheck(using Context): Unit = monitor("typechecking") {
def typeCheck(using Context): Boolean = monitor("typechecking") {
val unit = ctx.compilationUnit
try
if !unit.suspended then
Expand All @@ -49,7 +49,7 @@ class TyperPhase(addRootImports: Boolean = true) extends Phase {
catch case _: CompilationUnit.SuspendException => ()
}

def javaCheck(using Context): Unit = monitor("checking java") {
def javaCheck(using Context): Boolean = monitor("checking java") {
val unit = ctx.compilationUnit
if unit.isJava then
JavaChecks.check(unit.tpdTree)
Expand All @@ -72,11 +72,14 @@ class TyperPhase(addRootImports: Boolean = true) extends Phase {
else
newCtx

try
for given Context <- unitContexts do
enterSyms
finally
ctx.run.advanceSubPhase() // tick from "typer (indexing)" to "typer (typechecking)"
val unitContexts0 =
try
for
given Context <- unitContexts
if enterSyms
yield ctx
finally
ctx.run.advanceSubPhase() // tick from "typer (indexing)" to "typer (typechecking)"

ctx.base.parserPhase match {
case p: ParserPhase =>
Expand All @@ -88,18 +91,24 @@ class TyperPhase(addRootImports: Boolean = true) extends Phase {
case _ =>
}

try
for given Context <- unitContexts do
typeCheck
finally
ctx.run.advanceSubPhase() // tick from "typer (typechecking)" to "typer (java checking)"
val unitContexts1 =
try
for
given Context <- unitContexts0
if typeCheck
yield ctx
finally
ctx.run.advanceSubPhase() // tick from "typer (typechecking)" to "typer (java checking)"

record("total trees after typer", ast.Trees.ntrees)

for given Context <- unitContexts do
javaCheck // after typechecking to avoid cycles
val unitContexts2 =
for
given Context <- unitContexts1
if javaCheck // after typechecking to avoid cycles
yield ctx

val newUnits = unitContexts.map(_.compilationUnit).filterNot(discardAfterTyper)
val newUnits = unitContexts2.map(_.compilationUnit).filterNot(discardAfterTyper)
ctx.run.nn.checkSuspendedUnits(newUnits)
newUnits

Expand Down
Loading

0 comments on commit 517c3fe

Please sign in to comment.