From 517c3fef6c11fa4c012d88434e79193891697722 Mon Sep 17 00:00:00 2001 From: Jamie Thompson Date: Mon, 23 Oct 2023 18:07:02 +0200 Subject: [PATCH] add in cooperative cancellation, test that it works --- compiler/src/dotty/tools/dotc/Run.scala | 66 ++++++--- .../src/dotty/tools/dotc/core/Phases.scala | 52 ++++--- .../tools/dotc/fromtasty/ReadTasty.scala | 14 +- .../tools/dotc/parsing/ParserPhase.scala | 11 +- .../dotc/sbt/interfaces/ProgressCallback.java | 2 +- .../tools/dotc/transform/init/Checker.scala | 18 ++- .../dotty/tools/dotc/typer/TyperPhase.scala | 41 ++++-- .../tools/dotc/sbt/ProgressCallbackTest.scala | 137 ++++++++++++++---- .../tools/xsbt/ProgressCallbackImpl.java | 4 +- 9 files changed, 247 insertions(+), 98 deletions(-) diff --git a/compiler/src/dotty/tools/dotc/Run.scala b/compiler/src/dotty/tools/dotc/Run.scala index e12d151244d9..97b76856078c 100644 --- a/compiler/src/dotty/tools/dotc/Run.scala +++ b/compiler/src/dotty/tools/dotc/Run.scala @@ -177,9 +177,18 @@ class Run(comp: Compiler, ictx: Context) extends ImplicitRunInfo with Constraint if local != null then op(using ctx)(local) - def doBeginUnit()(using Context): Unit = - trackProgress: progress => - progress.informUnitStarting(ctx.compilationUnit) + private inline def foldProgress[T](using Context)(inline default: T)(inline op: Context ?=> Progress => T): T = + val local = _progress + if local != null then + op(using ctx)(local) + else + default + + def didEnterUnit()(using Context): Boolean = + foldProgress(true /* should progress by default */)(_.tryEnterUnit(ctx.compilationUnit)) + + def didEnterFinal()(using Context): Boolean = + foldProgress(true /* should progress by default */)(p => !p.checkCancellation()) def doAdvanceUnit()(using Context): Unit = trackProgress: progress => @@ -195,6 +204,13 @@ class Run(comp: Compiler, ictx: Context) extends ImplicitRunInfo with Constraint trackProgress: progress => progress.enterPhase(currentPhase) + /** interrupt the thread and set cancellation state */ + private def cancelInterrupted(): Unit = + try + trackProgress(_.cancel()) + finally + Thread.currentThread().nn.interrupt() + private def doAdvancePhase(currentPhase: Phase, wasRan: Boolean)(using Context): Unit = trackProgress: progress => progress.unitc = 0 // reset unit count in current (sub)phase @@ -213,7 +229,8 @@ class Run(comp: Compiler, ictx: Context) extends ImplicitRunInfo with Constraint progress.seen += 1 // trace that we've seen a (sub)phase progress.traversalc += 1 // add an extra traversal now that we completed a (sub)phase progress.subtraversalc += 1 // record that we've seen a subphase - progress.tickSubphase() + if !progress.isCancelled() then + progress.tickSubphase() /** Will be set to true if any of the compiled compilation units contains * a pureFunctions language import. @@ -297,7 +314,8 @@ class Run(comp: Compiler, ictx: Context) extends ImplicitRunInfo with Constraint Stats.trackTime(s"phase time ms/$phase") { val start = System.currentTimeMillis val profileBefore = profiler.beforePhase(phase) - units = phase.runOn(units) + try units = phase.runOn(units) + catch case _: InterruptedException => cancelInterrupted() profiler.afterPhase(phase, profileBefore) if (ctx.settings.Xprint.value.containsPhase(phase)) for (unit <- units) @@ -333,7 +351,7 @@ class Run(comp: Compiler, ictx: Context) extends ImplicitRunInfo with Constraint if (!ctx.reporter.hasErrors) Rewrites.writeBack() suppressions.runFinished(hasErrors = ctx.reporter.hasErrors) - while (finalizeActions.nonEmpty) { + while (finalizeActions.nonEmpty && didEnterFinal()) { val action = finalizeActions.remove(0) action() } @@ -481,6 +499,8 @@ object Run { private class Progress(cb: ProgressCallback, private val run: Run, val initialTraversals: Int): + export cb.{cancel, isCancelled} + private[Run] var totalTraversals: Int = initialTraversals // track how many phases we expect to run private[Run] var unitc: Int = 0 // current unit count in the current (sub)phase private[Run] var latec: Int = 0 // current late unit count @@ -515,34 +535,46 @@ object Run { /** Counts the number of completed full traversals over files, plus the number of units in the current phase */ - private def currentProgress()(using Context): Int = - traversalc * run.files.size + unitc + latec + private def currentProgress(): Int = + traversalc * work() + unitc + latec /**Total progress is computed as the sum of * - the number of traversals we expect to make over all files * - the number of late compilations */ - private def totalProgress()(using Context): Int = - totalTraversals * run.files.size + run.lateFiles.size + private def totalProgress(): Int = + totalTraversals * work() + run.lateFiles.size + + private def work(): Int = run.files.size private def requireInitialized(): Unit = require((currPhase: Phase | Null) != null, "enterPhase was not called") - /** trace that we are beginning a unit in the current (sub)phase */ - private[Run] def informUnitStarting(unit: CompilationUnit)(using Context): Unit = - requireInitialized() - cb.informUnitStarting(currPhaseName, unit) + private[Run] def checkCancellation(): Boolean = + if Thread.interrupted() then cancel() + isCancelled() + + /** trace that we are beginning a unit in the current (sub)phase, unless cancelled */ + private[Run] def tryEnterUnit(unit: CompilationUnit): Boolean = + if checkCancellation() then false + else + requireInitialized() + cb.informUnitStarting(currPhaseName, unit) + true /** trace the current progress out of the total, in the current (sub)phase, reporting the next (sub)phase */ private[Run] def refreshProgress()(using Context): Unit = requireInitialized() - cb.progress(currentProgress(), totalProgress(), currPhaseName, nextPhaseName) + val total = totalProgress() + if total > 0 && !cb.progress(currentProgress(), total, currPhaseName, nextPhaseName) then + cancel() extension (run: Run | Null) /** record that the current phase has begun for the compilation unit of the current Context */ - def beginUnit()(using Context): Unit = - if run != null then run.doBeginUnit() + def enterUnit()(using Context): Boolean = + if run != null then run.didEnterUnit() + else true // don't check cancellation if we're not tracking progress /** advance the unit count and record progress in the current phase */ def advanceUnit()(using Context): Unit = diff --git a/compiler/src/dotty/tools/dotc/core/Phases.scala b/compiler/src/dotty/tools/dotc/core/Phases.scala index f7c47ec54d5b..a7a0e3c90f14 100644 --- a/compiler/src/dotty/tools/dotc/core/Phases.scala +++ b/compiler/src/dotty/tools/dotc/core/Phases.scala @@ -326,16 +326,20 @@ object Phases { /** @pre `isRunnable` returns true */ def runOn(units: List[CompilationUnit])(using runCtx: Context): List[CompilationUnit] = - units.map { unit => + val buf = List.newBuilder[CompilationUnit] + for unit <- units do given unitCtx: Context = runCtx.fresh.setPhase(this.start).setCompilationUnit(unit).withRootImports - ctx.run.beginUnit() - try run - catch case ex: Throwable if !ctx.run.enrichedErrorMessage => - println(ctx.run.enrichErrorMessage(s"unhandled exception while running $phaseName on $unit")) - throw ex - finally ctx.run.advanceUnit() - unitCtx.compilationUnit - } + if ctx.run.enterUnit() then + try run + catch case ex: Throwable if !ctx.run.enrichedErrorMessage => + println(ctx.run.enrichErrorMessage(s"unhandled exception while running $phaseName on $unit")) + throw ex + finally ctx.run.advanceUnit() + buf += unitCtx.compilationUnit + end if + end for + buf.result() + end runOn /** Convert a compilation unit's tree to a string; can be overridden */ def show(tree: untpd.Tree)(using Context): String = @@ -448,14 +452,28 @@ object Phases { Iterator.iterate(this)(_.next) takeWhile (_.hasNext) /** run the body as one iteration of a (sub)phase (see Run.Progress), Enrich crash messages */ - final def monitor(doing: String)(body: Context ?=> Unit)(using Context): Unit = - ctx.run.beginUnit() - try body - catch - case NonFatal(ex) if !ctx.run.enrichedErrorMessage => - report.echo(ctx.run.enrichErrorMessage(s"exception occurred while $doing ${ctx.compilationUnit}")) - throw ex - finally ctx.run.advanceUnit() + final def monitor(doing: String)(body: Context ?=> Unit)(using Context): Boolean = + if ctx.run.enterUnit() then + try {body; true} + catch + case NonFatal(ex) if !ctx.run.enrichedErrorMessage => + report.echo(ctx.run.enrichErrorMessage(s"exception occurred while $doing ${ctx.compilationUnit}")) + throw ex + finally ctx.run.advanceUnit() + else + false + + /** run the body as one iteration of a (sub)phase (see Run.Progress), Enrich crash messages */ + final def monitorOpt[T](doing: String)(body: Context ?=> Option[T])(using Context): Option[T] = + if ctx.run.enterUnit() then + try body + catch + case NonFatal(ex) if !ctx.run.enrichedErrorMessage => + report.echo(ctx.run.enrichErrorMessage(s"exception occurred while $doing ${ctx.compilationUnit}")) + throw ex + finally ctx.run.advanceUnit() + else + None override def toString: String = phaseName } diff --git a/compiler/src/dotty/tools/dotc/fromtasty/ReadTasty.scala b/compiler/src/dotty/tools/dotc/fromtasty/ReadTasty.scala index 8ad228431420..1a0e7a3e0d89 100644 --- a/compiler/src/dotty/tools/dotc/fromtasty/ReadTasty.scala +++ b/compiler/src/dotty/tools/dotc/fromtasty/ReadTasty.scala @@ -22,12 +22,14 @@ class ReadTasty extends Phase { ctx.settings.fromTasty.value override def runOn(units: List[CompilationUnit])(using Context): List[CompilationUnit] = - withMode(Mode.ReadPositions)(units.flatMap(applyPhase(_))) + withMode(Mode.ReadPositions) { + val unitContexts = units.map(unit => ctx.fresh.setCompilationUnit(unit)) + unitContexts.flatMap(applyPhase()(using _)) + } - private def applyPhase(unit: CompilationUnit)(using Context): Option[CompilationUnit] = - ctx.run.beginUnit() - try readTASTY(unit) - finally ctx.run.advanceUnit() + private def applyPhase()(using Context): Option[CompilationUnit] = monitorOpt(phaseName): + val unit = ctx.compilationUnit + readTASTY(unit) def readTASTY(unit: CompilationUnit)(using Context): Option[CompilationUnit] = unit match { case unit: TASTYCompilationUnit => @@ -82,7 +84,7 @@ class ReadTasty extends Phase { } } case unit => - Some(unit) + Some(unit) } def run(using Context): Unit = unsupported("run") diff --git a/compiler/src/dotty/tools/dotc/parsing/ParserPhase.scala b/compiler/src/dotty/tools/dotc/parsing/ParserPhase.scala index 3b23847db7f5..d8c1f5f17adf 100644 --- a/compiler/src/dotty/tools/dotc/parsing/ParserPhase.scala +++ b/compiler/src/dotty/tools/dotc/parsing/ParserPhase.scala @@ -22,7 +22,7 @@ class Parser extends Phase { */ private[dotc] var firstXmlPos: SourcePosition = NoSourcePosition - def parse(using Context) = monitor("parser") { + def parse(using Context): Boolean = monitor("parser") { val unit = ctx.compilationUnit unit.untpdTree = if (unit.isJava) new JavaParsers.JavaParser(unit.source).parse() @@ -46,12 +46,15 @@ class Parser extends Phase { report.inform(s"parsing ${unit.source}") ctx.fresh.setCompilationUnit(unit).withRootImports - for given Context <- unitContexts do - parse + val unitContexts0 = + for + given Context <- unitContexts + if parse + yield ctx record("parsedTrees", ast.Trees.ntrees) - unitContexts.map(_.compilationUnit) + unitContexts0.map(_.compilationUnit) } def run(using Context): Unit = unsupported("run") diff --git a/compiler/src/dotty/tools/dotc/sbt/interfaces/ProgressCallback.java b/compiler/src/dotty/tools/dotc/sbt/interfaces/ProgressCallback.java index 8f81ea5f99a2..d1e076c75bfa 100644 --- a/compiler/src/dotty/tools/dotc/sbt/interfaces/ProgressCallback.java +++ b/compiler/src/dotty/tools/dotc/sbt/interfaces/ProgressCallback.java @@ -15,7 +15,7 @@ default void informUnitStarting(String phase, CompilationUnit unit) {} /** Record the current compilation progress. * @param current `completedPhaseCount * totalUnits + completedUnitsInCurrPhase + completedLate` * @param total `totalPhases * totalUnits + totalLate` - * @return true if the compilation should continue (if false, then subsequent calls to `isCancelled()` will return true) + * @return true if the compilation should continue (callers are expected to cancel if this returns false) */ default boolean progress(int current, int total, String currPhase, String nextPhase) { return true; } } diff --git a/compiler/src/dotty/tools/dotc/transform/init/Checker.scala b/compiler/src/dotty/tools/dotc/transform/init/Checker.scala index 1c0d68020737..ddb3ed2f2a50 100644 --- a/compiler/src/dotty/tools/dotc/transform/init/Checker.scala +++ b/compiler/src/dotty/tools/dotc/transform/init/Checker.scala @@ -31,13 +31,21 @@ class Checker extends Phase: override def isEnabled(using Context): Boolean = super.isEnabled && (ctx.settings.YcheckInit.value || ctx.settings.YcheckInitGlobal.value) + def traverse(traverser: InitTreeTraverser)(using Context): Boolean = monitor(phaseName): + val unit = ctx.compilationUnit + traverser.traverse(unit.tpdTree) + override def runOn(units: List[CompilationUnit])(using Context): List[CompilationUnit] = val checkCtx = ctx.fresh.setPhase(this.start) val traverser = new InitTreeTraverser() - for unit <- units do - checkCtx.run.beginUnit() - try traverser.traverse(unit.tpdTree) - finally ctx.run.advanceUnit() + val unitContexts = units.map(unit => checkCtx.fresh.setCompilationUnit(unit)) + + val unitContexts0 = + for + given Context <- unitContexts + if traverse(traverser) + yield ctx + val classes = traverser.getClasses() if ctx.settings.YcheckInit.value then @@ -46,7 +54,7 @@ class Checker extends Phase: if ctx.settings.YcheckInitGlobal.value then Objects.checkClasses(classes)(using checkCtx) - units + unitContexts0.map(_.compilationUnit) def run(using Context): Unit = unsupported("run") diff --git a/compiler/src/dotty/tools/dotc/typer/TyperPhase.scala b/compiler/src/dotty/tools/dotc/typer/TyperPhase.scala index a15ab8afee39..210e457a7764 100644 --- a/compiler/src/dotty/tools/dotc/typer/TyperPhase.scala +++ b/compiler/src/dotty/tools/dotc/typer/TyperPhase.scala @@ -31,13 +31,13 @@ class TyperPhase(addRootImports: Boolean = true) extends Phase { // Run regardless of parsing errors override def isRunnable(implicit ctx: Context): Boolean = true - def enterSyms(using Context): Unit = monitor("indexing") { + def enterSyms(using Context): Boolean = monitor("indexing") { val unit = ctx.compilationUnit ctx.typer.index(unit.untpdTree) typr.println("entered: " + unit.source) } - def typeCheck(using Context): Unit = monitor("typechecking") { + def typeCheck(using Context): Boolean = monitor("typechecking") { val unit = ctx.compilationUnit try if !unit.suspended then @@ -49,7 +49,7 @@ class TyperPhase(addRootImports: Boolean = true) extends Phase { catch case _: CompilationUnit.SuspendException => () } - def javaCheck(using Context): Unit = monitor("checking java") { + def javaCheck(using Context): Boolean = monitor("checking java") { val unit = ctx.compilationUnit if unit.isJava then JavaChecks.check(unit.tpdTree) @@ -72,11 +72,14 @@ class TyperPhase(addRootImports: Boolean = true) extends Phase { else newCtx - try - for given Context <- unitContexts do - enterSyms - finally - ctx.run.advanceSubPhase() // tick from "typer (indexing)" to "typer (typechecking)" + val unitContexts0 = + try + for + given Context <- unitContexts + if enterSyms + yield ctx + finally + ctx.run.advanceSubPhase() // tick from "typer (indexing)" to "typer (typechecking)" ctx.base.parserPhase match { case p: ParserPhase => @@ -88,18 +91,24 @@ class TyperPhase(addRootImports: Boolean = true) extends Phase { case _ => } - try - for given Context <- unitContexts do - typeCheck - finally - ctx.run.advanceSubPhase() // tick from "typer (typechecking)" to "typer (java checking)" + val unitContexts1 = + try + for + given Context <- unitContexts0 + if typeCheck + yield ctx + finally + ctx.run.advanceSubPhase() // tick from "typer (typechecking)" to "typer (java checking)" record("total trees after typer", ast.Trees.ntrees) - for given Context <- unitContexts do - javaCheck // after typechecking to avoid cycles + val unitContexts2 = + for + given Context <- unitContexts1 + if javaCheck // after typechecking to avoid cycles + yield ctx - val newUnits = unitContexts.map(_.compilationUnit).filterNot(discardAfterTyper) + val newUnits = unitContexts2.map(_.compilationUnit).filterNot(discardAfterTyper) ctx.run.nn.checkSuspendedUnits(newUnits) newUnits diff --git a/compiler/test/dotty/tools/dotc/sbt/ProgressCallbackTest.scala b/compiler/test/dotty/tools/dotc/sbt/ProgressCallbackTest.scala index 82cee9928271..e6e67b997aae 100644 --- a/compiler/test/dotty/tools/dotc/sbt/ProgressCallbackTest.scala +++ b/compiler/test/dotty/tools/dotc/sbt/ProgressCallbackTest.scala @@ -16,6 +16,7 @@ import dotty.tools.dotc.Run import dotty.tools.dotc.core.Phases.Phase import dotty.tools.io.VirtualDirectory import dotty.tools.dotc.NoCompilationUnit +import dotty.tools.dotc.interactive.Interactive.Include.all final class ProgressCallbackTest extends DottyTest: @@ -25,26 +26,86 @@ final class ProgressCallbackTest extends DottyTest: val source2 = """class Bar""" inspectProgress(List(source1, source2), terminalPhase = None): progressCallback => - // (1) assert that the way we compute next phase in `Run.doAdvancePhase` is correct - assertNextPhaseIsNext() + locally: + // (1) assert that the way we compute next phase in `Run.doAdvancePhase` is correct + assertNextPhaseIsNext() + + locally: + // (1) given correct computation, check that the recorded progression of phases is monotonic + assertMonotonicProgression(progressCallback) + + locally: + // (1) given monotonic progression, check that the recorded progression of phases is complete + val expectedCurr = allSubPhases + val expectedNext = expectedCurr.tail ++ syntheticNextPhases + assertProgressPhases(progressCallback, expectedCurr, expectedNext) + + locally: + // (2) next check that for each unit, we record all the "runnable" phases that could go through + assertExpectedPhasesForUnits(progressCallback, expectedPhases = runnableSubPhases) + + locally: + // (2) therefore we can now cross-reference the recorded progression with the recorded phases per unit + assertTotalUnits(progressCallback) + + locally: + // (3) finally, check that the callback was not cancelled + assertFalse(progressCallback.isCancelled) + end testCallback - // (1) given correct computation, check that the recorded progression is monotonic - assertMonotonicProgression(progressCallback) + // TODO: test cancellation - // (1) given monotonic progression, check that the recorded progression has full coverage - assertFullCoverage(progressCallback) + @Test + def cancelMidTyper: Unit = + inspectCancellationAtPhase("typer (typechecking)") - // (2) next check that for each unit, we record the expected phases that it should progress through - assertExpectedPhases(progressCallback) + @Test + def cancelErasure: Unit = + inspectCancellationAtPhase("erasure") - // (2) therefore we can now cross-reference the recorded progression with the recorded phases per unit - assertTotalUnits(progressCallback) + @Test + def cancelPickler: Unit = + inspectCancellationAtPhase("pickler") - // (3) finally, check that the callback was not cancelled - assertFalse(progressCallback.isCancelled) - end testCallback + def cancelOnEnter(targetPhase: String)(testCallback: TestProgressCallback): Boolean = + testCallback.latestProgress.exists(_.currPhase == targetPhase) - // TODO: test lateCompile, test cancellation + def inspectCancellationAtPhase(targetPhase: String): Unit = + val source1 = """class Foo""" + + inspectProgress(List(source1), cancellation = Some(cancelOnEnter(targetPhase))): progressCallback => + locally: + // (1) assert that the compiler was cancelled + assertTrue("should have cancelled", progressCallback.isCancelled) + + locally: + // (2) assert that compiler visited all the subphases before cancellation, + // and does not visit any after. + // (2.2) first extract the surrounding phases of the target + val (befores, target +: next +: _) = allSubPhases.span(_ != targetPhase): @unchecked + // (2.3) we expect to see the subphases before&including target reported as a "current" phase, so extract here + val expectedCurr = befores :+ target + // (2.4) we expect to see next after target reported as a "next" phase, so extract here + val expectedNext = expectedCurr.tail :+ next + assertProgressPhases(progressCallback, expectedCurr, expectedNext) + + locally: + // (3) assert that the compilation units were only entered in the phases before cancellation + val (befores, target +: next +: _) = runnableSubPhases.span(_ != targetPhase): @unchecked + assertExpectedPhasesForUnits(progressCallback, expectedPhases = befores) + + locally: + // (4) assert that the final progress recorded is at the target phase, + // and progress is equal to the number of phases before the target. + val (befores, target +: next +: _) = runnableSubPhases.span(_ != targetPhase): @unchecked + // (4.1) we expect cancellation to occur *as we enter* the target phase, + // so no units should be visited in this phase. Therefore progress + // should be equal to the number of phases before the target. (as we have 1 unit) + val expectedProgress = befores.size + progressCallback.latestProgress match + case Some(ProgressEvent(`expectedProgress`, _, `target`, `next`)) => + case other => fail(s"did not match expected progress, found $other") + end inspectCancellationAtPhase /** Assert that the computed `next` phase matches the real next phase */ def assertNextPhaseIsNext()(using Context): Unit = @@ -71,12 +132,13 @@ final class ProgressCallbackTest extends DottyTest: assertTrue(s"Predicted next phase `$next1` didn't match the following current `$curr2`", next1Index == curr2Index) /** Assert that the recorded progression of phases contains every phase in the plan */ - def assertFullCoverage(progressCallback: TestProgressCallback)(using Context): Unit = + def assertProgressPhases(progressCallback: TestProgressCallback, + currExpected: Seq[String], nextExpected: Seq[String])(using Context): Unit = val (allPhasePlan, expectedCurrPhases, expectedNextPhases) = - val allPhases = ctx.base.allPhases.flatMap(asSubphases) + val allPhases = currExpected val firstPhase = allPhases.head val expectedCurrPhases = allPhases.toSet - val expectedNextPhases = expectedCurrPhases - firstPhase ++ syntheticNextPhases + val expectedNextPhases = nextExpected.toSet //expectedCurrPhases - firstPhase ++ syntheticNextPhases (allPhases.toList, expectedCurrPhases, expectedNextPhases) for (expectedCurr, recordedCurr) <- allPhasePlan.zip(progressCallback.progressPhasesFinal.map(_.curr)) do @@ -98,8 +160,7 @@ final class ProgressCallbackTest extends DottyTest: /** Assert that the phases recorded per unit match the actual phases ran on them */ - def assertExpectedPhases(progressCallback: TestProgressCallback)(using Context): Unit = - val expectedPhases = runnablePhases().flatMap(asSubphases) + def assertExpectedPhasesForUnits(progressCallback: TestProgressCallback, expectedPhases: Seq[String])(using Context): Unit = for (unit, visitedPhases) <- progressCallback.unitPhases do val uniquePhases = visitedPhases.toSet assert(unit != NoCompilationUnit, s"unexpected NoCompilationUnit for phases $uniquePhases") @@ -121,18 +182,26 @@ final class ProgressCallbackTest extends DottyTest: case TotalEvent(total, _) :: _ => assertEquals(expectedTotal, total) - def inspectProgress(sources: List[String], terminalPhase: Option[String] = Some("typer"))(op: Context ?=> TestProgressCallback => Unit) = - // given Context = getCtx + def inspectProgress( + sources: List[String], + terminalPhase: Option[String] = Some("typer"), + cancellation: Option[TestProgressCallback => Boolean] = None)( + op: Context ?=> TestProgressCallback => Unit)(using Context) = + for cancelNow <- cancellation do + testProgressCallback.withCancelNow(cancelNow) val sources0 = sources.map(_.linesIterator.map(_.trim.nn).filterNot(_.isEmpty).mkString("\n|").stripMargin) val terminalPhase0 = terminalPhase.getOrElse(defaultCompiler.phases.last.last.phaseName) checkAfterCompile(terminalPhase0, sources0) { case given Context => - ctx.progressCallback match - case cb: TestProgressCallback => op(cb) - case _ => - fail(s"Expected TestProgressCallback but got ${ctx.progressCallback}") - ??? + op(testProgressCallback) } + private def testProgressCallback(using Context): TestProgressCallback = + ctx.progressCallback match + case cb: TestProgressCallback => cb + case _ => + fail(s"Expected TestProgressCallback but got ${ctx.progressCallback}") + ??? + override protected def initializeCtx(fc: FreshContext): Unit = super.initializeCtx( fc.setProgressCallback(TestProgressCallback()) @@ -150,8 +219,11 @@ object ProgressCallbackTest: val indices = 0 until phase.traversals indices.map(subPhases.subPhase) - def runnablePhases()(using Context): IArray[Phase] = - IArray.from(ctx.base.allPhases.filter(_.isRunnable)) + def runnableSubPhases(using Context): IndexedSeq[String] = + ctx.base.allPhases.filter(_.isRunnable).flatMap(asSubphases).toIndexedSeq + + def allSubPhases(using Context): IndexedSeq[String] = + ctx.base.allPhases.flatMap(asSubphases).toIndexedSeq private val syntheticNextPhases = List("") @@ -163,15 +235,20 @@ object ProgressCallbackTest: i final class TestProgressCallback extends interfaces.ProgressCallback: + import collection.immutable, immutable.SeqMap + private var _cancelled: Boolean = false - private var _unitPhases: Map[CompilationUnit, List[String]] = Map.empty + private var _unitPhases: SeqMap[CompilationUnit, List[String]] = immutable.SeqMap.empty // preserve order private var _totalEvents: List[TotalEvent] = List.empty + private var _latestProgress: Option[ProgressEvent] = None private var _progressPhases: List[PhaseTransition] = List.empty private var _shouldCancelNow: TestProgressCallback => Boolean = _ => false def totalEvents = _totalEvents + def latestProgress = _latestProgress def unitPhases = _unitPhases def progressPhasesFinal = _progressPhases.reverse + def currentPhase = _progressPhases.headOption.map(_.curr) def withCancelNow(f: TestProgressCallback => Boolean): this.type = _shouldCancelNow = f @@ -190,6 +267,8 @@ object ProgressCallbackTest: case events @ (head :: _) if head.total != total => TotalEvent(total, currPhase) :: events case events => events + _latestProgress = Some(ProgressEvent(current, total, currPhase, nextPhase)) + // record the current and next phase whenever the current phase changes _progressPhases = _progressPhases match case all @ PhaseTransition(head, _) :: rest => diff --git a/sbt-bridge/src/dotty/tools/xsbt/ProgressCallbackImpl.java b/sbt-bridge/src/dotty/tools/xsbt/ProgressCallbackImpl.java index ce9f7debbfa8..f5fb78f12bb1 100644 --- a/sbt-bridge/src/dotty/tools/xsbt/ProgressCallbackImpl.java +++ b/sbt-bridge/src/dotty/tools/xsbt/ProgressCallbackImpl.java @@ -30,8 +30,6 @@ public void informUnitStarting(String phase, CompilationUnit unit) { @Override public boolean progress(int current, int total, String currPhase, String nextPhase) { - boolean shouldAdvance = _progress.advance(current, total, currPhase, nextPhase); - if (!shouldAdvance) cancel(); - return shouldAdvance; + return _progress.advance(current, total, currPhase, nextPhase); } }